4. 降低损失 reducing loss

为了训练模型,我们需要一种可降低模型损失的好方法。迭代方法是一种广泛用于降低损失的方法,而且使用起来简单有效。

学习目标

  • 了解如何使用迭代方法来训练模型。
  • 全面了解梯度下降法和一些变体,包括:
    • 小批量梯度下降法
    • 随机梯度下降法
  • 尝试不同的学习速率。

-------------------------------------------------------

降低损失:迭代方法

下图 1 为ML算法用于训练模型的迭代试错过程:

图 1 用于训练模型的迭代方法

“模型”部分将一个或多个特征作为输入,然后返回一个预测 (y') 作为输出。为了简化,不妨考虑用一个特征并返回一个预测的模型:y' = b + w1x1

我们应该为b和w1设置哪些初始值?

对于线性回归问题,事实证明初始值并不重要。我们可以随机选择值,不过我们还是选择采用以下这些无关紧要的值:b=0,w1=0。假设第一个特征值是 10。将该特征值代入预测函数会得到以下结果:y'=0 + 0(10) —> y'=0.

图中的“计算损失”部分是模型将要使用的损失函数。假设我们使用平方损失函数。损失函数将采用两个输入值:y':模型对特征x的预测,y:特征x对应的正确标签。

最后,来看图的计算参数更新部分:机器学习系统就是在此部分检查损失函数的值,并为b和w1生成新值。

假设绿色框会产生新值,然后ML系统将根据所有标签重新评估所有特征,为损失函数生成一个新值,而该值又产生新的参数值。这种学习过程会持续迭代,直到该算法发现损失可能最低的模型参数。通常可以不断迭代,直到总体损失不再变化或至少变化极其缓慢为止。这时可以说该模型已收敛

-------------------------------------------------------

降低损失 (Reducing Loss):梯度下降法

我们将图1中的 计算参数更新 的绿框用更为实质的方法代替:

对于回归问题,所产生的损失与w1的图形始终是凸形图(碗状图),如下图 2 所示


图2 回归问题产生的损失和权重图为凸形

凸形问题只有一个最低点;即只存在一个斜率正好为 0 的位置。这个最小值就是损失函数收敛之处

通过计算整个数据集中w1每个可能值的损失函数来找到收敛点这种方法效率太低。

一种更好的机制:梯度下降法在ML领域非常热门。

梯度下降法第一个阶段:为w1选择一个起始值(起点)。

起点并不重要;因此很多算法就直接将w1设为0或随机选择一个值。下图3选择了一个稍大于0的起点:

图 3 梯度下降法的起点

然后,梯度下降法算法会计算损失曲线在起点处的梯度:

梯度是偏导数的矢量;它能让你了解哪个方向距离目标“更近”或“更远”。

请注意,损失相对于单个权重的梯度(如图3所示)就等于导数

梯度是一个矢量:具有方向大小

梯度始终指向损失函数中增长最为迅猛的方向梯度下降法算法会沿着负梯度的方向走一步,尽快降低损失。

图 4 梯度下降法依赖负梯度

个人理解,图4 红色箭头并不是梯度方向(梯度为w导数的方向),而是一个趋势。

为了确定损失函数曲线上的下一个点,梯度下降法算法会将梯度大小的一部分与起点相加,如下图所示:

图 5 一个梯度步长后移动到损失曲线的下一个点

然后,梯度下降法重复此过程,直到接近最低点。

-------------------------------------------------------

降低损失 (Reducing Loss):学习速率

梯度矢量具有方向和大小。梯度下降法算法梯度乘以一个称为学习速率(有时也称为步长的标量,以确定下一个点的位置。例:梯度大小为 2.5,学习速率为 0.01,则梯度下降法算法会选择距离前一个点 0.025 的位置作为下一个点。

超参数是编程人员在机器学习算法中用于调整的旋钮。大多数机器学习编程人员会花费相当多的时间来调整学习速率。

  • 选择的学习速率过小,就会花费太长的学习时间
  • 学习速率过大,下一个点将永远在 U 形曲线的底部随意弹跳
  • 每个回归问题都存在一个Goldilocks学习速率,Goldilocks值与损失函数的平坦程度相关。若损失函数梯度较小,则采用更大的学习速率,以补偿较小梯度并获得更大的步长。

-------------------------------------------------------

降低损失 (Reducing Loss):随机梯度下降法

在梯度下降法中,批量指的是用于在单次迭代中计算梯度的样本总数。到目前为止,我们一直假定批量是指整个数据集。如果是超大批量,则单次迭代就可能要花费很长时间进行计算,冗余数据也会增多。

从数据集中随机选择样本,通过小而多的数据集估算(尽管过程非常杂乱)出较大的平均值。 随机梯度下降法 (SGD) 将这种想法运用到极致,它每次迭代只使用一个样本(批量大小为 1)。进行足够的迭代,SGD可以发挥作用,但过程会非常杂乱。“随机”这一术语表示构成各个批量的一个样本都是随机选择的。

小批量随机梯度下降法(小批量 SGD是介于全批量迭代SGD之间的折衷方案。小批量通常包含10-1000个随机选择的样本。小批量SGD可以减少SGD中的杂乱样本数量,但仍然比全批量更高效。

-------------------------------------------------------

降低损失 (Reducing Loss):Playground 练习

点击打开链接

以上是链接,可以更直观的理解学习速率

-------------------------------------------------------

以上整理转载在谷歌出品的机器学习速成课程点击打开链接 侵删!

Google---机器学习速成课程(二)-SGD相关推荐

  1. Google机器学习速成课程 - 视频笔记整理汇总 - 基础篇核心部分

    Google机器学习速成课程 - 视频笔记整理 - 基础篇核心部分 课程网址: https://developers.google.com/machine-learning/crash-course/ ...

  2. 机器学习速成课程 | 练习 | Google Development——编程练习:提高神经网络的性能

    提高神经网络性能 学习目标:通过将特征标准化并应用各种优化算法来提高神经网络的性能 注意:本练习中介绍的优化方法并非专门针对神经网络:这些方法可有效改进大多数类型的模型. 设置 首先,我们将加载数据. ...

  3. 机器学习速成课程 | 练习 | Google Development——编程练习:使用 TensorFlow 的起始步骤

    使用 TensorFlow 的基本步骤 学习目标: 学习基本的 TensorFlow 概念 在 TensorFlow 中使用 LinearRegressor 类并基于单个输入特征预测各城市街区的房屋价 ...

  4. 机器学习速成课程 | 练习 | Google Development——编程练习:创建和操控张量

    创建和操控张量 学习目标: 初始化 TensorFlow 变量并赋值 创建和操控张量 回忆线性代数中的加法和乘法知识(如果这些内容对您来说很陌生,请参阅矩阵加法和乘法简介) 熟悉基本的 TensorF ...

  5. 机器学习速成课程 | 练习 | Google Development——编程练习:TensorFlow 编程概念

    TensorFlow 编程概念 学习目标: 学习 TensorFlow 编程模型的基础知识,重点了解以下概念: 张量 指令 图 会话 构建一个简单的 TensorFlow 程序,使用该程序绘制一个默认 ...

  6. 机器学习速成课程 | 练习 | Google Development——编程练习:稀疏数据和嵌套简介

    稀疏数据和嵌入简介 学习目标: 将影评字符串数据转换为稀疏特征矢量 使用稀疏特征矢量实现情感分析线性模型 通过将数据投射到二维空间的嵌入来实现情感分析 DNN 模型 将嵌入可视化,以便查看模型学到的词 ...

  7. 机器学习速成课程 | 练习 | Google Development——编程练习:使用神经网络对手写数字进行分类

    使用神经网络对手写数字进行分类 学习目标: 训练线性模型和神经网络,以对传统 MNIST 数据集中的手写数字进行分类 比较线性分类模型和神经网络分类模型的效果 可视化神经网络隐藏层的权重 我们的目标是 ...

  8. 机器学习速成课程 | 练习 | Google Development——编程练习:神经网络简介

    神经网络简介 学习目标: 使用 TensorFlow DNNRegressor 类定义神经网络 (NN) 及其隐藏层 训练神经网络学习数据集中的非线性规律,并实现比线性回归模型更好的效果 在之前的练习 ...

  9. 机器学习速成课程 | 练习 | Google Development——编程练习:稀疏性和 L1 正则化

    稀疏性和 L1 正则化 学习目标: 计算模型大小 通过应用 L1 正则化来增加稀疏性,以减小模型大小 降低复杂性的一种方法是使用正则化函数,它会使权重正好为零.对于线性模型(例如线性回归),权重为零就 ...

最新文章

  1. 独家 | 一文读懂神经网络(附解读案例)
  2. 深入理解JavaScript系列(5):强大的原型和原型链
  3. 深度学习笔记第二门课 改善深层神经网络 第一周:深度学习的实践层面
  4. postgresql的系统列(system cloumns)
  5. Windows 非阻塞或异步 socket
  6. 如何在 ASP.NET Core 中使用 API 分析器
  7. 博导眼里本科生的科研能力:“他们还在玩泥巴”
  8. 敏捷项目管理流程-Scrum框架最全总结
  9. 基于Vue.js的表格分页组件
  10. golang编译之vendor机制
  11. FTP连接报530错误(FTP Error: 530 User cannot log in, home directory inaccessible)
  12. SQL面试题(1-10)oracle写的
  13. 赚外快—常见编程接单的网站集合(20余个)
  14. web全栈工程师的自我修养(实际操作方面)
  15. SetWindowsHookEx全局钩子
  16. Tableau可视化---Tableau简介
  17. java和3d建模_基于Java3D技术和Swing技术的3D建模开发
  18. 熊猫烧香版《菊花台》pk《菊花台》
  19. 高德 php,高德地图WEB版的使用
  20. 从懵逼到恍然大悟之Java中RMI的使用

热门文章

  1. “脚本错误”到底意味着什么?
  2. ORACLE查询删除重复记录三种方法
  3. 厌倦了各种app推送广告?用RSS来订阅自己想看的内容吧
  4. AdminEx响应式Bootstrap后台管理模板
  5. 数学笔记26——参数方程
  6. Java Socket实现简易多人聊天室传输聊天内容或文件
  7. ASP.NET Core下FreeSql的仓储事务
  8. 钛媒体独家对话叶军:低代码到酷应用到底改变了什么?
  9. android 调用第三方应用市场,给自身应用评分
  10. 计算机考试 flash 世界杯,我的世界杯:FLash关于滚动足球动画制作... -电脑资料...