本文介绍了两种Proximal Methods的证明方法以及实现。内容主要来源于王然老师的《Proximal Methods》一文以及网络,加入了部分个人理解。由于水平有限,如有不妥之处,敬请指正。

为什么会有Proximal methods这个东东?

在机器学习的损失函数求解过程中,通过计算梯度然后迭代寻找最小值是一个常用的方法。而对于一些函数,是无法求导的,这时就无法用梯度下降等方法求解了。比如加了L1L1L1正则的损失函数。

argminβ1N∗∑i(yi−xi∗βt)+λ∗∥β∥1\mathop{\mathrm{argmin}} \limits_{\beta} \frac{1}{N}*\sum_i(y_i-x_i*\beta^t)+\lambda*\Vert \beta \Vert_1βargmin​N1​∗∑i​(yi​−xi​∗βt)+λ∗∥β∥1​

proximal methods主要就是解决这个问题的。

proximal methods证明前的铺垫

主要介绍sub-differential和proximal operator这两个概念,后面证明时会用到。

sub-differential 子梯度

先介绍一个概念,sub-differential 子梯度,也叫:subderivative, subgradient, and subdifferential,是对于不可导的凸函数的导数的一种推广。
比如,对于绝对值函数f(x)=∣x∣f(x)=\vert x \vertf(x)=∣x∣,当x=0x=0x=0时,函数是不可导的。
如下图,对于x0x_0x0​点不可导(类似绝对值函数),但是我们可以在点(x0,f(x0))(x_0,f(x_0))(x0​,f(x0​))上画一条线,这条线经过x0x_0x0​点,并且在曲线的下方,像这样的曲线的斜率就是sub-differential中的一个。

子梯度的严格定义:
对于凸函数f:I→Rf:I \to \mathbb{R}f:I→R在x0x_0x0​的子梯度是一个实数ccc,ccc满足以下条件:
f(x)−f(x0)≥c(x−x0)f(x)-f(x_0)\geq c(x-x_0)f(x)−f(x0​)≥c(x−x0​)
对于所有在III内的xxx,在x0x_0x0​的子梯度是一个非空的闭区间集合[a,b][a,b][a,b],其中:
a=lim⁡x→x0−f(x)−f(x0)x−x0a=\lim_{x \to x_0^-} \frac{f(x)-f(x_0)}{x-x_0}a=limx→x0−​​x−x0​f(x)−f(x0​)​

b=lim⁡x→x0+f(x)−f(x0)x−x0b=\lim_{x \to x_0^+} \frac{f(x)-f(x_0)}{x-x_0}b=limx→x0+​​x−x0​f(x)−f(x0​)​

sub-differential记为∂f\partial f∂f,有:

∂f={y∣f(x)−f(x0)≥yT(x−x0),forallx∈domf}\partial f = \{ y | f(x)-f(x_0)\geq y^T(x-x_0), for \ all \ x \in dom \ f\}∂f={y∣f(x)−f(x0​)≥yT(x−x0​),for all x∈dom f}

性质:
1、当一个凸函数在x0x_0x0​处的子梯度只有一个值,即a=ba=ba=b时,函数在这个点可导。
2、如果一个凸函数在x0x_0x0​处的子梯度集合为[a,b][a,b][a,b],当0∈[a,b]0 \in [a,b]0∈[a,b]时,函数在x0x_0x0​处取得最小值。
3、如果f,gf,gf,g两个函数都是凸函数,则:
∂(f+g)=∂f+∂g\partial(f+g)=\partial f + \partial g∂(f+g)=∂f+∂g

另外,维基百科上说,国内的部分机构认为的凸函数的定义与国外的正好相反,不过本文并不想纠结于这个问题。

详见:
https://en.wikipedia.org/wiki/Subderivative

Proximal Operator

还要介绍一个概念。Proximal操作算子:

provf(v)=argminx(f(x)+12∗∥x−v∥22)prov_f(v)=\mathop{\mathrm{argmin}} \limits_x (f(x) + \frac{1}{2}*\Vert x-v \Vert_2^2)provf​(v)=xargmin​(f(x)+21​∗∥x−v∥22​)

