简明《Stacked Capsule Autoencoders》

什么是胶囊?什么是胶囊网络?胶囊真的有用吗?怎么实现一个胶囊网络?本文将会原理到实现,解读来自Hinton团队2019年发布的胶囊网络《Stacked Capsule Autoencoders》。由于个人水平有限,欢迎勘误,欢迎交流讨论。

  • Github:https://github.com/QiangZiBro

  • Zhihu:https://www.zhihu.com/people/QiangZiBro

  • 公众号:QiangZiBro

笔者又制作了相关PPT,本文图片部分来自于PPT。搜索我的公众号QiangZiBro,回复”SCAE“获取PPT和相关资料。

一、前言

胶囊是什么?

它不是药物,而是来自人工智能领域被称作“神经网络之父”、“深度学习鼻祖”的科学家Hinton提出一种新型神经网络结构。从2017年开始,胶囊被认为是一个矢量神经元,对标于之前的标量网络。简而言之,胶囊是向量神经元,也即矢量神经元。

我们这里提到的神经元实际上来自被称作“神经网络(Neural Networks,NN)”的“人工神经网络”,神经元是NN的基本单位,神经元数学模型如下图所示

上一层输入神经元x1,…,xnx_1, … ,x_nx1,,xn经过相关前馈计算,得到下一层输出神经元yyy。这些神经元由一位数字表示,因此它们被称作标量神经元。可以简单认为,胶囊网络将这里的标量神经元改进为矢量神经元。

【名词】

  • 矢量神经元(Vector Neural,VN)即胶囊
  • 标量神经元(Scalar Neural, SN)

为什么提出 CapsNet?

关于为什么要从标量迈向矢量,网上已有许多直觉的解释[1,2],这里不再赘述。简而言之,使用VN的CapsNet被认为有以下优点

  • VN能更好地编码特征的一组相关信息,而SN不可以
  • 引入VN的CapsNet解决了原本CNN的池化丢失信息问题
  • CapsNet能学到姿态信息,换句话说,CapsNet是等变(equivariance)网络,而CNN是不变(invariance)网络。

上面提到的最后一点很重要,为什么我们关心等变性呢?因为(1)一个等变网络不需要大量训练数据(2)等变网络能学到输入的位姿。

这一版CapsNet有什么特点?

堆叠式胶囊自编码器(Stacked Capsule Autoencoders,SCAE)是Hinton团队提出第一、二两版胶囊网络后的又一版本,这篇使用无监督方法学习图像中的特征,并取得了最先进的结果。

直觉

中国有句古话:”横看成岭侧成峰,远近高低各不同“。人从不同视角(view)看object(比如山),看到的景物总是不一样的,然而object与part的关系确实固定的,正如山与树的关系。因此,object与part之间是视角不变的,人与object之间是视角等变的。因此,可以通过神经网络来显式地学习viewer与object的变换矩阵(OV)和object与part的变换矩阵(OP)。Hinton做过一个实验,如下图,旋转后的正方形看起来还像正方形吗?这个例子说明人的感官也会”欺骗“我们。

Is it a square, or is it a rhombus?

这一版胶囊网络将object胶囊表示为它对观察者的关系,因此这里不需要像CNN那样在空间上复制神经元的激活来表示一个object。

【使用到的名词】

  • 观察者 (viewer)
  • 对象 (object)
  • 部件、部分 (part)

效果一览

  • 无监督分类

    • MNIST 98.7%
    • SVHN 55%

阅读时思考几个问题

  • 这一个版本怎么表示胶囊的?使用到的胶囊学到了什么?

  • 论文里面提到的几种自编码器(CCAE,PCAE,OCAE,SCAE)的区别和联系是什么 ?

  • 为什么需要把PCAE和OCAE两个自编码器堆叠在一起同时训练?

  • 编码器部分:为什么使用Set Transformer对part capsules进行编码,这个编码器有什么特殊之处吗?

  • 解码器部分:第一个例子中是怎样决策每个点的类别?第二个例子中图片怎么解码出来?

二、前置知识

2.1 自编码器

自编码器是一个试图去还原其原始输入的系统,自编码器由编码器和解码器组成。试图重建自己这个事情本身对我们来说没有太大意义,而真正有帮助的在于自编码器最中间的特征层。为什么中间学到的特征很重要?如果我们把每个样本最中间的学好的特征向量放在一起,会出现类似聚类的效果:同类之间距离小,不同类之间距离大。注意到中间特征层一般都是高维,高维空间很难去想象,我们可以用t-SNE进行降维可视化。而本文在MNIST数据集上学到降维可视化效果如下:

image-20201114162935734

可见,不同样本在特征空间上分的很开,只有少数重叠,这个效果便很好。为什么自编码器能够达到这样的效果呢?因为在重建时,自编码器希望对每类样本都学到其最具有代表性的特征,只有能够清楚地区分它,才能够完成重建。而获得一个具有代表性的特征,才是特征学习或深度学习成功的关键。自编码器学习到优秀的特征后,可以用其做后续任务,比如本文用其做了一个简单的分类,在MNIST分类效果达到98.7%。

总结下,做深度学习时,我们更希望网络学到的特征能够尽可能全方位地代表这个样本,对应到特征空间,即是上图的效果;而不是某一方面来代表这个样本,自编码可能很好地做好前者,而不做自编码则很难获取有效的特征。因此,自编码器训练好之后,我们来进行后续的任务便较为容易了。

2.2 Set Transformer

笔者在这里第一次正式认识Transformer,发现这个词在中文里鲜有翻译,而是直接使用原词。笔者不是特别喜欢中文文章夹杂大量英文,但为了尽可能表示精准,还是使用适量英文。

2017年版本的胶囊网络[10]使用动态路由来完成一个聚类任务,而本文使用了Set Transformer[4],用来将part编码为object。Set Transformer输入包含nnn个样本的顺序无关的集合X∈Rn×dX \in \mathbb{R}^{n \times d}XRn×d,输出包含kkk个样本的集合O∈Rk×dO \in \mathbb{R}^{k \times d}ORk×d

