点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:量子位

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

在MNIST上进行训练,可以说是计算机视觉里的“Hello World”任务了。

而如果使用PyTorch的标准代码训练CNN,一般需要3分钟左右。

但现在,在一台笔记本电脑上就能将时间缩短200多倍。

速度直达0.76秒

那么,到底是如何仅在一次epoch的训练中就达到99%的准确率的呢?

八步提速200倍

这是一台装有GeForce GTX 1660 Ti GPU的笔记本。

我们需要的还有Python3.x和Pytorch 1.8。

先下载数据集进行训练,每次运行训练14个epoch。

这时两次运行的平均准确率在测试集上为99.185%,平均运行时间为2min 52s ± 38.1ms。

接下来,就是一步一步来减少训练时间:

一、提前停止训练

在经历3到5个epoch,测试准确率达到99%时就提前停止训练。

这时的训练时间就减少了1/3左右,达到了57.4s±6.85s。

二、缩小网络规模,采用正则化的技巧来加快收敛速度

具体的,在第一个conv层之后添加一个2x2的最大采样层(max pool layer),将全连接层的参数减少4倍以上。

然后再将2个dropout层删掉一个。

这样,需要收敛的epoch数就降到了3个以下,训练时间也减少到30.3s±5.28s。

三、优化数据加载

使用data_loader.save_data(),将整个数据集以之前的处理方式保存到磁盘的一个pytorch数组中。

也就是不再一次一次地从磁盘上读取数据,而是将整个数据集一次性加载并保存到GPU内存中。

这时,我们只需要一次epoch,就能将平均训练时间下降到7.31s ± 1.36s。

四、增加Batch Size

将Batch Size从64增加到128,平均训练时间减少到4.66s ± 583ms。

五、提高学习率

使用Superconvergence来代替指数衰减。

在训练开始时学习率为0,到中期线性地最高值(4.0),再慢慢地降到0。

这使得我们的训练时间下降到3.14s±4.72ms。

六、再次增加Batch Size、缩小缩小网络规模

重复第二步,将Batch Size增加到256。

重复第四步,去掉剩余的dropout层,并通过减少卷积层的宽度来进行补偿。

最终将平均时间降到1.74s±18.3ms。

七、最后的微调

首先,将最大采样层移到线性整流函数(ReLU)激活之前。

然后,将卷积核大小从3增加到5.

最后进行超参数调整:

使学习率为0.01(默认为0.001),beta1为0.7(默认为0.9),bata2为0.9(默认为0.999)。

到这时,我们的训练已经减少到一个epoch,在762ms±24.9ms的时间内达到了99.04%的准确率。

“这只是一个Hello World案例”

对于这最后的结果,有人觉得司空见惯:

优化数据加载时间,缩小模型尺寸,使用ADAM而不是SGD等等,都是常识性的事情。

我想没有人会真的费心去加速运行MNIST,因为这是机器学习中的“Hello World”,重点只是像你展示最小的关键值,让你熟悉这个框架——事实上3分钟也并不长吧。

而也有网友觉得,大多数人的工作都不在像是MNIST这样的超级集群上。因此他表示:

我所希望的是工作更多地集中在真正最小化训练时间方面。

GitHub:
https://github.com/tuomaso/train_mnist_fast

参考链接:
[1]https://www.reddit.com/r/MachineLearning/comments/p1168k/p_training_cnn_to_99_on_mnist_in_less_than_1/

-------------------

END

--------------------

我是王博Kings,985AI博士,华为云专家、CSDN博客专家(人工智能领域优质作者)。单个AI开源项目现在已经获得了2100+标星。现在在做AI相关内容,欢迎一起交流学习、生活各方面的问题,一起加油进步!

我们微信交流群涵盖以下方向(但并不局限于以下内容):人工智能,计算机视觉,自然语言处理,目标检测,语义分割,自动驾驶,GAN,强化学习,SLAM,人脸检测,最新算法,最新论文,OpenCV,TensorFlow,PyTorch,开源框架,学习方法...

这是我的私人微信,位置有限,一起进步!

王博的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点分享

点收藏

点点赞

点在看

