大家好,欢迎来到专栏《调参实战》,虽然当前自动化调参研究越来越火,但那其实只是换了一些参数来调,对参数的理解和调试在机器学习相关任务中是最基本的素质,在这个专栏中我们会带领大家一步一步理解和学习调参。

本次主要讲述图像分类项目中的优化方法的调参实践,这次与上一期学习率调参的内容是一脉相承的

作者&编辑 | 言有三

本文资源与结果展示

本文篇幅:3100字

背景要求:会使用Python和任一深度学习开源框架

附带资料:Caffe代码和数据集一份

同步平台:有三AI知识星球(一周内)

1 项目背景与准备工作

上一期我们基于图像分类任务,对学习率的几种常见的方式进行了调参学习,这一期我们来实验不同的优化方法的原理差异以及它们的性能比较。

本次项目开发需要以下环境:

(1) Linux系统,推荐ubuntu16.04或者ubuntu18.04。使用windows系统也可以完成,但是使用Linux效率更高。

(2) 最好拥有一块显存不低于6G的GPU显卡,如果没有使用CPU进行训练速度较慢。

(3) 安装好的Caffe开源框架。

本次的数据集和基准模型与上一期内容相同,大家如果不熟悉就去查看上一期的内容,链接如下:

【调参实战】如何开始你的第一个深度学习调参任务?不妨从图像分类中的学习率入手。

2 优化方法原理与实践

下面我们对各类优化算法的基本原理进行讲解,并进行实践。由于本文目标不是为了从零开始讲清楚优化算法,所以有些细节会略过。

2.1 标准梯度下降算法

梯度下降算法,即通过梯度的反方向来进行优化,批量梯度下降(Batch gradient descent)用公式表述如下:

写成伪代码如下:

for i in range(nb_epochs):
    params_grad = evaluate_gradient(loss_function, data, params)
    params = params - learning_rate * params_grad

上面的梯度下降算法公式用到了数据集所有的数据,这在解决实际问题时通常是不可能,比如ImageNet1000有100G以上的图像,内存装不下,速度也很慢。

我们需要在线能够实时计算,比如一次取一个样本,于是就有了随机梯度下降(Stochastic gradient descent),简称SGD,公式如下:

写成伪代码如下:

for i in range(nb_epochs):
     np.random.shuffle(data)
     for example in data:
          params_grad = evaluate_gradient(loss_function example , params)
          params = params - learning_rate * params_grad

SGD方法缺点很明显,梯度震荡,所以就有了后来大家常用的小批量梯度下降算法(Mini-batch gradient descent),公式如下:

伪代码如下:

for i in range(nb_epochs):
    np.random.shuffle(data)
    for batch in get_batches(data, batch_size=50):
        params_grad = evaluate_gradient(loss_function, batch, params)
        params = params - learning_rate * params_grad

平时当我们说SGD算法,实际上指的就是mini-batch gradient descent算法。

SGD算法的主要问题是学习率大小和策略需要手动选择,优化迭代比较慢,因此有很多方法对其进行改进。

2.2 动量法(momentum)

梯度下降算法是按照梯度的反方向进行参数更新,但是刚开始的时候梯度不稳定,方向改变是很正常的,梯度有时候一下正一下反,导致做了很多无用的迭代。而动量法做的很简单,相信之前的梯度。如果梯度方向不变,就越发更新的快,反之减弱当前梯度。

公式表达如下:

与SGD的对比如下:

动量法至今仍然是我觉得最为有用的学习率改进算法。那它和SGD的对比究竟如何呢?下面我们来实验不同的参数,需要在solver.prototxt中修改配置,完整的solver如下,需要修改的地方为标粗橙色部分,后面的实验同理

net: "allconv6.prototxt"

test_interval:100

test_iter:15

base_lr: 0.01

lr_policy: "step"

stepsize: 10000

gamma: 0.1

momentum: 0.9 ##动量项配置

weight_decay: 0.005

display: 100

max_iter: 100000

snapshot: 10000

snapshot_prefix: "models/allconv6_"

solver_mode: GPU

下图是实验结果对比:

我们可以发现,m=0.9时确实取得了最好的效果,m=0时效果最差,对于大部分的任务,我们在配置这个参数时也不需要修改,就采用m=0.9。