简介

Set Transformer有什么应用场景呢?它在最大值回归计数不同字符混合高斯集合异常检测点云分类五个任务上有较好表现。处理这些任务的模型有两个特性:

  • 排列不变性:即输出结果与输入顺序无关;
  • 能够处理任何大小的输入

模型框架

下面我们从至顶向下的方式了解Set Transformer细节。

首先解决集合排列不变的这类任务的模型,有文献[9]证明可以写成如下通式:
net⁡({x1,…,xn})=ρ(pool⁡({ϕ(x1),…,ϕ(xn)}))\operatorname{net}\left(\left\{x_{1}, \ldots, x_{n}\right\}\right)=\rho\left(\operatorname{pool}\left(\left\{\phi\left(x_{1}\right), \ldots, \phi\left(x_{n}\right)\right\}\right)\right) net({x1,,xn})=ρ(pool({ϕ(x1),,ϕ(xn)}))
上式是个自编码器,其中ϕ(x)\phi\left(x\right)ϕ(x)是编码器,ρ(x)\rho\left( x\right)ρ(x)是解码器,pool⁡(x)\operatorname{pool}(x)pool(x)是池化操作,比如max,mean操作。

Set Transformer也是这类网络结构,当然也是自编码器,其编码器可以为下面两种的其中一种:
Z=Encoder⁡(X)=SAB⁡(SAB⁡(X))∈Rn×dZ=\operatorname{Encoder}(X)=\operatorname{SAB}(\operatorname{SAB}(X)) \in \mathbb{R}^{n \times d} Z=Encoder(X)=SAB(SAB(X))Rn×d

Z=Encoder⁡(X)=ISAB⁡m(ISAB⁡m(X))∈Rn×dZ=\operatorname{Encoder}(X)=\operatorname{ISAB}_{m}\left(\operatorname{ISAB}_{m}(X)\right) \in \mathbb{R}^{n \times d} Z=Encoder(X)=ISABm(ISABm(X))Rn×d

解码器是
O=Decoder⁡(Z;λ)=rFF⁡(SAB⁡(PMA⁡k(Z)))∈Rk×dO=\operatorname{Decoder}(Z ; \lambda)=\operatorname{rFF}\left(\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)\right) \in \mathbb{R}^{k \times d} O=Decoder(Z;λ)=rFF(SAB(PMAk(Z)))Rk×d
其中池化操作PMA⁡k(Z)=MAB⁡(S,rFF⁡(Z))∈Rk×d\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) \in \mathbb{R}^{k \times d}PMAk(Z)=MAB(S,rFF(Z))Rk×dkkk表示输出集合中实例的个数,k<nk < nk<n

细节

模型框架引入若干概念,下面对它们一一作出解释。Set Transformer引入了Transformer相关概念,详细讲解可以参考[5,6,8]。

我们设Q∈Rn×dq,K∈Rnv×dv,V∈Rnv×dvQ \in \mathbb{R}^{n \times d_{q}},K \in \mathbb{R}^{n_v \times d_{v}},V \in \mathbb{R}^{n_v \times d_{v}}QRn×dq,KRnv×dv,VRnv×dv表示query,key,value;ωj(⋅)=softmax⁡(⋅/d)\omega_{j}(\cdot)=\operatorname{softmax}(\cdot / \sqrt{d})ωj()=softmax(/d

)表示一个变化的softmax函数,λ={WjQ,WjK,WjV}j=1h\lambda=\left\{W_{j}^{Q}, W_{j}^{K}, W_{j}^{V}\right\}_{j=1}^{h}λ={WjQ,WjK,WjV}j=1h表示一组系数矩阵。

Transformer有注意力机制和多头注意力机制两个基础构成:

  • 注意力机制:Att⁡(Q,K,V;ω)=ω(QK⊤)V∈Rn×dv\operatorname{Att}(Q, K, V ; \omega)=\omega\left(Q K^{\top}\right) V \in \mathbb{R}^{n \times d_v}Att(Q,K,V;ω)=ω(QK)VRn×dv,计算过程如下图,图源[8]

    【注】下面的例子改编自[5],如果有误,欢迎指正!

    Query,Key,Value的概念取自于信息检索系统,比如我们在电商平台搜索某件商品(年轻女士冬季穿的红色薄款羽绒服)时,在搜索引擎上输入的内容便是Query,怎么通过上面的式子给出想要的结果呢?搜索引擎根据Query匹配Key(例如商品的种类,颜色,描述等)和Value(各种羽绒服结果),然后根据Query和Key相乘并且通过激活函数ω\omegaω得到一个相似度矩阵ω(QK⊤)\omega\left(Q K^{\top}\right)ω(QK),相似度矩阵矩阵再和Value矩阵进行相乘得到最终需要的加权结果。

  • 多头注意力机制

Multihead⁡(Q,K,V;λ,ω)=concat(Z1,…,Zh)⁡WO∈Rn×d,Zj=Att⁡(QWjQ,KWjK,VWjV;ωj)​\operatorname{Multihead} (Q, K, V ; \lambda, \omega) = \operatorname{concat(Z_1,…,Z_h)}W^O \in \mathbb{R}^{n \times d},\\ Z_{j}=\operatorname{Att}\left(Q W_{j}^{Q}, K W_{j}^{K}, V W_{j}^{V} ; \omega_{j}\right)​ Multihead(Q,K,V;λ,ω)=concat(Z1,,Zh)WORn×d,Zj=Att(QWjQ,KWjK,VWjV;ωj)

​ 同样是Q,K,V输入,用多个注意机制获得多个输出,再将它们合并起来并用一个矩阵进行变换。在下图例子中h=7h=7h=7

  • rFF⁡(x)\operatorname{rFF}(x)rFF(x) Row-wise 前馈层,它同等地、独立地处理每个instance

