摘要

分类变量是代表世界离散结构的自然选择。然而,随机神经网络很少使用分类潜在变量,因为无法通过样本来反向传播梯度。在这项工作中,我们提出了一种有效的梯度估计器,通过使用一种基于Gumbel-Softmax分布的可微分采样,该估计器替代了分类分布中的不可微分采样。该分布具有必要的属性,可以将其顺利退火到分类分布中。我们表明,我们的Gumbel-Softmax估计器在具有分类潜在变量的结构化输出预测和无监督的生成模型任务中到达可最先进的梯度估计,并在半监督分类实现了加速。

1.介绍

具有离散随机变量的随机神经网络是在无监督学习,语言建模,注意力机制和强化学习领域中表示分布的强大技术。例如,离散变量已被用于学习对应于不同语义类,图像区域和存储器位置的概率潜在表示。离散表示往往比其连续表示更具可解释和高效的计算
  然而,具有离散变量的随机网络因为返乡传播算法而难以训练 (需要允许计算参数梯度),不能应用于不可微分层。在随机梯度估计上的先前工作传统上集中于使用蒙特卡罗方差减少技术增强的得分函数估计,或使用Bernoulli变量的偏差路径推导估计。但是,没有针对分类变量制定现有梯度估计器。因此,本文工作的贡献主要有三个方面:

  1. 我们引入Gumbel-Softmax,在可以近似分类样本的连续分布,并且可以通过重参数化技巧容易地计算其参数梯度。
  2. 我们通过实验显示Gumbel-Softmax优于Bernoulli变量和分类变量的所有单样本梯度估计。
  3. 我们表明,该估算器可用于有效地训练半监督模型(例如,Kingma et al. (2014)),并且在不可观测的分类潜变量上没有昂贵的边缘化。

本文提出的方法实际是一种用于分类变量的简单的,可微分的近似采样机制,可以集成到神经网络中并使用标准反向传播训练。

2.THE GUMBEL-SOFTMAX DISTRIBUTION


  我们首先定义了Gumbel-Softmax分布,可以将样本的分类分布近似为一个连续分布。令zzz为具有分类概率π1,π2,...πkπ_1,π_2,...π_kπ1​,π2​,...πk​的类别变量。对于本文的剩余部分,我们假设分类样本被编码为位于k−1k-1k−1维单纯形端点Δk−1Δ^{k-1}Δk−1上的kkk维one-hot向量。这允许我们定义这些向量的元素均值Ep[z]=[π1,...,πk]\mathbb E_p[z]=[π_1,...,π_k]Ep​[z]=[π1​,...,πk​]的数量。
  Gumbel-Max技巧提供了一种简单有效的方法,可以从具有分类概率πππ的类别分布中采样样本zzz:
z=one_hot(argmaxi[gi+logπi])(1)z=one\_hot\bigg (\mathop{argmax}\limits_{i}[g_i+log~\pi_i]\bigg)\tag{1}z=one_hot(iargmax​[gi​+log πi​])(1)
其中g1...gkg_1... g_kg1​...gk​是从gumbel(0,1)gumbel(0,1)gumbel(0,1)中采样的独立同分布样本。我们使用SoftMax函数作为argmaxarg~maxarg max连续的,可微分的近似,并生成kkk维样本向量y≥Δk−1y≥Δ^{k-1}y≥Δk−1:
yi=exp((log(πi))/τ)∑j=1kexp((log(πj)+gj)/τ)fori=1,...,k.(2)y_i=\frac{exp((log(\pi_i))/\tau)}{\sum^k_{j=1}exp((log(\pi_j)+g_j)/\tau)}\qquad for~i=1,...,k.\tag{2}yi​=∑j=1k​exp((log(πj​)+gj​)/τ)exp((log(πi​))/τ)​for i=1,...,k.(2)
  Gumbel-Softmax的密度分布(推导见附录B)是:
pπ,τ(y1,...,yk)=Γ(k)τk−1(∑i=1kπi/yiτ)−k∏i=1k(πi/yiτ+1)(3)p_{\pi,\tau}(y_1,...,y_k)=\Gamma(k)\tau^{k-1}\bigg(\sum^k_{i=1}\pi_i/y^{\tau}_i\bigg)^{-k}\prod^k_{i=1}(\pi_i/y^{\tau+1}_i)\tag{3}pπ,τ​(y1​,...,yk​)=Γ(k)τk−1(i=1∑k​πi​/yiτ​)−ki=1∏k​(πi​/yiτ+1​)(3)
该分布由 Maddison et al. (2016)独立发现,它被称为 concret分布。随着softmax温度τττ接近0,来自Gumbel-Softmax分布的样本变为one-hot,并且Gumbel-Softmax分布与分类分布p(z)p(z)p(z)变得相同。