Proximal Operator有两个神奇的性质,一是不动点,二是proximal operator和sub-differential之间有一定的关系。

  • 性质一:不动点

当x∗x^*x∗是f(x)f(x)f(x)的最小值时,等价于:
x∗=provf(x∗)x^*=prov_f(x^*)x∗=provf​(x∗)
证明:
首先证明:x∗x^*x∗是f(x)f(x)f(x)的最小值时,x∗=provf(x∗)x^*=prov_f(x^*)x∗=provf​(x∗)
f(x)+12∗∥x−x∗∥22≥f(x∗)=f(x∗)+12∥x∗−x∗∥22\begin{aligned} f(x) + \frac{1}{2}* \Vert x-x^* \Vert^2_2 & \geq f(x^*) \\ &=f(x^*)+\frac{1}{2} \Vert x^*-x^*\Vert^2_2 \\ \end{aligned}f(x)+21​∗∥x−x∗∥22​​≥f(x∗)=f(x∗)+21​∥x∗−x∗∥22​​
即f(x)+12∗∥x−x∗∥22f(x) + \frac{1}{2}* \Vert x-x^* \Vert^2_2f(x)+21​∗∥x−x∗∥22​在x=x∗x=x^*x=x∗处取得最小值,即:x∗=argminx(f(x)+12∗∥x−x∗∥22)x^*=\mathop{\mathrm{argmin}} \limits_x (f(x) + \frac{1}{2}* \Vert x-x^* \Vert^2_2)x∗=xargmin​(f(x)+21​∗∥x−x∗∥22​),也就是x∗=provf(x∗)x^*=prov_f(x^*)x∗=provf​(x∗)啦。

再证明:当x∗=provf(x∗)x^*=prov_f(x^*)x∗=provf​(x∗)时,x∗x^*x∗是f(x)f(x)f(x)的最小值。

x∗=provf(x∗)x^*=prov_f(x^*)x∗=provf​(x∗),根据sub-differential的性质,有:
0∈∂(provf(x∗))0∈∂(f(x)+12∥x−x∗∥22)0∈∂f(x)+(x−x∗)令x=x∗,则有:0∈∂f(x∗)\begin{aligned} 0 &\in \partial (prov_f(x^*)) \\ 0 &\in \partial (f(x)+\frac{1}{2}\Vert x-x^*\Vert^2_2)\\ 0 &\in \partial f(x) +(x-x^*)\\ 令x&=x^*,则有:\\ 0 &\in \partial f(x^*) \end{aligned}000令x0​∈∂(provf​(x∗))∈∂(f(x)+21​∥x−x∗∥22​)∈∂f(x)+(x−x∗)=x∗,则有:∈∂f(x∗)​
即x∗x^*x∗是f(x)f(x)f(x)的最小值。

  • 性质二:proximal operator实际上是sub-differential的一种解析形式,有:
    provλf=(I+λ∂f)−1\begin{aligned} prov_{\lambda f}=(I+\lambda \partial f)^{-1} \end{aligned}provλf​=(I+λ∂f)−1​
    说明:provλfprov_{\lambda f}provλf​和(I+λ∂f)−1(I+\lambda \partial f)^{-1}(I+λ∂f)−1都是操作算子,provλf(v)=argminx(f(x)+12λ∥x−v∥22)prov_{\lambda f}(v)=\mathop{\mathrm{argmin}} \limits_x(f(x)+\frac{1}{2\lambda}\Vert x-v \Vert_2^2)provλf​(v)=xargmin​(f(x)+2λ1​∥x−v∥22​),(I+λ∂f)−1(I+\lambda \partial f)^{-1}(I+λ∂f)−1是(I+λ∂f)(I+\lambda \partial f)(I+λ∂f)的反函数。
    证明:
    如果:
    z∈(I+λ∂f)−1(x)(I+λ∂f)(z)∋xz+λ∂f(z)∋x0∈λ∂f(z)+(z−x)0∈∂(λf(z)+12∥z−x∥22)0∈∂(f(z)+12λ∥z−x∥22)即:z=argminu(f(u)+12λ∥u−x∥22)\begin{aligned} z &\in (I+\lambda \partial f)^{-1}(x)\\ (I+\lambda \partial f)(z) &\ni x\\ z+\lambda \partial f(z) &\ni x\\ 0 &\in \lambda \partial f(z)+(z-x)\\ 0 &\in \partial(\lambda f(z)+ \frac{1}{2}\Vert z-x\Vert_2^2)\\ 0 &\in \partial( f(z)+ \frac{1}{2\lambda}\Vert z-x\Vert_2^2)\\ 即:\\ z&=\mathop{\mathrm{argmin}} \limits_u (f(u)+\frac{1}{2\lambda} \Vert u-x\Vert^2_2)\\ \end{aligned}z(I+λ∂f)(z)z+λ∂f(z)000即:z​∈(I+λ∂f)−1(x)∋x∋x∈λ∂f(z)+(z−x)∈∂(λf(z)+21​∥z−x∥22​)∈∂(f(z)+2λ1​∥z−x∥22​)=uargmin​(f(u)+2λ1​∥u−x∥22​)​
    即:(f(u)+12λ∥u−x∥22)( f(u)+ \frac{1}{2\lambda}\Vert u-x\Vert_2^2)(f(u)+2λ1​∥u−x∥22​)在zzz处取得最小值,z=provλf(x)z=prov_{\lambda f}(x)z=provλf​(x),注意这里的xxx其实是前面的vvv。
    这里有点儿神奇,当z∈(I+λ∂f)−1(x)z \in (I+\lambda \partial f)^{-1}(x)z∈(I+λ∂f)−1(x)时,z=provλf(x)z=prov_{\lambda f}(x)z=provλf​(x)
    两个看起来没什么关系的东西竟然也能联系在一起。。。