Set Transformer根据上面的基础,提出下面的模块:

  • 多头注意力模块(Multihead Attention Block)
    MAB⁡(X,Y)=LayerNorm⁡(H+rFF⁡(H))\operatorname{MAB}(X, Y)=\operatorname{LayerNorm}(H+\operatorname{rFF}(H)) MAB(X,Y)=LayerNorm(H+rFF(H))
    其中H=LayerNorm (X+Multihead⁡(X,Y,Y;ω))H=\text { LayerNorm }(X+\operatorname{Multihead}(X, Y, Y ; \omega))H=LayerNorm(X+Multihead(X,Y,Y;ω))

  • 集合注意力模块(Set Attention Block ),计算复杂度O(n2)\mathcal{O}\left(n^{2}\right)O(n2)

SAB⁡(X)=MAB⁡(X,X)\operatorname{SAB}(X) = \operatorname{MAB}(X,X) SAB(X)=MAB(X,X)

  • 诱导集合注意力模块(Induced Set Attention Block )

ISAB⁡m(X)=MAB⁡(X,H)∈Rn×d​\operatorname{ISAB}_m(X)=\operatorname{MAB}(X, H) \in \mathbb{R}^{n \times d}​ ISABm(X)=MAB(X,H)Rn×d

​ 其中H=MAB⁡(I,X)∈Rm×dH=\operatorname{MAB}(I,X) \in \mathbb{R}^{m \times d}H=MAB(I,X)Rm×dI∈Rm×dI \in \mathbb{R}^{m \times d}IRm×d为可学习参数。

​ 为什么提出ISAB呢?[7]SAB的问题是transformer的传统问题,复杂度太高。所以引入诱导点(induced points)矩阵I∈Rm×dI \in \mathbb{R}^{m \times d}IRm×d,类似于矩阵分解,将原来的一步attention拆分成两步attention,首先用III对输入X做self-attention,接着用得到的结果对输入做attention。将复杂度从O(n2)\mathcal{O}\left(n^{2}\right)O(n2)降到O(mn)\mathcal{O}\left(mn\right)O(mn)

  • 多头注意力机制的池化(Pooling by Multihead Attention)。池化是一种常见的聚合(aggregation)操作。上面提到,池化可以是最大或是平均。这里提出的池化是应用一个MAB在一个可学习的矩阵S∈Rk×dS \in \mathbb{R}^{k \times d}SRk×d上。在一些聚类任务上,kkk设为我们需要的类别数。使用基于注意力的池化的直觉是,每个实例对target的重要性都不一样

PMA⁡k(Z)=MAB⁡(S,rFF⁡(Z))\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) PMAk(Z)=MAB(S,rFF(Z))

H=SAB⁡(PMA⁡k(Z))H=\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right) H=SAB(PMAk(Z))

小结

Set Transformer是一个使用注意力机制的、以集合为输入的自编码器,这个模型在编码和池化模块中都使用了注意力机制。改变池化模块kkk的大小,可以改变输出集合大小。

三、 方法

3.1 一个玩具样例——集群自编码器(Constellation Autoencoder,CCAE)

“Constellation”中文译为星座,机器之心曾译作集群。

作者在这里做的任务是:对二维平面上的一组点{xm∣m=1,…,M}\left\{\mathbf{x}_{m} \mid m=1, \ldots, M\right\}{xmm=1,,M}进行聚类。如下图,平面上两个正方形、一个三角形共有11个点构成,每个点即是part,通过无监督聚类去把每个点划分到对应的类别当中。自编码器在学习的过程,是在学习数据的分布和表征。

这个自编码器整体思路如下图所示

编码器Set Transformer将11个2维点集(每个点是part)编码为3个object capsule,每个object 由一个2×22\times 22×2大小的观察者——对象(OV)矩阵,一个特征向量cic_ici,和概率aia_iai这三个量表示。每个object通过MLP进行解码获得4个part capsule,每个part由2×12\times 12×1对象——部件(OP)矩阵,概率ai,ja_{i,j}ai,j,和标准差λi,j\lambda_{i,j}λi,j构成。每个part用高斯分布近似表示在平面上点的概率分布,具体来讲,通过对应OV矩阵和OP矩阵相乘表示高斯分布的均值,λi,j\lambda_{i,j}λi,j表示这个高斯分量的方差。因此可以通过这个混合高斯模型来计算整个数据分布的可能性:
p(x1:M)=∏m=1M∑k=1K∑n=1Nakak,n∑iai∑jai,jp(xm∣k,n)p\left(\mathbf{x}_{1: M}\right)=\prod_{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, n\right) p(x1:M)=m=1Mk=1Kn=1Niaijai,jakak,np(xmk,n)
通过最大似然估计来对上式进行求解,通过RMSProp来对参数进行更新。

对每个点,其类别的最终决策为:
k⋆=arg⁡max⁡kakak,np(xm∣k,n)k^{\star}=\arg \max _{k} a_{k} a_{k, n} p\left(\mathbf{x}_{m} \mid k, n\right) k=argkmaxakak,np(xmk,n)
也就是说,每个点最终都会属于三个object capsule的其中一个。这里,object capsule和part capsule的个数都是超参,因为我们预先知道数据可以分三类,所以将object的个数定为3。

【注】

看原文的OV和OP矩阵大小可能会让我们产生疑惑,原文表述两个矩阵是3×33 \times 33×3,实际上,OV矩阵大小2×22\times 22×2,OP矩阵大小2×12\times 12×1,相乘得到大小为2×12\times 12×1的均值。也即每个高斯分量都是个二元的高斯分布。

3.2 图片实验——SCAE

模型概览

上面toy setup讨论完,来到图像实验。图片实验的模型SCAE也照应了标题,SCAE的框架如下图所示

