本系列文章已转至

机器学习的优化器​zhuanlan.zhihu.com

优化算法在机器学习中扮演着至关重要的角色,了解常用的优化算法对于机器学习爱好者和从业者有着重要的意义。

这系列文章先讲述优化算法和机器学习的关系,然后罗列优化算法分类,尤其是机器学习中常用的几类.接下来明确下数学符号,开始按照历史和逻辑顺序,依次介绍各种机器学习中常用的优化算法.

这篇先讲其中基于一阶导数的标准梯度下降法和Momentum,其中穿插学习率退火方法和基于二阶导数的优化算法来辅助说明各算法的意义和背后的想法.

优化算法和机器学习的关系

机器学习的过程往往是

  1. 建模实际问题,定义损失函数
  2. 代入训练数据,利用优化算法来优化损失函数并更新参数,直到终止条件(比如迭代数或者更新的收益或者损失函数的大小)

可见优化算法和损失函数在机器学习中占有重要的地位.

损失函数比较的一个例子请参看

郝曌骏:MSE vs 交叉熵​zhuanlan.zhihu.com

优化算法分类

优化算法有很多种,常见的包括

  • 基于导数的,比如基于一阶导数的梯度下降法(GD, Grandient Descent)和基于二阶导数的牛顿法等,要求损失函数(运筹学中更多叫做目标函数)可导
  • 群体方法(population method),比如遗传算法(Genetic Algo)和蚁群算法(Ant Colony Optimization),不依赖于问题(problem-independent),不需要对目标函数结构有太多的了解
  • 单体方法(single-state method),比如模拟退火算法(Simulated Annealing),同样,不依赖于问题(problem-independent),不需要对目标函数结构有太多的了解

等.

机器学习中常用的是基于导数,尤其是基于一阶导数的优化算法,包括

  • 标准梯度下降法(GD, standard Gradient Descent)
  • 带有momentum的GD
  • RMSProp (Root Mean Square Propagation)
  • AdaM (Adaptive Moment estimates)
  • AdaGrad (Adaptive Gradient Algo)
  • AdaDelta

符号规定