【调参实战】那些优化方法的性能究竟如何,各自的参数应该如何选择?相关推荐

  1. 【调参实战】BN和Dropout对小模型有什么影响?全局池化相比全连接有什么劣势?...

    大家好,欢迎来到专栏<调参实战>,虽然当前自动化调参研究越来越火,但那其实只是换了一些参数来调,对参数的理解和调试在机器学习相关任务中是最基本的素质,在这个专栏中我们会带领大家一步一步理解 ...

  2. 随机森林原理_机器学习(29):随机森林调参实战(信用卡欺诈预测)

    点击"机器学习研习社","置顶"公众号 重磅干货,第一时间送达 回复[大礼包]送你机器学习资料与笔记 回顾 推荐收藏>机器学习文章集合:1-20 机器学习 ...

  3. Lesson 14.3 Batch Normalization综合调参实战

    Lesson 14.3 Batch Normalization综合调参实战   根据Lesson 14.2最后一部分实验结果不难看出,带BN层的模型并不一定比不带BN层模型效果好,要充分发挥BN层的效 ...

  4. AIRec个性化推荐召回模型调参实战

    简介:本文是<AIRec个性化推荐召回模型调参实战(电商.内容社区为例)>的视频分享精华总结,主要由阿里巴巴的产品专家栀露向大家分享AIRec个性化推荐召回模型以及针对这些召回模型在电商和 ...

  5. 【调参实战】如何开始你的第一个深度学习调参任务?不妨从图像分类中的学习率入手。...

    大家好,欢迎来到专栏<调参实战>,虽然当前自动化调参研究越来越火,但那其实只是换了一些参数来调,对参数的理解和调试在机器学习相关任务中是最基本的素质,在这个专栏中我们会带领大家一步一步理解 ...

  6. ML之XGBoost:利用XGBoost算法对波士顿数据集回归预测(模型调参【2种方法,ShuffleSplit+GridSearchCV、TimeSeriesSplitGSCV】、模型评估)

    ML之XGBoost:利用XGBoost算法对波士顿数据集回归预测(模型调参[2种方法,ShuffleSplit+GridSearchCV.TimeSeriesSplitGSCV].模型评估) 目录 ...

  7. 推荐算法炼丹笔记:科学调参在模型优化中的意义

    作者:九羽 ,公众号:炼丹笔记 基于Embedding的推荐算法模型一直是近几年研究的热门,在各大国际会议期刊都能看到来自工业界研究与实践的成果.MF(Matrix Factorization)作为传 ...

  8. 机器学习调参神器——网格搜索方法

    网格搜索方法主要用于模型调参,也就是帮助我们找到一组最合适的模型设置参数,使得模型的预测达到更好的效果,这组参数于模型训练过程中学习到的参数不同,它是需要在训练前预设好的,我们称其为超参数. 超参数的 ...

  9. 基于scikit-learn的随机森林调参实战

    写在前面 在之前一篇机器学习算法总结之Bagging与随机森林中对随机森林的原理进行了介绍.还是老套路,学习完理论知识需要实践来加深印象.在scikit-learn中,RF的分类类是RandomFor ...

最新文章

  1. CF#212 Two Semiknights Meet
  2. mysql数据库杀掉堵塞_mysql数据库杀掉堵塞进程
  3. 力扣(LeetCode) 35. 搜索插入位置
  4. HTML 学习笔记 day one
  5. X210烧写linux系统
  6. udl 连mysql_自己如何正确获取MYSQL的ADO连接字符串
  7. 如果你也在学python,准备要学习python,希望这篇文章对你有用。
  8. web前端是什么?如何能成为一名合格的前端开发工程师?
  9. java 处理视频帧_如何将视频处理成每帧的图片?.最好是java实现..
  10. vue-touch不能上下滑动的问题【解决】
  11. python统计字典里面value出现的次数_python 统计list中各个元素出现的次数的几种方法...
  12. python模块安装位置_查看python模块的安装路径
  13. 计算机专业毕业设计选题与方向走势
  14. 迅为S5P6818核心板ARM Cortex-A53架构三星八核处理器
  15. Java 多线程设计模式
  16. 【行业报告】:低碳智能ALL “IN” | 印刷包装数智化转型之路
  17. echarts3d城市配置项
  18. oracle官怎么卸载网,Oracle终极彻底卸载
  19. torch.randn和torch.rand有什么区别
  20. 在 IIS 上构建静态网站

热门文章

  1. 面试中又被问到Redis如何实现抢购,赶快代码实现一波吧!
  2. 【拥抱大厂系列】百度面试官问过的 “JVM内存分配与回收策略原理”,我用这篇文章搞定了
  3. Oracle数据库之PL/SQL
  4. UI组件之TextView及其子类(四)AnalogClock,DigitalClock
  5. nivicat复制mysql数据库[Err] [Dtf] 1273 - Unknown collation: 'utf8mb4_0900_ai_ci'错误
  6. slim android7 nexus7,【畅玩7.0】加一直升pure nexus 7.0系统简单教程(1106更新)
  7. 苹果如何不显示云服务器照片,苹果云端照片怎么恢复到相册-互盾苹果恢复精灵...
  8. java堆中的组成部分,初识Java虚拟机的基本结构 | If Coding
  9. Serverless Kubernetes 再升级 | 全新的网关能力增强
  10. 托管节点池助力用户构建稳定自愈的 Kubernetes 集群