softmax函数:

交叉熵:


其中指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)

就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值

显而易见,预测越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss。

过程如下:
第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率).

第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵.

原始值[1, 2, 3] 
softmax后[ 0.09003057  0.24472848  0.66524094]
和标签[0, 0, 1]做交叉熵,实际就是-1 * log(0.66524094)

代码示例:
在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?

首先明确一点,loss是代价值,也就是我们要最小化的值

tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)
除去name参数用以指定该操作的name,与方法有关的一共两个参数:
第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes

第二个参数labels:实际的标签,大小同上。

注意,这个函数的返回值并不是一个数,而是一个向量。如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到;如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!

# -*-coding: utf-8-*-

import tensorflow as tf

# our NN's output
logits = tf.constant([
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0],
[1.0, 2.0, 3.0]])

# steps1 do softmax
y = tf.nn.softmax(logits)

# ture label
y_ = tf.constant([[0.0, 0.0, 1.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 1.0]])

# step2: do cross_entopy
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

# do cross_entropy just one step
cross_entropy2 = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_)) # dont forget tf.reduce_sum()!!

with tf.Session() as sess:
softmax = sess.run(y)
c_e = sess.run(cross_entropy)
c_e2 = sess.run(cross_entropy2)
print("step1:softmax result=")
print(softmax)
print("step2:cross_entropy result=")
print(c_e)
print("Function(softmax_cross_entropy_with_logits) result=")
print(c_e2)

输出:
step1:softmax result=
[[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.22282

自己计算:
soft_max:e^1/(e^1+e^2+e^3)=0.09003057

cross_entropy: -1 * log(0.66524094) * 3= 1.2228

---------------------
作者:大师鲁
来源:CSDN
原文:https://blog.csdn.net/laolu1573/article/details/60138455
版权声明:本文为博主原创文章,转载请附上博文链接!

转载于:https://www.cnblogs.com/andy-0212/p/10242392.html

softmax和cross_entropy相关推荐

  1. TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别

    TF之NN:利用DNN算法(SGD+softmax+cross_entropy)对mnist手写数字图片识别训练集(TF自带函数下载)实现87.4%识别 目录 输出结果 代码设计 输出结果 代码设计 ...

  2. 【深度学习】实验1答案:Softmax实现手写数字识别

    DL_class 学堂在线<深度学习>实验课代码+报告(其中实验1和实验6有配套PPT),授课老师为胡晓林老师.课程链接:https://www.xuetangx.com/training ...

  3. CNN基础知识 || softmax与交叉熵

    一.函数 1 sigmoid 将一个数值通过函数映射到0-1之间. sigmoid函数表达式如下                                                     ...

  4. tensorflow问题

    20210121 ImportError: No module named 'tensorflow.python' https://stackoverflow.com/questions/414156 ...

  5. Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化

    文章目录 1 可视化 神经网络的二元分类效果 2 全连接神经网络 3 TensorFlow搭建一个全连接神经网络 3.1 读取MNIST数据 3.2 建立占位符 3.3 建立模型 3.4 正确率 3. ...

  6. 《Tensorflow实战》之6.3VGGnet学习

    这是我改写的代码,可以运行,但是过拟合现象严重,不知道怎么修改比较好 # -*- coding: utf-8 -*- """ Created on Wed Dec 20 ...

  7. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

  8. 卷积神经网络原理及实现

    卷积神经网络的主要结构是卷积层+池化层,该算法在图像上有较好的效果 小知识:图片有彩色图片和黑白图片,颜色都是有RGB三种颜色调和而成,所以彩色图片有三层通道,黑白图片有一层通道 咱们拿黑白图片说事: ...

  9. TF之DNN:利用DNN【784→500→10】对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程

    TF之DNN:利用DNN[784→500→10]对MNIST手写数字图片识别数据集(TF自带函数下载)预测(98%)+案例理解DNN过程 目录 输出结果 案例理解DNN过程思路 代码设计 输出结果 案 ...

  10. tensorflow-Inception-v3模型训练自己的数据代码示例

    一.声明 本代码非原创,源网址不详,仅做学习参考. 二.代码 1 # -*- coding: utf-8 -*- 2 3 import glob # 返回一个包含有匹配文件/目录的数组 4 impor ...

最新文章

  1. Centos7下安装MongoDB
  2. 千言万语汇总的Mybatis-plus常用API全套教程
  3. 又一个可以弄垮iPhone手机的漏洞...
  4. linux arp 老化时间,Linux实现的ARP缓存老化时间原理解析
  5. 一份非常完整的 MySQL 规范
  6. apache camel_使用Apache Camel进行负载平衡
  7. OpenCV手部关键点检测(手势识别)代码示例
  8. Zookeeper数据一致性原理
  9. 计算机科学与技术的应用图,安徽农业大学计算机科学与技术视图及其应用.ppt...
  10. Oracle 隐含参数的查询
  11. verilog中generate用法及参数传递(转)
  12. ESD静电二极管封装规格,详细介绍
  13. OTSU算法实现二值化
  14. 小鸡腿U R III 杭州2013
  15. Docker Mysql使用学习
  16. 轻量级程序编辑器的选择:EmEditor、Editplus等---Web开发系列之工具篇
  17. 视频教程-WebGL 可视化3D绘图框架:Three.js 零基础上手实战-其他
  18. 利用ELK技术栈收集nginx日志
  19. 离线安装VS2019教程
  20. Muli3D源码分析(1) - 框架概览

热门文章

  1. Converter使用及其原理
  2. mysql的root用户无法给普通用户授权问题处理
  3. 如何在C#窗体中定义全局变量
  4. JavaWeb之路径问题
  5. 如何获取html输入框的值,jQuery如何获取各种input输入框的值
  6. java lombok 插件_idea 安装 lombok 插件
  7. 位置度标注方法图解_追踪主力-散户操盘实战图解:操盘手法分析
  8. 鸿蒙系统公布名单,鸿蒙系统支持名单曝光,有你的手机吗?
  9. 如何使用python提高办公效率-提升Python程序运行效率的6个方法
  10. bde连接mysql设置,delphi通过BDE方式连接数据库以及程序Demo