文章目录

  • 前言
  • 1. 交替方向乘子法
  • 2. 论文中的表述
  • 3. 对论文中的公式进行推导
  • 4. 代码流程
  • 5. 主要函数实现
  • 6. dense vs. prune(finetune)
  • 结束语

前言

  本篇博客记录一下自己根据对论文 GRIM: A General, Real-Time Deep Learning Inference Framework for Mobile Devices based on Fine-Grained Structured Weight Sparsity 中提到的ADMM算法的理解,给出了ADMM算法的推导过程,并在文章的末尾提供了实现的代码。

1. 交替方向乘子法

  交替方向乘子法(Alternating Direction Method of Multipliers, ADMM)作为一种求解优化问题的计算框架,适用于求解凸优化问题。ADMM算法的思想根源可以追溯到20世纪50年代,在20世纪八九十年代中期存在大量的文章分析这种方法的性质,但是当时ADMM主要用于解决偏微分方程问题。1970年由 R. GlowinskiD. Gabay 等提出的一种适用于可分离凸优化的简单有效方法,并在统计机器学习、数据挖掘和计算机视觉等领域中得到了广泛应用。ADMM算法主要解决带有等式约束的关于两个变量的目标函数的最小化问题,可以看作在增广拉朗格朗日算法基础上发展的算法,混合了对偶上升算法(Dual Ascent)的可分解性和乘子法(Method of Multipliers)的算法优越的收敛性。相对于乘子法,ADMM算法最大的优势在于其能够充分利用目标函数的可分解性,对目标函数中的多变量进行交替优化。在解决大规模问题上,利用ADMM算法可以将原问题的目标函数等价地分解成若干个可求解的子问题,然后并行求解每一个子问题,最后协调子问题的解得到原问题的全局解。1

  优化问题
minimizef(x)+g(z)subjecttoAx+Bz=cminimize\ f(x)+g(z) \\ subject\ to\ Ax+Bz=cminimize f(x)+g(z)subject to Ax+Bz=c  其中,x∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rpx \in R^n,z \in R^m,A \in R^{p \times n},B \in R^{p \times m},c \in R^px∈Rn,z∈Rm,A∈Rp×n,B∈Rp×m,c∈Rp,构造拉格朗日函数为
Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bz−c)L_p(x,z,\lambda )=f(x)+g(z)+\lambda ^{T}(Ax+Bz-c)Lp​(x,z,λ)=f(x)+g(z)+λT(Ax+Bz−c)  其增广拉格朗日函数(augmented Lagrangian function)
Lp(x,z,λ)=f(x)+g(z)+λT(Ax+Bz−c)+ρ2∣∣Ax+Bz−c∣∣2L_p(x,z,\lambda )=f(x)+g(z)+\lambda ^{T}(Ax+Bz-c)+ \frac {\rho} {2}||Ax+Bz-c||^{2}Lp​(x,z,λ)=f(x)+g(z)+λT(Ax+Bz−c)+2ρ​∣∣Ax+Bz−c∣∣2  对偶上升法迭代更新
(xk+1,zk+1)=argminx,zLp(x,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1−c)(x^{k+1},z^{k+1})=\underset {x,z} {argmin\ } L_p(x,z,\lambda ^k) \\ \lambda ^{k+1}=\lambda ^k+\rho (Ax^{k+1}+Bz^{k+1}-c)(xk+1,zk+1)=x,zargmin ​Lp​(x,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1−c)  交替方向乘子法则是在(x,z)(x,z)(x,z)一起迭代的基础上将x,zx,zx,z分别固定单独交替迭代,即
xk+1=argminxLp(x,zk,λk)zk+1=argminzLp(xk+1,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1−c)x^{k+1}=\underset {x} {argmin\ }L_p(x,z^k,\lambda ^k) \\ z^{k+1}=\underset {z} {argmin\ }L_p(x^{k+1},z,\lambda ^k) \\ \lambda ^{k+1}=\lambda ^k+\rho (Ax^{k+1}+Bz^{k+1}-c)xk+1=xargmin ​Lp​(x,zk,λk)zk+1=zargmin ​Lp​(xk+1,z,λk)λk+1=λk+ρ(Axk+1+Bzk+1−c)  交替方向乘子的另一种等价形式,将残差定义为rk=Axk+Bzk−cr^k=Ax^k+Bz^k-crk=Axk+Bzk−c,同时定义uk=1ρλku^k=\frac {1} {\rho} \lambda ^kuk=ρ1​λk作为缩放的对偶变量(dual variable),有
(λk)Trk+ρ2∣∣rk∣∣2=ρ2∣∣rk+uk∣∣2−ρ2∣∣uk∣∣2(\lambda ^k)^Tr^k+\frac {\rho} {2} ||r^k||^2=\frac {\rho} {2}||r^k+u^k||^2-\frac {\rho} {2}||u^k||^2(λk)Trk+2ρ​∣∣rk∣∣2=2ρ​∣∣rk+uk∣∣2−2ρ​∣∣uk∣∣2  改写 ADMM 的迭代过程
xk+1=argminx{f(x)+ρ2∣∣Ax+Bzk−c+uk∣∣2}zk+1=argminz{g(z)+ρ2∣∣Axk+1+Bz−c+uk∣∣2}uk+1=uk+Axk+1+Bzk+1−cx^{k+1} =\underset {x} {argmin\ }\bigg\{f(x)+\frac {\rho} {2}||Ax+Bz^k-c+u^k||^2\bigg\} \\[5pt] z^{k+1}=\underset {z} {argmin\ }\bigg\{g(z)+\frac {\rho} {2}||Ax^{k+1}+Bz-c+u^k||^2\bigg\} \\[5pt] u^{k+1}=u^k+Ax^{k+1}+Bz^{k+1}-c xk+1=xargmin ​{f(x)+2ρ​∣∣Ax+Bzk−c+uk∣∣2}zk+1=zargmin ​{g(z)+2ρ​∣∣Axk+1+Bz−c+uk∣∣2}uk+1=uk+Axk+1+Bzk+1−c

