《Spectral Normalization for Generative Adversarial Networks》【1】是Takeru Miyato在2018年2月发表的一篇将谱理论应用于Gan上的文章,在2017年,本文的第3作者Yuichi Yoshida就发表了一篇著名的谱范数正则(Spectral Norm Regularization)的文章【2】,如有兴趣也可参看我的上一篇Blog:https://blog.csdn.net/StreamRock/article/details/83539937
【1】、【2】两篇文章从不同的角度讨论了:参数矩阵的谱范数对多层神经网络的泛化的影响,并分别给出了两个不同的应对方法:前者对Discriminator矩阵参数进行归一化处理,后者可以加入任意多层网络(在更新梯度时加入了谱范数正则项)。本文将在【1】的阅读理解基础上,探讨其实现的方法。

一、Gan的Lipschitz稳定性约束

Gan好是好,但训练难,主要体现在:1)模式坍塌,即最后生成的对象就只有少数几个模式;2)不收敛,在训练过程中,Discriminator很早就进入了理想状态,总能perfectly分辨出真假,因此无法给Generator提供梯度信息,而导致训练无法进行下去。Martin Arjovsky在《Towards principled methods for training generative adversarial networks》【4】、《Wasserstein GAN》【5】文章中,对Gan难训练的原因做了详细的讨论,并给出一种新的Loss定义,即Wasserstein Distance:
W ( P r , P g ) = inf ⁡ γ ∈ ∏ ( P r , P g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] ( 1 ) W(P_r,P_g)=\inf_{\gamma\in\prod(P_r,P_g)}E_{(x,y)\sim \gamma}[\Vert x-y\Vert]\qquad(1) W(Pr​,Pg​)=γ∈∏(Pr​,Pg​)inf​E(x,y)∼γ​[∥x−y∥](1)
实际Wasserstein Distance的计算是通过它的变形来完成的:
W ( P r , P g ) = sup ⁡ ∥ f ∥ L i p E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] ( 2 ) W(P_r,P_g)=\sup_{\Vert f \Vert_{Lip}}E_{x∼P_r}[f(x)]−E_{x∼P_g}[f(x)]\qquad(2) W(Pr​,Pg​)=∥f∥Lip​sup​Ex∼Pr​​[f(x)]−Ex∼Pg​​[f(x)](2)
(2)式只要求 f ( ⋅ ) f(\cdot) f(⋅) 满足Lipschitz约束即可,在Gan中,判别器的映射函数可充当(2)式中的 f ( ⋅ ) f(\cdot) f(⋅) ,于是加入此一约束的Gan网络有了一个新的名称:WGan。
引入Wasserstein Distance,将传统Gan转变为WGan是有许多好处的,因为Wasserstein Distance具有如下优点:
1、 W ( P r , P g ) ≥ 0 W(P_r,P_g)\ge0 W(Pr​,Pg​)≥0, 等号在 P r , P g P_r,P_g Pr​,Pg​分布完全重合时成立;
2、 W ( P r , P g ) W(P_r,P_g) W(Pr​,Pg​)是对称的,较常用的 KL Divergence 的不对称,有优势;
3、即使两个分布 P r , P g P_r,P_g Pr​,Pg​ 的支撑不相交,亦可以作为衡量差异的距离,并在满足一定条件下可微,具备了后向传输的能力。
当 WGan 的 Discriminator 采用了这种距离来训练后,可以消除传统Gan训练时出现的收敛问题,使训练过程变得稳定。另外,要实施此策略也很简单,只需在传统Gan的Discriminator的参数矩阵上加上Lipschitz约束即可,其它的几乎不用改。


