pytorch框架下实现top-k剪枝

这篇博客,以MNIST数据集为例,对LSTM的权重矩阵实现top-k剪枝(7,2),介绍了如何在pytorch框架下实现top-k剪枝。

文章目录

  • pytorch框架下实现top-k剪枝
  • 一. top-k剪枝
  • 二. 生成掩模(mask)矩阵
  • 三. 定义剪枝函数
  • 总结
  • 参考文献

一. top-k剪枝

  • LSTM常被应用自然语言处理(NLP)相关的应用,由于引入了memory cell和gate unit,其含有大量参数,即使被剪枝90%的参数,仍然不会给模型带来太大的精度损失,较多的冗余参数带来很多不必要的资源消耗,因此需要被剪枝。随机剪枝产生的稀疏矩阵,需要额外的资源去存储位置信息,因此,规则剪枝更占优势。

  • 这篇博客采用MNIST数据集,搭建了一个含有双层LSTM,线性层的RNN模型,其中LSTM的输入,隐藏层输出维度均为28,采用的top-k为,lstm的权重矩阵的每一行,7个分为一组,每组只保留最大的2个,其余的均为0。top-k剪枝的文献

  • 这样剪枝获得的权重矩阵每一行数量都相等,且保留下来的权重的位置信息,只需要3个2进制数就可以表示,符合FPGA运算时对负载平衡和减少参数的需求。

二. 生成掩模(mask)矩阵

  • Pytorch剪枝时,需要一个掩模矩阵,该矩阵和待剪枝的矩阵维度大小相等,只包含1,0两个数值,1表示该位置的数据保留,0表示该位置的数据被剪枝;

可以使用如下代码,查看模型都含有哪些权重矩阵:

 for name, _  in model.named_parameters():print(name)
  • 我定义的rnn模型,lstm(双层)含有的权重参数为rnn.lstm.weight_ih_l0,rnn.lstm.weight_hh_l0, rnn.lstm.weight_ih_l1, rnn.lstm.weight_hh_l1.

矩阵每行含有28个参数,将其分为4组,每组7个元素,只保留最大的2个:

def topk(para, k):c = torch.zeros(para.size()[0], para.size()[1],dtype = torch.int) #初始化一个和权值矩阵相同大小的掩膜矩阵l = int(para.size()[1]/7) #将每行的每7个权值分为一组,l为分组的数量parameter = torch.abs(para)  #将权值矩阵取绝对值_, b = torch.topk(parameter[:,:7], k, 1, largest = True) #b为0~6之间的k个数,表示该组最大的前k个权值的位置for i in range(1,l):_, b1 = torch.topk(parameter[:,i*7:(i+1)*7], k, 1, largest = True) #遍历每一组最大的前k个值的位置b1 = b1 + i * 7  #得到每一行中保留的权值位置信息的绝对位值b = torch.cat((b,b1),dim=1) #将每一段拼接起来for j in range(c.size()[0]):c[j, b[j, :]] = 1 #将c中,b中位置信息的对应的位置,置1(保留),其他部分为0return c

c1,c2,c3,c4是根据四个权重矩阵生成的四个掩模矩阵(我定义的双层LSTM有四个权重矩阵),生成的掩模矩阵元素均为0或1

c1 = topk(rnn.lstm.weight_ih_l0.data, 2)
c2 = topk(rnn.lstm.weight_hh_l0.data, 2)
c3 = topk(rnn.lstm.weight_ih_l1.data, 2)
c4 = topk(rnn.lstm.weight_hh_l1.data, 2)

生成的掩模矩阵如图所示:

三. 定义剪枝函数

pytorch提供的自定义剪枝的模板,这里分别将c1,c2,c3,c4作为掩模矩阵,这段代码的意思就是,rnn模型中的lstm层的权重矩阵weight_ih_l0对应掩模矩阵c1, c1元素为1的位置,保留;c1为0的,weight_ih_l0对应的位置被剪枝掉,以此类推;

class FooBarPruningMethod1(prune.BasePruningMethod):"""Prune every other entry in a tensor"""PRUNING_TYPE = 'unstructured'def compute_mask(self, t, default_mask):mask = c1return mask
class FooBarPruningMethod2(prune.BasePruningMethod):"""Prune every other entry in a tensor"""PRUNING_TYPE = 'unstructured'def compute_mask(self, t, default_mask):mask = c2return mask
class FooBarPruningMethod3(prune.BasePruningMethod):"""Prune every other entry in a tensor"""PRUNING_TYPE = 'unstructured'def compute_mask(self, t, default_mask):mask = c3return mask
class FooBarPruningMethod4(prune.BasePruningMethod):"""Prune every other entry in a tensor"""PRUNING_TYPE = 'unstructured'def compute_mask(self, t, default_mask):mask = c4return mask
def foobar_unstructured(model):FooBarPruningMethod1.apply(model.lstm, 'weight_ih_l0')FooBarPruningMethod2.apply(model.lstm, 'weight_hh_l0')FooBarPruningMethod3.apply(model.lstm, 'weight_ih_l1')FooBarPruningMethod3.apply(model.lstm, 'weight_hh_l1')return model
rnn = foobar_unstructured(rnn) #对预训练完成的模型进行top-k剪枝