Proximal Methods的求解证明

文章的开头,我们就提出了一个问题:对于两个函数f+gf+gf+g,当fff可导,但ggg不可导时,如何求解最小值呢?

我们先给出答案,再对其进行证明。
通过以下迭代,能够计算出f+gf+gf+g的最小值。

xk+1=provλkg(xk−λk∇f(xk))x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k))xk+1=provλkg​(xk−λk∇f(xk))

  • 证明方法一
    如果x∗x^*x∗是f+gf+gf+g的最小值,则有0∈∇f(x∗)+∂g(x∗)0 \in \nabla f(x^*)+ \partial g(x^*)0∈∇f(x∗)+∂g(x∗),
    0∈λ∇f(x∗)+λ∂g(x∗)0∈λ∇f(x∗)−x∗+x∗+λ∂g(x∗)0∈λ∇f(x∗)−x∗+(I+λ∂g)(x∗)(I+λ∂g)(x∗)∋x∗−λ∇f(x∗)x∗=(I+λ∂g)−1(x∗−λ∇f(x∗))x∗=provλg(x∗−λ∇f(x∗))\begin{aligned} 0& \in \lambda \nabla f(x^*)+ \lambda \partial g(x^*) \\ 0& \in \lambda \nabla f(x^*)- x^* + x^* + \lambda \partial g(x^*)\\ 0& \in \lambda \nabla f(x^*)-x^* + (I+ \lambda \partial g)(x^*)\\ (I+ \lambda \partial g)(x^*) &\ni x^*-\lambda \nabla f(x^*)\\ x^* &= (I+\lambda \partial g)^{-1}(x^*-\lambda \nabla f(x^*))\\ x^* &= prov_{\lambda g}(x^*-\lambda \nabla f(x^*)) \end{aligned}000(I+λ∂g)(x∗)x∗x∗​∈λ∇f(x∗)+λ∂g(x∗)∈λ∇f(x∗)−x∗+x∗+λ∂g(x∗)∈λ∇f(x∗)−x∗+(I+λ∂g)(x∗)∋x∗−λ∇f(x∗)=(I+λ∂g)−1(x∗−λ∇f(x∗))=provλg​(x∗−λ∇f(x∗))​
    这个证明过程也是很神奇的。。。

  • 证明方法二