2.1 GUMBEL-SOFTMAX ESTIMATOR

当τ>0τ> 0τ>0,Gumbel-Softmax分布是平滑的,因此具有相对于参数πππ的明确定义的梯度∂y/∂π∂y/∂π∂y/∂π。因此,通过使用Gumbel-Softmax样本替换分类样本,我们可以使用反向传播来计算梯度(参见第3.1节)。我们将这种在训练期间用可微分的近似替换不可微分分类样本的过程,称为Gumbel-Softmax估计器
  虽然Gumbel-Softmax样本可微分,但它们与来自相应的非零温度的分类分布的样本不完全相同。为了进行学习,在温度大小之间存在权衡,小温度的样本接近one-hot,但梯度的方差很大,大温度的样本很平滑,但梯度的方差很小(图1)。在实验中,我们从高温度开始并退火到一个小但非零的温度。
  在我们的实验中,我们发现Softmax温度τττ可以根据各种schedules进行退火,并且仍然表现良好。如果τττ是可学习的参数(而不是通过固定schedules退火),则该方案可以被解释为熵正则化,其中Gumbel-Softmax分布可以在训练过程中自适应地调整所采样的样本的“信心”。

2.2 STRAIGHT-THROUGH GUMBEL-SOFTMAX ESTIMATOR

one-hot向量的连续松弛适用于学习隐藏表示和序列建模等问题。对于我们被限制为采样离散值的情况(例如,来自用于强化学习的离散动作空间或量化压缩),我们使用argmaxargmaxargmax获取离散的yyy,而不是是通过近似∇θz≈∇θy∇_θz≈∇_θy∇θ​z≈∇θ​y来使用我们的连续近似。我们称之为直通(ST)Gumbel估计器。ST Gumbel-SoftMax允许样本即使在高温度τττ时也会稀疏。

3.相关工作


  在本节中,我们回顾了用于离散变量的现有随机梯度估计技术(图2中所示)。考虑具有离散随机变量zzz的随机计算图,其分布取决于参数θθθ以及成本函数f(z)f(z)f(z)。目的是通过梯度下降使预期成本L(θ)=Ez∼pθ(z)[f(z)]L(θ)=\mathbb E_{z\sim p_θ(z)}[f(z)]L(θ)=Ez∼pθ​(z)​[f(z)]最小化,这需要我们估计∇θEzpθ(z)[f(z)]∇_θ\mathbb E{z~p_θ(z)}[f(z)]∇θ​Ez pθ​(z)[f(z)]。

3.1 PATH DERIVATIVE GRADIENT ESTIMATORS

对于能被重参数化的分布,我们可以将样本zzz作为参数θθθ和独立随机变量ϵ\epsilonϵ的确定性函数ggg来计算,使得z=g(θ,ϵ)z=g(θ,\epsilon)z=g(θ,ϵ)。然后可以计算从fff到θθθ的路径梯度,而无需遇到任何随机节点:
∂∂θEz∼pθ[f(z)]=∂∂θEϵ[f(g(θ,ϵ))]=Eϵ∼pϵ[∂f∂g∂g∂θ](4)\frac{\partial}{\partial \theta}\mathbb E_{z\sim p_{\theta}}[f(z)]=\frac{\partial}{\partial \theta}\mathbb E_{\epsilon}[f(g(\theta,\epsilon))]=\mathbb E_{\epsilon\sim p_{\epsilon}}\bigg[\frac{\partial f}{\partial g}\frac{\partial g}{\partial \theta}\bigg]\tag{4}∂θ∂​Ez∼pθ​​[f(z)]=∂θ∂​Eϵ​[f(g(θ,ϵ))]=Eϵ∼pϵ​​[∂g∂f​∂θ∂g​](4)
  例如,正态分布z〜N(μ,σ)z〜\mathcal N(μ,σ)z〜N(μ,σ)可以重写为μ+σ⋅N(0,1)μ+σ·\mathcal N(0,1)μ+σ⋅N(0,1),使其变得可以计算∂z/∂μ∂z/∂μ∂z/∂μ和∂z/∂σ∂z/∂σ∂z/∂σ。该重参数化技巧通常应用于使用反向传播训练具有连续潜在变量的变分自编码器。如图2所示,我们在Gumbel-Softmax估计器的构造中利用了这种技巧。
  即使当zzz是不可冲参数化的,也可以使用偏差路径推导估计器。通常,我们可以近似∇θz≈∇θm(θ)∇_θz≈∇_θm(θ)∇θ​z≈∇θ​m(θ),其中mmm是随机采样的可微分代理。对于具有平均参数θθθ的Bernoulli变量,直通(ST)估计器近似m=μθ(z)m=μ_θ(z)m=μθ​(z),因此∇θm=1∇_θm=1∇θ​m=1。对于k=2k=2k=2(Bernoulli),ST Gumbel-SoftMax类似于Chung et al. (2016)提出的slope-annealed Straight-Through估计器,但使用Softmax而不是硬的sigmoid来确定斜率。Rolfe (2016) 考虑一种替代方法,其中每个二元潜在变量参数化连续混合模型。 重参数化梯度是通过反向传播连续变量并使二元变量边缘化而获得的。
  ST估计器的一个限制是样本无关的反向传播意味着可能导致前向和后向计算之间的差异,导致更高的方差。Gumbel-Softmax避免了这个问题,因为每个样本yyy是相应的离散样本zzz的可微分代理。

