训练一个 Softmax 分类器(Training a Softmax classifier)

上一个笔记中我们学习了Softmax层和Softmax激活函数,在这个笔记中,你将更深入地了解Softmax分类,并学习如何训练一个使用了Softmax层的模型。

回忆一下我们之前举的的例子,输出层计算出的z^([l])如下,

我们有四个分类C=4,z^([l])可以是4×1维向量,我们计算了临时变量t

对元素进行幂运算,最后如果你的输出层的激活函数g^([L]) ()是Softmax激活函数,那么输出就会是这样的:

简单来说就是用临时变量t将它归一化,使总和为1,于是这就变成了a^([L]),你注意到向量z中,最大的元素是5,而最大的概率也就是第一种概率。

Softmax这个名称的来源是与所谓hardmax对比,hardmax会把向量z变成这个向量

,hardmax函数会观察z的元素,然后在z中最大元素的位置放上1,其它位置放上0,所这是一个hard max,也就是最大的元素的输出为1,其它的输出都为0。与之相反,Softmax所做的从z到这些概率的映射更为温和,我不知道这是不是一个好名字,但至少这就是softmax这一名称背后所包含的想法,与hardmax正好相反。

有一点我没有细讲,但之前已经提到过的,就是Softmax回归或Softmax激活函数将logistic激活函数推广到C类,而不仅仅是两类,结果就是如果C=2,那么C=2的Softmax实际上变回了logistic回归,我不会在这个笔记中给出证明,但是大致的证明思路是这样的,

如果C=2,并且你应用了Softmax,那么输出层a^([L])将会输出两个数字,如果C=2的话,也许输出0.842和0.158,对吧?这两个数字加起来要等于1,因为它们的和必须为1,其实它们是冗余的,也许你不需要计算两个,而只需要计算其中一个,结果就是你最终计算那个数字的方式又回到了logistic回归计算单个输出的方式。

这算不上是一个证明,但我们可以从中得出结论,Softmax回归将logistic回归推广到了两种分类以上。

接下来我们来看怎样训练带有Softmax输出层的神经网络,具体而言,我们先定义训练神经网络使会用到的损失函数。

举个例子,我们来看看训练集中某个样本的目标输出,真实标签是

,用上一个笔记中讲到过的例子,这表示这是一张猫的图片,因为它属于类1,现在我们假设你的神经网络输出的是^y,^y是一个包括总和为1的概率的向量,

,你可以看到总和为1,这就是a^([l]),

对于这个样本神经网络的表现不佳,这实际上是一只猫,但却只分配到20%是猫的概率,所以在本例中表现不佳。

那么你想用什么损失函数来训练这个神经网络?

在Softmax分类中,我们一般用到的损失函数是

我们来看上面的单个样本来更好地理解整个过程。

注意在这个样本中y_1=y_3=y_4=0,因为这些都是0,只有y_2=1,如果你看这个求和,所有含有值为0的y_j的项都等于0,最后只剩下-y_2 tlog^y_2,

因为当你按照下标j全部加起来,所有的项都为0,除了j=2时,又因为y_2=1,所以它就等于- log^y_2。

这就意味着,如果你的学习算法试图将它变小,因为梯度下降法是用来减少训练集的损失的,要使它变小的唯一方式就是使-log^y_2变小,要想做到这一点,就需要使^y_2尽可能大,因为这些是概率,所以不可能比1大,但这的确也讲得通,因为在这个例子中x是猫的图片,你就需要这项输出的概率尽可能地大

概括来讲,损失函数所做的就是它找到你的训练集中的真实类别,然后试图使该类别相应的概率尽可能地高,如果你熟悉统计学中最大似然估计,这其实就是最大似然估计的一种形式。但如果你不知道那是什么意思,也不用担心,用我们刚刚讲过的算法思维也足够了。

这是单个训练样本的损失,整个训练集的损失J又如何呢?

也就是设定参数的代价之类的,还有各种形式的偏差的代价,它的定义你大致也能猜到,就是整个训练集损失的总和,把你的训练算法对所有训练样本的预测都加起来,

因此你要做的就是用梯度下降法,使这里的损失最小化。

最后还有一个实现细节,注意因为C=4,y是一个4×1向量,y也是一个4×1向量,如果你实现向量化,矩阵大写Y就是[y^((1)) y^((2))…… y^((m) )],例如如果上面这个样本是你的第一个训练样本,那么矩阵

那么这个矩阵Y最终就是一个4×m维矩阵。类似的,^Y=[^y^((1)) ^y^((2))…… ^y^((m))],这个其实就是^y^((1))

,或是第一个训练样本的输出,那么

,^Y本身也是一个4×m维矩阵。

最后我们来看一下,在有Softmax输出层时如何实现梯度下降法,这个输出层会计算z^([l]),它是C×1维的,在这个例子中是4×1,然后你用Softmax激活函数来得到a^([l])或者说y,然后又能由此计算出损失。

我们已经讲了如何实现神经网络前向传播的步骤,来得到这些输出,并计算损失,那么反向传播步骤或者梯度下降法又如何呢?

其实初始化反向传播所需要的关键步骤或者说关键方程是这个表达式dz^([l])=^y-y,你可以用^y这个4×1向量减去y这个4×1向量,你可以看到这些都会是4×1向量,当你有4个分类时,在一般情况下就是C×1,这符合我们对dz的一般定义,这是对z^([l])损失函数的偏导数(dz^([l])=∂J/(∂z^([l]) )),如果你精通微积分就可以自己推导,或者说如果你精通微积分,可以试着自己推导,但如果你需要从零开始使用这个公式,它也一样有用。

