这篇博客是在pytorch中基于apex使用混合精度加速的一个偏工程的描述,原理层面的解释并不是这篇博客的目的,不过在参考部分提供了非常有价值的资料,可以进一步研究。

一个关键原则:“仅仅在权重更新的时候使用fp32,耗时的前向和后向运算都使用fp16”。其中的一个技巧是:在反向计算开始前,将dloss乘上一个scale,人为变大;权重更新前,除去scale,恢复正常值。目的是为了减小激活gradient下溢出的风险。

apex是nvidia的一个pytorch扩展,用于支持混合精度训练和分布式训练。在之前的博客中,神经网络的Low-Memory技术梳理了一些low-memory技术,其中提到半精度,比如fp16。apex中混合精度训练可以通过简单的方式开启自动化实现,组里同学交流的结果是:一般情况下,自动混合精度训练的效果不如手动修改。分布式训练中,有社区同学心心念念的syncbn的支持。关于syncbn,在去年做CV的时候,我们就有一些来自民间的尝试,不过具体提升还是要考虑具体任务场景。

那么问题来了,如何在pytorch中使用fp16混合精度训练呢?

第零:混合精度训练相关的参数

parser.add_argument('--fp16',action='store_true',help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',type=float, default=0,help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n""0 (default value): dynamic loss scaling.\n""Positive power of 2: static loss scaling value.\n")

第一:模型参数转换为fp16

nn.Module中的half()方法将模型中的float32转化为float16,实现的原理是遍历所有tensor,而float32和float16都是tensor的属性。也就是说,一行代码解决,如下:

model.half()

第二:修改优化器

在pytorch下,当使用fp16时,需要修改optimizer。类似代码如下(代码参考这里):

# Prepare optimizerif args.do_train:param_optimizer = list(model.named_parameters())no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]if args.fp16:try:from apex.optimizers import FP16_Optimizerfrom apex.optimizers import FusedAdamexcept ImportError:raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")optimizer = FusedAdam(optimizer_grouped_parameters,lr=args.learning_rate,bias_correction=False,max_grad_norm=1.0)if args.loss_scale == 0:optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)else:optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,t_total=num_train_optimization_steps)else:optimizer = BertAdam(optimizer_grouped_parameters,lr=args.learning_rate,warmup=args.warmup_proportion,t_total=num_train_optimization_steps)

第三:backward时做对应修改

 if args.fp16:optimizer.backward(loss)else:loss.backward()

第四:学习率修改

if args.fp16:# modify learning rate with special warm up BERT uses# if args.fp16 is False, BertAdam is used that handles this automaticallylr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)for param_group in optimizer.param_groups:param_group['lr'] = lr_this_stepoptimizer.step()optimizer.zero_grad()

根据参考3,值得重述一些重要结论:

(1)深度学习训练使用16bit表示/运算正逐渐成为主流。

(2)低精度带来了性能、功耗优势,但需要解决量化误差(溢出、舍入)。

(3)常见的避免量化误差的方法:为权重保持高精度(fp32)备份;损失放大,避免梯度的下溢出;一些特殊层(如BatchNorm)仍使用fp32运算。

参考资料:

1.nv官方repo给了一些基于pytorch的apex加速的实现

实现是基于fairseq实现的,可以直接对比代码1-apex版和代码2-非apex版(fairseq官方版),了解是如何基于apex实现加速的。

2.nv官方关于混合精度优化的原理介绍

按图索骥,可以get到很多更加具体地内容。

3.低精度表示用于深度学习 训练与推断

感谢团队同学推荐。

[Pytorch]基于混和精度的模型加速相关推荐

  1. Ultralytics公司YOLOv8来了(训练自己的数据集并基于NVIDIA TensorRT和华为昇腾端到端模型加速)--跟不上“卷“的节奏

    Official YOLOv8 训练自己的数据集并基于NVIDIA TensorRT和华为昇腾端到端模型加速 说明: 本项目支持YOLOv8的对应的package的版本是:ultralytics-8. ...

  2. 从零开始构建基于textcnn的文本分类模型(上),word2vec向量训练,预训练词向量模型加载,pytorch Dataset、collete_fn、Dataloader转换数据集并行加载

    伴随着bert.transformer模型的提出,文本预训练模型应用于各项NLP任务.文本分类任务是最基础的NLP任务,本文回顾最先采用CNN用于文本分类之一的textcnn模型,意在巩固分词.词向量 ...

  3. pytorch基于web端和C++的两种深度学习模型部署方式

    本文对深度学习两种模型部署方式进行总结和梳理.一种是基于web服务端的模型部署,一种是基于C++软件集成的方式进行部署. 基于web服务端的模型部署,主要是通过REST API的形式来提供接口方便调用 ...

  4. 【模型加速】关于模型加速的总结

    概述 ● 模型加速的目标: a. Increase inference speed:加快推理速度(应用层面). b. Reduce model size:压缩模型. ● 关于模型的加速大致可以分为三个 ...

  5. 模型转换、模型压缩、模型加速工具汇总

    点击上方"计算机视觉工坊",选择"星标" 干货第一时间送达 编辑丨机器学习AI算法工程 一.场景需求解读   在现实场景中,我们经常会遇到这样一个问题,即某篇论 ...

  6. 【Pytorch】运用英伟达DALI加速技巧可使PyTorch运算速度快4倍

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 你的数据处理影响整个训练速度,如果加上英伟达 DALI 库,处理速度比原生 PyT ...

  7. 卡内基梅隆大学提出基于学习的动作捕捉模型,用自监督学习实现人类3D动作追踪

    原文来源:Cornell University Library 作者:Hsiao-Yu Fish Tung.Hsiao-Wei Tung.Ersin Yumer. Katerina Fragkiada ...

  8. pytorch基于卷积层通道剪枝的方法

    pytorch基于卷积层通道剪枝的方法 原文:https://blog.csdn.net/yyqq7226741/article/details/78301231 本文基于文章:Pruning Con ...

  9. 10倍加速!爱奇艺超分辨模型加速实践

    关注公众号,发现CV技术之美 随着终端播放设备的升级,观众对于视频的品质需求也逐步提升.需求从最开始的高清过渡到4K,最近8K也有开始流行的趋势.除了对于分辨率提升的需求之外,视频在采集的过程中,也难 ...

最新文章

  1. libjpeg-turbo介绍及测试代码
  2. Oracle RAC错误之--oifcfg错误案例
  3. docker oracle navicat_拥抱开源从零开始 Docker、Mysql amp; JPA
  4. 08--swift之类与结构体
  5. 常用PHP array数组函数
  6. RecyclerView因版本问题无法加载
  7. python基础(part5)--容器类型之字符串
  8. ios 开发账号 退出协作_如何在iOS 10中的Notes上进行协作
  9. [ASP.NET Core 2.0 前方速报].NET Core 2.0.3 已经支持引用第三方程序集了
  10. Uva 11354 LCA 倍增祖先
  11. Sublime Text 2/3如何支持中文GBK编码(亲测实现)
  12. 利用VBB仿真——实现摇杆时钟
  13. PDMS Pipeline Tool 教程(三):材料表
  14. matlab基于视频的车流量检测,基于视频的车流量统计——matlab代码.docx
  15. 中国科技统计年鉴Excel版本(1991-2021年)
  16. 自备一个刷BIOS神器
  17. block的名词形式_block是什么意思_block在线翻译_英语_读音_用法_例句_海词词典
  18. 手机地图导航哪个好?手机导航地图推荐
  19. 医院时钟系统,NTP子钟,网络子母钟系统,ntp子母钟,网络子母钟——为您的系统保驾护航
  20. 成都Uber优步司机奖励政策(2月21日)

热门文章

  1. 2022-2028年中国锂电材料产业投资分析及前景预测报告
  2. 只有变强大,才能照亮他人
  3. 利用pandas读写HDF5文件
  4. GPU与CPU交互技术
  5. MindSpore平台系统类
  6. 【杂】LaTeX中一些符号的输入方法
  7. [JavaScript] JavaScript 运算符与流程控制
  8. [JAVA EE] 内联用法
  9. php mongodb execute,php简单操作mongodb
  10. the server responded with a status of 404 (HTTP/1.1 404 Not Found)