在具体解释前先规定下符号

  • 损失函数为

    (很多地方也会写作
  • 梯度为
  • 表示第t次迭代的梯度,
  • 第t次迭代时,
  • 学习率为
  • 表示
    的高阶无穷小,也就是当
    无限接近0时,
    ,比如
    就是
    的高阶无穷小

标准梯度下降法(GD, standard Gradient Descent)

每次迭代的更新为

其中

表示第t次迭代的梯度,
为人工预先设定的学习率(learning rate).
图1

标准GD的想法来源于一阶泰勒展开

其中

叫做皮亚诺(Peano)余项,当
很小时,这个余项可以忽略不计.

和一阶导数也就是梯度相反方向时,
(在机器学习中指的的损失函数)下降最快.

一个经典的解释是:想象我们从山上下来,每步都沿着坡度最陡的方向.这时,水平面是我们的定义域,海拔是值域.

GD缺点

但GD有两个主要的缺点:

  1. 优化过程中,保持一定的学习率,并且这个学习率是人工设定.当学习率过大时,可能在靠近最优点附近震荡(想象一步子太大跨过去了);学习率过小时,优化的速度太慢
  2. 学习率对于每个维度都一样,而我们经常会遇到不同维度的曲率(二阶导数)差别比较大的情况,这时GD容易出现zig-zag路径.(参考图2,优化路径呈现zig-zag形状,该图绘制代码放在附录1中)
图2

考虑

所以人们考虑

  1. 动态选择更好的学习率,比如前期大些来加速优化,靠近低点了小些避免在低点附近来回震荡,甚至
  2. 为每个维度选择合适的学习率
    .

学习率退火 (Learning Rate Annealing)

出于考虑1,人们参考了单体优化方法中的模拟退火(Simulated Annealing),学习率随着迭代次数的增加或者损失函数在验证集上的表现变好而衰减(decay).

学习率退化可以直接加在GD上.

改进方向

AdaGrad等算法(

郝曌骏:机器学习中的优化算法(3)-AdaGrad, Adadelta​zhuanlan.zhihu.com

介绍)就借鉴了退火的学习率衰减的思想.不过这个不是这篇的重点.

牛顿法 (Newton's Method)

出于考虑2(为每个维度选择合适的学习率

),基于二阶导数的牛顿法被提出.它来源于泰勒二阶展开.

对于多元函数

,

其中

黑塞矩阵_百度百科​baike.baidu.com

我们有

.

这样每次迭代都会考虑损失函数的曲率(二阶导数)来选择步长.对比图2中的标准GD,牛顿法可以一步就到达最优点.

牛顿法缺点

但是牛顿法的计算复杂度很高,因为Hessian矩阵的维度是参数个数的平方,而参数的个数往往很多.

改进方向

不同的方法随即被提出,比如

  • Becker和LeCun提出的

用对角线元素来代替Hessian全矩阵​nyuscholars.nyu.edu

  • 依靠历史的梯度信息来模拟二阶方法,包括Momentum,RMSProp(用二阶距来模拟二阶导数),AdaM(用一阶矩和二阶矩的比例来模拟二阶导数)等.

我们先介绍Momentum

Momentum

Sutskeverd等人在2013年​proceedings.mlr.press

,借鉴了物理中动量(momentum)的概念,让

保留一部分之前的方向和速度.

Classical Momentum每次迭代的更新为

这样预期可以达到两个效果:

  1. 某个维度在近几次迭代中正负号总是改变时,说明二阶导数可能相对其他维度比较大或者说每次步子迈得太大了,需要改变幅度小些或者迈得小点来避免zig-zag路径
  2. 某个维度在近几次迭代中符号几乎不变,说明二阶导数可能相对其他维度比较小或者说大方向是正确的,这个维度改变的幅度可以扩大些,来加速改进.
图3

如图3所示,加入了Classical Momentum,前期的训练加快了,靠近低点时也减小了震荡.

关于NAG(Nesterov's Accelerated Gradient)可参看附录1中的代码.

附录1

import math
import numpy as np
import matplotlib.pyplot as pltRATIO = 3   # 椭圆的长宽比
LIMIT = 1.2 # 图像的坐标轴范围class PlotComparaison(object):"""多种优化器来优化函数 x1^2 + x2^2 * RATIO^2.每次参数改变为(d1, d2).梯度为(dx1, dx2)t+1次迭代,标准GD,d1_{t+1} = - eta * dx1d2_{t+1} = - eta * dx2带Momentum,d1_{t+1} = eta * (mu * d1_t - dx1_{t+1})d2_{t+1} = eta * (mu * d2_t - dx2_{t+1})    带Nesterov Momentum,d1_{t+1} = eta * (mu * d1_t - dx1^{nag}_{t+1})d2_{t+1} = eta * (mu * d2_t - dx2^{nag}_{t+1})其中(dx1^{nag}, dx2^{nag})为(x1 + eta * mu * d1_t, x2 + eta * mu * d2_t)处的梯度"""def __init__(self, eta=0.1, mu=0.9, angles=None, contour_values=None,stop_condition=1e-4):# 全部算法的学习率self.eta = eta# 启发式学习的终止条件self.stop_condition = stop_condition# Nesterov Momentum超参数self.mu = mu# 用正态分布随机生成初始点self.x1_init, self.x2_init = np.random.uniform(LIMIT / 2, LIMIT), np.random.uniform(LIMIT / 2, LIMIT) / RATIOself.x1, self.x2 = self.x1_init, self.x2_init# 等高线相关if angles == None:angles = np.arange(0, 2 * math.pi, 0.01)self.angles = anglesif contour_values == None:contour_values = [0.25 * i for i in range(1, 5)]self.contour_values = contour_valuessetattr(self, "contour_colors", None)def draw_common(self, title):"""画等高线,最优点和设置图片各种属性"""# 坐标轴尺度一致plt.gca().set_aspect(1)# 根据等高线的值生成坐标和颜色# 海拔越高颜色越深num_contour = len(self.contour_values)if not self.contour_colors:self.contour_colors = [(i / num_contour, i / num_contour, i / num_contour) for i in range(num_contour)]self.contour_colors.reverse()self.contours = [[list(map(lambda x: math.sin(x) * math.sqrt(val), self.angles)),list(map(lambda x: math.cos(x) * math.sqrt(val) / RATIO, self.angles))]for val in self.contour_values]# 画等高线for i in range(num_contour):plt.plot(self.contours[i][0],self.contours[i][1],linewidth=1,linestyle='-',color=self.contour_colors[i],label="y={}".format(round(self.contour_values[i], 2)))# 画最优点plt.text(0, 0, 'x*')# 图片标题plt.title(title)# 设置坐标轴名字和范围plt.xlabel("x1")plt.ylabel("x2")plt.xlim((-LIMIT, LIMIT))plt.ylim((-LIMIT, LIMIT))# 显示图例plt.legend(loc=1)def forward_gd(self):"""SGD一次迭代"""self.d1 = -self.eta * self.dx1self.d2 = -self.eta * self.dx2self.ite += 1def draw_gd(self, num_ite=5):"""画基础SGD的迭代优化.包括每次迭代的点,以及表示每次迭代改变的箭头"""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_gd()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:breakdef forward_momentum(self):"""带Momentum的SGD一次迭代"""self.d1 = self.eta * (self.mu * self.d1_pre - self.dx1)self.d2 = self.eta * (self.mu * self.d2_pre - self.dx2)self.ite += 1self.d1_pre, self.d2_pre = self.d1, self.d2def draw_momentum(self, num_ite=5):"""画带Momentum的迭代优化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "d1_pre", 0)setattr(self, "d2_pre", 0)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_momentum()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:breakdef forward_nag(self):"""Nesterov Accelerated的SGD一次迭代"""self.d1 = self.eta * (self.mu * self.d1_pre - self.dx1_nag)self.d2 = self.eta * (self.mu * self.d2_pre - self.dx2_nag)self.ite += 1self.d1_pre, self.d2_pre = self.d1, self.d2def draw_nag(self, num_ite=5):"""画Nesterov Accelerated的迭代优化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "d1_pre", 0)setattr(self, "d2_pre", 0)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_nag()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:break@propertydef dx1(self, x1=None):return self.x1 * 2@propertydef dx2(self):return self.x2 * 2 * (RATIO ** 2)@propertydef dx1_nag(self, x1=None):return (self.x1 + self.eta * self.mu * self.d1_pre) * 2@propertydef dx2_nag(self):return (self.x2 + self.eta * self.mu * self.d2_pre) * 2 * (RATIO ** 2)@propertydef loss(self):return self.x1 ** 2 + (RATIO * self.x2) ** 2def show(self):# 设置图片大小plt.figure(figsize=(20, 20))# 展示plt.show()def main_2():"""画图2"""xixi = PlotComparaison()xixi.draw_gd()xixi.draw_common("Optimize x1^2+x2^2*{} Using SGD".format(RATIO ** 2))xixi.show()def main_3(num_ite=15):"""画图3"""xixi = PlotComparaison()xixi.draw_gd(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using SGD".format(RATIO ** 2))xixi.show()xixi.draw_momentum(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using SGD With Momentum".format(RATIO ** 2))xixi.show()

附录2

带Momentum机制的GD在pytorch中的实现为

import torch
torch.optim.SGD(lr, momentum) # lr为学习率,momentum为可选参数

louvian算法 缺点 优化_机器学习中的优化算法(1)-优化算法重要性,SGD,Momentum(附Python示例)...相关推荐

  1. tensorflow超参数优化_机器学习模型的超参数优化

    引言 模型优化是机器学习算法实现中最困难的挑战之一.机器学习和深度学习理论的所有分支都致力于模型的优化. 机器学习中的超参数优化旨在寻找使得机器学习算法在验证数据集上表现性能最佳的超参数.超参数与一般 ...

  2. python决策树实例_机器学习中的决策树及python实例

    一棵树在现实生活中有许多枝叶,事实上树的概念在机器学习也有广泛应用,涵盖了分类和回归.在决策分析中,决策树可用于直观地决策和作出决策.决策树,顾名思义,一个树状的决策模型.尽管数据挖掘与机器学习中常常 ...

  3. 傅里叶描述子欧氏距离_机器学习中的各种距离

    让我们一起 改变智造 hi,大家国庆都玩的怎么样啊? 是不是很诧异我现在才问候国庆的事情? 因为我今天才刚刚上班(呵呵,怎么可能) 加班加到"秃"起~ 即使这样我也要继续伴随大家去 ...

  4. 机器学习 贝叶斯方法_机器学习中的常客与贝叶斯方法

    机器学习 贝叶斯方法 There has always been a debate between Bayesian and frequentist statistical inference. Fr ...

  5. python分类分析模型_机器学习中最常见的四种分类模型

    作者:Jason Brownlee 翻译:候博学 前言 机器学习是一个从训练集中学习出算法的研究领域. 分类是一项需要使用机器学习算法的任务,该算法学习如何为数据集分配类别标签. 举一个简单易懂的例子 ...

  6. 机器学习朴素贝叶斯算法_机器学习中的朴素贝叶斯算法

    机器学习朴素贝叶斯算法 朴素贝叶斯算法 (Naive Bayes Algorithm) Naive Bayes is basically used for text learning. Using t ...

  7. 算法的优缺点_机器学习算法优缺点 amp; 如何选择

    (点击上方公众号,可快速关注) 转自: 算法与数学之美 主要回顾下几个常用算法的适应场景及其优缺点! 机器学习算法太多了,分类.回归.聚类.推荐.图像识别领域等等,要想找到一个合适算法真的不容易,所以 ...

  8. 机器学习线性回归学习心得_机器学习中的线性回归

    机器学习线性回归学习心得 机器学习中的线性回归 (Linear Regression in Machine Learning) There are two types of supervised ma ...

  9. 机器学习集群_机器学习中的多合一集群技术在无监督学习中应该了解

    机器学习集群 Clustering algorithms are a powerful technique for machine learning on unsupervised data. The ...

最新文章

  1. 用结构体实现一个电话本
  2. Asp.net SignalR 应用并实现群聊功能 开源代码
  3. 网络扫描工具Nmap常用命令
  4. Beaglebone Back学习七(URAT串口测试)
  5. Android中的复制粘贴
  6. 摄像头poe供电原理_什么是POE供电,这种POE套装有什么优势呢?
  7. Java数据库连接池c3p0和druid
  8. xcode7: Undefined symbols for architecture i386: _iconv_open, referenced from:
  9. ionic3 百度地图插件定位 问题
  10. SpingMVC之拦截器
  11. 计算机远程桌面连接命令行,远程桌面连接命令,小编教你win7远程桌面连接命令的使用教程...
  12. html制作个人简历
  13. java 打包exe_Java项目打包成exe的详细教程
  14. 网易卡搭python_网易卡搭编程
  15. 域名不要www如何解析
  16. iwifi 技术规范
  17. RapidASR项目(语音转文本):更快、更容易部署、开箱即用
  18. 实现数智内控,数据分析创造价值——辽宁烟草智能风险体检系统
  19. mysql数据库巡检方案_Mysql数据库巡检
  20. python:凯撒密码

热门文章

  1. Knockout js 绑定 radio 时,当绑定整形的时候,绑定不生效
  2. Win7访问局域网内共享文件夹
  3. 匿名内部类和局部内部类访问的外部类的局部变量必须是final的
  4. 计算机视觉开源库OpenCV之照明和色彩空间
  5. 英特尔AIDC秀肌肉:展示AI软硬件+生态全景图
  6. tensorflow函数总结
  7. 多生产者_多线程必考的「生产者 - 消费者」模型,看齐姐这篇文章就够了
  8. 大专学完出来学计算机,浙江2021年计算机学校读出来是什么文凭
  9. 在python中、列表中的元素可以是_在Python中存储一个列表的元素,在另一个列表中 – 通过引用?...
  10. shiroConfig配置中要注意的事项