向AI转型的程序员都关注了这个号????????????

机器学习AI算法工程   公众号:datayx

入门pytorch似乎不慢,写好dataloader和model就可以跑起来了,然而把模型搭好用起来时,却往往发觉自己的程序运行效率并不高,GPU使用率宛如舞动的妖精...忽高忽低,影响模型迭代不说,占着显存还浪费人家的计算资源hh 我最近就是遇到这个困难,花了一些精力给模型提速,这里总结一下(有些描述可能并不准确,但至少这些point可以借鉴hh,不妥之处恳请大家指正/补充啦)

dataloader方面:

  1. DataLoader的参数中,

    1. num_workers与batch_size调到合适值,并非越大越快(注意后者也影响模型性能)

    2. eval/test时shuffle=False

    3. 内存够大的情况下,dataloader的pin_memory设为True,并用data_prefetcher技巧(方法)

  2. https://zhuanlan.zhihu.com/p/80695364

  3. 把不怎么会更改的预处理提前做好 保存到硬盘上,不要放dataloader里做,尤其对图像视频的操作

  4. 减少IO操作,服务器如果是hdd根本架不住多人对磁盘的折磨,曾经几个小伙伴把服务器弄得卡到无法login...可先多线程把数据放到内存,不太会的话可以先用个dataloader专门把数据先读到list中(即内存),然后把这个list作为每次取数据的池。还可以用tmpfs把内存当硬盘,不过需要权限

  5. prefetch_generator(方法)让读数据的worker能在运算时预读数据,而默认是数据清空时才读

model方面:

  1. 用float16代替默认的float32运算(方法参考,搜索"fp16"可以看到需要修改之处,包括model、optimizer、backward、learning rate)

  2. 优化器以及对应参数的选择,如learning rate,不过它对性能的影响似乎更重要【占坑】

  3. 少用循环,多用向量化操作

  4. 经典操作尽量用别人优化好的库,别自己写(想自己实现锻炼能力除外)

  5. 数据很多时少用append,虽然使用很方便,不过它每次都会重新分配空间?所以数据很大的话,光一次append就要几秒(测过),可以先分配好整个容器大小,每次用索引去修改内容,这样一步只要0.0x秒

  6. 固定对模型影响不大的部分参数,还能节约显存,可以用detach()切断反向传播,注意若仅仅给变量设置required_grad=False 还是会计算梯度的

  7. eval/test的时候,加上model.eval()和torch.no_grad(),前者固定batch-normalization和dropout 但是会影响性能,后者关闭autograd

  8. 提高程序并行度,例如 我想train时对每个epoch都能test一下以追踪模型性能变化,但是test时间成本太高要一个小时,所以写了个socket,设一个127.0.0.1的端口,每次train完一个epoch就发个UDP过去,那个进程就可以自己test,同时原进程可以继续train下一个epoch(对 这是自己想的诡异方法hhh)

Using CUDA in correct way:

  • 设置torch.backends.cudnn.benchmark = True

使用benchmark以启动CUDNN_FIND自动寻找最快的操作,当计算图不会改变的时候(每次输入形状相同,模型不改变)的情况下可以提高性能,反之则降低性能

  • 另外,还可以采用确定性卷积:(相当于把所有操作的seed=0,以便重现,会变慢)

torch.backends.cudnn.deterministic

添加torch.cuda.get_device_name和torch.cuda.get_device_capability实现如下功能。例:

如果设置torch.backends.cudnn.deterministic = True,则CuDNN卷积使用确定性算法

torch.cuda_get_rng_state_all并torch.cuda_set_rng_state_all引入,让您一次保存/加载随机数生成器的状态在所有GPU上

torch.cuda.emptyCache()释放PyTorch的缓存分配器中的缓存内存块。当与其他进程共享GPU时特别有用。

