softmax函数

softmax函数接收一个N维向量作为输入,然后把每一维的值转换到(0, 1)之间的一个实数。假设模型全连接网络输出为a,有C个类别,则输出为a1,a2,...,aC,对于每个样本,属于类别i的输出概率为:

属于各个类别的概率和为1。

贴一张形象的说明图:

如图将原来输入的3,1,-3通过softmax函数的作用,映射成为(0,1)的值,而这些值的累和为1(满足概率的性质),我们可以将它理解成概率,在最后选取输出结点的时候,我们就可以选取概率值最大的结点,作为我们的预测目标。

softmax导数

对softmax求导即:

当i = j 时:

当i ≠ j时:

softmax数值稳定性

传入数据[1, 2, 3, 4, 5]时

传入数据[1000, 2000, 3000, 4000, 5000]时

导致输出是nan的原因是exp(x)对较大的数求指数溢出的问题。

一般的做法是额外加上一个非零常数,使所有的输入在0的附近。

比如:

def softmax(x):shift_x = x - np.max(x)exp_x = np.exp(shift_x)return exp_x / np.sum(exp_x)

交叉熵:用来判定实际的输出与期望的输出的接近程度!

刻画的是实际输出与期望输出的距离,也就是交叉熵的值越小,两个概率分布就越接近,假设概率分布p为期望输出,概率分布q为实际输出,H(p,q)为交叉熵,则:

或者:

Tensorflow中对交叉熵的计算可以采用两种方式

1.手动实现:

