封面图片:ThisisEngineering RAEng on Unsplash


本文是对论文“To prune, or not to prune: exploring the efficacy ofpruning for model compression”的摘抄。这篇文章是TensorFlow模型优化工具文档中推荐的,作者Michael H. Zhu,来自斯坦福大学。在这里可以找到论文原文。背景

对于资源有限的移动终端设备来说,内容带宽通常是一个重要的限制因素。模型压缩至少有两点好处:减少耗电的内存访问次数;同等带宽下提升压缩模型参数的获取效率。剪枝将不重要的模型权重归零,实现了模型压缩的同时只带来了较小的质量损失。剪枝之后的模型是稀疏的,在支持稀疏矩阵加速运算的硬件上可以进一步获得加速效果。

国冰提示:英伟达的第三代张量核心(Tensor Core)对于稀疏矩阵的运算有约5倍的性能提升。这一点我们在文章“RTX30系列,香吗”中有过介绍。

在模型内存足迹(memory footprint)一定的前提下,如何获得最准确的模型,是本文的核心内容。作者通过对比两种模型来回答这个问题。第一种,先训练一个大模型,然后通过剪枝将其转换为一个强稀疏模型;第二种,直接训练一个非稀疏模型,尺寸与稀疏模型相当。在具体的模型架构与任务上,作者做出以下选择:

  • 图像分类:Inception V3与MobileNets

  • 序列分析:stacked LSTMs与seq2seq

相关工作

90年代的剪枝通过将权重置零时网络损失函数增量的二阶泰勒级数近似来实现。最近的工作中,基于数量级的权重剪枝开始流行。这种方式简单易行,并且适用于大型网络与数据集。本文通过剪去最小数量级的权重来控制模型稀疏程度。这种策略不需要人工选择权重阈值,不仅适用于CNN也可以用于LSTM。

方案

作者在TensorFlow的基础上实现了训练中剪枝。针对每一个选定的layer,增加同尺寸同形状的二元mask变量作为该layer的权重张量,并决定在哪个权重参数参与网络的前向传播。同时,在训练图中注入算子对该层的权重按照绝对值大小排序,将最小的权重置零直到该层的稀疏程度达到预定指标。反向传播的梯度同样会经过该二元mask,但是不会更新被置零的权重。

稀疏程度随着训练的进程逐渐增加,并满足公式:

其中si为初始稀疏率,sf为最终稀疏率,n为修剪总步数,t为训练步数,Δt为修剪频率。

二元mask每Δt更新一次,这有助于网络的准确率从修剪后的状态逐步恢复。作者的实验表明当Δt的取值在100到1000之间时剪枝对最终的模型质量影响可以忽略。一旦模型达到预定的稀疏指标,权重mask停止更新。按照稀疏率变化公式,训练早期频繁剪枝,随着训练进展,剪枝的频率越来越低——因为可供剪枝的权重越来越少。如下图所示。

实际剪枝过程中,网络会先训练若干epoch或者加载一个已经训练好的网络,这就决定了t0。参数n则很大程度上取决于学习率曲线。作者观察到随着学习率的下降,剪枝后的模型准确率可能会很难恢复过来。反过来,过高的学习率则可能导致权重在收敛到较优值之前被剪掉。因此需要将两者紧密结合起来。例如上图中,对Inception V3剪枝的过程安排在学习率相对较大的阶段。

下图则展示了训练过程中模型准确率的变化。对于稀疏率达87.5%的模型,随着稀疏程度的上升,模型经历了“几乎灾难性”的衰退,但是随后又很快恢复了过来。这种现象在高稀疏率的模型中更加常见。

下表展示了稀疏程度与模型准确率之间的关系。随着稀疏程度的增加,模型的准确率开始下降。不过即便有一半的权重被裁剪,模型的准确率也只下降了一点。

比较“大稀疏”与“小密实”模型

对于紧凑模型来说剪枝仍然是有效的,可以与width multiplier相比。同等大小的模型,稀疏模型优于非稀疏模型。训练也很简单,只是初始学习率小10倍。

国冰提示:

这一章节的模型都为常见模型,且数据较多,因此只说结论。感兴趣的读者可以查阅原文。另外两个子章节涉及模型与图像无关,故此跳过。

讨论

稀疏模型的内存足迹包括非零参数的保存以及索引它们所需的数据结构。模型剪枝可以减少非零参数之间的连接数,但是稀疏矩阵的存储不可避免的减小了压缩率。无论权重是否为0,二元mask稀疏矩阵都需要为之存储1bit,同时还需要一个向量存储非0值。无论稀疏率的大小,这部分开销无法避免。

