点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达

金磊 发自 凹非寺
量子位 报道 | 公众号 QbitAI

面对数以亿计的图片数据,到底该用什么样的方法才能快速搞实验?

这样的问题,或许在做机器学习研究的你,也会经常遇到。

而就在最近,一个国外小哥就提出了一种建议:

在Pytorch lightning基础上,让深度学习pipeline速度提升10倍

用他自己的话来说就是——“爬楼时像给了你一个电梯”。

这般“酸爽”,到底是如何做到的呢?

优化机器学习pipeline,很重要

无论你是身处学术界还是工业界,时间资源等各种因素,往往会成为你在搞实验的枷锁

尤其是随着数据集规模和机器学习模型,变得越发庞大和复杂,让实验变得既费时又耗力。

提速这件事,就变得至关重要。

例如在2012年的时候,训练一个AlexNet,要花上5到6天的时间。

而现如今,只需要短短几分钟就可以在更大的数据集上训练更大的图像模型。

这位小哥认为,从某种角度上来说,这是得益于各种各样的“利器”的出现。

例如Pytorch Lingtning,就是其中一种。

于是,他便“死磕”pipeline,总结了六种“闪电加速”实验周期的方法。

并行数据加载

数据加载和增强(augmentation)往往被认为是训练pipeline时的瓶颈之一。

一个典型的数据pipeline包含以下步骤:

  • 从磁盘加载数据

  • 在运行过程中创建随机增强

  • 将每个样本分批整理

在这个过程中,倒是可以用多个CPU进程并行加载数据来优化。

但与此同时,还可以通过下面的操作来加速这一过程:

1、将DataLoader中的num_workers参数设置为CPU的数量。

2、当与GPU一起工作时,将DataLoader中的pin_memory参数设置为True。这可以将数据分配到页锁定的内存中,从而加快数据传输到GPU的速度。

使用分布式数据并行的多GPU训练

与CPU相比,GPU已经大大加速了训练和推理时间。

但有没有比一个GPU更好的方法?或许答案就是:

多个GPU!

在PyTorch中,有几种范式可以用多个GPU训练你的模型。

两个比较常见的范式是 “DataParallel ”和 “DistributedDataParallel”。

而小哥采用的方法是后者,因为他认为这是一种更可扩展的方法。

但在PyTorch(以及其他平台)中修改训练pipeline并非易事。

必须考虑以分布式方式加载数据以及权重、梯度和指标的同步等问题。

不过,有了PyTorch Lightning,就可以非常容易地在多个GPU上训练PyTorch模型,还是几乎不需要修改代码的那种!

混合精度

在默认情况下,输入张量以及模型权重是以单精度(float32)定义的。

然而,某些数学运算可以用半精度(float16)进行。

这样一来,就可以显著提升速度,并降低了模型的内存带宽,还不会牺牲模型的性能。

通过在PyTorch Lightning中设置混合精度标志(flag),它会在可能的情况下自动使用半精度,而在其他地方保留单精度。

通过最小的代码修改,模型训练的速度可以提升1.5至2倍。

早停法

当我们训练深度学习神经网络的时候,通常希望能获得最好的泛化性能。

但是所有的标准深度学习神经网络结构,比如全连接多层感知机都很容易过拟合。

当网络在训练集上表现越来越好,错误率越来越低的时候,实际上在某一刻,它在测试集的表现已经开始变差。

因此,早停法 (Early Stopping)便在训练过程中加入了进来。

具体来说,就是当验证损失在预设的评估次数(在小哥的例子中是10次评估)后停止训练。

这样一来,不仅防止了过拟合的现象,而且还可以在几十个 epoch内找到最佳模型。

Sharded Training

Sharded Training是基于微软的ZeRO研究和DeepSpeed库。

它显著的效果,就是让训练大模型变得可扩展和容易。

否则,这些模型就不适合在单个GPU上使用了。

而在Pytorch Lightning的1.2版本中,便加入了对Shared Training的支持。

虽然在小哥的实验过程中,并没有看到训练时间或内存占用方面有任何改善。

但他认为,这种方法在其它实验中可能会提供帮助,尤其是在不使用单一GPU的大模型方面。

模型评估和推理中的优化

在模型评估和推理期间,梯度不需要用于模型的前向传递。

因此,可以将评估代码包裹在一个torch.no_grad上下文管理器中。

这可以防止在前向传递过程中的存储梯度,从而减少内存占用。

如此一来,就可以将更大的batch送入模型,让评估和推理变得更快。

效果如何?

介绍了这么多,你肯定想知道上述这些方法,具体起到了怎样的作用。

小哥为此做了一张表格,详解了方法的加速效果。

那么这些方法,是否对在做机器学习实验的你有所帮助呢?

快去试试吧~

参考链接:

https://devblog.pytorchlightning.ai/how-we-used-pytorch-lightning-to-make-our-deep-learning-pipeline-10x-faster-731bd7ad318a

本文系网易新闻•网易号特色内容激励计划签约账号【量子位】原创内容,未经账号授权,禁止随意转载。

点个在看 paper不断!