训练模型个人的基本要求是deterministic/reproducible,或者说是可重复性。也就是说在随机种子固定的情况下,每次训练出来的模型要一样。之前遇到了两次不可重复的情况。第一次是训练CNN的时候,发现每次跑出来小数点后几位会有不一样。epoch越多,误差就越多,虽然结果大致上一样,但是强迫症真的不能忍。后来发现在0.3.0的时候已经修复了这个问题,可以用torch.backends.cudnn.deterministic = True
这样调用的CuDNN的卷积操作就是每次一样的了。

预先分配内存空间:pin_memory + non_blocking async GPU training

为了防止多GPU同时读取内存导致blocking,non_blocking需要对train data设置,否则,0.4.0版本中的DataParallel会自动尝试用async GPU training。

用比Adam更快的优化器

  • SGD with Momentum :该优化器在多项式时间内的收敛性已经明确被证明,更不用说所有的参数都已经像您的老朋友一样熟悉了

  • 【暂时不可用】使用AdamW or Adam with correct weight decay:

因为Adam在优化过程中有一个L2正则化参数,但在当前版本的Pytorch中,L2正则化没有根据学习率进行归一化,AdamW论文中提出的Adam修改方案解决了这一问题并证明收敛更快,而且适用于cosine学习率衰减等。

其他:

  1. torch.backends.cudnn.benchmark设为True,可以让cudnn根据当前训练各项config寻找优化算法,但这本身需要时间,所以input size在训练时会频繁变化的话,建议设为False

  • 使用pytorch时,训练集数据太多达到上千万张,Dataloader加载很慢怎么办?

https://www.zhihu.com/question/356829360

  • MrTian:给训练踩踩油门 —— Pytorch 加速数据读取

https://zhuanlan.zhihu.com/p/80695364


阅读过本文的人还看了以下文章:

【全套视频课】最全的目标检测算法系列讲解,通俗易懂!

《美团机器学习实践》_美团算法团队.pdf

《深度学习入门:基于Python的理论与实现》高清中文PDF+源码

特征提取与图像处理(第二版).pdf

python就业班学习视频,从入门到实战项目

2019最新《PyTorch自然语言处理》英、中文版PDF+源码

《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码

《深度学习之pytorch》pdf+附书源码

PyTorch深度学习快速实战入门《pytorch-handbook》

【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》

《Python数据分析与挖掘实战》PDF+完整源码

汽车行业完整知识图谱项目实战视频(全23课)

李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材

笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!

《神经网络与深度学习》最新2018版中英PDF+源码

将机器学习模型部署为REST API

FashionAI服装属性标签图像识别Top1-5方案分享

重要开源!CNN-RNN-CTC 实现手写汉字识别

yolo3 检测出图像中的不规则汉字

同样是机器学习算法工程师,你的面试为什么过不了?

前海征信大数据算法:风险概率预测

【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类

VGG16迁移学习,实现医学图像识别分类工程项目

特征工程(一)

特征工程(二) :文本数据的展开、过滤和分块

特征工程(三):特征缩放,从词袋到 TF-IDF

特征工程(四): 类别特征

特征工程(五): PCA 降维

特征工程(六): 非线性特征提取和模型堆叠

特征工程(七):图像特征提取和深度学习

如何利用全新的决策树集成级联结构gcForest做特征工程并打分?

Machine Learning Yearning 中文翻译稿

蚂蚁金服2018秋招-算法工程师(共四面)通过

全球AI挑战-场景分类的比赛源码(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)

python+flask搭建CNN在线识别手写中文网站

中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程

不断更新资源

深度学习、机器学习、数据分析、python

 搜索公众号添加: datayx  


机器学习算法资源社群

不断上传电子版PDF资料

技术问题求解

 QQ群号: 333972581  

长按图片,识别二维码