另外,大稀疏模型比同体积非稀疏模型的表现更好。而且随着模型规模的扩大,这种差异越发明显。

总结

经过剪枝之后的稀疏大模型要优于同体积的非稀疏模型。作者提出的递进剪枝策略可以广泛的应用于各种模型。有限资源环境下,剪枝是有效的模型压缩策略。深度学习加速硬件应当将稀疏矩阵的存储与运算提供支持。


TensorFlow支持剪枝,请参阅官方文档:https://www.tensorflow.org/model_optimization/guide/pruning/针对HRNet的剪枝可能会晚一点。因为目前剪枝不支持 subclassed model。参见issue:https://github.com/tensorflow/model-optimization/issues/155


⊱ 推荐阅读 ⊰

tensorflow 训练权重不更新_TensorFlow模型剪枝原理相关推荐

  1. PyTorch载入预训练权重方法和冻结权重方法

    载入预训练权重 1. 直接载入预训练权重 简单粗暴法: pretrain_weights_path = "./resnet50.pth" net.load_state_dict(t ...

  2. 深度学习加载预训练权重好处

    深度学习加载预训练权重好处: 在模型开始训练前,使模型参数得到一个好的初始化,对于后面的训练学习有非常大的帮助.

  3. 使用TensorFlow训练WDL模型性能问题定位与调优

    简介 TensorFlow是Google研发的第二代人工智能学习系统,能够处理多种深度学习算法模型,以功能强大和高可扩展性而著称.TensorFlow完全开源,所以很多公司都在使用,但是美团点评在使用 ...

  4. 使用PaddleFluid和TensorFlow训练序列标注模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  5. Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练)

    Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练) 目录 Pytorch基础训练库Pytorch-Base-Trainer(PBT)(支持分布式训练) 1.I ...

  6. 使用tensorflow object detection API 训练自己的目标检测模型 (三)

    在上一篇博客"使用tensorflow object detection API 训练自己的目标检测模型 (二)"中介绍了如何使用LabelImg标记数据集,生成.xml文件,经过 ...

  7. 将TensorFlow训练的模型移植到Android手机

    2019独角兽企业重金招聘Python工程师标准>>> 前言 本文中出现的TF皆为TensorFlow的简称. 先说两句题外话吧,TensorFlow 前两天热热闹闹的发布了正式版r ...

  8. 基于pytorch的模型稀疏训练与模型剪枝示例

    基于pytorch的模型稀疏训练与模型剪枝示例 稀疏训练+模型剪枝代码下载地址:下载地址 CIFAR10-VGG16BN Baseline Trained with Sparsity (1e-4) P ...

  9. 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...

最新文章

  1. kafka安装、配置、启动、常用命令及shell启动脚本编写
  2. ORC文件存储格式的深入探究
  3. Telegraf和Grafana监控多平台上的SQL Server
  4. 高仿wx钱包页H5网站源码
  5. 【软考2】Java语言的基本知识汇总
  6. hyperledger fabric PBFT算法简要解析
  7. Node.js开发框架Express4.x
  8. Spring定时器cron表达式
  9. 《工业设计史》 第二章:手工艺设计阶段
  10. st7789 旋转_st7789v spi通信
  11. java中将汉字转拼音,解决pinyin4j多音节问题
  12. 自控力:和压力做朋友(斯坦福大学实用的心理学课程) 读后感
  13. grep -A -B -C
  14. 传教士 野人 过河问题
  15. iOS开发-思维导图(初级)
  16. 软件测试面试题:一个输入手机号获取验证码的页面,说出测试过程
  17. 09 conventional exercise
  18. 张氏华孙公 福建省上杭县张氏第一代开基祖宗
  19. createBuilderConfig 0XFFFF异常
  20. 小米、360、盛大路由器?居然还有这么多人趋之若鹜!!!想不通!

热门文章

  1. 华为FusionCloud 云计算解决方案及相关资料下载
  2. 谷歌升级Android分析应用程序
  3. Rexsee API介绍:Android传感器系列之 - 磁场传感器Magnetic Field源码
  4. 高性能WEB开发(6) - web性能测试工具推荐
  5. swoole websocket服务
  6. php abstract
  7. 细说plsql中的空值表达式
  8. 蓝牙BLE4.0的LL层数据和L2CAP层数据的区分与理解
  9. Django发HTML邮件
  10. 2018年6月2号(线段树(2))