2. 论文中的表述


3. 对论文中的公式进行推导

  为便于推导公式,将论文中的进行简化,参数W和b简记为W,此时的优化问题变为
minimizef(Wi)+∑i=1Ng(Zi)subjecttoWi=Zi,i=1,2,...,Nminimize\ f(W_i)+\sum_{i=1}^{N} g(Z_i) \\[4pt] subject\ to\ W_i=Z_i, i=1,2,...,Nminimize f(Wi​)+i=1∑N​g(Zi​)subject to Wi​=Zi​,i=1,2,...,N  构造拉格朗日函数为
Lp(w,z,λ)=f(w)+∑g(z)+λT(w−z)L_p(w,z,\lambda )=f(w)+\sum g(z)+\lambda ^{T}(w-z)Lp​(w,z,λ)=f(w)+∑g(z)+λT(w−z)  其增广拉格朗日函数为
Lp(w,z,λ)=f(w)+∑g(z)+λT(w−z)+∑ρ2∣∣w−z∣∣2L_p(w,z,\lambda )=f(w)+\sum g(z)+\lambda ^{T}(w-z)+ \sum \frac {\rho} {2}||w-z||^{2}Lp​(w,z,λ)=f(w)+∑g(z)+λT(w−z)+∑2ρ​∣∣w−z∣∣2  交替方向乘子法:在(x, z)一起迭代的基础上将 x, z 分别固定,单独交替迭代,即
wk+1=argminwLp(w,zk,λk)zk+1=argminzLp(wk+1,z,λk)λk+1=λk+∑ρ(w−z)w^{k+1}=\underset {w} {argmin\ }L_p(w,z^k,\lambda ^k) \\[4pt] z^{k+1}=\underset {z} {argmin\ }L_p(w^{k+1},z,\lambda ^k) \\[4pt] \lambda ^{k+1}=\lambda ^k+\sum \rho (w-z)wk+1=wargmin ​Lp​(w,zk,λk)zk+1=zargmin ​Lp​(wk+1,z,λk)λk+1=λk+∑ρ(w−z)  定义一个对偶变量
uk=1ρλku^k=\frac {1} {\rho} \lambda ^kuk=ρ1​λk  改写ADMM的迭代过程
wk+1=argminw{f(w)+∑ρ2∣∣w−zk+uk∣∣2}zk+1=argminz{∑g(z)+∑ρ2∣∣wk+1−z+uk∣∣2}uk+1=uk+wk+1−zk+1w^{k+1} =\underset {w} {argmin\ }\bigg\{f(w)+\sum \frac {\rho} {2}||w-z^k+u^k||^2\bigg\} \\[5pt] z^{k+1}=\underset {z} {argmin\ }\bigg\{\sum g(z)+\sum \frac {\rho} {2}||w^{k+1}-z+u^k||^2\bigg\} \\[5pt] u^{k+1}=u^k+w^{k+1}-z^{k+1}wk+1=wargmin ​{f(w)+∑2ρ​∣∣w−zk+uk∣∣2}zk+1=zargmin ​{∑g(z)+∑2ρ​∣∣wk+1−z+uk∣∣2}uk+1=uk+wk+1−zk+1

