正如上篇博客中所讲,在Keras框架下执行深度学习任务时,一般会先根据训练数据集训练出模型,然后拿训练好的模型到生产环境(测试集)中部署并生产。以分类问题为例,当训练好了分类模型之后,我们要用这个模型预测测试集中每一个样本的标签,这里有两个常用的方法:predict()方法和predict_classes()方法,下面将以具体例子说明二者的区别。

相关文件的下载地址如下:
链接:https://pan.baidu.com/s/1qQU3FTtGeLYK5jXn0cRpQQ
提取码:qn8r

1 predict()方法
        当使用predict()方法进行预测时,返回值是数值,表示样本属于每一个类别的概率,我们可以使用numpy.argmax()方法找到样本以最大概率所属的类别作为样本的预测标签。下面以卷积神经网络中的图片分类为例说明,代码如下:

import numpy as np
import scipy.misc
from keras.models import model_from_json
from keras.optimizers import SGD# 加载模型
model_architecture = 'cifar10_architecture.json'
model_weights = 'cifar_weights.h5'
model = model_from_json(open(model_architecture).read())       # 加载模型结构
model.load_weights(model_weights)       # 加载模型权重# 加载图片
img_names = ['cat.jpg', 'deer.jpg', 'dog.jpg']
# np.transpose:对数组进行转置,返回原数组的视图
# scipy.misc.imresize:重新调整图片的形状
# scipy.misc.imread:将图片读取出来,返回np.array类型
imgs = [np.transpose(scipy.misc.imresize(scipy.misc.imread(img_name), (32, 32)), (1, 0, 2)).astype('float32')for img_name in img_names]
imgs = np.array(imgs) / 255        # 归一化# 训练
optim = SGD()
model.compile(loss='categorical_crossentropy', optimizer=optim, metrics=['accuracy'])        # 编译模型# 预测样本属于每个类别的概率
print(model.predict(imgs))      # 打印概率
print(np.argmax(model.predict(imgs), axis=1))      # 打印最大概率对应的标签
# [[3.3745366e-01 2.2980917e-02 2.0197949e-03 1.2046755e-02 1.9850987e-03
#   1.3152690e-04 4.0220530e-03 1.3779138e-03 5.9722424e-01 2.0758053e-02]
#  [5.0913623e-06 5.6117901e-08 9.7215974e-01 2.0343825e-05 2.3693956e-02
#   1.6027538e-03 7.3659585e-06 2.5106100e-03 5.8250910e-10 1.4506637e-09]
#  [7.1339104e-03 6.1033275e-06 2.1771197e-03 9.7346401e-01 2.2141664e-06
#   1.6861971e-02 8.6817810e-05 1.2291509e-04 4.7768017e-06 1.4035056e-04]]
# [8 2 3]

效果截图如下:

2 predict_classes()方法
        当使用predict_classes()方法进行预测时,返回的是类别的索引,即该样本所属的类别标签。以卷积神经网络中的图片分类为例说明,代码如下:

import numpy as np
import scipy.misc
from keras.models import model_from_json
from keras.optimizers import SGD# 加载模型
model_architecture = 'cifar10_architecture.json'
model_weights = 'cifar_weights.h5'
model = model_from_json(open(model_architecture).read())       # 加载模型结构
model.load_weights(model_weights)       # 加载模型权重# 加载图片
img_names = ['cat.jpg', 'deer.jpg', 'dog.jpg']
# np.transpose:对数组进行转置,返回原数组的视图
# scipy.misc.imresize:重新调整图片的形状
# scipy.misc.imread:将图片读取出来,返回np.array类型
imgs = [np.transpose(scipy.misc.imresize(scipy.misc.imread(img_name), (32, 32)), (1, 0, 2)).astype('float32')for img_name in img_names]
imgs = np.array(imgs) / 255        # 归一化# 训练
optim = SGD()
model.compile(loss='categorical_crossentropy', optimizer=optim, metrics=['accuracy'])        # 编译模型# 预测样本类别
predictions = model.predict_classes(imgs)
print(predictions)
# [8 2 3]

结果截图如下:

欢迎交流! QQ:3408649893