3.2 SCORE FUNCTION-BASED GRADIENT ESTIMATORS

得分函数估计器(SF,也称为REINFORCE和似然比估计器)使用标识∇θpθ(z)=pθ(z)∇θlogpθ(z)∇_θp_θ(z)=p_θ(z)∇_θlogp_θ(z)∇θ​pθ​(z)=pθ​(z)∇θ​logpθ​(z)来导出以下无偏的估计器:
∇θEz[f(z)]=Ez[f(z)∇θlogpθ(z)](5)∇_θ\mathbb E_z[f(z)]=\mathbb E_z[f(z)∇_θlog~p_{\theta}(z)]\tag{5}∇θ​Ez​[f(z)]=Ez​[f(z)∇θ​log pθ​(z)](5)
  SF仅要求pθ(z)p_θ(z)pθ​(z)在θ中连续,并且不需要通过fff或样本zzz来反向传播。然而,SF会面临高方差,因此难于收敛。特别地,SF方差会随采样向量的维度数量进行线性缩放,使其对分类分布特别具有挑战性。
  可以通过从学习信号fff中减去控制变量b(z)b(z)b(z)来减少得分函数估计器的方差,并加入其分析期望μb=Ez[b(z)∇θlogpθ(z)]μ_b=\mathbb E_z [b(z)∇_{\theta}log~p_θ(z)]μb​=Ez​[b(z)∇θ​log pθ​(z)]以保持估计器的无偏:
∇θEz[f(z)]=Ez[f(z)∇θlogpθ(z)+(b(z)∇θlogpθ(z)−b(z)∇θlogpθ(z))](6)∇_{\theta}\mathbb E_z[f(z)]=\mathbb E_z[f(z)∇_{\theta}log~p_{\theta}(z)+(b(z)∇_{\theta}log~p_θ(z)-b(z)∇_{\theta}log~p_θ(z))]\tag{6}∇θ​Ez​[f(z)]=Ez​[f(z)∇θ​log pθ​(z)+(b(z)∇θ​log pθ​(z)−b(z)∇θ​log pθ​(z))](6)
=Ez[(f(z)−b(z))∇θlogpθ(z)]+μb(7)=\mathbb E_z[(f(z)-b(z))∇_{\theta}log~p_{\theta}(z)]+\mu_b\tag{7}=Ez​[(f(z)−b(z))∇θ​log pθ​(z)]+μb​(7)
  我们简要概括了利用控制变量的最近随机梯度估计器。我们期望读者通过Gu et al. (2016) 了解这些技术的详细信息。

  • NVIL 使用两个基线:(1)fff的移动平均值fˉ\bar ffˉ​为中心,以得到学习信号,(2)由1层神经网络计算的输入相关基线f−fˉf-\bar ff−fˉ​(学习信号本身的控制变量)。最后,方差归一化通过max(1,σf)max(1,\sigma_f)max(1,σf​)划分学习信号,其中σf2σ^2_fσf2​是var[f]var[f]var[f]的移动平均值。
  • DARN 使用b=f(zˉ)+f′(zˉ)(z−zˉ)b=f(\bar z)+f'(\bar z)(z-\bar z)b=f(zˉ)+f′(zˉ)(z−zˉ),其中基线对应于f(z)f(z)f(z)在f(zˉ)f(\bar z)f(zˉ)的第一阶泰勒近似。zzz以1/2从伯努利变量中被选择,这使得估计器用于非二次fff,因为它忽略了估计表达式中的校正项μbμ_bμb​。
  • MuProp 同样将基线建模为一阶泰勒展开:b=f(zˉ)+f′(zˉ)(z−zˉ)b=f(\bar z)+f'(\bar z)(z-\bar z)b=f(zˉ)+f′(zˉ)(z−zˉ),且μb=f′(zˉ)∇θEz[z]μ_b=f'(\bar z)∇_θ\mathbb E_z[z]μb​=f′(zˉ)∇θ​Ez​[z]。为了克服离散采样的反向传播,使用平均场近似fMF(μθ(z))f_{MF}(μ_θ(z))fMF​(μθ​(z))代替f(z)f(z)f(z)以计算基线并导出相关梯度。
  • VIMCO 是用于多样本目标的梯度估计器,其使用其他样本的均值b=1/m∑j≠if(zj)b=1/m\sum_{j\ne i}f(z_j)b=1/m∑j​=i​f(zj​)来构建每个样本zi∈z1:Mz_i\in z_{1:M}zi​∈z1:M​的基线。我们从我们的实验中排除了VIMCO,因为我们正在比较单个样本目标的估计,尽管Gumbel-Softmax可以很容易地扩展到多样本目标。