SCAE由两个自编码器构成:

  • 部分胶囊自编码器 (Part Capsule Autoencoder,PCAE)
图片 --> part --> 图片

其中图片到part使用了CNN,一直卷积到特征图大小为1,对结果特征图划分了24个part胶囊,每个胶囊长度为(6+n+1)位(见下表)

  • 对象胶囊自编码器 (Object Capsule Autoencoder,OCAE)
part --> object --> part

part到object编码器是Set Transformer,object到part解码器是mlp,这些可以参考3.1和2.2。

胶囊表示方法对比

SCAE和CCAE胶囊表示方法的不同,笔者将其归纳如下表(红色显示出来不同项)

image-20201115115130399

PCAE

原文对PCAE表述如下

第一行式子是CNN编码part,下面几行表示将part解码回图片,是我们需要重点了解的。基于图像模板技术重建图片流程可以用下图所示

图像实验中的part可以比作CCAE例子中平面的点,CCAE的part是2维,PCAE的part是(6+1+n)维。

PCAE如何进行解码的? 原作者使用了图片模板技术将part解码图片,具体讲,第mmm个part使用一个可学习的图片模板Tm∈[0,1]ht×wt×(c+1)T_{m} \in[0,1]^{h_{t} \times w_{t} \times(c+1)}Tm[0,1]ht×wt×(c+1),这个模板比输入图片小,并且多一个通道TaT^aTa表示被其他模板的遮挡。比如MNIST大小为28×28×128 \times 28 \times 128×28×1,那么使用的模板大小有11×11×(1+1)11\times11 \times (1+1)11×11×(1+1)。应用第mmm个part的姿态矩阵xmx_mxmTmT_mTm上,即可得到和输入大小相同的图片T^m\widehat{T}_{m}T

m。由上图可以看到TmT_{m}Tm学习到了字体的笔画细节,变换后的对最终总的输出重建有T^m\widehat{T}_{m}T

m
贡献,加权这些分量,即可得到最终的预测。

**每个part解码出的结果怎么表示图片呢?**对输出图片每个像素位置(i,j)(i,j)(i,j),可以得到关于像素强度的高斯分布。有了这个分布,可以确定可能性最大的像素强度,进而确定一张由part预测的图片。若干part解码的高斯分布的叠加,得到总的输出图片。可以用下式表示图片的总似然值
p(y)=∏i,j∑m=1Mpm,i,jyN(yi,j∣cm⋅T^m,i,jc;σy2)p(\mathbf{y})=\prod_{i, j} \sum_{m=1}^{M} p_{m, i, j}^{y} \mathcal{N}\left(y_{i, j} \mid \boldsymbol{c}_{m} \cdot \widehat{T}_{m, i, j}^{c} ; \sigma_{y}^{2}\right) p(y)=i,jm=1Mpm,i,jyN(yi,jcmT

m,i,jc;σy2)

OCAE

最后 使用OCAE对part进行自编码
p(x1:M,d1:M)=∏m=1M[∑k=1Kakak,m∑iai∑jai,jp(xm∣k,m)]dmp\left(\mathbf{x}_{1: M}, d_{1: M}\right)=\prod_{m=1}^{M}\left[\sum_{k=1}^{K} \frac{a_{k} a_{k, m}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, m\right)\right]^{d_{m}} p(x1:M,d1:M)=m=1M[k=1Kiaijai,jakak,mp(xmk,m)]dm

小结

上面是全部的SCAE全部内容,理论上来说,只需要对目标函数log⁡p(y)+log⁡p(x1:M))\left.\log p(\mathbf{y})+\log p\left(\mathbf{x}_{1: M}\right)\right)logp(y)+logp(x1:M))进行训练就好,但作者考虑不对目标函数进行约束可能遇到两个问题:(1)使用所有part和object去推理一张图片 (2)不管输入是什么,总是用一部分胶囊学习。为了让模型能够使用不同的part,以及让object的特别性,作者引入了稀疏性和交叉熵约束,这里的具体数学细节可以查阅原文。

四、源码阅读

原作者开源的代码基于TensorFlow1.15[11]。笔者没有学过tf,虽然网上也有一些开源复现版本[12],但还是原作者tf版本写的最详细,遂硬着头皮读源码。

4.1 配置代码

作者写的代码规范性良好,而且也很容易配置好环境跑起来,不过鉴于开源在了google-research这个大仓库里,进行调试开发时还是要进行相关的配置。

  • 使用环境

    • ubuntu18(GP100)
    • pycharm
  • 配置项目

这个项目放在谷歌的大仓库里面,我们把它单独提出来。有两个方案,使用svn更快一些。

# 方案1 使用git
git clone https://github.com/google-research/google-research --depth 1
mkdir stacked_capsule_autoencoders
cp -r google-research/stacked_capsule_autoencoders stacked_capsule_autoencoders/# 方案2 使用svn
sudo apt-get install subversion -y
SUBDIR="stacked_capsule_autoencoders"
svn export https://github.com/google-research/google-research/trunk/$SUBDIR
mkdir tmp
mv stacked_capsule_autoencoders tmp
mv tmp stacked_capsule_autoencoders #项目根目录、工作目录
  • 代码风格

原作者使用2个空格缩进,笔者习惯4个空格。对原作者所有代码运用了autopep8,2个空格的缩进改成了4个空格缩进。一个命令即可:

cd stacked_capsule_autoencoders
find . -name '*.py' -exec autopep8 --in-place --aggressive --aggressive {} \;
  • Python环境配置

作者使用了virtualenv来单独管理这个python环境,使用下面这个脚本便可以一键设置python环境,很省心。注意要配置好镜像源,否则下载会比较慢。

【Note】 在安装环境时遇到一个error,笔者将requirements.txt第一行改成了absl-py<0.11,>=0.9

bash stacked_capsule_autoencoders/setup_virtualenv.sh

