【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法_xf__mao的博客-CSDN博客 https://blog.csdn.net/mao_xiao_feng/article/details/53382790#

在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?

首先明确一点,loss是代价值,也就是我们要最小化的值

tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

与方法有关的一共两个参数 :

  1. 第一个参数logits:就是神经网络最后一层的输出,如果有batch(批处理)的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes
  2. 第二个参数labels实际的标签,大小同上

具体的执行流程大概分为两步:
第一步是先对网络最后一层的输出做一个softmax(归一化处理),这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3…]其中Y1,Y2,Y3…分别代表了是属于该类的概率

softmax的公式是:

至于为什么是用的这个公式?这里不介绍了,涉及到比较多的理论证明

第二步是softmax的输出向量[Y1,Y2,Y3…]和样本的实际标签,做一个交叉熵,公式如下:

其中yi'指代实际的标签中第 i 个的值
(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)(感觉就是one hot类型的数据)

yi就是softmax的输出向量[Y1,Y2,Y3…]中,第i个元素的值
显而易见,预测yi越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss

注意!!!

  • 这个函数的返回值:并不是一个数,而是一个向量

  • 如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到,

  • 如果求loss,则要做一步tf.reduce_mean操作,对向量求均值

上代码:

import tensorflow as tf#our NN's output 假设为:神经网络的最后一层输出
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax 使用softmax进行归一化处理
y=tf.nn.softmax(logits)
#true label 监督学习中,数据输入网络前的正确的标签
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
#step2:do cross_entropy 求交叉熵损坏的方法一
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#do cross_entropy just one step 求交叉熵损坏的方法二:使用本次讲解的函数
cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!!with tf.Session() as sess:softmax=sess.run(y)c_e = sess.run(cross_entropy)c_e2 = sess.run(cross_entropy2)print("step1:softmax result=")print(softmax)print("step2:cross_entropy result=")print(c_e)print("Function(softmax_cross_entropy_with_logits) result=")print(c_e2)

输出结果是:

step1:softmax result=
[[ 0.09003057  0.24472848  0.66524094][ 0.09003057  0.24472848  0.66524094][ 0.09003057  0.24472848  0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.2228

最后大家可以试试e1/(e1+e2+e3)是不是0.09003057,发现确实一样!!这也证明了 我们的输出是符合公式逻辑的

还有一篇,建议结合起来看:
【TensorFlow】tf.nn.softmax_cross_entropy_with_logits中的“logits”到底是个什么意思?

【TensorFlow】tf.nn.softmax_cross_entropy_with_logits 函数:求交叉熵损失相关推荐

  1. tf.nn.softmax_cross_entropy_with_logits()笔记及交叉熵

    交叉熵 交叉熵可在神经网络(机器学习)中作为损失函数,p表示真实标记的分布,q则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量p与q的相似性.交叉熵作为损失函数还有一个好处是使用sigmoid函 ...

  2. 【TensorFlow】TensorFlow函数精讲之tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax_cross_entropy_with_logits()函数是TensorFlow中计算交叉熵常用的函数. 后续版本中,TensorFlow更新为:tf.nn.softmax ...

  3. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits中的“logits”到底是个什么意思?

    tf.nn.softmax_cross_entropy_with_logits中的"logits"到底是个什么意思?_玉来愈宏的随笔-CSDN博客 https://blog.csd ...

  4. TensorFlow基础篇(三)——tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax_cross_entropy_with_logits()函数是TensorFlow中计算交叉熵常用的函数. 后续版本中,TensorFlow更新为:tf.nn.softmax ...

  5. Softmax函数与交叉熵

    Softmax函数 背景与定义 导数 softmax的计算与数值稳定性 Loss function 对数似然函数 交叉熵 Loss function求导 TensorFlow 方法1手动实现不建议使用 ...

  6. tf.nn.softmax_cross_entropy_with_logits 和 tf.contrib.legacy_seq2seq.sequence_loss_by_example 的联系与区别

    文章目录 0.函数介绍 1.区别联系 1.1 tf.nn.softmax_cross_entropy_with_logits 1.2 tf.nn.sparse_softmax_cross_entrop ...

  7. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法

    [TensorFlow]tf.nn.softmax_cross_entropy_with_logits的用法 from:https://blog.csdn.net/mao_xiao_feng/arti ...

  8. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  9. Softmax函数下的交叉熵损失含义与求导

    交叉熵损失函数(CrossEntropy Function)是分类任务中十分常用的损失函数,但若仅仅看它的形式,我们不容易直接靠直觉来感受它的正确性,因此我查阅资料写下本文,以求彻底搞懂. 1.Sof ...

最新文章

  1. [UE4]RetainerBox,控制UI更新频率,把渲染后的UI当成Texture
  2. ML之SVM:基于sklearn的svm算法实现对支持向量的数据进行标注
  3. VMware vSphere 服务器虚拟化之十七 桌面虚拟化之安装View链接服务器
  4. c语言:宏里面参数不加括号容易出错,在使用时尽量加括号及举例
  5. java.lang.OutOfMemoryError 解决方式
  6. Redis的数据类型之String
  7. ue4缓存位置怎么改_怎么从蓝图节点跳转到C++源码?
  8. 什么是activemq_什么是ActiveMQ?
  9. 轻松部署IE7(下),SMS2003系列之六
  10. Python利用双端队列判断回文词
  11. MTK 驱动(80)---MTK平台User版本开机异常/无法开机,如何抓取log
  12. ExtJs 设置GridPanel表格文本垂直居中
  13. python中的shallow copy 和 deep copy
  14. python界面颜色设置_pycharm修改界面主题颜色的方法
  15. 风尚云网学习-vue项目的构建/打包/发布
  16. 联想小新i1000拆机图解_联想小新笔记本拆机解析
  17. 游戏研发人才学校培养、企业需求与个人快速成长,华科校友分享了这些实用观点
  18. iOS - 距离传感器
  19. 请假工资扣费总额计算机公式,病假扣款计算公式excel
  20. 大数据在职研究生哪个好_大数据在职研究生

热门文章

  1. wxWidgets:显示 wxDebugReport 和相关类的最小示例
  2. wxWidgets:通过组合现有小部件制作新的可重用小部件
  3. boost::mpi::wait_all相关用法的测试程序
  4. boost::hana模块将 reference_wrappers 保存到其元素的元组
  5. boost::hana::test::TestGroup用法的测试程序
  6. boost::intrusive::auto_unlink_hook用法的测试程序
  7. VTK:可视化算法之LOxSeeds
  8. VTK:Utilities之2DArray
  9. VTK:Rendering之Rainbow
  10. VTK:PolyData之ExtractCellsUsingPoints