有了这个,你就可以计算dz^([l]),然后开始反向传播的过程,计算整个神经网络中所需要的所有导数。

在后面,我们会讲解一些深度学习编程框架,对于这些编程框架,通常你只需要专注于把前向传播做对,只要你将它指明为编程框架,前向传播,它自己会弄明白怎样反向传播,会帮你实现反向传播,所以这个表达式值得牢记(dz^([l])=^y-y),如果你需要从头开始,实现Softmax回归或者Softmax分类

训练softmax分类器实例_吴恩达深度学习笔记(56)-训练一个 Softmax 分类器相关推荐

  1. 深度学习如何提高训练集准确率_吴恩达深度学习笔记(61)-训练调参中的准确率和召回率...

    单一数字评估指标(Single number evaluation metric) 无论你是调整超参数,或者是尝试不同的学习算法,或者在搭建机器学习系统时尝试不同手段,你会发现,如果你有一个单实数评估 ...

  2. 准确率 召回率_吴恩达深度学习笔记(61)-训练调参中的准确率和召回率

    单一数字评估指标(Single number evaluation metric) 无论你是调整超参数,或者是尝试不同的学习算法,或者在搭建机器学习系统时尝试不同手段,你会发现,如果你有一个单实数评估 ...

  3. 吴恩达深度学习代码_吴恩达深度学习笔记(58)-深度学习框架Tensorflow

    TensorFlow 有很多很棒的深度学习编程框架,其中一个是TensorFlow,很期待帮助你开始学习使用TensorFlow,我想在这个笔记中向你展示TensorFlow程序的基本结构,然后让你自 ...

  4. yolo算法_吴恩达深度学习笔记(100)-目标检测之YOLO 算法讲解

    YOLO 算法(Putting it together: YOLO algorithm) 你们已经学到对象检测算法的大部分组件了,在这个笔记里,我们会把所有组件组装在一起构成YOLO对象检测算法. 我 ...

  5. 创建一列矩阵数字一样吗_吴恩达深度学习笔记(122) | NLP | 嵌入矩阵Embedding Matrix...

    嵌入矩阵(Embedding Matrix) 接下来我们要将学习词嵌入这一问题具体化,当你应用算法来学习词嵌入时,实际上是学习一个嵌入矩阵,我们来看一下这是什么意思. 和之前一样,假设我们的词汇表含有 ...

  6. 吴恩达深度学习笔记(四)

    吴恩达深度学习笔记(四) 卷积神经网络CNN-第二版 卷积神经网络 深度卷积网络:实例探究 目标检测 特殊应用:人脸识别和神经风格转换 卷积神经网络编程作业 卷积神经网络CNN-第二版 卷积神经网络 ...

  7. 吴恩达深度学习笔记——卷积神经网络(Convolutional Neural Networks)

    深度学习笔记导航 前言 传送门 卷积神经网络(Convolutional Neural Networks) 卷积神经网络基础(Foundations of Convolutional Neural N ...

  8. 吴恩达深度学习笔记——结构化机器学习项目(Structuring Machine Learning Projects)

    深度学习笔记导航 前言 传送门 结构化机器学习项目(Machine Learning Strategy) 机器学习策略概述 正交化(orthogonalization) 评价指标 数字评估指标的单一性 ...

  9. 吴恩达深度学习笔记1-Course1-Week1【深度学习概论】

    2018.5.7 吴恩达深度学习视频教程网址 网易云课堂:https://mooc.study.163.com/smartSpec/detail/1001319001.htm Coursera:htt ...

最新文章

  1. 计算机视觉 | 图像描述与注意力机制
  2. 用批处理查询电脑信息
  3. Java ConcurrentModificationException异常原因和解决方法
  4. android 关闭多点触控_Cocos Creator关闭多点触摸的问题
  5. SQL --分支取数据
  6. Python 基本数据类型 (二) - 字符串
  7. 网易财报暗藏玄机,不经意间已编织出电商大网
  8. js获取非行间样式--有bug,忧伤
  9. 关于ADO.Net连接池(Connection Pool)的一些个人见解
  10. Download SQL Server Management Studio (SSMS)下载地址
  11. java configuration类_使用@Configuration编写自定义配置类
  12. ASP.NET - 将 ASP.NET 用作高性能文件下载器
  13. Android调试出现问题:failed to connect to /10.0.2.2 (port 8080) from /192.168.31.150 (port 37592) after 300
  14. 金士顿DT100 G3 PS2251-07海力士U盘量产修复成功教程
  15. 如何营造游戏的打击感(一)
  16. Web前端和Web后端的区分
  17. matlab2019b重装导致mjs安装失败问题解决
  18. 一种基于STM32F1 MCU的增量型编码器测速的方法
  19. C语言五子棋的项目背景,五子棋项目源码!
  20. (有小案例)初始Mybatis框架及使用

热门文章

  1. GIT与SVN的比较
  2. Linux驱动之oops错误:addr2line工具定位错误
  3. matlab输入信号延迟simulink,Simulink仿真报错积分器不收敛存在奇异点的问题及Simulink仿真信号延迟问题...
  4. MFC-最简单的MFC程序
  5. ArcGis 中打开 shp 文件时 未知的空间参考 警告
  6. 基于微信在线教育视频学习小程序毕业设计毕设作品(6)开题答辩PPT
  7. java说课_Java说课演示稿.ppt
  8. ubuntu安装截图工具 flameshot(对标windows下snipaste)
  9. SQL 替换特定字符
  10. Shopify搞Dropshipping模板评测二 – Konversion