环境配置好之后,目录结构如下图所示。接着我们在pycharm里选择这个编译器,这里就不细讲了。

  • pycharm配置和运行。根据提供的运行脚本,编辑pycharm的运行配置,方便IDE调试。这里提供笔者运行的参数。

注意工作目录选项目根目录

Constellation,发现不能使用–plot参数进行可视化,有bug。

--name=constellation
--model=constellation
--dataset=constellation
--prior_within_example_sparsity_weight=1.
--prior_between_example_sparsity_weight=1.
--posterior_within_example_sparsity_weight=0.
--posterior_between_example_sparsity_weight=0.
--overwrite

MNIST

--name=mnist
--model=scae
--dataset=mnist
--max_train_steps=300000
--batch_size=128
--lr=3e-5
--use_lr_schedule=True
--canvas_size=40
--n_part_caps=40
--n_obj_caps=32
--colorize_templates=True
--use_alpha_channel=True
--plot=True
--posterior_between_example_sparsity_weight=0.2
--posterior_within_example_sparsity_weight=0.7
--prior_between_example_sparsity_weight=0.35
--prior_within_example_constant=4.3
--prior_within_example_sparsity_weight=2.
--color_nonlin=sigmoid
--template_nonlin=sigmoid
  • 命令行运行
bash stacked_capsule_autoencoders/run_constellation.sh
bash stacked_capsule_autoencoders/run_mnist.sh
  • 调试

由于TensorFlow使用了计算图机制,调试变得较为复杂,笔者搜集了关于调试TensorFlow的资料,可以参考[16]

4.2 技术框架

每个作者都有不同的工具喜好,首先简要了解下原作者使用了什么技术进行构建,方便我们更好理解项目。TensorFlow不用说,作者还使用了montysonnetabsl

  • monty

Monty是Python缺少特性的补充。Monty为不属于标准库的Python实现了补充的helper函数。如对文件压缩的透明支持,有用的设计模式,如单例和缓存类等。笔者最喜欢的一个数据结构是支持字典的点索引,例子:

from monty.collections import AttrDict
d = AttrDict(foo=1, bar=2)
assert d["foo"] == d.foo
d.bar = "hello"
assert d.bar == "hello"
  • Sonnet

Sonnet是一个建立在TensorFlow 之上的库,旨在为机器学习研究提供简单的、可组合的抽象。比如在这个项目中会看到每个模型里定义了_build类方法,这个方法类似pytorchModule.forward 方法。

  • absl

这个库用于构建Python应用程序的Python库代码的集合。 该代码是从Google自己的Python代码库中收集的,并且已经过广泛测试并用于生产中。可以参考其例子[13]和文档[14]进行学习。这里也举个例子

from absl import app
from absl import flagsFLAGS = flags.FLAGSflags.DEFINE_string('name', 'Jane Random', 'Your name.')def main(argv):print('Happy, ', FLAGS.name)if __name__ == '__main__':app.run(main)

当程序运行时,app.run()解析flags命令行参数,并运行main程序。

4.3 从训练管窥整个框架

由于这个项目代码量相对较多,我们从顶层逐渐阅读。我们这里以mnist为例,弄明白下面三个问题:

  1. 训练的主循环

这里最大训练30万次,每500次进行一个train-val和report

  1. 了解模型是怎样定义前向计算

结论

  • 每个模型里定义的_build类方法,类似pytorchModule.forward 方法,定义了相关前向计算
  • 每个模型里定义的_loss类方法,是在前向计算好的结果上,求和得到最终损失
  • 每个模型里面make_target先调用_build后调用_loss计算损失

【探究细节】

作者基于snt.AbstractModule实现了一个抽象类Model类,make_target是继承于Model类的方法,其首先计算一次前向运算的相关结果,并且计算损失。其在Model类实现如下

image-20201106154107197

简而言之,Model类的make_target需要调用self.__call__方法,ok,没问题,来看看self.__call__的实现

image-20201106154916946

粗读其注释,了解到_call方法通过_build方法将模块连到计算图里面。也就是说,我们自己构建模型需要实现一个类似pytorchforward 的方法,这个方法取名叫_build 。adamk同学对每个模型都实现了_build方法,我们直接看图像SCAE的前向计算

3.如何使用定义好的模型?

实现的模型放在models这个文件夹里,每个模型都是一个python类。模型实现好之后,在model_config.py有一个get函数,根据配置实例化一个对象来。这是一个典型的工厂模式
60607651.png

get函数返回模型架构、优化器、学习率。我们使用下面代码即可做相应的调用

model_dict = model_config.get(FLAGS)lr = model_dict.lr
opt = model_dict.opt
model = model_dict.model

4.4 Set Transformer的实现

Set Transformer是本文用到的一个编码器,其完成part到object的编码。来看看其具体实现

class SetTransformer(snt.AbstractModule):"""Permutation-invariant Transformer."""def __init__(self,n_layers,n_heads,n_dims,n_output_dims,n_outputs,layer_norm=False,dropout_rate=0.,n_inducing_points=0):super(SetTransformer, self).__init__()self._n_layers = n_layersself._n_heads = n_headsself._n_dims = n_dimsself._n_output_dims = n_output_dimsself._n_outputs = n_outputsself._layer_norm = layer_normself._dropout_rate = dropout_rateself._n_inducing_points = n_inducing_pointsdef _build(self, x, presence=None):batch_size = int(x.shape[0])h = snt.BatchApply(snt.Linear(self._n_dims))(x)args = [self._n_heads, self._layer_norm, self._dropout_rate]klass = SelfAttention # MABif self._n_inducing_points > 0:args = [self._n_inducing_points] + argsklass = InducedSelfAttention #ISABfor _ in range(self._n_layers):h = klass(*args)(h, presence)z = snt.BatchApply(snt.Linear(self._n_output_dims))(h)inducing_points = tf.get_variable('inducing_points', shape=[1, self._n_outputs, self._n_output_dims])inducing_points = snt.TileByDim([0], [batch_size])(inducing_points)return MultiHeadQKVAttention(self._n_heads)(inducing_points, z, z, presence)

