@创建于:2022.04.17
@修改于:2022.04.17

文章目录

  • 1、方法介绍
  • 2、建议使用predict(),进而改写
  • 3、predict()参数介绍
  • 4、参考链接

predict() 在tf.keras.Sequential 和 tf.keras.Model模块都有效;
predict_classes()、predict_proba()方法 在tf.keras.Sequential 模块下有效,在tf.keras.Model模块下无效。

1、方法介绍

predict()方法预测时,返回值是数值,表示样本属于每一个类别的概率。

predict_proba() 方法预测时,返回值是数值,表示样本属于每一个类别的概率。与predict输出结果一致。

predict_classes() 方法预测时,返回的是类别的索引,即该样本所属的类别标签。

predict_classes() 和 predict_proba()方法将逐渐被弃用了,不建议尝试了。虽然在tensorflow2.5.0中还存在,如果使用会报出如下警告。

UserWarning: `model.predict_proba()` is deprecated and will be removed after 2021-01-01. Please use `model.predict()` instead.warnings.warn('`model.predict_proba()` is deprecated and '

2、建议使用predict(),进而改写

predict_classes() 已被弃用并且将在 2021-01-01 之后删除,请改用下面做法。这样能够与tf.keras.Model模块保持一致,实现统一。

  • 如果你的模型进行多类分类(例如,如果它使用 softmax 最后一层激活),使用np.argmax获得最大值索引。因为深度学习的多分类模型里面需要使用独热向量编码,或参考keras中to_categorical函数解析
np.argmax(model.predict(x), axis=-1)
  • 如果你的模型进行二分类(例如,如果它使用 sigmoid 最后一层激活),使用下面代码获得分类
(model.predict(x) > 0.5).astype("int32")

3、predict()参数介绍

  def predict(self,x,batch_size=None,verbose=0,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False):"""Generates output predictions for the input samples.Computation is done in batches. This method is designed for performance inlarge scale inputs. For small amount of inputs that fit in one batch,directly using `__call__` is recommended for faster execution, e.g.,`model(x)`, or `model(x, training=False)` if you have layers such as`tf.keras.layers.BatchNormalization` that behaves differently duringinference. Also, note the fact that test loss is not affected byregularization layers like noise and dropout.Args:x: Input samples. It could be:- A Numpy array (or array-like), or a list of arrays(in case the model has multiple inputs).- A TensorFlow tensor, or a list of tensors(in case the model has multiple inputs).- A `tf.data` dataset.- A generator or `keras.utils.Sequence` instance.A more detailed description of unpacking behavior for iterator types(Dataset, generator, Sequence) is given in the `Unpacking behaviorfor iterator-like inputs` section of `Model.fit`.batch_size: Integer or `None`.Number of samples per batch.If unspecified, `batch_size` will default to 32.Do not specify the `batch_size` if your data is in theform of dataset, generators, or `keras.utils.Sequence` instances(since they generate batches).verbose: Verbosity mode, 0 or 1.steps: Total number of steps (batches of samples)before declaring the prediction round finished.Ignored with the default value of `None`. If x is a `tf.data`dataset and `steps` is None, `predict` willrun until the input dataset is exhausted.callbacks: List of `keras.callbacks.Callback` instances.List of callbacks to apply during prediction.See [callbacks](/api_docs/python/tf/keras/callbacks).max_queue_size: Integer. Used for generator or `keras.utils.Sequence`input only. Maximum size for the generator queue.If unspecified, `max_queue_size` will default to 10.workers: Integer. Used for generator or `keras.utils.Sequence` inputonly. Maximum number of processes to spin up when usingprocess-based threading. If unspecified, `workers` will defaultto 1.use_multiprocessing: Boolean. Used for generator or`keras.utils.Sequence` input only. If `True`, use process-basedthreading. If unspecified, `use_multiprocessing` will default to`False`. Note that because this implementation relies onmultiprocessing, you should not pass non-picklable arguments tothe generator as they can't be passed easily to children processes.

4、参考链接

Keras中predict()方法和predict_classes()方法的区别

Keras非顺序模型没有model.predict_classes()方法如何获取测试数据分类的标签

tensorflow.keras.models.Sequential——predict()、predict_classes()、predict_proba()方法的区别相关推荐

  1. tensorflow2.1学习--tf.keras学习之tf.keras.models.Sequential

    tf.keras.models.Sequential是描述网络层架构的一个api,是顺序的结构即一层一层的描述,但是对于跳跃式的就不行,需要使用自定义层,或者使用类实现 . import tensor ...

  2. sklearn 中 predict 方法和 predict_proba 方法的区别和使用

    一.predict 和 predict_proba的概念和区别     1.predict和predict_proba都是用于模型的预测.     2.predict返回的是一个预测的值,predic ...

  3. 【Python学习】 - TensorFlow.keras 不显示epochs进度条的方法

    一.概述 在我们使用TensorFlow进行神经网络的搭建时,难免遇到需要训练很多次来拟合数据的情况,假设需要拟合1000次数据,那么可能前800次的拟合效果都不是很好,所以显示进度条就会使得输出面板 ...

  4. keras.models导入Sequential错误

    刚开始pip的最新版本的keras,找不到keras.models. keras.layers from keras.models import Sequential from keras.layer ...

  5. [深度学习-实践]GAN入门例子-利用Tensorflow Keras与数据集CIFAR10生成新图片

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子; 深度学习GAN(三)之基于手写体Mnist数据集的例子; 深度学习GAN(四)之PIX2PIX G ...

  6. Tensorflow.Keras 时序回归模型的建立

    Learn from Keras官方网站 目前tensorflow 2.0版本已经集成了keras的所有功能,所以直接安装tensorflow就可以调用Keras,非常方便. 作为Keras入门可以看 ...

  7. tensorflow的容器Sequential 笔记心得

    tensorflow的Sequential不能够改名字,源代码处仅为可读,可以修改为读写,layer = model.layers,layer.name. Sequential可以添加模型,层,比较灵 ...

  8. 用Keras构建神经网络的3种方法

    作者|Orhan Gazi Yalçın 编译|VK 来源|Towards Datas Science 如果你看看不同的教程,搜索,花大量时间研究关于TensorFlow的Stack Overflow ...

  9. TensorFlow(keras)入门课程--06 CNN用于猫狗数据集

    目录 1 简介 在本节中,我们将学习如何使用卷积神经网络,并使用更大的数据集,这有助于避免过度拟合的问题! 2 使用更大的数据集进行训练-猫和狗 在之前的实验中,训练了一个马与人类数据的分类器.尽管在 ...

最新文章

  1. 三维点云分割综述(中)
  2. 不想被AI降维打击?美国“四院院士”写的DL科普书了解一下
  3. linux下JDK的安装
  4. WITH AS【原创】
  5. 一个历时五天的 Bug
  6. 关于火车票预定助手的声明
  7. springboot-自动配置流程
  8. greenplum 单表 数据扫描
  9. UVa340 Master-Mind Hints
  10. jquery事件绑定与事件委托
  11. 机器学习基础:模糊C均值聚类(Machine Learning Fundamentals: Fuzzy C-Means )Python实现
  12. NHibernate one-to-one 关系的几点说明
  13. 随便谈谈alphago与人机大战
  14. nginx三种负载均衡的方式
  15. Redis复习记录(二):数据类型与基本操作
  16. 微信公众号token验证问题
  17. 【JAVA】PAT 乙级 1059 C语言竞赛(测试点1、2超时) 内含1-10000的素数表和0-10000是否素数的boolean值
  18. 如何卸载Win10系统内置的应用
  19. mysql 单表关联_MySQL 基础之 单表、多表联查
  20. python中average什么意思_在Python3 numpy中mean和average的区别详解

热门文章

  1. 【已解决】找到无效的 Gradle JDK 配置(invalid Gradle JDK configuration found)
  2. css3选择器详细探索
  3. CSS3选择器及权重
  4. 【lwIP(第三章)】内存管理
  5. 前端开发实习面试题(JavaScript篇)
  6. mysql cast()与convert() 函数
  7. bzoj 2876: [Noi2012]骑行川藏 拉格朗日乘子法
  8. 如何创建数据链接文件
  9. 微信小程序之日期时间筛选器实现(支持年月日时分)
  10. 常用搜索引擎链接及参数