Keras中predict()方法和predict_classes()方法的区别相关推荐

  1. python predict_对Keras中predict()方法和predict_classes()方法的区别说明

    1 predict()方法 当使用predict()方法进行预测时,返回值是数值,表示样本属于每一个类别的概率,我们可以使用numpy.argmax()方法找到样本以最大概率所属的类别作为样本的预测标 ...

  2. Keras中predict()方法和predict_classes()方法和evaluate()方法

     predict()方法         当使用predict()方法进行预测时,返回值是数值,表示样本属于每一个类别的概率,我们可以使用numpy.argmax()方法找到样本以最大概率所属的类别作 ...

  3. sklearn 中 predict 方法和 predict_proba 方法的区别和使用

    一.predict 和 predict_proba的概念和区别     1.predict和predict_proba都是用于模型的预测.     2.predict返回的是一个预测的值,predic ...

  4. java——Scanner中nextLine()方法和next()方法的区别

    遇到一个有意思的东西,在整理字符串这块知识的时候,发现我在用Scanner函数时,在字符串中加入空格,结果空格后面的东西没有输出来(/尴尬),不多说直接上代码: import java.util.Sc ...

  5. ExtJS中listener方法和handler方法的区别

    listener方法和handler方法的区别在文档中的说明的太玄乎了,看不懂 listeners监听能够对一个click Event事件添加任意多个的事件响应处理函数 而handler处理只能够通过 ...

  6. VBA中Activate方法和Select方法的区别

    VBA中的Activate方法和Select方法看起来似乎相同,其实二者是有区别的.Activate方法的作用是激活,而Select方法的作用是选择.其区别如下: 1.对于"Sheets&q ...

  7. jquery中prop()方法和attr()方法的区别浅析

    引用:http://www.jb51.net/article/41170.htm 官方例举的例子感觉和attr()差不多,也不知道有什么区别,既然有了prop()这个新方法,不可能没用吧,那什么时候该 ...

  8. Scanner中nextLine()方法和next()方法的区别

    我们在使用扫描器Scanner时,遇到了字符串肯定会使用API中定义好的next()和nextLine()方法.两者一个是能读取空格一个是不能读取空格就像下面的样子 当我们把二者交换位置,再来看一下效 ...

  9. Hibernate中get方法和load方法的区别

    一.get和load方法都是根据id去获得对应数据的,但是获得机制不同:如果使用get方法,hibernate会去确认该id对应的数据是否存在,它首先会去session中去查询(session缓存其实 ...

最新文章

  1. BZOJ 3585: mex( 离线 + 线段树 )
  2. 版是什么_雕版研习 | 什么是版画?版是画的母亲,画是版的子女
  3. 如何查看linux下的环境变量
  4. 如何使用以下命令 ls cat mv touch 以及如何使用 explainshell.com 这个网站
  5. python是什么课程-请问自学 Python 有必要买课程吗?
  6. 在EXCEL中进行趋势拟合与预测的方法
  7. 第四篇:在MVPArms中报错error: cannot find symbol class DaggerXXXComponent的问题
  8. 计算椭圆运动轨迹的算法
  9. Allegro 走高速线等长线时怎么画成椭圆的走线
  10. 【Git】clone项目push项目没反应,Cloning into...没下载
  11. 密码强度正则表达式 – 必须包含大写字母,小写字母和数字,至少8个字符等...
  12. 五星大饭店续集剧情大放送(最新更新)
  13. Contiki-NG在GD32F310的移植
  14. debian apache2不执行php,Debian下Apache2的安装与配置
  15. 学习笔记(15):C++编程FFMpeg(QT5+OpenCV)实战--实时美颜直播推流-opencv播放rtsp海康摄像头和播放系统摄像头...
  16. CTF-密码学-bacon
  17. 欢乐的票圈重构——九宫格控件(上)
  18. GridView ---->Indicator
  19. P1345 [USACO5.4]奶牛的电信Telecowmunication
  20. Lumerical官方案例、FDTD时域有限差分法仿真学习(九)——布拉格光栅(Bragg gratings)

热门文章

  1. 【JavaWeb】Request对象详解
  2. linux vim 编辑 保存 退出
  3. t3软件怎么生成报表_t3财务报表
  4. 计算机对音乐课堂的帮助,电脑音乐在音乐教学中的应用
  5. [webView stopLoading]; 和 [webView release];
  6. 一篇博客教会你写序列化工具
  7. speedoffice(PPT)插入的图片如何自动适合幻灯片页面大小呢?
  8. Pytorch函数之topk()方法
  9. 高盐废水如何处理,离子交换树脂在高盐废水中的应用
  10. 云计算实训之项目3-基于微信实现自动化监控报警