xk+1=provλkg(xk−λk∇f(xk))xk+1=argminx(g(x)+12λk∥x−(xk−λk∇f(xk))∥22)xk+1=argminx(g(x)+λk2∥∇f(xk)∥22+∇f(xk)T(x−xk)+12λk∥x−xk∥22)\begin{aligned} x^{k+1}&=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k))\\ x^{k+1}&=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +\frac{1}{2\lambda^k}\Vert x-(x^k-\lambda^k \nabla f(x^k))\Vert^2_2)\\ x^{k+1}&=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +\frac{\lambda^k}{2}\Vert \nabla f(x^k) \Vert^2_2 + \nabla f(x^k)^T (x-x^k)+\frac{1}{2 \lambda^k}\Vert x-x^k\Vert^2_2)\\ \end{aligned}xk+1xk+1xk+1​=provλkg​(xk−λk∇f(xk))=xargmin​(g(x)+2λk1​∥x−(xk−λk∇f(xk))∥22​)=xargmin​(g(x)+2λk​∥∇f(xk)∥22​+∇f(xk)T(x−xk)+2λk1​∥x−xk∥22​)​
由于上式是对于xxx求最小值,而λk2∥∇f(xk)∥22\frac{\lambda^k}{2}\Vert \nabla f(x^k) \Vert^2_22λk​∥∇f(xk)∥22​是一个与xxx无关的常量,则可将其替换为f(xk)f(x^k)f(xk),则上式等价于:
xk+1=argminx(g(x)+f(xk)+∇f(xk)T(x−xk)+12λk∥x−xk∥22)\begin{aligned} x^{k+1}=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +f(x^k) + \nabla f(x^k)^T (x-x^k)+\frac{1}{2 \lambda^k}\Vert x-x^k\Vert^2_2) \end{aligned}xk+1=xargmin​(g(x)+f(xk)+∇f(xk)T(x−xk)+2λk1​∥x−xk∥22​)​
根据泰勒级数展开:
f(x)=f(xk)+∇f(xk)T(x−xk)+12λk∥x−xk∥22\begin{aligned} f(x)=f(x^k) + \nabla f(x^k)^T (x-x^k)+\frac{1}{2 \lambda^k}\Vert x-x^k\Vert^2_2 \end{aligned}f(x)=f(xk)+∇f(xk)T(x−xk)+2λk1​∥x−xk∥22​​
则有:
xk+1=argminx(g(x)+f(x))\begin{aligned} x^{k+1}=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +f(x)) \end{aligned}xk+1=xargmin​(g(x)+f(x))​
说句实在话,对于上面这种方式,个人表示还能凑合着理解,第一种证明的思路实在是难以想象。

根据前文不动点的性质,x∗=provf(x∗)x^*=prov_f(x^*)x∗=provf​(x∗),类似xk+1=provλkg(xk−λk∇f(xk))x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k))xk+1=provλkg​(xk−λk∇f(xk))这种形式迭代方式也称为不动点迭代,

对于Proximal Method的理解

这是我在网上找到的比较能够理解的说法:
对于函数f+gf+gf+g,给定起点xkx^{k}xk,首先可微函数f(x)f(x)f(x)沿着起点的负梯度方向,作步长为λk\lambda^kλk的梯度下降得到一个预更新值xk−λk∇f(x)x^k-\lambda^k \nabla f(x)xk−λk∇f(x),然后使用近端映射寻找一个xxx ,这个xxx 能使得不可微函数g(x)g(x)g(x)足够小,且接近这个预更新值xk−λk∇f(x)x^k-\lambda^k \nabla f(x)xk−λk∇f(x),就用这个xxx作为本次迭代的更新值xk+1x^{k+1}xk+1 。

还有一个问题

