【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法

from:https://blog.csdn.net/mao_xiao_feng/article/details/53382790

在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?

首先明确一点,loss是代价值,也就是我们要最小化的值

tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)

除去name参数用以指定该操作的name,与方法有关的一共两个参数

第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes

第二个参数labels:实际的标签,大小同上

 

具体的执行流程大概分为两步:

第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)

softmax的公式是:

至于为什么是用的这个公式?这里不介绍了,涉及到比较多的理论证明

 

第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:


其中指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)

就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值

显而易见,预测越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss

注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!

 

理论讲完了,上代码

[python] view plaincopy
  1. import tensorflow as tf
  2. #our NN's output
  3. logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
  4. #step1:do softmax
  5. y=tf.nn.softmax(logits)
  6. #true label
  7. y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
  8. #step2:do cross_entropy
  9. cross_entropy = -tf.reduce_sum(y_*tf.log(y))
  10. #do cross_entropy just one step
  11. cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!!
  12. with tf.Session() as sess:
  13. softmax=sess.run(y)
  14. c_e = sess.run(cross_entropy)
  15. c_e2 = sess.run(cross_entropy2)
  16. print("step1:softmax result=")
  17. print(softmax)
  18. print("step2:cross_entropy result=")
  19. print(c_e)
  20. print("Function(softmax_cross_entropy_with_logits) result=")
  21. print(c_e2)

输出结果是:

[python] view plaincopy
  1. step1:softmax result=
  2. [[ 0.09003057  0.24472848  0.66524094]
  3. [ 0.09003057  0.24472848  0.66524094]
  4. [ 0.09003057  0.24472848  0.66524094]]
  5. step2:cross_entropy result=
  6. 1.22282
  7. Function(softmax_cross_entropy_with_logits) result=
  8. 1.2228

最后大家可以试试e^1/(e^1+e^2+e^3)是不是0.09003057,发现确实一样!!这也证明了我们的输出是符合公式逻辑的

转载于:https://www.cnblogs.com/bonelee/p/8995936.html

【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法相关推荐

  1. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits 函数:求交叉熵损失

    [TensorFlow]tf.nn.softmax_cross_entropy_with_logits的用法_xf__mao的博客-CSDN博客 https://blog.csdn.net/mao_x ...

  2. 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits中的“logits”到底是个什么意思?

    tf.nn.softmax_cross_entropy_with_logits中的"logits"到底是个什么意思?_玉来愈宏的随笔-CSDN博客 https://blog.csd ...

  3. 【TensorFlow】TensorFlow函数精讲之tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax_cross_entropy_with_logits()函数是TensorFlow中计算交叉熵常用的函数. 后续版本中,TensorFlow更新为:tf.nn.softmax ...

  4. TensorFlow基础篇(三)——tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax_cross_entropy_with_logits()函数是TensorFlow中计算交叉熵常用的函数. 后续版本中,TensorFlow更新为:tf.nn.softmax ...

  5. tf.nn.softmax_cross_entropy_with_logits()笔记及交叉熵

    交叉熵 交叉熵可在神经网络(机器学习)中作为损失函数,p表示真实标记的分布,q则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量p与q的相似性.交叉熵作为损失函数还有一个好处是使用sigmoid函 ...

  6. tf.nn.sparse_softmax_cross_entropy_with_logits()与tf.nn.softmax_cross_entropy_with_logits的差别

    这两个函数的用法类似 sparse_softmax_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=N ...

  7. tf.nn.softmax_cross_entropy_with_logits 和 tf.contrib.legacy_seq2seq.sequence_loss_by_example 的联系与区别

    文章目录 0.函数介绍 1.区别联系 1.1 tf.nn.softmax_cross_entropy_with_logits 1.2 tf.nn.sparse_softmax_cross_entrop ...

  8. tf.nn.rnn_cell.DropoutWrapper用法细节案例2

    -- coding: utf-8 -- import tensorflow as tf from tensorflow.contrib import rnn 导入 MINST 数据集 from ten ...

  9. tf.nn.embedding_lookup()的用法

    函数: tf.nn.embedding_lookup( params, ids, partition_strategy='mod', name=None, validate_indices=True, ...

最新文章

  1. 怎么修改网页服务器数据库连接,如何修改网页服务器数据库连接
  2. Linux启动网卡时出现RTNETLINK answers: File exists错误解决方法
  3. WMS仓储管理系统有那些功能?
  4. Part Six 地理定位API
  5. LearnOpenGL之OpenGL特性
  6. 用SqlConnectionStringBuilder修改连接超时时间
  7. 新概念51单片机c语言教程doc,新概念51单片机C语言教程实例代码.doc
  8. 虚拟机系统和windows主机系统的文件互传方法 ——WinSCP使用
  9. 雷达信号处理基础ppt
  10. 【C语言】案例五十一 员工档案管理系统
  11. Rk3288 android USB双摄像头录制视频
  12. c语言 计算平均分
  13. js动态点击放大缩小图片
  14. matlab中 %d,%f,%c,%s代表什么意思
  15. 浅谈javascript面向对象理解
  16. 【turtle库】Python绘制圣诞树
  17. python 做网站_怎么用python做网站
  18. 前端面试笔试错题指南(四)
  19. 解决ini-parser解析ini文件中文乱码问题
  20. 如何将以前wm手机所备份的bkg文件导入android手机,【极光ROM】-【三星S20(国行/港版/台版/韩版/美版) G981X-高通865】-【V5.0 Android-Q-TI8】...

热门文章

  1. Selenium3自动化测试——23.自动发送邮件功能
  2. 和linux关系_Linux内核Page Cache和Buffer Cache关系及演化历史
  3. linux7.4 配置yum,Centos7.4重装yum
  4. 将线程pid转成16进制_硬件资讯 | AMD 线程撕裂者 5000 系 CPU 将包含 16 核版本
  5. 台式计算机怎么加一个硬盘,如何再安装一个台式计算机硬盘驱动器?如何在计算机安装中添加额外的硬盘...
  6. 语法分析器 java实验报告_词法分析器实验报告.doc
  7. 凭借这份Java面试题集,java上传文件夹
  8. 【Ubuntu入门到精通系列讲解】文件和目录常用命令速查
  9. 【机器学习入门到精通系列】机器学习系统设计(Precision Recall)
  10. 使用keras时下载VGG19过慢的解决方法