4. 代码流程

# 初始化参数Z和U
Z, U = initialize_Z_and_U(model)# 训练model,并更新X,Z,U,损失函数为admm loss
for epoch in range(epochs):for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = admm_loss(model, Z, U, output, target)loss.backward()optimizer.step()W = update_W(model)Z = update_Z(W, U, percent)U = update_U(U, W, Z)# 对weight进行剪枝,返回 mask
mask = apply_prune(model, percent)# 对剪枝后的model进行finetune
finetune(model, mask, train_loader, test_loader, optimizer)

5. 主要函数实现

def admm_loss(args, device, model, Z, U, output, target):idx = 0loss = F.nll_loss(output, target)for name, param in model.named_parameters():if name.split('.')[-1] == "weight":u = U[idx].to(device)z = Z[idx].to(device)# 这里就是推导出来的admm的表达式loss += args.rho / 2 * (param - z + u).norm()return lossdef update_W(model):W = ()for name, param in model.named_parameters():if name.split('.')[-1] == "weight":W += (param.detach().cpu().clone(),)return Wdef update_Z(W, U, args):new_Z = ()idx = 0for w, u in zip(W, U):z = w + upcen = np.percentile(abs(z), 100*args.percent[idx])under_threshold = abs(z) < pcen# percent剪枝率,小于percent分位数的置为0z.data[under_threshold] = 0new_Z += (z,)idx += 1return new_Zdef update_U(U, W, Z):new_U = ()for u, w, z in zip(U, W, Z):new_u = u + w - znew_U += (new_u,)return new_Udef prune_weight(weight, device, percent):# to work with admm, we calculate percentile based on all elements instead of nonzero elements.weight_numpy = weight.detach().cpu().numpy()pcen = np.percentile(abs(weight_numpy), 100*percent)under_threshold = abs(weight_numpy) < pcen# 非结构化剪枝weight_numpy[under_threshold] = 0mask = torch.Tensor(abs(weight_numpy) >= pcen).to(device)return mask

6. dense vs. prune(finetune)

结束语

  对论文中算法的推导仅限于自己的理解,可能还存在一些问题,欢迎来评论区交流哦^_^

参考教程


  1. 《分布式机器学习:交替方向乘子法在机器学习中的应用》---- 雷大江著 ↩︎