xk+1=provλkg(xk−λk∇f(xk))x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k))xk+1=provλkg​(xk−λk∇f(xk)),这个迭代算法为什么会成立?
除了不动点迭代外,还有一种解释这里只简单提一下,我也没深入研究(其实是水平不够,看文章太累了。。。),只是看了个皮毛。
当∇f\nabla f∇f是 Lipschitz continuous的,并且Lipshitz constant是LLL的情况下,当λk∈(0,1/L]\lambda^k \in (0,1/L]λk∈(0,1/L]时,这是一个majorization-minimization method,具体可以查一下这个算法相关的资料。当λk>1/L\lambda^k > 1/Lλk>1/L时,是另外一个问题。

关于不动点迭代的问题,继续解释可以了解:Forward-backward integration of gradient flow。

Proximal Methods的应用

设fβ(X)f_\beta(X)fβ​(X)是负对数似然函数,其中β\betaβ是需要求解的参数,XXX是样本数据,我们希望得到下面式子的最小值:
fβ(X)+λ∥β∥1,其中λ>0\begin{aligned} f_\beta(X)+\lambda \Vert \beta \Vert_1,其中 \lambda >0 \end{aligned}fβ​(X)+λ∥β∥1​,其中λ>0​
怎么求解β\betaβ呢?
我们直接用xk+1=provλkg(xk−λk∇f(xk))x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k))xk+1=provλkg​(xk−λk∇f(xk))这个迭代来搞定。
为了计算方便,我们令ω=βk−λk∇fβk(xk)\omega=\beta^k-\lambda^k \nabla f_{\beta^k}(x^k)ω=βk−λk∇fβk​(xk),其中λk\lambda^kλk中在第kkk步迭代的步长,βk\beta^kβk是在第kkk步迭代的β\betaβ。
则有:
βk+1=provλg(ω)=argminβk(λkλg(βk)+12∥βk−ω∥22)=argminβk(λ∥βk∥1+12λk∥βk−ω∥22)\begin{aligned} \beta^{k+1}&=prov_{\lambda g}(\omega)\\ =&\mathop{\mathrm{argmin} }\limits_{\beta_k}(\lambda^k \lambda g(\beta^k) + \frac{1}{2} \Vert \beta^k - \omega \Vert^2_2)\\ =&\mathop{\mathrm{argmin} }\limits_{\beta_k}(\lambda \Vert \beta^k\Vert_1 + \frac{1}{2\lambda^k} \Vert \beta^k - \omega \Vert^2_2)\\ \end{aligned}βk+1==​=provλg​(ω)βk​argmin​(λkλg(βk)+21​∥βk−ω∥22​)βk​argmin​(λ∥βk∥1​+2λk1​∥βk−ω∥22​)​
而:
∥βk∥1=∑i∣βi∣,∥βk−ω∥22=∑i(βi−ωi)2\Vert \beta^k \Vert_1=\sum_i \vert \beta_i\vert, \Vert \beta^k - \omega \Vert^2_2=\sum_i (\beta_i- \omega_i)^2∥βk∥1​=∑i​∣βi​∣,∥βk−ω∥22​=∑i​(βi​−ωi​)2
要计算λ∥βk∥1+12λk∥βk−ω∥22)\lambda \Vert \beta^k\Vert_1 + \frac{1}{2\lambda^k} \Vert \beta^k - \omega \Vert^2_2)λ∥βk∥1​+2λk1​∥βk−ω∥22​)的最小值,我们只要找到每个λ∣βi∣+12λk(βi−ωi)2\lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2λ∣βi​∣+2λk​1​(βi​−ωi​)2的最小值,然后求和就是总体的最小值了。

对于λ∣βi∣+12λk(βi−ωi)2\lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2λ∣βi​∣+2λk​1​(βi​−ωi​)2的最小值,因为有绝对值,需要分类讨论:

  • 当βi≥0\beta_i \geq0βi​≥0时