笔记本上的CNN搞定了MNIST相关推荐

  1. word手写字体以假乱真_学会Word上下标,搞定公式输入

    点击"蓝字"关注我们 在Word文档输入时,你是否会遇到需要插入上下标的文本呢,像这样,这样,这样的.我们在Word文档中的数学公式.化学符号等经常需要为文本设置上标.下标符号.W ...

  2. 比用Pytorch框架快200倍!0.76秒后,笔记本上的CNN就搞定了MNIST | 开源

    博雯 发自 凹非寺 量子位 报道 | 公众号 QbitAI 在MNIST上进行训练,可以说是计算机视觉里的"Hello World"任务了. 而如果使用PyTorch的标准代码训练 ...

  3. mac 10.10 apache php,在Mac上10分钟搞定Apache服务器配置

    目的:创建一个专属的测试环境 一.Apache服务器使用最广的 Web 服务器 Mac自带Apache,只需要修改几个配置就可使用 有些特殊的服务器功能,Apache都能很好的支持 二.硬件要求 1. ...

  4. 如何搞定笔记本检测不到wifi,图标,Netkeeper链接不上

    如何搞定笔记本检测不到wifi,图标,Netkeeper链接不上 例如 像上面这种情况的两种解决办法: 问题解决了,把下面的代码复制粘贴到文本文档, 然后改后缀.reg,再运行最后重启就好了. Win ...

  5. 不仅搞定“梯度消失”,还让CNN更具泛化性:港科大开源深度神经网络训练新方法

    原文链接:不仅搞定"梯度消失",还让CNN更具泛化性:港科大开源深度神经网络训练新方法 paper: https://arxiv.org/abs/2003.10739 code: ...

  6. 酸爽!我用这套无人值守安装系统瞬间搞定上百台服务器

    来自:DBAplus社群 作者介绍: 季城希,甜橙金融运维工程师,多年IDC运维经验.擅长IDC中服务器批量高效快速集成交付,精通各品牌型号服务器硬件产品及维护. 一.前言 为啥要用无人值守安装系统? ...

  7. 一行代码快速搞定Flowable断点下载(上)

    一行代码快速搞定Flowable断点下载(上) 之前我们大致讲了讲,到底怎么完全将disposable相关代码完全隐藏. 然后到了这里,可能有些杠精就会说了,你那个方式,我们不是完全不能拿到Flowa ...

  8. wps合并所有sheet页_Python一键合并上千个Excel表,一天的工作量一小时搞定!下班...

    一.老板的需求总是莫名奇妙 老板需求:一天老板说,嘿!放牛娃,将这些excel表合并到一个总表里,下班前交给我 老板话刚讲完,我心里就想,这还不简单么,excel不就是有合并表的功能么!!简单的要死! ...

  9. 破解前端面试系列(3):如何搞定纸上代码环节?

    很多重视技术的互联网公司在工程师招聘的技术面环节都会要求候选人在纸上写代码(后文用"纸上代码"代称),面试官想通过这种方式考察哪些点?候选人该注意哪些点?本文基于美团早几年常用的一 ...

最新文章

  1. 内存分配算法 之 首次适应-最佳适应
  2. mysql一个用户SQL慢查询分析,原因及优化
  3. [转]unity3D游戏开发之GUI
  4. safari only css hack,css hack将Safari和Chrome同时作为目标单独使用
  5. MySQL(一)存储引擎
  6. opencv imshow plt imshow
  7. CoinFLEX的基本情况以及与Bakkt
  8. MySQLl数据量不一样,导致走不同的索引
  9. html返回顶部代码(简单)
  10. matlab实现长除法,【网安智库】基于长除法的BCH(15,7)译码算法
  11. anaconda下jupyter无法自动打开网页
  12. 【Linux实验】LINUX系统的文件操作命令
  13. win10无法复制文件到system32,提示需要权限操作
  14. 163邮箱提示: 535 Error: authentication failed
  15. 计算机教程打字方法,技巧:打字指法和关键位置教程_IT /计算机_信息
  16. Emulex光纤卡lpfc配置文件的修改
  17. Android 腾讯Bugly的应用升级热更新
  18. 数字IC设计工程师要具备哪些技能
  19. 服装企业在饱和的情况下,如何避免交期延误?
  20. linux系统英语词汇大全,linux系统中常命令和英语词汇.docx

热门文章

  1. 幼儿园故事导入语案例_幼儿园小班安全教案
  2. PHP抽象函数的依赖注入,laravel 抽象类实现接口,具体类继承抽象类,使用依赖注入,如何知道接口选择的是哪个具体实现类啊?...
  3. 怎么将tflite部署在安卓上_tensorflow从训练自定义CNN网络模型到Android端部署tflite...
  4. java中简单的if语句_java中if语句的写法
  5. math库是python语言的数学模块_Python 数学模块(Math)
  6. 原版98启动盘镜像.img_装机技巧系列(二):系统安装之Windows 10启动盘制作
  7. 市面上有哪几种门_市面上常见的木门种类有哪些呢?
  8. 计算机病毒怎么做图片解说,【虎子_游戏解说】计算机病毒防范的实施方法
  9. java 并行_Java 中不同的并行实现的性能比较
  10. iis如何处理并发请求