博雯 发自 凹非寺
量子位 报道 | 公众号 QbitAI

在MNIST上进行训练,可以说是计算机视觉里的“Hello World”任务了。

而如果使用PyTorch的标准代码训练CNN,一般需要3分钟左右。

但现在,在一台笔记本电脑上就能将时间缩短200多倍。

速度直达0.76秒

那么,到底是如何仅在一次epoch的训练中就达到99%的准确率的呢?

八步提速200倍

这是一台装有GeForce GTX 1660 Ti GPU的笔记本。

我们需要的还有Python3.x和Pytorch 1.8。

先下载数据集进行训练,每次运行训练14个epoch。

这时两次运行的平均准确率在测试集上为99.185%,平均运行时间为2min 52s ± 38.1ms。

接下来,就是一步一步来减少训练时间:

一、提前停止训练

在经历3到5个epoch,测试准确率达到99%时就提前停止训练。

这时的训练时间就减少了1/3左右,达到了57.4s±6.85s。

二、缩小网络规模,采用正则化的技巧来加快收敛速度

具体的,在第一个conv层之后添加一个2x2的最大采样层(max pool layer),将全连接层的参数减少4倍以上。

然后再将2个dropout层删掉一个。

这样,需要收敛的epoch数就降到了3个以下,训练时间也减少到30.3s±5.28s。

三、优化数据加载

使用data_loader.save_data(),将整个数据集以之前的处理方式保存到磁盘的一个pytorch数组中。

也就是不再一次一次地从磁盘上读取数据,而是将整个数据集一次性加载并保存到GPU内存中。

这时,我们只需要一次epoch,就能将平均训练时间下降到7.31s ± 1.36s。

四、增加Batch Size

将Batch Size从64增加到128,平均训练时间减少到4.66s ± 583ms。

五、提高学习率

使用Superconvergence来代替指数衰减。

在训练开始时学习率为0,到中期线性地最高值(4.0),再慢慢地降到0。

这使得我们的训练时间下降到3.14s±4.72ms。

六、再次增加Batch Size、缩小缩小网络规模

重复第二步,将Batch Size增加到256。

重复第四步,去掉剩余的dropout层,并通过减少卷积层的宽度来进行补偿。

最终将平均时间降到1.74s±18.3ms。

七、最后的微调

首先,将最大采样层移到线性整流函数(ReLU)激活之前。

然后,将卷积核大小从3增加到5.

最后进行超参数调整:

使学习率为0.01(默认为0.001),beta1为0.7(默认为0.9),bata2为0.9(默认为0.999)。

到这时,我们的训练已经减少到一个epoch,在762ms±24.9ms的时间内达到了99.04%的准确率。

“这只是一个Hello World案例”

对于这最后的结果,有人觉得司空见惯:

优化数据加载时间,缩小模型尺寸,使用ADAM而不是SGD等等,都是常识性的事情。

我想没有人会真的费心去加速运行MNIST,因为这是机器学习中的“Hello World”,重点只是像你展示最小的关键值,让你熟悉这个框架——事实上3分钟也并不长吧。

而也有网友觉得,大多数人的工作都不在像是MNIST这样的超级集群上。因此他表示:

我所希望的是工作更多地集中在真正最小化训练时间方面。

GitHub:
https://github.com/tuomaso/train_mnist_fast

参考链接:
[1]https://www.reddit.com/r/MachineLearning/comments/p1168k/p_training_cnn_to_99_on_mnist_in_less_than_1/