Lipschitz约束简单而言就是:要求在整个 f ( ⋅ ) f(\cdot) f(⋅) 的定义域内有
∥ f ( x ) − f ( x ′ ) ∥ 2 ∥ x − x ′ ∥ 2 ≤ M ( 3 ) \frac{\Vert f(x)-f(x') \Vert_2}{\Vert x-x' \Vert_2} \le M \qquad(3) ∥x−x′∥2​∥f(x)−f(x′)∥2​​≤M(3)
其中,M是一个常数。满足公式(3)的函数 f ( ⋅ ) f(\cdot) f(⋅),具体表现为:函数变化不会太快,其梯度总是有限的,即使最剧烈时,也被限制在小于等于M的范围。


WGan首先提出Discriminator的参数矩阵需要满足Lipschitz约束,但其方法比较简单粗暴:直接对参数矩阵中元素进行限制,不让其大于某个值。这种方法,是可以保证Lipschitz约束的,但在削顶的同时,也破坏了整个参数矩阵的结构——各参数之间的比例关系。针对这个问题,【1】提出了一个既满足Lipschitz条件,又不用破坏矩阵结构的方法——Spectral Normalization。

二、多层神经网络的分析

为简便分析,可将Discriminator看作是多层网络,因为CNNs可看作是特殊的多层网络。对于多层网络的第n层,其输入与输出关系可以表示为:
x n = a n ( W n x n − 1 + b n ) ( 4 ) \mathbf x_n = a_n(W_n\mathbf x_{n-1}+\mathbf b_n)\qquad(4) xn​=an​(Wn​xn−1​+bn​)(4)
其中, a n ( ⋅ ) a_n(\cdot) an​(⋅) 是该层网络的非线性激活函数,可采用ReLU; W l W_l Wl​ 是网络参数矩阵, b l \mathbf b_l bl​ 是网络的偏置,为推导方便,对 b l \mathbf b_l bl​ 进行省略处理,则(4)式可写为:
x n = D n W n x n − 1 ( 5 ) \mathbf x_n = D_n W_n\mathbf x_{n-1} \qquad(5) xn​=Dn​Wn​xn−1​(5)
其中 D n D_n Dn​ 是对角矩阵,用于表示ReLU的作用,当其对应输入为负数时,对角元素为0;当其对应输入为正数时,对角元素为1。于是,多层神经网络(假设是N层)输入输出关系可以写成:
f ( x ) = D N W N ⋯ D 1 W 1 x ( 6 ) f(\mathbf x)=D_NW_N\cdots D_1W_1 \mathbf x \qquad(6) f(x)=DN​WN​⋯D1​W1​x(6)
Lipschitz约束是对 f ( x ) f(\mathbf x) f(x) 的梯度提出的要求:
∥ ∇ x ( f ( x ) ) ∥ 2 = ∥ D N W N ⋯ D 1 W 1 ∥ 2 ≤ ∥ D N ∥ 2 ∥ W N ∥ 2 ⋯ ∥ D 1 ∥ 2 ∥ W 1 ∥ 2 ( 7 ) \Vert \nabla_x(f(\mathbf x)) \Vert_2 = \Vert D_NW_N\cdots D_1W_1 \Vert_2\le \Vert D_N \Vert_2 \Vert W_N\Vert_2\cdots \Vert D_1\Vert_2 \Vert W_1 \Vert_2 \qquad(7) ∥∇x​(f(x))∥2​=∥DN​WN​⋯D1​W1​∥2​≤∥DN​∥2​∥WN​∥2​⋯∥D1​∥2​∥W1​∥2​(7)
此处 ∥ W ∥ \Vert W \Vert ∥W∥ 表示矩阵W的谱范数,它的定义如下:
σ ( A ) : = max ⁡ ∥ h ∥ ≠ 0 ∥ A h ∥ 2 ∥ h ∥ 2 = max ⁡ ∥ h ∥ = 1 ∥ A ∥ 2 ( 8 ) \sigma(A) :=\max_{\Vert h \Vert\neq0} \frac{\Vert Ah \Vert_2}{\Vert h \Vert_2}=\max_{\Vert h \Vert = 1} \Vert A \Vert_2 \qquad(8) σ(A):=∥h∦​=0max​∥h∥2​∥Ah∥2​​=∥h∥=1max​∥A∥2​(8)
σ ( W ) \sigma(W) σ(W)是矩阵W的最大奇异值,对于对角矩阵D,有 σ ( D ) = max ⁡ ( d 1 , ⋯   , d n ) \sigma(D) =\max(d_1,\cdots,d_n) σ(D)=max(d1​,⋯,dn​),即对角元素上最大的元素。由此,(7)可表示为:
∥ ∇ x ( f ( x ) ) ∥ 2 ≤ ∏ i = 1 N σ ( W i ) ( 9 ) \Vert \nabla_x(f(\mathbf x)) \Vert_2 \le \prod_{i=1}^N \sigma(W_i) \qquad(9) ∥∇x​(f(x))∥2​≤i=1∏N​σ(Wi​)(9)
因为,ReLU所对应的对角矩阵的谱范数最大为1。为使 f ( x ) f(\mathbf x) f(x) 满足Lipschitz约束,可对(7)进行归一化:
∥ ∇ x ( f ( x ) ) ∥ 2 = ∥ D N W N σ ( W N ) ⋯ D 1 W 1 σ ( W 1 ) ∥ 2 ≤ ∏ i = 1 N σ ( W i ) σ ( W i ) = 1 ( 10 ) \Vert \nabla_x(f(\mathbf x)) \Vert_2 = \Vert D_N \frac {W_N}{\sigma(W_N)}\cdots D_1\frac {W_1}{\sigma(W_1)} \Vert_2 \le \prod_{i=1}^N \frac {\sigma(W_i)}{\sigma(W_i)} =1\qquad(10) ∥∇x​(f(x))∥2​=∥DN​σ(WN​)WN​​⋯D1​σ(W1​)W1​​∥2​≤i=1∏N​σ(Wi​)σ(Wi​)​=1(10)
由此可见,只需让每层网络的网络参数除以该层参数矩阵的谱范数即可满足Lipschitz=1的约束,由此诞生了谱归一化(Spectral Normailization)。

三、谱归一化的实现

为获得每层参数矩阵的谱范数,需要求解 W i W_i Wi​ 的奇异值,这将耗费大量的计算资源,因而可采用“幂迭代法”来近似求取,其迭代过程如下:
1 、 v l 0 ← a random Gaussian vector 2 、 loop k : u l k ← W l v l k − 1 , normalization:  u l k ← u l k ∥ u l k ∥ , v l k ← ( W l ) T u l k , normalization:  v l k ← v l k ∥ v l k ∥ , end loop 3 、 σ l ( W ) = ( u l k ) T W v l k 1、v_l^{0} \leftarrow \text{ a random Gaussian vector} \\ 2、\text{loop k :} \\ u_l^{k}\leftarrow W_lv_l^{k-1}, \text{ normalization: } u_l^{k}\leftarrow \frac{u_l^{k}}{\Vert u_l^{k} \Vert},\\ v_l^k\leftarrow (W_l)^Tu_l^k , \text{ normalization: } v_l^{k}\leftarrow \frac{v_l^{k}}{\Vert v_l^{k} \Vert},\\ \text{end loop} \\ 3、\sigma_l(W)= (u_l^k)^T W v_l^k 1、vl0​← a random Gaussian vector2、loop k :ulk​←Wl​vlk−1​, normalization: ulk​←∥ulk​∥ulk​​,vlk​←(Wl​)Tulk​, normalization: vlk​←∥vlk​∥vlk​​,end loop3、σl​(W)=(ulk​)TWvlk​
求得谱范数后,每个参数矩阵上的参数皆除以它,以达到归一化目的。其实,上述算法在迭代了足够次数后, u k \mathbf u^k uk就是该矩阵( W W W)的最大奇异值对应的特征矢量,有:
W W T u = σ ( W ) ⋅ u ⇒ u T W W T u = 1 ⋅ σ ( W ) , as  ∥ u ∥ = 1 σ ( W ) = u T W v , as  v = W T u WW^T \mathbf u=\sigma(W)\cdot \mathbf u \Rightarrow \mathbf u^TWW^T \mathbf u = 1\cdot \sigma(W), \text{ as } \Vert \mathbf u \Vert=1\\ \sigma(W) = \mathbf u^TW\mathbf v, \text{ as } \mathbf v=W^T \mathbf u WWTu=σ(W)⋅u⇒uTWWTu=1⋅σ(W), as ∥u∥=1σ(W)=uTWv, as v=WTu
谱归一具体的pytorch实现代码可以参考【3】,以下摘抄部分如下:
1、计算谱范数

import torch
import torch.nn.functional as F#define _l2normalization
def _l2normalize(v, eps=1e-12):return v / (torch.norm(v) + eps)def max_singular_value(W, u=None, Ip=1):"""power iteration for weight parameter"""#xp = W.dataif not Ip >= 1:raise ValueError("Power iteration should be a positive integer")if u is None:u = torch.FloatTensor(1, W.size(0)).normal_(0, 1).cuda()_u = ufor _ in range(Ip):_v = _l2normalize(torch.matmul(_u, W.data), eps=1e-12)_u = _l2normalize(torch.matmul(_v, torch.transpose(W.data, 0, 1)), eps=1e-12)sigma = torch.sum(F.linear(_u, torch.transpose(W.data, 0, 1)) * _v)return sigma, _u

2、构造带归一化的层
线性层:

class SNLinear(Linear):def __init__(self, in_features, out_features, bias=True):super(SNLinear, self).__init__(in_features, out_features, bias)self.register_buffer('u', torch.Tensor(1, out_features).normal_())@propertydef W_(self):w_mat = self.weight.view(self.weight.size(0), -1)sigma, _u = max_singular_value(w_mat, self.u)self.u.copy_(_u)return self.weight / sigmadef forward(self, input):return F.linear(input, self.W_, self.bias)

卷积层:

class SNConv2d(conv._ConvNd):def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):kernel_size = _pair(kernel_size)stride = _pair(stride)padding = _pair(padding)dilation = _pair(dilation)super(SNConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,False, _pair(0), groups, bias)self.register_buffer('u', torch.Tensor(1, out_channels).normal_())@propertydef W_(self):w_mat = self.weight.view(self.weight.size(0), -1)sigma, _u = max_singular_value(w_mat, self.u)self.u.copy_(_u)return self.weight / sigmadef forward(self, input):return F.conv2d(input, self.W_, self.bias, self.stride,self.padding, self.dilation, self.groups)