剪枝过后再训练,会发现,剪枝后的训练速度,明显快于剪枝前。
剪枝后的矩阵如图所示:

总结

这篇博客以MNIST数据集为例,搭建了一个含有双层LSTM,和FC层的模型,预训练后对其进行top-k剪枝,详细介绍了pytorch框架下的top-k剪枝过程;

  • 完整代码下载:pytorch-topk

参考文献

  • top-k剪枝的文献:E-LSTM: An Efficient Hardware Architecture for Long Short-Term Memory
  • pytorch官方剪枝教程:pytorch剪枝

pytorch实现topk剪枝相关推荐

  1. Pytorch torch.topk()的简单用法

    官方文档:https://pytorch.org/docs/stable/generated/torch.topk.html?highlight=topk#torch.topk 由于numpy本身是没 ...

  2. 基于pytorch的模型剪枝+模型量化+BN合并+TRT部署(cifar数据)(2)

    1)量化:High-Bit(>2b): QAT, PTQ, QAFT; Low-Bit(≤2b)/Ternary and Binary: QAT 2)剪枝:正常.规整和分组卷积结构剪枝 3)针对 ...

  3. pytorch 中的topk函数

    pytorch中topk() 函数用法 1. 函数介绍 最近在代码中看到这两个语句 maxk = max(topk) _, pred = output.topk(maxk, 1, True, True ...

  4. 剪枝PRUNING TUTORIAL

    最新的深度学习技术依赖于难以部署的过度参数化模型.相反,已知生物神经网络使用有效的稀疏连通性.为了减少内存,电池和硬件消耗,同时又不牺牲精度,在设备上部署轻量级模型并通过私有设备上计算来确保私密性,确 ...

  5. 模型压缩:量化、剪枝和蒸馏

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 编者荐语 近年来,BERT 系列模型成了应用最广的预训练语言模型, ...

  6. 2023年的深度学习入门指南(8) - 剪枝和量化

    2023年的深度学习入门指南(8) - 剪枝和量化 从这一节开始,我们要准备一些技术专项了.因为目前大模型技术还在快速更新迭代中,各种库和实现每天都在不停出现.因为变化快,所以难免会遇到一些问题.对于 ...

  7. 剪枝与重参第三课:常用剪枝工具

    目录 常用剪枝工具 前言 1.torch.nn.utils.prune 1.1 API简单示例 1.2 拓展之钩子函数 2.pytorch pruning functions 3.custom pru ...

  8. 基于FPGA的LSTM加速器设计(MNIST数据集为例)

    摘要 本文以MNIST手写数字识别任务为例,使用FPGA搭建了一个LSTM网络加速器,并选取MNIST数据集中的10张图片,通过vivado软件进行仿真验证.实验结果表明,本文设计的基于FPGA的LS ...

  9. 【深度学习】超强优化器如何与网络有机结合

    [深度学习]超强优化器如何与网络有机结合 1 Ranger优化器 2 一个例子(基于CNN和pytorch) 3 剪枝(减小优化器压力) 1 Ranger优化器 RAdam + Lookahead + ...

最新文章

  1. 理解离散傅立叶变换(一. 傅立叶变换的由来)
  2. learn_Day14 内置函数补充、反射、初识面向对象
  3. Public Sale【博弈】
  4. postgresql设置postgres密码_django项目时配置postgresql数据库的方法
  5. UI_UISlider控件
  6. Luogu2606[ZJOI2010] 排列计数
  7. 解决无法加载虚拟仿真实验unity3d插件的“failed to update unity web player”问题2019年12月27日
  8. 人体动作捕捉技术综述
  9. cf两边黑屏怎么解决win10_win10摄像机黑屏的解决方案!
  10. 树莓派存储方案_树莓派网络存储(NAS)
  11. MySQL之between and 临界值问题
  12. OCR最佳实践项目汇总
  13. java web argox打印机 用jna调用dll
  14. 建模--知名软件介绍
  15. Labview的下载地址
  16. MySQL使用gpfs共享磁盘_GPFS文件系统笔记
  17. 用Firefox的userChrome.css定制自己的Firefox界面
  18. 关于超分数据库的一个总结
  19. Python带你跨年!用Python送你一场跨年烟花秀
  20. conda 换成清华的源_[mcj]conda设置清华源以及更换删除源|conda常用命令集锦

热门文章

  1. U盘的工作原理(读取和存储数据)
  2. springBoot 整合 hikari
  3. 微软Power Platform正式商用
  4. 清华申请退学博士作品:完全用Linux工作,凸Windows
  5. pageInfo分页无效问题
  6. 传球游戏c语言,[蓝桥杯][算法训练VIP]传球游戏-题解(Java代码)
  7. 国家大环境施压,开曼公司来控股国内公司。
  8. android 常用颜色对照表
  9. 计算机制作节日贺卡教案,《节日贺卡自己做》说课稿
  10. 2023基于微信小程序的火锅店点餐订餐系统(SSM+mysql)-JAVA.VUE(论文+开题报告+运行)