比用Pytorch框架快200倍!0.76秒后,笔记本上的CNN就搞定了MNIST | 开源相关推荐

  1. Docker安装Mysql8.0,并配置忽略大小写,一句命令搞定

    Docker安装Mysql8.0,并配置忽略大小写,一句命令搞定 docker run --name mysql8.db -p 3307:3306 -e MYSQL_ROOT_PASSWORD=Csd ...

  2. php 数据 缓存,php终极数据缓存,比redis、GlobalData等快200倍以上,极致性能

    一.效果:每秒读取2000万条.写入2200万条.cpu开销很小二.原理:1.将数据以数组方式存储在内存中,php进程需要数据时直接通过内存地址访问数据,没有任何IO开销以及CPU开销. 三.具体实现 ...

  3. 超级计算机summit存储容量,天河3号超级计算机 我国正在开发超级计算机 将比”天河一号”快200倍...

    核心提示:国家超级计算天津中心应用研发工程师张婷表示,新开发的超级计算机系统将比我国2010年开通运行的第一台千万亿次超级计算机"天河一号"运算速度要快200倍,存储容量高达100 ...

  4. 它是谁?一个比 c3p0 快200倍的数据库连接池!

    点击上方"方志朋",选择"设为星标" 回复"666"获取新整理的面试文章 什么是数据库连接池 连接池是一种常用的技术,为什么需要连接池呢?这 ...

  5. c#打开数据库连接池的工作机制_它是谁?一个比 c3p0 快 200 倍的数据库连接池!...

    什么是数据库连接池 连接池是一种常用的技术,为什么需要连接池呢?这个需要从 TCP 说起.假如我们的服务器跟数据库没有部署在同一台机器,那么,服务器每次查询数据库都要先建立连接,一般都是 TCP 链接 ...

  6. 快1倍,我在 M1 Max 上开发 iOS 应用有了这些发现

    整理 | 章雨铭 责编 | 屠敏 出品 | CSDN(ID:CSDNnews) 科技的进步.资源的共享使得进入iOS开发变得前所未有的容易.很多开发工具都是免费的,网上的学习资料应有尽有.然而,随着代 ...

  7. 没有最快,只有更快!富士通74.7秒在ImageNet上训练完ResNet-50

    https://www.toutiao.com/a6675198538592289288/ 大数据文摘出品 编译:林安安.蒋宝尚 74.7秒! 根据日本富士通实验室最新研究.他们应用了一种优化方法,在 ...

  8. 【快应用】菜单遮挡内容?教你一招快速搞定!

    快应用规范从1070版本开始强制设置显示菜单,但是在有些快应用页面,菜单会遮挡住应用自身的内容,例如下图菜单便遮挡住了登录功能,虽然可以将菜单配置为可移动,但是用户却不知道可以移动,从而影响用户的使用 ...

  9. RedisJson 是什么?比ES快 500 倍?

    -     概述    - 近期官网给出了RedisJson(RedisSearch)的性能测试报告,可谓碾压其他NoSQL,下面是核心的报告内容,先上结论: 对于隔离写入(isolated writ ...

最新文章

  1. SpringBoot conditional注解和自定义conditional注解使用
  2. celery源码分析:multi命令分析
  3. SAP S/4 HANA的物料编码40位设置
  4. Python | [a for b in c for a in b]的用法
  5. jquery学习(六)-jquery中的动画
  6. Winform判断一个窗口是否以模态化方式打开
  7. 通过gps给定的两个经纬度坐标,计算两点之间的距离
  8. jsp动作元素include学习
  9. 【程序员の英文听写】Trump’s Totally Not Weird Way of Standing | The Daily Social Distancing Show
  10. RBAC、控制权限设计、权限表设计 基于角色权限控制和基于资源权限控制的区别优劣
  11. linux 信号_Linux中的信号处理机制 [四]
  12. 惊艳!28岁就任副教授,“最美女教授”年纪轻轻已是博导、院长
  13. 疫情中的2021,云原生会走向哪里
  14. MyEclipse在删除文件后servers报错问题解决
  15. springboot配置请求头大小
  16. 【Proteus仿真】74LS138译码器流水灯
  17. 利用python进行数据分析(4)
  18. java引用类型内存_Java的引用类型的内存分析
  19. android:一套默认头像的封装
  20. Lync客户端证书安装

热门文章

  1. Windows Server2008 R2安装wampserver缺少api-ms-win-crt-runtime-l1-1-0.dll解决方案
  2. 如何使flexbox子代的父母高度为100%?
  3. 如何在Javascript中访问对象的第一个属性?
  4. 反映参数名称:滥用C#lambda表达式还是语法亮度?
  5. 如何从Android中的另一个应用程序启动活动
  6. 如何在Git中克隆单个分支?
  7. DockerSwarm 微服务部署
  8. k-means算法的理解与实现
  9. PHP变量在内存中的存储方式
  10. PHP之文件上传: 参数enctype