由这两个层的构造可看到:谱范数的计算和应用谱范数的归一化层。这些层可以加到Discriminator中,如下:

class ResBlock(nn.Module):def __init__(self, in_channels, out_channels, hidden_channels=None, use_BN = False, downsample=False):super(ResBlock, self).__init__()#self.conv1 = SNConv2d(n_dim, n_out, kernel_size=3, stride=2)hidden_channels = in_channelsself.downsample = downsampleself.resblock = self.make_res_block(in_channels, out_channels, hidden_channels, use_BN, downsample)self.residual_connect = self.make_residual_connect(in_channels, out_channels)def make_res_block(self, in_channels, out_channels, hidden_channels, use_BN, downsample):model = []if use_BN:model += [nn.BatchNorm2d(in_channels)]model += [nn.ReLU()]model += [SNConv2d(in_channels, hidden_channels, kernel_size=3, padding=1)]model += [nn.ReLU()]model += [SNConv2d(hidden_channels, out_channels, kernel_size=3, padding=1)]if downsample:model += [nn.AvgPool2d(2)]return nn.Sequential(*model)def make_residual_connect(self, in_channels, out_channels):model = []model += [SNConv2d(in_channels, out_channels, kernel_size=1, padding=0)]if self.downsample:model += [nn.AvgPool2d(2)]return nn.Sequential(*model)else:return nn.Sequential(*model)def forward(self, input):return self.resblock(input) + self.residual_connect(input)class OptimizedBlock(nn.Module):def __init__(self, in_channels, out_channels):super(OptimizedBlock, self).__init__()self.res_block = self.make_res_block(in_channels, out_channels)self.residual_connect = self.make_residual_connect(in_channels, out_channels)def make_res_block(self, in_channels, out_channels):model = []model += [SNConv2d(in_channels, out_channels, kernel_size=3, padding=1)]model += [nn.ReLU()]model += [SNConv2d(out_channels, out_channels, kernel_size=3, padding=1)]model += [nn.AvgPool2d(2)]return nn.Sequential(*model)def make_residual_connect(self, in_channels, out_channels):model = []model += [SNConv2d(in_channels, out_channels, kernel_size=1, padding=0)]model += [nn.AvgPool2d(2)]return nn.Sequential(*model)def forward(self, input):return self.res_block(input) + self.residual_connect(input)class SNResDiscriminator(nn.Module):def __init__(self, ndf=64, ndlayers=4):super(SNResDiscriminator, self).__init__()self.res_d = self.make_model(ndf, ndlayers)self.fc = nn.Sequential(SNLinear(ndf*16, 1), nn.Sigmoid())def make_model(self, ndf, ndlayers):model = []model += [OptimizedBlock(3, ndf)]tndf = ndffor i in range(ndlayers):model += [ResBlock(tndf, tndf*2, downsample=True)]tndf *= 2model += [nn.ReLU()]return nn.Sequential(*model)def forward(self, input):out = self.res_d(input)out = F.avg_pool2d(out, out.size(3), stride=1)out = out.view(-1, 1024)return self.fc(out)