用上Pytorch Lightning的这六招,深度学习pipeline提速10倍!相关推荐

  1. pytorch自带网络_使用PyTorch Lightning自动训练你的深度神经网络

    作者:Erfandi Maula Yusnu, Lalu 编译:ronghuaiyang 原文链接 使用PyTorch Lightning自动训练你的深度神经网络​mp.weixin.qq.com 导 ...

  2. <计算机视觉 六> 深度学习目标检测模型的评估标准

    鼠标点击下载     项目源代码免费下载地址 <计算机视觉一> 使用标定工具标定自己的目标检测 <计算机视觉二> labelme标定的数据转换成yolo训练格式 <计算机 ...

  3. [翻译] 神经网络与深度学习 第六章 深度学习 - Chapter 6 Deep learning

    目录: 首页 译序 关于本书 关于习题和难题 第一章 利用神经网络识别手写数字 第二章 反向传播算法是如何工作的 第三章 提升神经网络学习的效果 第四章 可视化地证明神经网络可以计算任何函数 第五章 ...

  4. GitHub 上 57 款最流行的开源深度学习项目【转】

    GitHub 上 57 款最流行的开源深度学习项目[转] 2017-02-19 20:09 334人阅读 评论(0) 收藏 举报 分类: deeplearning(28) from: https:// ...

  5. 李宏毅-机器学习深度学习-第六讲-深度学习介绍

    哔哩哔哩视频地址:https://www.bilibili.com/video/av94411666?p=10 (请自行拷贝到浏览器打开) 李宏毅深度学习–第六讲–深度学习介绍

  6. 检验 pytorch,tensorflow,paddle,mxnet 深度学习框架是否正确支持GPU功能

    检验 pytorch,tensorflow,paddle,mxnet 深度学习框架是否正确支持GPU功能 1.pytorch 框架 import torch a = torch.cuda.is_ava ...

  7. 让 PyTorch 更轻便,这款深度学习框架你值得拥有!在 GitHub 上斩获 6.6K 星

    白交 发自 凹非寺  量子位 报道 | 公众号 QbitAI 一直以来,PyTorch就以简单又好用的特点,广受AI研究者的喜爱. 但是,一旦任务复杂化,就可能会发生一系列错误,花费的时间更长. 于是 ...

  8. 让PyTorch更轻便,这款深度学习框架你值得拥有!在GitHub上斩获6.6k星

    白交 发自 凹非寺  量子位 报道 | 公众号 QbitAI 一直以来,PyTorch就以简单又好用的特点,广受AI研究者的喜爱. 但是,一旦任务复杂化,就可能会发生一系列错误,花费的时间更长. 于是 ...

  9. 使用PyTorch Lightning自动训练你的深度神经网络

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Erfandi Maula Yusnu, Lalu 编译:ronghuai ...

最新文章

  1. 制作多域名(SAN/UCC)CSR(证书请求文件)
  2. 惠普服务器显示灯闪红灯,惠普打印机指示灯闪烁什么意思? 惠普2130打印机故障灯大全图解...
  3. linux 内核load addr,linux2.4启动分析(1)---内核启动地址的确定 vmlinux LOAD_ADDR ZRELADDR...
  4. 【Silverlight】解决DataTemplate绑定附加属性
  5. RocketMQ集群知识介绍
  6. pip install时发生raise ReadTimeoutError(self._pool, None, 'Read timed out.')的解决方案
  7. JSON 之 SuperObject(8): 关于乱码的几种情况 - 向 Henri Gourvest 大师报告
  8. 云智能资深专家崮德:谈谈我对华为HarmonyOS 2.0的看法
  9. Office 365网络链接概览(三)--专线express route
  10. 华为机试:机器人走迷宫
  11. html的3d效果怎么设置,HTML5如何在网页中实现3D效果?
  12. c语言指针的运用——回文单词与回文句子
  13. StrStrI 与 strstr
  14. ARM中断向量表与响应流程
  15. DXP软件使用快捷键
  16. 广义相加模型(GAM)与向前逐步选择算法(基于R语言)
  17. 转变自己的信仰——致少年的自己
  18. Python3.6获取QQ空间全部好友列表
  19. JavaScript设计模式综合应用案例
  20. 参加Kaggle比赛的流程

热门文章

  1. xcode 4.2 不再支持 Window-Based Application 的解决办法(转载)
  2. 【青少年编程竞赛交流】10月份微信图文索引
  3. 【NCEPU】徐韬:街景字符编码识别比赛
  4. Numpy入门教程:03.数组操作
  5. 刻意练习:LeetCode实战 -- Task03. 移除元素
  6. 简介+原理+绘制,详解 Python「瀑布图」的整个制作流程!
  7. 清明出游,你会“鸽”酒店吗?AI 早已看穿一切
  8. 限时早鸟票 | 2019 中国大数据技术大会(BDTC)超豪华盛宴抢先看!
  9. 六大主题报告,四大技术专题,AI开发者大会首日精华内容全回顾
  10. 拯救老电影——详解爱奇艺ZoomAI视频增强技术的应用