import tensorflow as tfinput = tf.placeholder(dtype=tf.float32, shape=[None, 28*28])
output = tf.placeholder(dtype=tf.float32, shape=[None, 10])w_fc1 = tf.Variable(tf.truncated_normal([28*28, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
h_fc1 = tf.matmul(input, w_fc1) + b_fc1w_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
logits = tf.nn.softmax(tf.matmul(h_fc1, w_fc2) + b_fc2)cross_entropy = -tf.reduce_sum(output * tf.log(logits))

output是one-hot类型的实际输出,logits是对全连接的输出用softmax进行转换为概率值的预测,最后通过cross_entropy = -tf.reduce_sum(label * tf.log(y))求出交叉熵的。

2.tf.nn.softmax_cross_entropy_with_logits:

tensorflow已经对softmax和交叉熵进行了封装

import tensorflow as tfinput = tf.placeholder(dtype=tf.float32, shape=[None, 28*28])
output = tf.placeholder(dtype=tf.float32, shape=[None, 10])w_fc1 = tf.Variable(tf.truncated_normal([28*28, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
h_fc1 = tf.matmul(input, w_fc1) + b_fc1w_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
logits = tf.matmul(h_fc1, w_fc2) + b_fc2cross_entropy = -tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=output, logits=logits))

函数的参数logits在函数内会用softmax进行处理,所以传进来时不能是softmax的输出。

官方的封装函数会在内部处理数值不稳定等问题,如果选择方法1,需要自己在softmax函数里面添加trick。

tensoflow随笔——softmax和交叉熵相关推荐

  1. 神经网络学习中的SoftMax与交叉熵

    简 介: 对于在深度学习中的两个常见的函数SoftMax,交叉熵进行的探讨.在利用paddle平台中的反向求微分进行验证的过程中,发现结果 与数学定义有差别.具体原因还需要之后进行查找. 关键词: 交 ...

  2. 人脸识别-Loss-2010:Softmax Loss(Softmax激活函数 + “交叉熵损失函数”)【样本3真实标签为c_5,则样本3的损失:loss_3=-log(\hat{y}_5^3)】

    一般一个CNN网络主要包含卷积层,池化层(pooling),全连接层,损失层等. 全连接层:等号左边部分就是全连接层做的事, W W W 是全连接层的参数,我们也称为权值, X X X 是全连接层的输 ...

  3. pytoch人工神经网络基础:最简单的分类(softmax回归+交叉熵分类)

    softmax回归分类原理 对于回归问题,可以用模型预测值与真实值比较,用均方误差这样的损失函数表示误差,迭代使误差最小训练模型. 那么分类问题是否可以用线性回归模型预测呢.最简单的方法就是用soft ...

  4. 度量学习(Metric learning)—— 基于分类损失函数(softmax、交叉熵、cosface、arcface)

    概述 首先,我们把loss归为两类:一类是本篇讲述的基于softmax的,一类是基于pair对的(如对比损失.三元损失等). 基于pair对的,参考我的另一篇博客: https://blog.csdn ...

  5. 【机器学习基础】Softmax与交叉熵的数学意义(信息论与概率论视角)

    经过我长时间的观察,发现很多人对人工智能/机器学习的理解只停留在"这是个经典/最新棒槌,我拿这个棒槌敲钉子贼6--"这个level.当然,如果真的敲得很6,那也是一个很厉害的大佬了 ...

  6. 图像分类_03分类器及损失:线性分类+ SVM损失+Softmax 分类+交叉熵损失

    2.3.1 线性分类 2.3.1.1 线性分类解释 上图图中的权重计算结果结果并不好,权重会给我们的猫图像分配⼀个⾮常低的猫分数.得出的结果偏向于狗. 如果可视化分类,我们为了⽅便,将⼀个图⽚理解成⼀ ...

  7. 动手学深度学习——softmax回归之OneHot、softmax与交叉熵

    目录 一.从回归到多类分类 1. 回归估计一个连续值 2. 分类预测一个离散类别 二.独热编码OneHot 三.校验比例--激活函数softmax 四.损失函数--交叉熵 五.总结 回归可以用于预测多 ...

  8. 神经网络适用于分类问题的最后一层-Softmax和交叉熵损失介绍及梯度推导

    前言 传统机器学习中两大经典任务就是回归与分类.分类在深度学习中也很常见,令我印象最深的是图像分类.当然,在NLP中,分类也无处不在.从RNN与其变体,到Transformer.Bert等预训练模型, ...

  9. CNN基础知识 || softmax与交叉熵

    一.函数 1 sigmoid 将一个数值通过函数映射到0-1之间. sigmoid函数表达式如下                                                     ...

最新文章

  1. 成功解决SyntaxError: (unicode error) ‘unicodeescape‘ codec can‘t decode bytes in position 0-1: malformed
  2. Android短视频中如何实现720P磨皮美颜录制
  3. linux怎么修改sftp默认端口,转:linux 修改sftp服务默认提供者sshd的session timeout
  4. mysql 分析服务_MySQL分析服务器状态_MySQL
  5. Javascript document对象常用的方法和属性
  6. paip.调试js 查看元素事件以及事件断点
  7. 算法设计与分析:Jewels and Stones(Week 1)
  8. 基于SSM的宠物医院信息管理系统javaweb毕业设计项目源码论文
  9. Win10指定用户访问共享文件及“无法访问。你可能没有权限使用网络资源。”问题解决
  10. HIT CS:APP 计算机系统大作业 《程序人生-Hello’s P2P》
  11. 给父母的礼物!一键让Android变身老人机
  12. 直击JDD | 京东开启技术服务元年:携手合作伙伴,共创产业未来
  13. 数据库关系代数中除运算讲解和SQL语句的实现
  14. 区分: 间宾直宾(双宾语) 宾补(复合宾语)
  15. ubuntu16.04下解决wps无法使用五笔输入中文的问题
  16. C#关闭、启动、重启IIS
  17. 赛尔102S助力云南开展2020年白马雪山国家级自然保护区低空无人机生态监测
  18. 视觉检测应用之电路板二维码读取
  19. 区块链需要学习哪些东西_2020年学区块链需要什么基础?(非常详细)
  20. 自己如何建立一个公司网站?公司网站建设策划书怎么写?

热门文章

  1. bzoj 1648: [Usaco2006 Dec]Cow Picnic 奶牛野餐(暴力DFS)
  2. python库skimage 将针对灰度图像的滤波器用于RGB图像
  3. torch.utils.data.DataLoader 详解
  4. k8s优先级priority的使用
  5. matlab2c使用c++实现matlab函数系列教程-rot90函数
  6. 回顾InfoGAN与隐变量
  7. 使用Git将本地文件夹同步至github
  8. VS2012统计代码量
  9. 消息中间件解决方案-JMS-ActiveMQ
  10. c++ 输出格式控制