pytorch如何将训练提速?相关推荐

  1. pytorch 多GPU训练总结(DataParallel的使用)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_40087578/arti ...

  2. 基于pytorch量化感知训练(mnist分类)--浮点训练vs多bit后量化vs多bit量化感知训练效果对比

    基于pytorch量化感知训练–浮点训练vs多bit后量化vs多bit量化感知训练效果对比 代码下载地址:下载地址 灰色线是量化训练,橙色线是后训练量化,可以看到,在 bit = 2.3 的时候,量化 ...

  3. pytorch量化感知训练(QAT)示例---ResNet

    pytorch量化感知训练(QAT)示例---ResNet 训练浮点模型,测试浮点模式在CPU和GPU上的时间; BN层融合,测试融合前后精度和结果比对; 加入torch的量化感知API,训练一个QA ...

  4. PyTorch深度学习训练可视化工具tensorboardX

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 之前笔者提到了PyTorch的专属可视化工具visdom,参看Py ...

  5. pytorch读取文本训练

    2019独角兽企业重金招聘Python工程师标准>>> # References # https://github.com/yunjey/pytorch-tutorial/blob/ ...

  6. PyTorch 的预训练,是时候学习一下了

    前言 最近使用 PyTorch 感觉妙不可言,有种当初使用 Keras 的快感,而且速度还不慢.各种设计直接简洁,方便研究,比 tensorflow 的臃肿好多了.今天让我们来谈谈 PyTorch 的 ...

  7. pytorch 多GPU训练

    pytorch 多GPU训练 pytorch多GPU最终还是没搞通,可用的部分是前向计算,back propagation会出错,当时运行通过,也不太确定是如何通过了的.目前是这样,有机会再来补充 p ...

  8. pytorch 驱动不兼容_解决Pytorch 加载训练好的模型 遇到的error问题

    这是一个非常愚蠢的错误 debug的时候要好好看error信息 提醒自己切记好好对待error!切记!切记! -----------------------分割线---------------- py ...

  9. 【深度学习】基于PyTorch的模型训练实用教程之数据处理

    [深度学习]基于PyTorch的模型训练实用教程之数据处理 文章目录 1 transforms 的二十二个方法 2 数据加载和预处理教程 3 torchvision 4 如何用Pytorch进行文本预 ...

最新文章

  1. Linux Shell 截取字符串
  2. [视频教程] 如何在Linux深度系统deepin下安装docker
  3. Perl 之 use(), require(), do(), %INC and @INC
  4. AUTOSAR从入门到精通100讲(四十九)-AUTOSAR 通信服务Dcm篇-Dcm概念及DSL详解与实战案例
  5. centos 7新机使用前操作
  6. 受迫阻尼 matlab 仿真,MATLAB系统仿真报告——有阻尼受迫振动系统
  7. tuxedo连接mysql_9.5.3 Tuxedo与各种数据库的连接
  8. 仿QQ聊天室【方案】
  9. Struts2.0 xml文件的配置(package,namespace,action)
  10. for update引发了血案
  11. Atitit.月度计划日程表 每月流程表
  12. html网页添加友链,教你如何添加网站友情链接
  13. 手机查看云服务器文件夹,手机查看云服务器文件夹
  14. 数据分析试题集+答案
  15. 云计算机遇与挑战,中国云计算产业发展面临机遇与挑战
  16. torch.mul、matmul、mm、bmm的区别
  17. JAVA采用S7通信协议访问西门子PLC
  18. 数据分析师岗位热招!你也有希望进大厂~
  19. 低通滤波器转带通滤波器公式由来_开关电源电磁兼容进级EMI传导输入滤波器的设计理论(EDTEST上海)...
  20. 功利主义穆勒思维导图_约翰·穆勒功利主义教育思想概述

热门文章

  1. ZeroMQ(java)之负载均衡
  2. Java 的 ArrayList 的底层数据结构
  3. AngularJs ui-router 路由的简单介绍
  4. 剑灵灵动区服务器位置,盘点国服剑灵灵动内测4大玩家人气玩法(2)
  5. linux activemq修改端口号,linux下 activemq集群配置
  6. 佳能9100cdn故障_佳能 打印机故障代码大全
  7. java mina 大文件传输_mina 传输java对象
  8. oracle临时表经常被锁_这是一篇长篇入门级数据库讲解:oracle数据库数据导入导出步骤
  9. Linux-定时任务(Crontab)基本用法
  10. php curl加密获取数据,PHP利用Curl模拟登录并获取数据例子