λ∣βi∣+12λk(βi−ωi)2=12λk(βi2+2(λkλ−ωi)βi+ω2)\begin{aligned} &\lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2\\ =&\frac{1}{2\lambda_k}(\beta_i^2+2(\lambda_k\lambda -\omega_i)\beta_i+\omega^2) \end{aligned}=​λ∣βi​∣+2λk​1​(βi​−ωi​)22λk​1​(βi2​+2(λk​λ−ωi​)βi​+ω2)​
此时,当βi=ωi−λkλ\beta_i=\omega_i-\lambda_k\lambdaβi​=ωi​−λk​λ时,取得最小值,由于βi≥0\beta_i \geq0βi​≥0,要求:ωi−λkλ≥0\omega_i-\lambda_k\lambda\geq0ωi​−λk​λ≥0。
但如果:ωi−λkλ<0\omega_i-\lambda_k\lambda<0ωi​−λk​λ<0,βi\beta_iβi​无法取到ωi−λkλ\omega_i-\lambda_k\lambdaωi​−λk​λ,当βi=0\beta_i=0βi​=0时,取到最小值。

  • 当βi<0\beta_i<0βi​<0时
    λ∣βi∣+12λk(βi−ωi)2=12λk(βi2−2(λkλ+ωi)βi+ω2)\begin{aligned} &\lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2\\ =&\frac{1}{2\lambda_k}(\beta_i^2-2(\lambda_k\lambda +\omega_i)\beta_i+\omega^2) \end{aligned}=​λ∣βi​∣+2λk​1​(βi​−ωi​)22λk​1​(βi2​−2(λk​λ+ωi​)βi​+ω2)​
    此时,当βi=ωi+λkλ\beta_i=\omega_i+\lambda_k\lambdaβi​=ωi​+λk​λ时,取得最小值,由于βi<0\beta_i <0βi​<0,要求:ωi+λkλ<0\omega_i+\lambda_k\lambda<0ωi​+λk​λ<0。
    但如果:ωi+λkλ>0\omega_i+\lambda_k\lambda>0ωi​+λk​λ>0,βi\beta_iβi​无法取到ωi+λkλ\omega_i+\lambda_k\lambdaωi​+λk​λ,当βi=0\beta_i=0βi​=0时,取到最小值。
  • 综上:
    βi={ωi−λkλ,ωi>λkλ0,−λλk<ω<λλkωi+λkλ,ωi<−λkλ\begin{aligned} \beta_i=\begin{cases} \omega_i-\lambda_k\lambda ,      \omega_i>\lambda_k\lambda\\ 0,     -\lambda \lambda_k<\omega<\lambda\lambda_k\\ \omega_i+\lambda_k\lambda,     \omega_i<-\lambda_k\lambda \end{cases} \end{aligned}βi​=⎩⎪⎨⎪⎧​ωi​−λk​λ,     ωi​>λk​λ0,     −λλk​<ω<λλk​ωi​+λk​λ,     ωi​<−λk​λ​​

正是由于−λλk<ω<λλk-\lambda \lambda_k<\omega<\lambda\lambda_k−λλk​<ω<λλk​时,βi\beta_iβi​会出现截断,取值为0时才能取得最小值,才使得损失函数+L1L1L1正则化时,得到稀疏解。

Proxiaml Methods的实现

这里我就不贴自己写的代码了,直接贴一下王然老师的代码:

  1. 构造一个sigmoid函数:
def sigmoid(x):return 0.5 * (jnp.tanh(x / 2) + 1)
  1. 构建逻辑回归模型:
def predict(beta, x):return sigmoid(x.dot(beta))
  1. 构造数据
key = random.PRNGKey(0)
x_key, beta_key, beta_test_key = random.split(key,3)
x = random.normal(x_key, (10000, 10))
beta = random.normal(beta_key, (10,))*2.0    #beta是一个列向量
beta_test = random.normal(beta_test_key, (10,))
y = (sigmoid(x.dot(beta))>=0.5).astype(jnp.float32)
  1. 建立逻辑回归的对数似然函数
def loss(beta):preds = predict(beta,x)#下面用了一个trick,进行了计算简化,如果不简化的话,应该是:y*jnp.log(preds) + (1 - y)jnp.log(1 - preds) ,而由于y只能为0或1,所以可以通过简化用以下的步骤实现:label_probs = preds * y + (1 - preds) * (1 - y) return -jnp.sum(jnp.log(label_probs))/10000.00
  1. 对损失函数求梯度,有两种方式,两个的结果是一样的:
    一是数学推导,如下:
def custom_grad(beta):residual = y - predict(beta, x)return jnp.transpose(x).dot(-residual)/10000.00

二是通过jax.grad进行计算:

grad_func = jax.grad(loss)
  1. 构造软阈值函数,就是Proximal Method最后那个βi\beta_iβi​。这里是通过jax.lax.cond来实现的,具体的介绍可以看一下官方文档,这个比较简单。
    前面写了那么那么多,在代码实现的时候,只有最后的结论能用的上。。。
def soft_threshold(x, thres):return jax.lax.cond(x > thres,lambda _: x - thres,lambda _: jax.lax.cond(x < -thres,lambda _: x + thres,lambda _:0.0,None),None)
  1. Proximal methods算法的迭代过程,具体我不多介绍了,应该算是一个比较标准的迭代过程。
    特别要说明一下,其实写这些代码的关键在于如何检测每步计算都是正确的,特别是在有向量,矩阵,求导,迭代的过程中,如何验证正确性是很麻烦的,检测的过程是保证结果正确的关键。

另外这里面计算每个βi\beta_iβi​时,用的是jax.vmap实现的并行计算。
对于jax.vmap,可以参考:https://jiayiwu.me/blog/2021/04/05/learning-about-jax-axes-in-vmap.html

def proximal_methods(beta_init, max_iter, eps, lr, penalty):converged = Falsebeta_old = beta_initbeta_new = beta_initsoft_threshold_partial = lambda x: soft_threshold(x, lr*penalty)current_iter = 0while not converged and current_iter < max_iter:print("Current iteration is %d"% current_iter)beta_copy = beta_old current_loss = loss(beta_copy) + penalty*jnp.linalg.norm(beta_copy, 1)current_grad = custom_grad(beta_copy)w = beta_copy - lr*current_gradbeta_new = jax.vmap(soft_threshold_partial, 0)(w)new_loss = loss(beta_new) + penalty*jnp.linalg.norm(beta_new, 1)diff = jnp.abs(new_loss-current_loss)print("The difference is %.5f"%diff, "   current_loss%.5f"%current_loss, "   new_loss%.5f"%new_loss,)beta_old = beta_newif diff <= eps:   converged = Trueprint("Algorithm converged")breakelse:current_iter +=1if current_iter >= max_iter:print("The algorithm have failed to converge.")breakreturn beta_new, converged

参考资料

  1. Proximal Mehtods,Ran Wang
  2. Proximal Algorithm,Neal Parikh,Department of Computer Science Stanford University
  3. 机器学习 | 近端梯度下降法 (proximal gradient descent)
  4. 对近端梯度算法(Proximal Gradient Method)的理解
  5. 【凸优化笔记4】-近端梯度下降(Proximal gradient descent)
  6. Majorization-Minimization优化框架
  7. 浅谈MM优化算法以及CCP算法

关于Proximal Methods,近端梯度下降的理解相关推荐

  1. 【优化】近端梯度下降(Proximal Gradient Descent)求解Lasso线性回归问题

    文章目录 近端梯度下降的背景 常见线性回归问题 近端算子(Proximal Operator) 近端梯度下降迭代递推方法 以Lasso线性回归问题为例 参考资料 近端梯度下降的背景 近端梯度下降(Pr ...

  2. 近端梯度下降与软阈值迭代:PGD and ISTA

    近端梯度下降与软阈值迭代:PGD and ISTA 简介 对于包含不可微部分的优化目标如lasso回归的求解,梯度下降等算法已不再适用.近端梯度下降(Proximal Gradient Descent ...

  3. 怎样理解近端梯度下降PGD?

    周志华<机器学习>第11章L1正则化部分提到的 知乎有人解答的挺好的~ 就是 4) 对泰勒展开式简化 公式第三行的第二个减号应该是+ https://www.zhihu.com/quest ...

  4. 梯度下降图示理解(以二元函数为例)

  5. APG(Accelerate Proximal Gradient)加速近端梯度算法 和 NAG(Nesterov accelerated gradient)优化器原理 (一)

    文章目录 前言 APG(Accelerate Proximal Gradient)加速近端梯度算法[^1] PGD (Proximal Gradient Descent)近端梯度下降法推导[^2] E ...

  6. LASSO近端梯度下降法Proximal Gradient Descent公式推导及代码

    文章目录 LASSO by Proximal Gradient Descent Proximal Gradient Descent Framework近端梯度下降算法框架 Proximal Gradi ...

  7. 稀疏编码: LASSO,近端梯度优化(PGD),迭代软阈值算法(ISTA),L-Lipschitz条件,软阈值

    在用到稀疏编码时,难免会提到以下概念: (1)LASSO(Least Absolute Shrinkage and Selection Operatior): 最小绝对收缩选择算子.这个名词听名字比较 ...

  8. 随机梯度下降(Stochastic gradient descent)

    总目录 一. 凸优化基础(Convex Optimization basics) 凸优化基础(Convex Optimization basics) 二. 一阶梯度方法(First-order met ...

  9. 【转载】深度学习数学基础(二)~随机梯度下降(Stochastic Gradient Descent, SGD)

    Source: 作者:Evan 链接:https://www.zhihu.com/question/264189719/answer/291167114 来源:知乎 著作权归作者所有.商业转载请联系作 ...

  10. 13_线性回归分析、线性模型、损失函数、最小二乘法之梯度下降、回归性能评估、sklearn回归评估API、线性回归正规方程,梯度下降API、梯度下降 和 正规方程对比

    1.线性回归 1.1 线性模型 试图学得一个通过属性的线性组合来进行预测的函数: 1.2 定义 定义:线性回归通过一个或者多个自变量与因变量之间进行建模的回归分析.其中特点为一个或多个称为回归系数的模 ...

最新文章

  1. Bzoj3998: [TJOI2015]弦论
  2. 关于${ctx}拿不到值的问题
  3. 递归神经网络部分组件(七)
  4. android Glide简单使用
  5. Python_字符串
  6. linux如何打开22端口?如何开启ssh远程链接
  7. 函数(详解)——C语言
  8. 订单外卖小程序前台后台项目设计
  9. 【论文】文本相似度计算方法综述
  10. android仿照网易新闻源代码
  11. PXC 5.7 WSREP_SST: [ERROR] xtrabackup_checkpoints missing
  12. Matlab读取pfm文件
  13. iOS CI自动化构建:FastLane+企业重签+上传fir+jenkins
  14. linux7.3启动iscsi服务,RHEL7 配置iscsi服务端并实现客户端自动开机挂载
  15. 支持鸿蒙系统的手机名单,鸿蒙系统支持哪些手机
  16. Ubuntu 下安装官网最新版Mongodb (apt命令安装)
  17. 【IoT】WiFi 模组主流供应商
  18. C语言:输入1到5的阶乘
  19. 装修验收自备“精明眼”居住半年就后悔
  20. java-数字金额大写转换

热门文章

  1. CCIE实验笔记之---第1章WAN协议(HDLC)
  2. java计算机毕业设计共享充电宝管理系统源码+mysql数据库+系统+lw文档+部署
  3. Rasa 聊天机器人Rasa_NLU_Chi
  4. 深度学习在训练时对图片随机剪裁(random crop)
  5. minaRActivator三网完美解信号,支持IOS15.6
  6. Docker资源控制的Cgroup--CPU权重等--Block IO、bps和iops 的限制详细操作
  7. java八皇后答案_java八皇后问题详解
  8. JEECG 3.7.2 专业接口开发版本发布,企业级JAVA快速开发平台
  9. MATLAB写入文件的操作
  10. Github开始强制使用PAT(Personal Access Token)了