ADMM算法在神经网络模型剪枝方面的应用相关推荐

  1. 微软NNI进行神经网络模型剪枝压缩的踩坑记录

    微软NNI进行神经网络模型剪枝压缩的踩坑记录 NNI进行模型剪枝分类 NNI剪枝的流程 NNI现有剪枝方法 剩下的剪枝操作比较复杂还没有研究透,不过应该大同小异,有机会继续研究更新. 最近做毕设嵌入式 ...

  2. 深度神经网络模型剪枝

    深度神经网络模型剪枝 神经网络剪枝 Neural Network Pruning 下面是我对剪枝的一点点理解,如有理解不到位,请大家指正 ▶剪枝只是将模型中权重比较小,对输出影响不大的神经元参数置0, ...

  3. DL:深度学习算法(神经网络模型集合)概览之《THE NEURAL NETWORK ZOO》的中文解释和感悟(六)

    DL:深度学习算法(神经网络模型集合)概览之<THE NEURAL NETWORK ZOO>的中文解释和感悟(六) 目录 DRN DNC NTM CN KN AN 相关文章 DL:深度学习 ...

  4. DL:深度学习算法(神经网络模型集合)概览之《THE NEURAL NETWORK ZOO》的中文解释和感悟(四)

    DL:深度学习算法(神经网络模型集合)概览之<THE NEURAL NETWORK ZOO>的中文解释和感悟(四) 目录 CNN DN DCIGN 相关文章 DL:深度学习算法(神经网络模 ...

  5. DL:深度学习算法(神经网络模型集合)概览之《THE NEURAL NETWORK ZOO》的中文解释和感悟(二)

    DL:深度学习算法(神经网络模型集合)概览之<THE NEURAL NETWORK ZOO>的中文解释和感悟(二) 目录 AE VAE DAE SAE 相关文章 DL:深度学习算法(神经网 ...

  6. DL:深度学习算法(神经网络模型集合)概览之《THE NEURAL NETWORK ZOO》的中文解释和感悟(一)

    DL:深度学习算法(神经网络模型集合)概览之<THE NEURAL NETWORK ZOO>的中文解释和感悟(一) 目录 THE NEURAL NETWORK ZOO perceptron ...

  7. 建模算法(六)——神经网络模型

    (一)神经网络简介 主要是利用计算机的计算能力,对大量的样本进行拟合,最终得到一个我们想要的结果,结果通过0-1编码,这样就OK啦 (二)人工神经网络模型 一.基本单元的三个基本要素 1.一组连接(输 ...

  8. BP神经网络模型与学习算法

    转载自:http://www.cnblogs.com/wentingtu/archive/2012/06/05/2536425.html 一,什么是BP "BP(Back Propagati ...

  9. 数学建模神经网络模型,数学建模神经网络算法

    神经网络能对数据进行预测吗 数学建模 . 神经网络本身就是数学的逼近模型,网络最早是由数学中的函数逼近技术而来,按照统计学规律,组合成线性叠加网络,从中分析出一些现实中高度非线性的模型,神经网络本身就 ...

  10. trainlm算法c语言,粒子群优化的BP神经网络模型对C、Mn两种元素收得率的预测

    粒子群优化的BP神经网络模型对C.Mn两种元素收得率的预测 来源:用户上传 作者: 摘 要:本文首先对数据进行处理,利用收得率公式求出历史收得率:并利用已知的影响元素收得率的主要影响因素结合BP神经网 ...

最新文章

  1. Java语法基础-1
  2. POJ 3855 计算几何·多边形重心
  3. kafka消费者接收分区测试
  4. ux设计师薪水_我是如何从33岁的博物馆导游变成专业的Web开发人员和UX设计师的:我的…...
  5. 在.NET Core 3.0 Preview上使用Windows窗体设计器
  6. xcode 配置wechat_友盟微信、QQ等分享提示未验证应用配置
  7. C/C++[codeup 1929,]今天星期几
  8. dw网页制作入学教程_网站制作DW教程:Dreamweaver CC零基础入门视频课程
  9. 小程序增加 文章 / 新闻 / 资讯 / 动态 功能,支持用户投稿
  10. 删除xp计算机用户账户,XP系统怎么删除多余的用户帐号?XP系统删除多余用户帐号的方法...
  11. ThinkPHP开发手册
  12. 首届中国移动互联网直播行业峰会在京召开
  13. 数据可视化 信息可视化_可视化数据操作数据可视化与纪录片的共同点
  14. 洛谷 P4735 最大异或和
  15. 吐血整理-周志华演讲合集
  16. [模仿]html5游戏_FlyppyBird
  17. 安装MAC系统必备工具
  18. 一文简要概述Seata AT与TCC的区别
  19. 游戏编程大师技巧—windows程序的基本构造
  20. 初探Turtlebot2

热门文章

  1. Windows 10 RTM 官方正式版
  2. 腾讯IM : 如何替换String 表情库
  3. RE-Base64编码分析
  4. 36.windbg-!peb(手工分析PEB结构)
  5. 电视盒子_刷机固件_免费合集分享
  6. c语言知识点总结大全(史上最全)
  7. 文本在线查重系统的设计与实现(毕业设计)
  8. 联想thinkpadE14 vm蓝屏问题解决
  9. Hadoop大数据实战权威指南
  10. sqlbulkcopy是覆盖式更新吗_酒店无线覆盖解决方案,一文了解清楚