简明《Stacked Capsule Autoencoders》
简明《Stacked Capsule Autoencoders》
什么是胶囊?什么是胶囊网络?胶囊真的有用吗?怎么实现一个胶囊网络?本文将会原理到实现,解读来自Hinton团队2019年发布的胶囊网络《Stacked Capsule Autoencoders》。由于个人水平有限,欢迎勘误,欢迎交流讨论。
Github:https://github.com/QiangZiBro
Zhihu:https://www.zhihu.com/people/QiangZiBro
公众号:QiangZiBro
笔者又制作了相关PPT,本文图片部分来自于PPT。搜索我的公众号QiangZiBro,回复”SCAE“获取PPT和相关资料。
一、前言
胶囊是什么?
它不是药物,而是来自人工智能领域被称作“神经网络之父”、“深度学习鼻祖”的科学家Hinton提出一种新型神经网络结构。从2017年开始,胶囊被认为是一个矢量神经元,对标于之前的标量网络。简而言之,胶囊是向量神经元,也即矢量神经元。
我们这里提到的神经元实际上来自被称作“神经网络(Neural Networks,NN)”的“人工神经网络”,神经元是NN的基本单位,神经元数学模型如下图所示
上一层输入神经元x1,…,xnx_1, … ,x_nx1,…,xn经过相关前馈计算,得到下一层输出神经元yyy。这些神经元由一位数字表示,因此它们被称作标量神经元。可以简单认为,胶囊网络将这里的标量神经元改进为矢量神经元。
【名词】
- 矢量神经元(Vector Neural,VN)即胶囊
- 标量神经元(Scalar Neural, SN)
为什么提出 CapsNet?
关于为什么要从标量迈向矢量,网上已有许多直觉的解释[1,2],这里不再赘述。简而言之,使用VN的CapsNet被认为有以下优点
- VN能更好地编码特征的一组相关信息,而SN不可以
- 引入VN的CapsNet解决了原本CNN的池化丢失信息问题
- CapsNet能学到姿态信息,换句话说,CapsNet是等变(equivariance)网络,而CNN是不变(invariance)网络。
上面提到的最后一点很重要,为什么我们关心等变性呢?因为(1)一个等变网络不需要大量训练数据(2)等变网络能学到输入的位姿。
这一版CapsNet有什么特点?
堆叠式胶囊自编码器(Stacked Capsule Autoencoders,SCAE)是Hinton团队提出第一、二两版胶囊网络后的又一版本,这篇使用无监督方法学习图像中的特征,并取得了最先进的结果。
直觉
中国有句古话:”横看成岭侧成峰,远近高低各不同“。人从不同视角(view)看object(比如山),看到的景物总是不一样的,然而object与part的关系确实固定的,正如山与树的关系。因此,object与part之间是视角不变的,人与object之间是视角等变的。因此,可以通过神经网络来显式地学习viewer与object的变换矩阵(OV)和object与part的变换矩阵(OP)。Hinton做过一个实验,如下图,旋转后的正方形看起来还像正方形吗?这个例子说明人的感官也会”欺骗“我们。
这一版胶囊网络将object胶囊表示为它对观察者的关系,因此这里不需要像CNN那样在空间上复制神经元的激活来表示一个object。
【使用到的名词】
- 观察者 (viewer)
- 对象 (object)
- 部件、部分 (part)
效果一览
- 无监督分类
- MNIST 98.7%
- SVHN 55%
阅读时思考几个问题
这一个版本怎么表示胶囊的?使用到的胶囊学到了什么?
论文里面提到的几种自编码器(CCAE,PCAE,OCAE,SCAE)的区别和联系是什么 ?
为什么需要把PCAE和OCAE两个自编码器堆叠在一起同时训练?
编码器部分:为什么使用Set Transformer对part capsules进行编码,这个编码器有什么特殊之处吗?
解码器部分:第一个例子中是怎样决策每个点的类别?第二个例子中图片怎么解码出来?
二、前置知识
2.1 自编码器
自编码器是一个试图去还原其原始输入的系统,自编码器由编码器和解码器组成。试图重建自己这个事情本身对我们来说没有太大意义,而真正有帮助的在于自编码器最中间的特征层。为什么中间学到的特征很重要?如果我们把每个样本最中间的学好的特征向量放在一起,会出现类似聚类的效果:同类之间距离小,不同类之间距离大。注意到中间特征层一般都是高维,高维空间很难去想象,我们可以用t-SNE进行降维可视化。而本文在MNIST数据集上学到降维可视化效果如下:
可见,不同样本在特征空间上分的很开,只有少数重叠,这个效果便很好。为什么自编码器能够达到这样的效果呢?因为在重建时,自编码器希望对每类样本都学到其最具有代表性的特征,只有能够清楚地区分它,才能够完成重建。而获得一个具有代表性的特征,才是特征学习或深度学习成功的关键。自编码器学习到优秀的特征后,可以用其做后续任务,比如本文用其做了一个简单的分类,在MNIST分类效果达到98.7%。
总结下,做深度学习时,我们更希望网络学到的特征能够尽可能全方位地代表这个样本,对应到特征空间,即是上图的效果;而不是某一方面来代表这个样本,自编码可能很好地做好前者,而不做自编码则很难获取有效的特征。因此,自编码器训练好之后,我们来进行后续的任务便较为容易了。
2.2 Set Transformer
笔者在这里第一次正式认识Transformer,发现这个词在中文里鲜有翻译,而是直接使用原词。笔者不是特别喜欢中文文章夹杂大量英文,但为了尽可能表示精准,还是使用适量英文。
2017年版本的胶囊网络[10]使用动态路由来完成一个聚类任务,而本文使用了Set Transformer[4],用来将part编码为object。Set Transformer输入包含nnn个样本的顺序无关的集合X∈Rn×dX \in \mathbb{R}^{n \times d}X∈Rn×d,输出包含kkk个样本的集合O∈Rk×dO \in \mathbb{R}^{k \times d}O∈Rk×d。
简介
Set Transformer有什么应用场景呢?它在最大值回归、计数不同字符、混合高斯、集合异常检测和点云分类五个任务上有较好表现。处理这些任务的模型有两个特性:
- 排列不变性:即输出结果与输入顺序无关;
- 能够处理任何大小的输入
模型框架
下面我们从至顶向下的方式了解Set Transformer细节。
首先解决集合排列不变的这类任务的模型,有文献[9]证明可以写成如下通式:
net({x1,…,xn})=ρ(pool({ϕ(x1),…,ϕ(xn)}))\operatorname{net}\left(\left\{x_{1}, \ldots, x_{n}\right\}\right)=\rho\left(\operatorname{pool}\left(\left\{\phi\left(x_{1}\right), \ldots, \phi\left(x_{n}\right)\right\}\right)\right) net({x1,…,xn})=ρ(pool({ϕ(x1),…,ϕ(xn)}))
上式是个自编码器,其中ϕ(x)\phi\left(x\right)ϕ(x)是编码器,ρ(x)\rho\left( x\right)ρ(x)是解码器,pool(x)\operatorname{pool}(x)pool(x)是池化操作,比如max,mean操作。
Set Transformer也是这类网络结构,当然也是自编码器,其编码器可以为下面两种的其中一种:
Z=Encoder(X)=SAB(SAB(X))∈Rn×dZ=\operatorname{Encoder}(X)=\operatorname{SAB}(\operatorname{SAB}(X)) \in \mathbb{R}^{n \times d} Z=Encoder(X)=SAB(SAB(X))∈Rn×d
Z=Encoder(X)=ISABm(ISABm(X))∈Rn×dZ=\operatorname{Encoder}(X)=\operatorname{ISAB}_{m}\left(\operatorname{ISAB}_{m}(X)\right) \in \mathbb{R}^{n \times d} Z=Encoder(X)=ISABm(ISABm(X))∈Rn×d
解码器是
O=Decoder(Z;λ)=rFF(SAB(PMAk(Z)))∈Rk×dO=\operatorname{Decoder}(Z ; \lambda)=\operatorname{rFF}\left(\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right)\right) \in \mathbb{R}^{k \times d} O=Decoder(Z;λ)=rFF(SAB(PMAk(Z)))∈Rk×d
其中池化操作PMAk(Z)=MAB(S,rFF(Z))∈Rk×d\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) \in \mathbb{R}^{k \times d}PMAk(Z)=MAB(S,rFF(Z))∈Rk×d,kkk表示输出集合中实例的个数,k<nk < nk<n。
细节
模型框架引入若干概念,下面对它们一一作出解释。Set Transformer引入了Transformer相关概念,详细讲解可以参考[5,6,8]。
我们设Q∈Rn×dq,K∈Rnv×dv,V∈Rnv×dvQ \in \mathbb{R}^{n \times d_{q}},K \in \mathbb{R}^{n_v \times d_{v}},V \in \mathbb{R}^{n_v \times d_{v}}Q∈Rn×dq,K∈Rnv×dv,V∈Rnv×dv表示query,key,value;ωj(⋅)=softmax(⋅/d)\omega_{j}(\cdot)=\operatorname{softmax}(\cdot / \sqrt{d})ωj(⋅)=softmax(⋅/d
Transformer有注意力机制和多头注意力机制两个基础构成:
注意力机制:Att(Q,K,V;ω)=ω(QK⊤)V∈Rn×dv\operatorname{Att}(Q, K, V ; \omega)=\omega\left(Q K^{\top}\right) V \in \mathbb{R}^{n \times d_v}Att(Q,K,V;ω)=ω(QK⊤)V∈Rn×dv,计算过程如下图,图源[8]
【注】下面的例子改编自[5],如果有误,欢迎指正!
Query,Key,Value的概念取自于信息检索系统,比如我们在电商平台搜索某件商品(年轻女士冬季穿的红色薄款羽绒服)时,在搜索引擎上输入的内容便是Query,怎么通过上面的式子给出想要的结果呢?搜索引擎根据Query匹配Key(例如商品的种类,颜色,描述等)和Value(各种羽绒服结果),然后根据Query和Key相乘并且通过激活函数ω\omegaω得到一个相似度矩阵ω(QK⊤)\omega\left(Q K^{\top}\right)ω(QK⊤),相似度矩阵矩阵再和Value矩阵进行相乘得到最终需要的加权结果。
多头注意力机制
Multihead(Q,K,V;λ,ω)=concat(Z1,…,Zh)WO∈Rn×d,Zj=Att(QWjQ,KWjK,VWjV;ωj)\operatorname{Multihead} (Q, K, V ; \lambda, \omega) = \operatorname{concat(Z_1,…,Z_h)}W^O \in \mathbb{R}^{n \times d},\\ Z_{j}=\operatorname{Att}\left(Q W_{j}^{Q}, K W_{j}^{K}, V W_{j}^{V} ; \omega_{j}\right) Multihead(Q,K,V;λ,ω)=concat(Z1,…,Zh)WO∈Rn×d,Zj=Att(QWjQ,KWjK,VWjV;ωj)
同样是Q,K,V输入,用多个注意机制获得多个输出,再将它们合并起来并用一个矩阵进行变换。在下图例子中h=7h=7h=7
- rFF(x)\operatorname{rFF}(x)rFF(x) Row-wise 前馈层,它同等地、独立地处理每个instance
Set Transformer根据上面的基础,提出下面的模块:
多头注意力模块(Multihead Attention Block)
MAB(X,Y)=LayerNorm(H+rFF(H))\operatorname{MAB}(X, Y)=\operatorname{LayerNorm}(H+\operatorname{rFF}(H)) MAB(X,Y)=LayerNorm(H+rFF(H))
其中H=LayerNorm (X+Multihead(X,Y,Y;ω))H=\text { LayerNorm }(X+\operatorname{Multihead}(X, Y, Y ; \omega))H=LayerNorm(X+Multihead(X,Y,Y;ω))集合注意力模块(Set Attention Block ),计算复杂度O(n2)\mathcal{O}\left(n^{2}\right)O(n2)
SAB(X)=MAB(X,X)\operatorname{SAB}(X) = \operatorname{MAB}(X,X) SAB(X)=MAB(X,X)
- 诱导集合注意力模块(Induced Set Attention Block )
ISABm(X)=MAB(X,H)∈Rn×d\operatorname{ISAB}_m(X)=\operatorname{MAB}(X, H) \in \mathbb{R}^{n \times d} ISABm(X)=MAB(X,H)∈Rn×d
其中H=MAB(I,X)∈Rm×dH=\operatorname{MAB}(I,X) \in \mathbb{R}^{m \times d}H=MAB(I,X)∈Rm×d,I∈Rm×dI \in \mathbb{R}^{m \times d}I∈Rm×d为可学习参数。
为什么提出ISAB呢?[7]SAB的问题是transformer的传统问题,复杂度太高。所以引入诱导点(induced points)矩阵I∈Rm×dI \in \mathbb{R}^{m \times d}I∈Rm×d,类似于矩阵分解,将原来的一步attention拆分成两步attention,首先用III对输入X做self-attention,接着用得到的结果对输入做attention。将复杂度从O(n2)\mathcal{O}\left(n^{2}\right)O(n2)降到O(mn)\mathcal{O}\left(mn\right)O(mn)
- 多头注意力机制的池化(Pooling by Multihead Attention)。池化是一种常见的聚合(aggregation)操作。上面提到,池化可以是最大或是平均。这里提出的池化是应用一个MAB在一个可学习的矩阵S∈Rk×dS \in \mathbb{R}^{k \times d}S∈Rk×d上。在一些聚类任务上,kkk设为我们需要的类别数。使用基于注意力的池化的直觉是,每个实例对target的重要性都不一样
PMAk(Z)=MAB(S,rFF(Z))\operatorname{PMA}_{k}(Z)=\operatorname{MAB}(S, \operatorname{rFF}(Z)) PMAk(Z)=MAB(S,rFF(Z))
H=SAB(PMAk(Z))H=\operatorname{SAB}\left(\operatorname{PMA}_{k}(Z)\right) H=SAB(PMAk(Z))
小结
Set Transformer是一个使用注意力机制的、以集合为输入的自编码器,这个模型在编码和池化模块中都使用了注意力机制。改变池化模块kkk的大小,可以改变输出集合大小。
三、 方法
3.1 一个玩具样例——集群自编码器(Constellation Autoencoder,CCAE)
“Constellation”中文译为星座,机器之心曾译作集群。
作者在这里做的任务是:对二维平面上的一组点{xm∣m=1,…,M}\left\{\mathbf{x}_{m} \mid m=1, \ldots, M\right\}{xm∣m=1,…,M}进行聚类。如下图,平面上两个正方形、一个三角形共有11个点构成,每个点即是part,通过无监督聚类去把每个点划分到对应的类别当中。自编码器在学习的过程,是在学习数据的分布和表征。
这个自编码器整体思路如下图所示
编码器Set Transformer将11个2维点集(每个点是part)编码为3个object capsule,每个object 由一个2×22\times 22×2大小的观察者——对象(OV)矩阵,一个特征向量cic_ici,和概率aia_iai这三个量表示。每个object通过MLP进行解码获得4个part capsule,每个part由2×12\times 12×1对象——部件(OP)矩阵,概率ai,ja_{i,j}ai,j,和标准差λi,j\lambda_{i,j}λi,j构成。每个part用高斯分布近似表示在平面上点的概率分布,具体来讲,通过对应OV矩阵和OP矩阵相乘表示高斯分布的均值,λi,j\lambda_{i,j}λi,j表示这个高斯分量的方差。因此可以通过这个混合高斯模型来计算整个数据分布的可能性:
p(x1:M)=∏m=1M∑k=1K∑n=1Nakak,n∑iai∑jai,jp(xm∣k,n)p\left(\mathbf{x}_{1: M}\right)=\prod_{m=1}^{M} \sum_{k=1}^{K} \sum_{n=1}^{N} \frac{a_{k} a_{k, n}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, n\right) p(x1:M)=m=1∏Mk=1∑Kn=1∑N∑iai∑jai,jakak,np(xm∣k,n)
通过最大似然估计来对上式进行求解,通过RMSProp来对参数进行更新。
对每个点,其类别的最终决策为:
k⋆=argmaxkakak,np(xm∣k,n)k^{\star}=\arg \max _{k} a_{k} a_{k, n} p\left(\mathbf{x}_{m} \mid k, n\right) k⋆=argkmaxakak,np(xm∣k,n)
也就是说,每个点最终都会属于三个object capsule的其中一个。这里,object capsule和part capsule的个数都是超参,因为我们预先知道数据可以分三类,所以将object的个数定为3。
【注】
看原文的OV和OP矩阵大小可能会让我们产生疑惑,原文表述两个矩阵是3×33 \times 33×3,实际上,OV矩阵大小2×22\times 22×2,OP矩阵大小2×12\times 12×1,相乘得到大小为2×12\times 12×1的均值。也即每个高斯分量都是个二元的高斯分布。
3.2 图片实验——SCAE
模型概览
上面toy setup讨论完,来到图像实验。图片实验的模型SCAE也照应了标题,SCAE的框架如下图所示
SCAE由两个自编码器构成:
- 部分胶囊自编码器 (Part Capsule Autoencoder,PCAE)
图片 --> part --> 图片
其中图片到part使用了CNN,一直卷积到特征图大小为1,对结果特征图划分了24个part胶囊,每个胶囊长度为(6+n+1)位(见下表)
- 对象胶囊自编码器 (Object Capsule Autoencoder,OCAE)
part --> object --> part
part到object编码器是Set Transformer,object到part解码器是mlp,这些可以参考3.1和2.2。
胶囊表示方法对比
SCAE和CCAE胶囊表示方法的不同,笔者将其归纳如下表(红色显示出来不同项)
PCAE
原文对PCAE表述如下
第一行式子是CNN编码part,下面几行表示将part解码回图片,是我们需要重点了解的。基于图像模板技术重建图片流程可以用下图所示
图像实验中的part可以比作CCAE例子中平面的点,CCAE的part是2维,PCAE的part是(6+1+n)维。
PCAE如何进行解码的? 原作者使用了图片模板技术将part解码图片,具体讲,第mmm个part使用一个可学习的图片模板Tm∈[0,1]ht×wt×(c+1)T_{m} \in[0,1]^{h_{t} \times w_{t} \times(c+1)}Tm∈[0,1]ht×wt×(c+1),这个模板比输入图片小,并且多一个通道TaT^aTa表示被其他模板的遮挡。比如MNIST大小为28×28×128 \times 28 \times 128×28×1,那么使用的模板大小有11×11×(1+1)11\times11 \times (1+1)11×11×(1+1)。应用第mmm个part的姿态矩阵xmx_mxm在TmT_mTm上,即可得到和输入大小相同的图片T^m\widehat{T}_{m}T
**每个part解码出的结果怎么表示图片呢?**对输出图片每个像素位置(i,j)(i,j)(i,j),可以得到关于像素强度的高斯分布。有了这个分布,可以确定可能性最大的像素强度,进而确定一张由part预测的图片。若干part解码的高斯分布的叠加,得到总的输出图片。可以用下式表示图片的总似然值
p(y)=∏i,j∑m=1Mpm,i,jyN(yi,j∣cm⋅T^m,i,jc;σy2)p(\mathbf{y})=\prod_{i, j} \sum_{m=1}^{M} p_{m, i, j}^{y} \mathcal{N}\left(y_{i, j} \mid \boldsymbol{c}_{m} \cdot \widehat{T}_{m, i, j}^{c} ; \sigma_{y}^{2}\right) p(y)=i,j∏m=1∑Mpm,i,jyN(yi,j∣cm⋅T
OCAE
最后 使用OCAE对part进行自编码
p(x1:M,d1:M)=∏m=1M[∑k=1Kakak,m∑iai∑jai,jp(xm∣k,m)]dmp\left(\mathbf{x}_{1: M}, d_{1: M}\right)=\prod_{m=1}^{M}\left[\sum_{k=1}^{K} \frac{a_{k} a_{k, m}}{\sum_{i} a_{i} \sum_{j} a_{i, j}} p\left(\mathbf{x}_{m} \mid k, m\right)\right]^{d_{m}} p(x1:M,d1:M)=m=1∏M[k=1∑K∑iai∑jai,jakak,mp(xm∣k,m)]dm
小结
上面是全部的SCAE全部内容,理论上来说,只需要对目标函数logp(y)+logp(x1:M))\left.\log p(\mathbf{y})+\log p\left(\mathbf{x}_{1: M}\right)\right)logp(y)+logp(x1:M))进行训练就好,但作者考虑不对目标函数进行约束可能遇到两个问题:(1)使用所有part和object去推理一张图片 (2)不管输入是什么,总是用一部分胶囊学习。为了让模型能够使用不同的part,以及让object的特别性,作者引入了稀疏性和交叉熵约束,这里的具体数学细节可以查阅原文。
四、源码阅读
原作者开源的代码基于TensorFlow1.15[11]。笔者没有学过tf,虽然网上也有一些开源复现版本[12],但还是原作者tf版本写的最详细,遂硬着头皮读源码。
4.1 配置代码
作者写的代码规范性良好,而且也很容易配置好环境跑起来,不过鉴于开源在了google-research这个大仓库里,进行调试开发时还是要进行相关的配置。
使用环境
- ubuntu18(GP100)
- pycharm
配置项目
这个项目放在谷歌的大仓库里面,我们把它单独提出来。有两个方案,使用svn更快一些。
# 方案1 使用git
git clone https://github.com/google-research/google-research --depth 1
mkdir stacked_capsule_autoencoders
cp -r google-research/stacked_capsule_autoencoders stacked_capsule_autoencoders/# 方案2 使用svn
sudo apt-get install subversion -y
SUBDIR="stacked_capsule_autoencoders"
svn export https://github.com/google-research/google-research/trunk/$SUBDIR
mkdir tmp
mv stacked_capsule_autoencoders tmp
mv tmp stacked_capsule_autoencoders #项目根目录、工作目录
- 代码风格
原作者使用2个空格缩进,笔者习惯4个空格。对原作者所有代码运用了autopep8,2个空格的缩进改成了4个空格缩进。一个命令即可:
cd stacked_capsule_autoencoders
find . -name '*.py' -exec autopep8 --in-place --aggressive --aggressive {} \;
- Python环境配置
作者使用了virtualenv来单独管理这个python环境,使用下面这个脚本便可以一键设置python环境,很省心。注意要配置好镜像源,否则下载会比较慢。
【Note】 在安装环境时遇到一个error,笔者将requirements.txt第一行改成了
absl-py<0.11,>=0.9
。
bash stacked_capsule_autoencoders/setup_virtualenv.sh
环境配置好之后,目录结构如下图所示。接着我们在pycharm里选择这个编译器,这里就不细讲了。
- pycharm配置和运行。根据提供的运行脚本,编辑pycharm的运行配置,方便IDE调试。这里提供笔者运行的参数。
注意工作目录选项目根目录
Constellation,发现不能使用–plot参数进行可视化,有bug。
--name=constellation
--model=constellation
--dataset=constellation
--prior_within_example_sparsity_weight=1.
--prior_between_example_sparsity_weight=1.
--posterior_within_example_sparsity_weight=0.
--posterior_between_example_sparsity_weight=0.
--overwrite
MNIST
--name=mnist
--model=scae
--dataset=mnist
--max_train_steps=300000
--batch_size=128
--lr=3e-5
--use_lr_schedule=True
--canvas_size=40
--n_part_caps=40
--n_obj_caps=32
--colorize_templates=True
--use_alpha_channel=True
--plot=True
--posterior_between_example_sparsity_weight=0.2
--posterior_within_example_sparsity_weight=0.7
--prior_between_example_sparsity_weight=0.35
--prior_within_example_constant=4.3
--prior_within_example_sparsity_weight=2.
--color_nonlin=sigmoid
--template_nonlin=sigmoid
- 命令行运行
bash stacked_capsule_autoencoders/run_constellation.sh
bash stacked_capsule_autoencoders/run_mnist.sh
- 调试
由于TensorFlow使用了计算图机制,调试变得较为复杂,笔者搜集了关于调试TensorFlow的资料,可以参考[16]
4.2 技术框架
每个作者都有不同的工具喜好,首先简要了解下原作者使用了什么技术进行构建,方便我们更好理解项目。TensorFlow不用说,作者还使用了monty
、sonnet
和absl
。
- monty
Monty是Python缺少特性的补充。Monty为不属于标准库的Python实现了补充的helper函数。如对文件压缩的透明支持,有用的设计模式,如单例和缓存类等。笔者最喜欢的一个数据结构是支持字典的点索引,例子:
from monty.collections import AttrDict
d = AttrDict(foo=1, bar=2)
assert d["foo"] == d.foo
d.bar = "hello"
assert d.bar == "hello"
- Sonnet
Sonnet是一个建立在TensorFlow 之上的库,旨在为机器学习研究提供简单的、可组合的抽象。比如在这个项目中会看到每个模型里定义了_build
类方法,这个方法类似pytorchModule.forward
方法。
- absl
这个库用于构建Python应用程序的Python库代码的集合。 该代码是从Google自己的Python代码库中收集的,并且已经过广泛测试并用于生产中。可以参考其例子[13]和文档[14]进行学习。这里也举个例子
from absl import app
from absl import flagsFLAGS = flags.FLAGSflags.DEFINE_string('name', 'Jane Random', 'Your name.')def main(argv):print('Happy, ', FLAGS.name)if __name__ == '__main__':app.run(main)
当程序运行时,app.run()
解析flags命令行参数,并运行main程序。
4.3 从训练管窥整个框架
由于这个项目代码量相对较多,我们从顶层逐渐阅读。我们这里以mnist为例,弄明白下面三个问题:
- 训练的主循环
这里最大训练30万次,每500次进行一个train-val和report
- 了解模型是怎样定义前向计算
结论
- 每个模型里定义的
_build
类方法,类似pytorchModule.forward
方法,定义了相关前向计算 - 每个模型里定义的
_loss
类方法,是在前向计算好的结果上,求和得到最终损失 - 每个模型里面
make_target
先调用_build
后调用_loss
计算损失
【探究细节】
作者基于
snt.AbstractModule
实现了一个抽象类Model
类,make_target
是继承于Model
类的方法,其首先计算一次前向运算的相关结果,并且计算损失。其在Model
类实现如下简而言之,
Model
类的make_target
需要调用self.__call__
方法,ok,没问题,来看看self.__call__
的实现粗读其注释,了解到
_call
方法通过_build
方法将模块连到计算图里面。也就是说,我们自己构建模型需要实现一个类似pytorchforward
的方法,这个方法取名叫_build
。adamk同学对每个模型都实现了_build
方法,我们直接看图像SCAE的前向计算
3.如何使用定义好的模型?
实现的模型放在models
这个文件夹里,每个模型都是一个python类。模型实现好之后,在model_config.py
有一个get
函数,根据配置实例化一个对象来。这是一个典型的工厂模式
get
函数返回模型架构、优化器、学习率。我们使用下面代码即可做相应的调用
model_dict = model_config.get(FLAGS)lr = model_dict.lr
opt = model_dict.opt
model = model_dict.model
4.4 Set Transformer的实现
Set Transformer是本文用到的一个编码器,其完成part到object的编码。来看看其具体实现
class SetTransformer(snt.AbstractModule):"""Permutation-invariant Transformer."""def __init__(self,n_layers,n_heads,n_dims,n_output_dims,n_outputs,layer_norm=False,dropout_rate=0.,n_inducing_points=0):super(SetTransformer, self).__init__()self._n_layers = n_layersself._n_heads = n_headsself._n_dims = n_dimsself._n_output_dims = n_output_dimsself._n_outputs = n_outputsself._layer_norm = layer_normself._dropout_rate = dropout_rateself._n_inducing_points = n_inducing_pointsdef _build(self, x, presence=None):batch_size = int(x.shape[0])h = snt.BatchApply(snt.Linear(self._n_dims))(x)args = [self._n_heads, self._layer_norm, self._dropout_rate]klass = SelfAttention # MABif self._n_inducing_points > 0:args = [self._n_inducing_points] + argsklass = InducedSelfAttention #ISABfor _ in range(self._n_layers):h = klass(*args)(h, presence)z = snt.BatchApply(snt.Linear(self._n_output_dims))(h)inducing_points = tf.get_variable('inducing_points', shape=[1, self._n_outputs, self._n_output_dims])inducing_points = snt.TileByDim([0], [batch_size])(inducing_points)return MultiHeadQKVAttention(self._n_heads)(inducing_points, z, z, presence)
4.5 CCAE实现
model_config.py
里配置的CCAE模型:
model = ConstellationAutoencoder(encoder=encoder, # Set Transformerdecoder=decoder, # mlpmixing_kl_weight=config.mixing_kl_weight, #0,作者没有用sparsity_weight=config.sparsity_weight, #10 这一个和下一个用于对损失加权dynamic_l2_weight=config.dynamic_l2_weight, #10#prior_sparsity_loss_type='l2',prior_within_example_sparsity_weight=config.prior_within_example_sparsity_weight,# pylint:disable=line-too-longprior_between_example_sparsity_weight=config.prior_within_example_sparsity_weight,# pylint:disable=line-too-longprior_within_example_constant=0.,posterior_sparsity_loss_type='entropy',posterior_within_example_sparsity_weight=config.posterior_within_example_sparsity_weight,# pylint:disable=line-too-longposterior_between_example_sparsity_weight=config.posterior_between_example_sparsity_weight,# pylint:disable=line-too-long
)
模型前向计算
def _build(self, data):x = data[self._input_key]presence = data[self._presence_key] if self._presence_key else Noneinputs = nest.flatten(x)if presence is not None:inputs.append(presence)h = self._encoder(*inputs)res = self._decoder(h, *inputs)n_points = int(res.posterior_mixing_probs.shape[1])mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1)(res.posterior_within_sparsity_loss,res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(self._posterior_sparsity_loss_type,mass_explained_by_capsule / n_points,num_classes=self._n_classes)(res.prior_within_sparsity_loss,res.prior_between_sparsity_loss) = _capsule.sparsity_loss(self._prior_sparsity_loss_type,res.caps_presence_prob,num_classes=self._n_classes,within_example_constant=self._prior_within_example_constant)return res
4.6 SCAE实现
model_config.py
里SCAE调用
model = ImageAutoencoder(primary_encoder=part_encoder, # 编码器,图片 --> partprimary_decoder=part_decoder, # 基于图片模板的解码器, part --> 图片encoder=obj_encoder, # SetTransformer编码器, part --> objectdecoder=obj_decoder, # mlp解码器 object --> partinput_key='image',label_key='label',n_classes=10,dynamic_l2_weight=10,caps_ll_weight=1.,vote_type='enc', # pose, soft, hardpres_type='enc',stop_grad_caps_inpt=True,stop_grad_caps_target=True,prior_sparsity_loss_type='l2',prior_within_example_sparsity_weight=config.prior_within_example_sparsity_weight,# pylint:disable=line-too-longprior_between_example_sparsity_weight=config.prior_between_example_sparsity_weight,# pylint:disable=line-too-longposterior_sparsity_loss_type='entropy',posterior_within_example_sparsity_weight=config.posterior_within_example_sparsity_weight,# pylint:disable=line-too-longposterior_between_example_sparsity_weight=config.posterior_between_example_sparsity_weight,# pylint:disable=line-too-long
)
SCAE的计算细节?
可以通过_build
了解每个模型的计算细节。看看SCAE的前向计算
# ImageAutoencoder,即SCAE
def _build(self, data):input_x = self._img(data, False)target_x = self._img(data, prep=self._prep)batch_size = int(input_x.shape[0])primary_caps = self._primary_encoder(input_x) # CNN构造初级胶囊pres = primary_caps.presence # 每个part的存在概率expanded_pres = tf.expand_dims(pres, -1)pose = primary_caps.pose # 每个part的姿态input_pose = tf.concat([pose, 1. - expanded_pres], -1) # 将part的姿态信息和概率信息作为 object编码器的输入input_pres = pres # 将part的概率作为 object编码器的输入if self._stop_grad_caps_inpt:input_pose = tf.stop_gradient(input_pose)input_pres = tf.stop_gradient(pres)target_pose, target_pres = pose, pres # OCAE的输出if self._stop_grad_caps_target:target_pose = tf.stop_gradient(target_pose)target_pres = tf.stop_gradient(target_pres)# skip connection from the img to the higher level capsuleif primary_caps.feature is not None:input_pose = tf.concat([input_pose, primary_caps.feature], -1)# OCAE的编码,注意这里给Set Transformer输入的细节:# try to feed presence as a separate input# and if that works, concatenate templates to poses# this is necessary for set transformern_templates = int(primary_caps.pose.shape[1])templates = self._primary_decoder.make_templates(n_templates,primary_caps.feature) try:if self._feed_templates:inpt_templates = templatesif self._stop_grad_caps_inpt:inpt_templates = tf.stop_gradient(inpt_templates)if inpt_templates.shape[0] == 1:inpt_templates = snt.TileByDim([0], [batch_size])(inpt_templates)inpt_templates = snt.BatchFlatten(2)(inpt_templates)pose_with_templates = tf.concat([input_pose, inpt_templates], -1)else:pose_with_templates = input_poseh = self._encoder(pose_with_templates, input_pres)except TypeError:h = self._encoder(input_pose)# OCAE的解码res = self._decoder(h, target_pose, target_pres)# 下面是其他中间结果、损失计算res.primary_presence = primary_caps.presenceif self._vote_type == 'enc':primary_dec_vote = primary_caps.poseelif self._vote_type == 'soft':primary_dec_vote = res.soft_winnerelif self._vote_type == 'hard':primary_dec_vote = res.winnerelse:raise ValueError('Invalid vote_type="{}"".'.format(self._vote_type))if self._pres_type == 'enc':primary_dec_pres = preselif self._pres_type == 'soft':primary_dec_pres = res.soft_winner_preselif self._pres_type == 'hard':primary_dec_pres = res.winner_preselse:raise ValueError('Invalid pres_type="{}"".'.format(self._pres_type))# PCAE解码部分# part重建图片res.bottom_up_rec = self._primary_decoder(primary_caps.pose,primary_caps.presence,template_feature=primary_caps.feature,img_embedding=primary_caps.img_embedding)# object重建出概率大的part,来重建图片res.top_down_rec = self._primary_decoder(res.winner,primary_caps.presence,template_feature=primary_caps.feature,img_embedding=primary_caps.img_embedding)#?rec = self._primary_decoder(primary_dec_vote,primary_dec_pres,template_feature=primary_caps.feature,img_embedding=primary_caps.img_embedding)tile = snt.TileByDim([0], [res.vote.shape[1]])tiled_presence = tile(primary_caps.presence)tiled_feature = primary_caps.featureif tiled_feature is not None:tiled_feature = tile(tiled_feature)tiled_img_embedding = tile(primary_caps.img_embedding)# object重建出概率大的part,来重建图片res.top_down_per_caps_rec = self._primary_decoder(snt.MergeDims(0, 2)(res.vote),snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence,template_feature=tiled_feature,img_embedding=tiled_img_embedding)res.templates = templatesres.template_pres = presres.used_templates = rec.transformed_templatesres.rec_mode = rec.pdf.mode()res.rec_mean = rec.pdf.mean()res.mse_per_pixel = tf.square(target_x - res.rec_mode)res.mse = math_ops.flat_reduce(res.mse_per_pixel)res.rec_ll_per_pixel = rec.pdf.log_prob(target_x)res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel)n_points = int(res.posterior_mixing_probs.shape[1])mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs, 1)(res.posterior_within_sparsity_loss,res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(self._posterior_sparsity_loss_type,mass_explained_by_capsule / n_points,num_classes=self._n_classes)(res.prior_within_sparsity_loss,res.prior_between_sparsity_loss) = _capsule.sparsity_loss(self._prior_sparsity_loss_type,res.caps_presence_prob,num_classes=self._n_classes,within_example_constant=self._prior_within_example_constant)label = self._label(data)if label is not None:res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe(mass_explained_by_capsule,label,self._n_classes,labeled=data.get('labeled', None))res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe(res.caps_presence_prob,label,self._n_classes,labeled=data.get('labeled', None))res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc)res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence)if self._weight_decay > 0.0:decay_losses_list = []for var in tf.trainable_variables():if 'w:' in var.name or 'weights:' in var.name:decay_losses_list.append(tf.nn.l2_loss(var))res.weight_decay_loss = tf.reduce_sum(decay_losses_list)else:res.weight_decay_loss = 0.0return res
这里的res
是个字典,在return打个断点,可以看到其保存许多计算结果
>>> res.keys()
dict_keys(['vote', 'scale', 'vote_presence', 'pres_logit_per_caps', 'pres_logit_per_vote', 'dynamic_weights_l2', 'raw_caps_params', 'raw_caps_features', 'log_prob', 'winner', 'winner_pres', 'is_from_capsule', 'mixing_logits', 'mixing_log_prob', 'soft_winner', 'soft_winner_pres', 'posterior_mixing_probs', 'caps_presence_prob', 'primary_presence', 'bottom_up_rec', 'top_down_rec', 'top_down_per_caps_rec', 'templates', 'template_pres', 'used_templates', 'rec_mode', 'rec_mean', 'mse_per_pixel', 'mse', 'rec_ll_per_pixel', 'rec_ll', 'posterior_within_sparsity_loss', 'posterior_between_sparsity_loss', 'prior_within_sparsity_loss', 'prior_between_sparsity_loss', 'posterior_cls_xe', 'posterior_cls_acc', 'prior_cls_xe', 'prior_cls_acc', 'best_cls_acc', 'primary_caps_l1', 'weight_decay_loss'])
有了这些计算结果,其损失即可计算:
def _loss(self, data, res):loss = (-res.rec_ll - self._caps_ll_weight * res.log_prob +self._dynamic_l2_weight * res.dynamic_weights_l2 +self._primary_caps_sparsity_weight * res.primary_caps_l1 +self._posterior_within_example_sparsity_weight *res.posterior_within_sparsity_loss -self._posterior_between_example_sparsity_weight *res.posterior_between_sparsity_loss +self._prior_within_example_sparsity_weight *res.prior_within_sparsity_loss -self._prior_between_example_sparsity_weight *res.prior_between_sparsity_loss +self._weight_decay * res.weight_decay_loss)try:loss += res.posterior_cls_xe + res.prior_cls_xeexcept AttributeError:passreturn loss
五、总结
本文是笔者对Stacked Capsule Autoencoders这篇文章的学习笔记,并研究了原版TensorFlow的实现。作者使用的技术细节很多,还有许多地方没有cover到。下一步打算动手实现一个pytorch版本,增进理解。
参考资料
[1] Capsule Networks (CapsNets) – Tutorial https://www.youtube.com/watch?v=pPN8d0E3900&t=425s
[2] 胶囊 (向量神经) 网络 https://mp.weixin.qq.com/s/Gjf_0y8waZ6Xx7r0Qvf4Ow
[3] Set Transformer系列:从集合数据处理网络到集合数据生成模型 https://zhuanlan.zhihu.com/p/264788321
[4] J. Lee, Y. Lee, J. Kim, A. R. Kosiorek, S. Choi, and Y. W. Teh, “Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks,” arXiv:1810.00825 [cs, stat], May 2019, Accessed: Jun. 26, 2020. [Online]. Available: http://arxiv.org/abs/1810.00825.
[5] 详解Transformer (Attention Is All You Need) https://zhuanlan.zhihu.com/p/48508221
[6] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need [C]//Advances in Neural Information Processing Systems. 2017: 5998-6008.
[7] Transformers Assemble(PART V) https://zhuanlan.zhihu.com/p/112477169
[8] The Illustrated Transformer http://jalammar.github.io/illustrated-transformer/
[9] Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. In Advances in Neural Information Processing Systems (NeurIPS), 2017.
[10] S. Sabour, N. Frosst, and G. E. Hinton, “Dynamic Routing Between Capsules,” p. 11.
[11] 原作者代码https://github.com/google-research/google-research/tree/master/stacked_capsule_autoencoders
[12] Pytorch复现 https://github.com/phanideepgampa/stacked-capsule-networks
[13] https://github.com/abseil/abseil-py/blob/master/smoke_tests/sample_app.py
[14] https://abseil.io/docs/python/quickstart
[15] akosiorek的githubhttps://github.com/akosiorek
[16] https://www.cnblogs.com/huangshiyu13/p/6721805.html
[17] https://towardsdatascience.com/debugging-in-tensorflow-392b193d0b8
简明《Stacked Capsule Autoencoders》相关推荐
- Stacked Capsule Autoencoders
引言 a) 物体由一系列几何组织部件所组成.本文引入了一个无监督的胶囊自动编码器(SCAE),意在使用部件之间的几何关系去推理对象: b) 由于部件间的关系,不依赖于视角的变动,所以该模型对视点的变化 ...
- 【文献阅读】Stacked What-Where Auto-encoders -ICLR-2016
一.Abstract 提出一种新的autoencoder -- SWWAE(stacked what-where auto-encoders),更准确的说是一种 convolutional autoe ...
- Stacked Denoising Autoencoders (SDAE)
教程地址:http://www.deeplearning.net/tutorial/SdA.html The Stacked Denoising Autoencoder (SdA) is an ext ...
- 图灵奖得主Geoffrey Hinton:脱缰的无监督学习,将带来什么
与6位图灵奖得主和100多位专家 共同探讨人工智能的下一个十年 北京智源大会倒计时:4天 在即将举行的第二届北京智源大会上(官网:https://2020.baai.ac.cn),图灵奖获得者Geof ...
- 一文全览,AAAI 2020上的知识图谱
2020-02-15 05:34:40 作者 | 杨晓凡 责编 | 贾伟 AI 科技评论按:2020 年 2 月 9 日,AAAI 2020 的主会议厅讲台上迎来了三位重量级嘉宾,这三位也是我们熟悉. ...
- 三巨头共聚AAA:ICapsule没有错,LeCun看好自监督,Bengio谈注意力
2020-02-11 15:38:14 机器之心报道 参与:思源.Jamin 深度学习三巨头在 AAAI 讲了什么?2019 版 Capsule 这条路走得对:自监督学习是未来:注意力机制是逻辑推理与 ...
- Hinton向AAAI提交论文竟收到最差评价!深度学习三教父再押宝,AI或突破常识瓶颈...
点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自新智元. 新智元报道 来源:zdnet等 编辑:张佳.程旸.鹏飞 [新智元导读]日前,图灵奖获得者.深度学习三巨头Geoffrey H ...
- 【今日CV 计算机视觉论文速览 第132期】Tue, 18 Jun 2019
今日CS.CV 计算机视觉论文速览 Tue, 18 Jun 2019 Totally 64 papers ?上期速览✈更多精彩请移步主页 Interesting: ?****MMDetection, ...
- Hinton AAAI2020 演讲——胶囊网络
2020 年 2 月 9 日,AAAI 2020 的主会议厅讲台上迎来了三位重量级嘉宾,这三位也是我们熟悉.拥戴的深度学习时代的开拓者:Geoffrey Hinton,Yann LeCun,Yoshu ...
最新文章
- 干货丨计算机视觉必读:图像分类、定位、检测,语义分割和实例分割方法梳理(经典长文,值得收藏)
- 2018, 自动驾驶异常艰难的一年
- 10年嵌入式工程师经验之谈:对于研发工作的感悟
- php5.3教程,php5.3.3配置教程
- 每日一题/007/级数/设a_n=1-1/2+1/3- ... + (-1)^(n-1)*1/n,求 lim_{n\to\infty}a_n
- python 公司教程_最全Python快速入门教程,满满都是干货
- Session存放token/获取token,销毁session
- matlab 指数函数拟合,[转载]MATLAB数据拟合例子(一次函数、指数函数、双曲线)...
- MyBatis-18MyBatis代码生成器-Example讲解
- VM12 虚拟机使用桥接模式却连不上网的解决办法
- python中怎么统计英文字符的个数_【Python练习1】统计一串字符中英文字母、空格、数字和其他字符的个数...
- python web py入门(12)- 实现用户登录论坛
- HBase-14.1-JMX监控实战-hadoop
- 车企围攻整车OS,这张“新王牌”怎么打?
- 线性回归系数的几个性质
- 服务器操作系统windows2016,微软正式发布服务器操作系统系统Windows Server 2016
- Python抓图必学的8种方式!
- linux redis查看密码,Redis集群设置密码和查看密码方法
- JavaScript 伪数组和数组
- 《Adobe Photoshop CS6中文版经典教程(彩色版)》—第2课2.12节保存用于四色印刷的图像...
热门文章
- 第三章 基本数据类型
- MATLAB 存储读入:dat文件、bin文件以及mat文件
- 广西北海中学2021年高考成绩查询,北海中学排名前十名,2021年北海中学排名一览表...
- c语言作业朱鸣华,2c语言程序设计教程 上机实验答案 朱鸣华 刘旭麟 杨微 著 机械工业出版社.pdf...
- SAP中物料成本核算单位不能小于价格单位
- 国外别墅后期PS教程洛阳生
- 智慧地产.美的·云与湖科技潮玩体育公园
- vue.js-脚手架
- android7.1 修改TTS文字转语音选项的首选引擎默认项
- MSN三叉戟的要害程度都强于C罗