生成器SNResDiscriminator 用到两个构建模块ResBlock、OptimizedBlock,这两个模块都用SNConv2d层来构建带有谱归一化的卷积层。在SNConv2d实现中,用到@property def W_(self),是我第一次见到的,接下来要好好研究研究。

小结:

Gan要想训练稳定进行,就需要其Discriminator的映射函数满足Lipschitz约束,[1]提出谱范数可作为Lipschitz约束的实施方法,进而给出归一化的实现思路,整个过程十分精巧,值得学习。


[1] Spectral Normalization for Generative Adversarial Networks, Takeru Miyato, 2018.2, (arXiv:1802.05957v1)
[2] Spectral Norm Regularization for Improving the Generalizability of Deep Learning, Yuchi Yoshida, National Institute of Informatics, 2017. 5, (arXiv: 1705.10941v1)
[3] https://github.com/godisboy/SN-GAN
[4] Towards principled methods for training generative adversarial networks
[5] Wasserstein GAN

谱归一化(Spectral Normalization)的理解相关推荐

  1. 详解GAN的谱归一化(Spectral Normalization)

    作者丨尹相楠 学校丨里昂中央理工博士在读 研究方向丨人脸识别.对抗生成网络 本文主要介绍谱归一化这项技术,详细论文参考 Spectral Normalization for Generative Ad ...

  2. 图卷积网络 GCN Graph Convolutional Network(谱域GCN)的理解和详细推导

    文章目录 1. 为什么会出现图卷积神经网络? 2. 图卷积网络的两种理解方式 2.1 vertex domain(spatial domain):顶点域(空间域) 2.2 spectral domai ...

  3. Batch Normalization深入理解

    Batch Normalization深入理解 1. BN的提出背景是什么? 统计学习中的一个很重要的假设就是输入的分布是相对稳定的.如果这个假设不满足,则模型的收敛会很慢,甚至无法收敛.所以,对于一 ...

  4. 深度剖析 | SN 可微分学习的自适配归一化 (Switchable Normalization)

    补充:NIPS 2018 | MIT新研究参透批归一化原理 根据最新的研究,BN层的成功和协方差什么的没有关联!证明这种层输入分布稳定性与 BatchNorm 的成功几乎没有关系.相反,我们发现 Ba ...

  5. 批归一化(Batch Normalization)详解

    批归一化(Batch Normalization)详解 文章目录 批归一化(Batch Normalization)详解 前言 一.数据归一化 二.BN解决的问题:Internal Covariate ...

  6. 谱聚类(Spectral Clustering)详解

    原文地址为: 谱聚类(Spectral Clustering)详解 谱聚类(Spectral Clustering)详解 谱聚类(Spectral Clustering, SC)是一种基于图论的聚类方 ...

  7. 22 谱聚类 Spectral Clustering

    1 Background 本章节主要是描述的一种聚类算法,谱聚类(Spectral Clustering).对机器学习有点了解的同学对聚类算法肯定是很熟悉的,那么谱聚类和之前普通的聚类算法有什么不一样 ...

  8. (转载)深度剖析 | 可微分学习的自适配归一化 (Switchable Normalization)

    深度剖析 | 可微分学习的自适配归一化 (Switchable Normalization) 作者:罗平.任家敏.彭章琳 编写:吴凌云.张瑞茂.邵文琪.王新江 转自:知乎.原论文参考arXiv:180 ...

  9. Spectral normalization及torch实现

    Spectral normalization及torch实现 Spectral normalization含义 torch实现一个带有spectral nomalization的2D卷积 Spectr ...