3.3 SEMI-SUPERVISED GENERATIVE MODELS

半监督学习考虑了从有标注数据(x,y)〜DL(x,y)〜\mathcal D_L(x,y)〜DL​和无标注数据x〜DUx〜\mathcal D_Ux〜DU​中学习的问题,其中xxx是观测值(即图像),yyy是相应的标签(例如语义类别)。对于半监督分类,Kingma et al. (2014) 提出了一种变分自编码器(VAE),其潜在状态是高斯分布变量zzz和语义类别变量yyy(附录图6)。 VAE目标通过最大化对数似然上的变分下界,使判别网络qφ(y∣x)q_φ(y|x)qφ​(y∣x),推理网络qφ(z∣x,y)qφ(z|x,y)qφ(z∣x,y),以及生成网络pθ(x∣y,z)p_θ(x|y,z)pθ​(x∣y,z)来训练端到端生成模型。对于标注数据,类别yyy已知,因此仅在z〜q(z∣x,y)z〜q(z|x,y)z〜q(z∣x,y)上完成推断。标注数据的变分下限为:
logpθ(x,y)≥−L(x,y)=Ez∼qϕ(z∣x,y)[logpθ(x∣y,z)]−KL[q(z∣x,y)∣∣pθ(y)p(z)](8)log~p_{\theta}(x,y)\ge -\mathcal L(x,y)=\mathbb E_{z\sim q_{\phi}(z|x,y)}[log~p_{\theta}(x|y,z)]-KL[q(z|x,y)||p_{\theta}(y)p(z)]\tag{8}log pθ​(x,y)≥−L(x,y)=Ez∼qϕ​(z∣x,y)​[log pθ​(x∣y,z)]−KL[q(z∣x,y)∣∣pθ​(y)p(z)](8)
  对于未标注数据,将会有一些困难,因为类别分布不能重参数化。Kingma et al. (2014) 通过在所有类别上都是边缘化yyy来处理这一点,因此对于未标注的数据,推理仍然是每个yyy的qφ(z∣x,y)q_φ(z|x,y)qφ​(z∣x,y)。未标注数据的下限是:
logpθ(x)≥−U(x)=Ez∼qϕ(y,z∣x)[logpθ(x∣y,z)+logpθ(y)+logp(z)−qϕ(y,z∣x)](9)log~p_{\theta}(x)\ge -\mathcal U(x)=\mathbb E_{z\sim q_{\phi}(y,z|x)}[log~p_{\theta}(x|y,z)+log~p_{\theta}(y)+log~p(z)-q_{\phi}(y,z|x)]\tag{9}log pθ​(x)≥−U(x)=Ez∼qϕ​(y,z∣x)​[log pθ​(x∣y,z)+log pθ​(y)+log p(z)−qϕ​(y,z∣x)](9)
=∑yqϕ(y∣x)(−L(x,y)+H(qϕ(y∣x)))(10)=\sum_y q_{\phi}(y|x)(-\mathcal L(x,y)+\mathcal H(q_{\phi}(y|x)))\tag{10}=y∑​qϕ​(y∣x)(−L(x,y)+H(qϕ​(y∣x)))(10)

CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX翻译相关推荐

  1. TF笔记:小trick之gumbel softmax

    TF笔记:小trick之gumbel softmax 0. 引言 1. gumbel softmax 2. tf代码实现 3. 参考链接 0. 引言 故事的起因在于我们在实际工作中遇到的一个小的需求, ...

  2. Gumbel Max与Gumbel Softmax演示动画

    Gumbel Max以及Gumbel Softmax的理论证明见: 漫谈重参数:从正态分布到Gumbel Softmax 我用js写了一个利用Gumbel Max来对离散分布进行重参数化的过程,地址: ...

  3. Softmax和Cross-entropy是什么关系?

    公众号关注 "视学算法" 设为 "星标",DLCV消息即可送达! 来自 | 知乎   作者 | 董鑫 https://www.zhihu.com/questio ...

  4. 带你认识神奇的Gumbel trick

    The Gumbel soft-max Gumbel trick有两个用途,一个用途是是用来对离散分布进行采样,这是一种重参数化(reparameterization trick)的技巧,另外一个用途 ...

  5. 【论文翻译】统一知识图谱学习和建议:更好地理解用户偏好

    一.摘要 将知识图谱(KG)纳入推荐系统有望提高推荐的准确性和可解释性.然而,现有方法主要假设KG是完整的并且简单地在实体原始数据或嵌入的浅层中转移KG中的"知识".这可能导致性能 ...

  6. 卑微且强大的Gumbel分布

    GAN从诞生开始,广泛应用于计算机视觉领域,并取得了巨大地成功,相比而言,GAN在NLP领域的应用还是相对较少,这是因为GAN对离散型数据序列的处理显得无能为力,从离散分布中采样的数据时不可导的,在使 ...

  7. ECCV2020论文-稀疏性表示-Neural Sparse Representation for Image Restoration翻译

    Neural Sparse Representation for Image Restoration 用于图像复原的神经稀疏表示 Abstract 在基于稀疏编码的图像恢复模型中,基于稀疏表示的鲁棒性 ...

  8. 计算机如何“看懂”图片?达摩院提出新的研究方法

    简介: 本文的部分内容基于英文论文"Learning in the frequency domain"翻译而来,英文论文已经被计算机视觉顶级会议Computer Vision an ...

  9. ICCV2021 |优胜劣汰,MIT团队提出自适应多模态选取框架用于视频理解

    关注公众号,发现CV技术之美 AdaMML: Adaptive Multi-Modal Learning for Efficient Video Recognition 项目主页:https://rp ...

最新文章

  1. 为何程序员们没事总爱戴个耳机,看完恍然大悟......
  2. CList添加出错AddTail();rror C2664: 'struct __POSITION *__thiscall CList错误
  3. java怎么解决重复支付问题_支付系统设计中,如何防止重复支付?
  4. 震惊!温州一程序员编完八皇后吐血而亡,他的代码是什么样子?!!
  5. 关于异地高考引发的又一次舆论攻势
  6. 卷积神经网络原理_人脸识别背后,卷积神经网络的数学原理原来是这样的
  7. 单目摄像机测距(python+opencv)(转载)
  8. ROS2的学习笔记(legacy)
  9. 算法眼中的世界是什么样子?他们用一些彩色方块画了出来
  10. 【离散】画哈斯图--最好理解绝不会出错
  11. 计算机二级MSoffice
  12. java继承和接口的优缺点_Java抽象类和接口的优缺点---总结-2
  13. 【DeepLab v1 2016】SEMANTIC IMAGE SEGMENTATION WITH DEEP CON- VOLUTIONAL NETS AND FULLY CONNECTED CRFS
  14. CSS单行/多行文本溢出隐藏
  15. python网络安全应用中心_安全人员常用的python库
  16. 港府:“同股同权”制度也有调整空间 尊重阿里
  17. JavaScript日记——实现图片的瀑布流和底部刷新
  18. 你的浏览器正在排斥 IPv6
  19. SQL常见函数以及使用
  20. 芝麻动态码-小程序动态二维码生成

热门文章

  1. matlab pwm整流仿真
  2. 手风琴效果(vue实现)
  3. [转]Clion2019破解-Jetbrains系列产品2019.1.1最新激活方法[持续更新]
  4. gallery3D(3)
  5. 菜鸟的linux云服务器第一次木马入侵处理记录(名为xmrigMiner的木马)
  6. 如何将JPG转换为PNG?两种图片格式转换的方法交给你
  7. eclipse点餐系统的框架
  8. TICKScript简介
  9. Linux 下后台运行程序,查看和关闭后台运行程序(转载)
  10. 初二因式分解奥数竞赛题_因式分解(竞赛题)含答案