4.5 CCAE实现

model_config.py里配置的CCAE模型:

model = ConstellationAutoencoder(encoder=encoder, # Set Transformerdecoder=decoder, # mlpmixing_kl_weight=config.mixing_kl_weight, #0,作者没有用sparsity_weight=config.sparsity_weight, #10 这一个和下一个用于对损失加权dynamic_l2_weight=config.dynamic_l2_weight, #10#prior_sparsity_loss_type='l2',prior_within_example_sparsity_weight=config.prior_within_example_sparsity_weight,# pylint:disable=line-too-longprior_between_example_sparsity_weight=config.prior_within_example_sparsity_weight,# pylint:disable=line-too-longprior_within_example_constant=0.,posterior_sparsity_loss_type='entropy',posterior_within_example_sparsity_weight=config.posterior_within_example_sparsity_weight,# pylint:disable=line-too-longposterior_between_example_sparsity_weight=config.posterior_between_example_sparsity_weight,# pylint:disable=line-too-long
)

模型前向计算

def _build(self, data):x = data[self._input_key]presence = data[self._presence_key] if self._presence_key else Noneinputs = nest.flatten(x)if presence is not None:inputs.append(presence)h = self._encoder(*inputs)res = self._decoder(h, *inputs)n_points = int(res.posterior_mixing_probs.shape[1])mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1)(res.posterior_within_sparsity_loss,res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(self._posterior_sparsity_loss_type,mass_explained_by_capsule / n_points,num_classes=self._n_classes)(res.prior_within_sparsity_loss,res.prior_between_sparsity_loss) = _capsule.sparsity_loss(self._prior_sparsity_loss_type,res.caps_presence_prob,num_classes=self._n_classes,within_example_constant=self._prior_within_example_constant)return res

4.6 SCAE实现

model_config.py里SCAE调用

model = ImageAutoencoder(primary_encoder=part_encoder, # 编码器,图片 --> partprimary_decoder=part_decoder, # 基于图片模板的解码器,  part --> 图片encoder=obj_encoder, # SetTransformer编码器, part --> objectdecoder=obj_decoder, # mlp解码器  object --> partinput_key='image',label_key='label',n_classes=10,dynamic_l2_weight=10,caps_ll_weight=1.,vote_type='enc', # pose, soft, hardpres_type='enc',stop_grad_caps_inpt=True,stop_grad_caps_target=True,prior_sparsity_loss_type='l2',prior_within_example_sparsity_weight=config.prior_within_example_sparsity_weight,# pylint:disable=line-too-longprior_between_example_sparsity_weight=config.prior_between_example_sparsity_weight,# pylint:disable=line-too-longposterior_sparsity_loss_type='entropy',posterior_within_example_sparsity_weight=config.posterior_within_example_sparsity_weight,# pylint:disable=line-too-longposterior_between_example_sparsity_weight=config.posterior_between_example_sparsity_weight,# pylint:disable=line-too-long
)

SCAE的计算细节?

可以通过_build了解每个模型的计算细节。看看SCAE的前向计算

# ImageAutoencoder,即SCAE
def _build(self, data):input_x = self._img(data, False)target_x = self._img(data, prep=self._prep)batch_size = int(input_x.shape[0])primary_caps = self._primary_encoder(input_x) # CNN构造初级胶囊pres = primary_caps.presence # 每个part的存在概率expanded_pres = tf.expand_dims(pres, -1)pose = primary_caps.pose # 每个part的姿态input_pose = tf.concat([pose, 1. - expanded_pres], -1) # 将part的姿态信息和概率信息作为 object编码器的输入input_pres = pres # 将part的概率作为 object编码器的输入if self._stop_grad_caps_inpt:input_pose = tf.stop_gradient(input_pose)input_pres = tf.stop_gradient(pres)target_pose, target_pres = pose, pres # OCAE的输出if self._stop_grad_caps_target:target_pose = tf.stop_gradient(target_pose)target_pres = tf.stop_gradient(target_pres)# skip connection from the img to the higher level capsuleif primary_caps.feature is not None:input_pose = tf.concat([input_pose, primary_caps.feature], -1)# OCAE的编码,注意这里给Set Transformer输入的细节:# try to feed presence as a separate input# and if that works, concatenate templates to poses# this is necessary for set transformern_templates = int(primary_caps.pose.shape[1])templates = self._primary_decoder.make_templates(n_templates,primary_caps.feature)     try:if self._feed_templates:inpt_templates = templatesif self._stop_grad_caps_inpt:inpt_templates = tf.stop_gradient(inpt_templates)if inpt_templates.shape[0] == 1:inpt_templates = snt.TileByDim([0], [batch_size])(inpt_templates)inpt_templates = snt.BatchFlatten(2)(inpt_templates)pose_with_templates = tf.concat([input_pose, inpt_templates], -1)else:pose_with_templates = input_poseh = self._encoder(pose_with_templates, input_pres)except TypeError:h = self._encoder(input_pose)# OCAE的解码res = self._decoder(h, target_pose, target_pres)# 下面是其他中间结果、损失计算res.primary_presence = primary_caps.presenceif self._vote_type == 'enc':primary_dec_vote = primary_caps.poseelif self._vote_type == 'soft':primary_dec_vote = res.soft_winnerelif self._vote_type == 'hard':primary_dec_vote = res.winnerelse:raise ValueError('Invalid vote_type="{}"".'.format(self._vote_type))if self._pres_type == 'enc':primary_dec_pres = preselif self._pres_type == 'soft':primary_dec_pres = res.soft_winner_preselif self._pres_type == 'hard':primary_dec_pres = res.winner_preselse:raise ValueError('Invalid pres_type="{}"".'.format(self._pres_type))# PCAE解码部分# part重建图片res.bottom_up_rec = self._primary_decoder(primary_caps.pose,primary_caps.presence,template_feature=primary_caps.feature,img_embedding=primary_caps.img_embedding)# object重建出概率大的part,来重建图片res.top_down_rec = self._primary_decoder(res.winner,primary_caps.presence,template_feature=primary_caps.feature,img_embedding=primary_caps.img_embedding)#?rec = self._primary_decoder(primary_dec_vote,primary_dec_pres,template_feature=primary_caps.feature,img_embedding=primary_caps.img_embedding)tile = snt.TileByDim([0], [res.vote.shape[1]])tiled_presence = tile(primary_caps.presence)tiled_feature = primary_caps.featureif tiled_feature is not None:tiled_feature = tile(tiled_feature)tiled_img_embedding = tile(primary_caps.img_embedding)# object重建出概率大的part,来重建图片res.top_down_per_caps_rec = self._primary_decoder(snt.MergeDims(0, 2)(res.vote),snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence,template_feature=tiled_feature,img_embedding=tiled_img_embedding)res.templates = templatesres.template_pres = presres.used_templates = rec.transformed_templatesres.rec_mode = rec.pdf.mode()res.rec_mean = rec.pdf.mean()res.mse_per_pixel = tf.square(target_x - res.rec_mode)res.mse = math_ops.flat_reduce(res.mse_per_pixel)res.rec_ll_per_pixel = rec.pdf.log_prob(target_x)res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel)n_points = int(res.posterior_mixing_probs.shape[1])mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1)(res.posterior_within_sparsity_loss,res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(self._posterior_sparsity_loss_type,mass_explained_by_capsule / n_points,num_classes=self._n_classes)(res.prior_within_sparsity_loss,res.prior_between_sparsity_loss) = _capsule.sparsity_loss(self._prior_sparsity_loss_type,res.caps_presence_prob,num_classes=self._n_classes,within_example_constant=self._prior_within_example_constant)label = self._label(data)if label is not None:res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe(mass_explained_by_capsule,label,self._n_classes,labeled=data.get('labeled', None))res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe(res.caps_presence_prob,label,self._n_classes,labeled=data.get('labeled', None))res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc)res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence)if self._weight_decay > 0.0:decay_losses_list = []for var in tf.trainable_variables():if 'w:' in var.name or 'weights:' in var.name:decay_losses_list.append(tf.nn.l2_loss(var))res.weight_decay_loss = tf.reduce_sum(decay_losses_list)else:res.weight_decay_loss = 0.0return res

这里的res是个字典,在return打个断点,可以看到其保存许多计算结果

>>> res.keys()
dict_keys(['vote', 'scale', 'vote_presence', 'pres_logit_per_caps', 'pres_logit_per_vote', 'dynamic_weights_l2', 'raw_caps_params', 'raw_caps_features', 'log_prob', 'winner', 'winner_pres', 'is_from_capsule', 'mixing_logits', 'mixing_log_prob', 'soft_winner', 'soft_winner_pres', 'posterior_mixing_probs', 'caps_presence_prob', 'primary_presence', 'bottom_up_rec', 'top_down_rec', 'top_down_per_caps_rec', 'templates', 'template_pres', 'used_templates', 'rec_mode', 'rec_mean', 'mse_per_pixel', 'mse', 'rec_ll_per_pixel', 'rec_ll', 'posterior_within_sparsity_loss', 'posterior_between_sparsity_loss', 'prior_within_sparsity_loss', 'prior_between_sparsity_loss', 'posterior_cls_xe', 'posterior_cls_acc', 'prior_cls_xe', 'prior_cls_acc', 'best_cls_acc', 'primary_caps_l1', 'weight_decay_loss'])

有了这些计算结果,其损失即可计算:

def _loss(self, data, res):loss = (-res.rec_ll - self._caps_ll_weight * res.log_prob +self._dynamic_l2_weight * res.dynamic_weights_l2 +self._primary_caps_sparsity_weight * res.primary_caps_l1 +self._posterior_within_example_sparsity_weight *res.posterior_within_sparsity_loss -self._posterior_between_example_sparsity_weight *res.posterior_between_sparsity_loss +self._prior_within_example_sparsity_weight *res.prior_within_sparsity_loss -self._prior_between_example_sparsity_weight *res.prior_between_sparsity_loss +self._weight_decay * res.weight_decay_loss)try:loss += res.posterior_cls_xe + res.prior_cls_xeexcept AttributeError:passreturn loss

五、总结

本文是笔者对Stacked Capsule Autoencoders这篇文章的学习笔记,并研究了原版TensorFlow的实现。作者使用的技术细节很多,还有许多地方没有cover到。下一步打算动手实现一个pytorch版本,增进理解。

参考资料

[1] Capsule Networks (CapsNets) – Tutorial https://www.youtube.com/watch?v=pPN8d0E3900&t=425s

[2] 胶囊 (向量神经) 网络 https://mp.weixin.qq.com/s/Gjf_0y8waZ6Xx7r0Qvf4Ow

[3] Set Transformer系列:从集合数据处理网络到集合数据生成模型 https://zhuanlan.zhihu.com/p/264788321

[4] J. Lee, Y. Lee, J. Kim, A. R. Kosiorek, S. Choi, and Y. W. Teh, “Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks,” arXiv:1810.00825 [cs, stat], May 2019, Accessed: Jun. 26, 2020. [Online]. Available: http://arxiv.org/abs/1810.00825.

[5] 详解Transformer (Attention Is All You Need) https://zhuanlan.zhihu.com/p/48508221

[6] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need [C]//Advances in Neural Information Processing Systems. 2017: 5998-6008.

[7] Transformers Assemble(PART V) https://zhuanlan.zhihu.com/p/112477169

[8] The Illustrated Transformer http://jalammar.github.io/illustrated-transformer/

[9] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. In Advances in Neural Information Processing Systems (NeurIPS), 2017.

[10] S. Sabour, N. Frosst, and G. E. Hinton, “Dynamic Routing Between Capsules,” p. 11.

[11] 原作者代码https://github.com/google-research/google-research/tree/master/stacked_capsule_autoencoders

[12] Pytorch复现 https://github.com/phanideepgampa/stacked-capsule-networks

[13] https://github.com/abseil/abseil-py/blob/master/smoke_tests/sample_app.py

[14] https://abseil.io/docs/python/quickstart

[15] akosiorek的githubhttps://github.com/akosiorek

[16] https://www.cnblogs.com/huangshiyu13/p/6721805.html

[17] https://towardsdatascience.com/debugging-in-tensorflow-392b193d0b8

简明《Stacked Capsule Autoencoders》相关推荐

  1. Stacked Capsule Autoencoders

    引言 a) 物体由一系列几何组织部件所组成.本文引入了一个无监督的胶囊自动编码器(SCAE),意在使用部件之间的几何关系去推理对象: b) 由于部件间的关系,不依赖于视角的变动,所以该模型对视点的变化 ...

  2. 【文献阅读】Stacked What-Where Auto-encoders -ICLR-2016

    一.Abstract 提出一种新的autoencoder -- SWWAE(stacked what-where auto-encoders),更准确的说是一种 convolutional autoe ...

  3. Stacked Denoising Autoencoders (SDAE)

    教程地址:http://www.deeplearning.net/tutorial/SdA.html The Stacked Denoising Autoencoder (SdA) is an ext ...

  4. 图灵奖得主Geoffrey Hinton:脱缰的无监督学习,将带来什么

    与6位图灵奖得主和100多位专家 共同探讨人工智能的下一个十年 北京智源大会倒计时:4天 在即将举行的第二届北京智源大会上(官网:https://2020.baai.ac.cn),图灵奖获得者Geof ...

  5. 一文全览,AAAI 2020上的知识图谱

    2020-02-15 05:34:40 作者 | 杨晓凡 责编 | 贾伟 AI 科技评论按:2020 年 2 月 9 日,AAAI 2020 的主会议厅讲台上迎来了三位重量级嘉宾,这三位也是我们熟悉. ...

  6. 三巨头共聚AAA:ICapsule没有错,LeCun看好自监督,Bengio谈注意力

    2020-02-11 15:38:14 机器之心报道 参与:思源.Jamin 深度学习三巨头在 AAAI 讲了什么?2019 版 Capsule 这条路走得对:自监督学习是未来:注意力机制是逻辑推理与 ...

  7. Hinton向AAAI提交论文竟收到最差评价!深度学习三教父再押宝,AI或突破常识瓶颈...

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自新智元.   新智元报道   来源:zdnet等 编辑:张佳.程旸.鹏飞 [新智元导读]日前,图灵奖获得者.深度学习三巨头Geoffrey H ...

  8. 【今日CV 计算机视觉论文速览 第132期】Tue, 18 Jun 2019

    今日CS.CV 计算机视觉论文速览 Tue, 18 Jun 2019 Totally 64 papers ?上期速览✈更多精彩请移步主页 Interesting: ?****MMDetection, ...

  9. Hinton AAAI2020 演讲——胶囊网络

    2020 年 2 月 9 日,AAAI 2020 的主会议厅讲台上迎来了三位重量级嘉宾,这三位也是我们熟悉.拥戴的深度学习时代的开拓者:Geoffrey Hinton,Yann LeCun,Yoshu ...

最新文章

  1. 干货丨计算机视觉必读:图像分类、定位、检测,语义分割和实例分割方法梳理(经典长文,值得收藏)
  2. 2018, 自动驾驶异常艰难的一年
  3. 10年嵌入式工程师经验之谈:对于研发工作的感悟
  4. php5.3教程,php5.3.3配置教程
  5. 每日一题/007/级数/设a_n=1-1/2+1/3- ... + (-1)^(n-1)*1/n,求 lim_{n\to\infty}a_n
  6. python 公司教程_最全Python快速入门教程,满满都是干货
  7. Session存放token/获取token,销毁session
  8. matlab 指数函数拟合,[转载]MATLAB数据拟合例子(一次函数、指数函数、双曲线)...
  9. MyBatis-18MyBatis代码生成器-Example讲解
  10. VM12 虚拟机使用桥接模式却连不上网的解决办法
  11. python中怎么统计英文字符的个数_【Python练习1】统计一串字符中英文字母、空格、数字和其他字符的个数...
  12. python web py入门(12)- 实现用户登录论坛
  13. HBase-14.1-JMX监控实战-hadoop
  14. 车企围攻整车OS,这张“新王牌”怎么打?
  15. 线性回归系数的几个性质
  16. 服务器操作系统windows2016,微软正式发布服务器操作系统系统Windows Server 2016
  17. Python抓图必学的8种方式!
  18. linux redis查看密码,Redis集群设置密码和查看密码方法
  19. JavaScript 伪数组和数组
  20. 《Adobe Photoshop CS6中文版经典教程(彩色版)》—第2课2.12节保存用于四色印刷的图像...

热门文章

  1. 第三章 基本数据类型
  2. MATLAB 存储读入:dat文件、bin文件以及mat文件
  3. 广西北海中学2021年高考成绩查询,北海中学排名前十名,2021年北海中学排名一览表...
  4. c语言作业朱鸣华,2c语言程序设计教程 上机实验答案 朱鸣华 刘旭麟 杨微 著 机械工业出版社.pdf...
  5. SAP中物料成本核算单位不能小于价格单位
  6. 国外别墅后期PS教程洛阳生
  7. 智慧地产.美的·云与湖科技潮玩体育公园
  8. vue.js-脚手架
  9. android7.1 修改TTS文字转语音选项的首选引擎默认项
  10. MSN三叉戟的要害程度都强于C罗