论文理解记录:The Lottery Ticket Hypothesis
论文核心
传统非结构化剪枝虽然较大幅度减少了模型参数量,但是由于非结构化的原因导致网络稀疏化,因此很难对剪枝后的模型进行训练。这是非结构化剪枝相对于结构化剪枝不被看好的原因之一。 而本文中提出可以在任意初始化后的原始模型中找到子模块(即非结构化剪枝后的模型),该初始化后的子模块训练时间不会超过原始模型,并达到或超过原始模型的精度,该子模块在论文中被称作“中彩票”。
注意事项
论文中经过实验发现第一次随机初始化至关重要,中彩票的子模块结构和随机初始化参数有关。实验找到子模块后对子模块重新初始化,此时子模块的性能明显下降。对此,论文给出的理由:子模块初始化的参数与优化算法、数据集和模型有关,如中奖彩票的最初初始化的参数特别适合所选择的优化算法,因此优化效果很好。
不足之处
- 本论文只在较小的数据集上进行相关研究测试
- 非结构化修剪是论文作者找到唯一的能找到中奖彩票的方法,未能在结构化剪枝的方法中找到中奖彩票
- 未能明确解释初始化为何对中奖彩票如此重要
- 在更深的网络中,迭代修剪无法找到中奖彩票,除非用超参数learning rate warmup来调节
寻找中彩票的子模块算法
算法一:
- 随机初始化神经网络,保存初始化值,并创建掩码
- 训练迭代
- 裁剪参数,并更新掩码
- 重置裁剪后的神经网络权重,使权重值恢复到初始化值
- 重复2-4步骤,直到得到裁剪充分的网络
算法二:
- 随机初始化神经网络,保存初始化值,并创建掩码
- 训练迭代
- 裁剪参数,更新掩码
- 重复2-3步骤直到得到裁剪充分的网络
- 重置裁剪后的神经网络权重,使权重值恢复到初始化值
论文提出两种算法,差别在于第二种每一轮修剪后使用已训练的权重进行重新训练,而第一种在每次重新训练前重置权重,实验证明算法一在训练速度和测试集精度上都要优于第二种
部分代码
# 创建掩码用于模型裁剪
def make_mask(model):global stepglobal maskstep = 0for name, param in model.named_parameters(): if 'weight' in name:step = step + 1mask = [None] * stepstep = 0for name, param in model.named_parameters(): if 'weight' in name:tensor = param.data.cpu().numpy()# 把None元素替换为1mask[step] = np.ones_like(tensor)step = step + 1step = 0
# 百分位数裁剪
def prune_by_percentile(percent, resample=False, reinit=False, **kwargs):global stepglobal maskglobal model# Calculate percentile valuestep = 0for name, param in model.named_parameters():# We do not prune bias termif 'weight' in name:tensor = param.data.cpu().numpy()alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values# 求百分位值percentile_value = np.percentile(abs(alive), percent)# Convert Tensors to numpy and calculateweight_dev = param.device# 小于百分位数的参数设置为0new_mask = np.where(abs(tensor) < percentile_value, 0, mask[step])# Apply new weight and maskparam.data = torch.from_numpy(tensor * new_mask).to(weight_dev)mask[step] = new_maskstep += 1step = 0
完整代码链接
个人总结
本篇论文重点不在于如何进一步压缩模型大小,更多的是提高模型训练速度。在大模型中找到体积小、容易训练且不降低精度的子模型,这对模型的结构探索有很重要的意义。同时论文提出了一些问题有待解决,如为何原始初始化参数对子模型如此重要、在更深的网络中如何寻找到符合要求的子模型等等。
论文理解记录:The Lottery Ticket Hypothesis相关推荐
- 彩票假设 (Lottery Ticket Hypothesis) 在CV、NLP和OOD领域的应用
©PaperWeekly 原创 · 作者 | 张一帆 学校 | 中科院自动化所博士生 研究方向 | 计算机视觉 本文用三篇论文稍微普及和解读一下最近 Lottery Ticket Hypothesis ...
- The Lottery Ticket Hypothesis
The Lottery Ticket Hypothesis THE MOTIVATION Pruned network is difficult to train from the start. TH ...
- PacificA: Replication in Log-Based Distributed Storage Systems 论文理解
PacificA: Replication in Log-Based Distributed Storage Systems 论文理解 思考:论文有个结论说,相比 GFS 具有中心化的实体,Pacif ...
- Life Long Learning论文阅读记录之LwF
Life Long Learning论文阅读记录之LwF 写在前面 获取原文 问题 难点 目标 符号说明 现有方法 不使用旧数据集的方法 Learning without Forgetting(LwF ...
- 论文理解【RL - Exp Replay】—— 【ReMERN ReMERT】Regret Minimization Exp Replay in Off-Policy RL
标题:Regret Minimization Experience Replay in Off-Policy Reinforcement Learning 文章链接:Regret Minimizati ...
- ICCV2017 论文浏览记录(转)
mark一下,感谢作者分享! 作者将ICCV2017上的论文进行了汇总,在此记录下来,平时多注意阅读积累. 之前很早就想试着做一下试着把顶会的论文浏览一遍看一下自己感兴趣的,顺便统计一下国内高校或者研 ...
- ICCV2017 论文浏览记录
之前很早就想试着做一下试着把顶会的论文浏览一遍看一下自己感兴趣的,顺便统计一下国内高校或者研究机构的研究方向,下面是作为一个图像处理初学者在浏览完论文后的 觉得有趣的文章: ICCV2017 论文浏览 ...
- ResNet 论文理解含视频
ResNet 论文理解 问题导引论文理解 Q1.神经网络真的越深越好吗? Q2. 为什么加深网络会带来退化问题? Q3. 如何构建更深层的网络? 基于残差的深度学习框架 Residual Learni ...
- 鹅鹅鹅的论文投稿记录~
记录读研期间的论文.专利.会议论文投稿记录 2022.4 基于已有数据和修改好的word稿,在网上找了爱思唯尔的latex模板后排了几天版,然后读了期刊的投稿须知 2022.4.14 Submitte ...
最新文章
- Python基础学习3
- 多年密谋「闹独立」,谷歌为何拴不住DeepMind的心?
- DSP调试报错:OMAPL138 Connect to PRSC failed
- 传智C++课程笔记-1
- files函数提取文件名HTML,Javascript – 如何从文件input控件提取文件名
- can4--测试can
- Python中hasattr() getattr() setattr() 函数的使用
- 美团数据库高可用架构的演进与设想
- 首师大2计算机考研分数线,2021考研分数线:首都师范大学2021年考研复试分数线...
- 学习笔记(06):MySQL数据库运维与管理-01-用户创建及授权
- 开课吧Java课程之详解文件输出流FileInputStream
- 谷歌浏览器修改CSS和js后同步保存到文件中 (译)
- 字符串、数组处理方法总结
- c语言39关键字及其含义,C语言关键字含义
- 计算机操作系统(第四版)学习笔记
- 微信步数修改.html,httpCatcher,charles修改微信步数,支付宝森林能量满满
- jersey文件服务器,通过jersey实现客户端图片上传
- FPGA工程师面试试题集锦11~20
- 巾帼亮相申城,群英共筑梦想
- potentially fixable with the `--fix` option.