最新文章

  1. Microbiome:韦中组揭示根际原生动物群落是决定植物健康的关键因素
  2. 如何通过Geth、Node.js和UNIX/PHP访问以太坊节点
  3. bios设置_老富士通bios设置启动项方法是什么 富士通bios设置u盘启动的方法
  4. iis 6.0上部署.net 2.0和4.0网站
  5. css:text-overflow属性
  6. 我的AutoHotkey配置
  7. 线性筛法 欧拉筛c语言,[洛谷P3383][模板]线性筛素数-欧拉筛法
  8. ThreadLocal线程本地存储
  9. 谈谈应届生应聘的一点看法
  10. 【笔记】线性代数的本质
  11. Vue开发实例(01)之环境搭建nodejs与运行第一个Vue项目
  12. 比CMD更强大的命令行:WMIC后渗透利用(系统命令)
  13. Linux 上格式化ssd硬盘方法
  14. python中的异常、模块、文件
  15. SpringMVC大体流程框架类图版
  16. ROG 冰刃 3 枪神 2 Plus 第二时间上手体验
  17. mysql建表 括号_MySQL建表的问题,关于索引
  18. linux安装trac+svn+apache+wike,搭建apache+svn+trac平台
  19. mysql如何查询成绩前5名_sql 语句查询 前5名后5名的成绩
  20. 上海市“专精特新”中小企业认定

热门文章

  1. Linux系统编程之捕捉SIGCHLD
  2. 【Unity3D】AudioSource组件
  3. DXP_protel2004_原理图设计基础_新建和添加原理图库文件
  4. Hello New World 写在 Conflux 网络 Tethys 上线之际
  5. PHP+Redis令牌桶算法 接口限流
  6. 如何在浏览器中支持H265/HEVC
  7. 数学之美读书笔记第一章
  8. 《python密码学编程》笔记
  9. 如何选择优化器 optimizer
  10. java核心之类和对象