GAN中的Spectral Normalization

  Spectral Normalization 出自 《Spectral Norm Regularization for Improving the Generalizability of Deep Learning》和《Spectral Normalization for Generative Adversarial Networks》,是为了解决GAN训练不稳定的问题,从“层参数”的角度用spectral normalization的方式施加regularization,从而使判别器D具备Lipschitz连续条件


为什么要让D具备Lipschitz连续

为了防止判别器“放飞自我”

根据Wikipedia的定义,Lipschitz连续条件 如下。其意义在于使得 f 足够稳定,在输入发生少量变化时,输出不会有太巨大的变化。如果有图像A,修改少量像素得到图像B,输入判别器D后得到相差非常巨大的判别效果,那么判别器就是不稳定的,它对输入过于敏感。
∣∣f(x1)−f(x2)∣∣∣∣x1−x2∣∣≤K,∀x1,x2\frac{||f(x_1)-f(x_2)||}{||x_1-x_2||}\leq K, \forall x_1,x_2∣∣x1​−x2​∣∣∣∣f(x1​)−f(x2​)∣∣​≤K,∀x1​,x2​
这里的 K 被称为 f(x)f(x)f(x) 的 Lipschitz constantK的最小值(上确界)被称为 ∣∣f∣∣Lip||f||_{Lip}∣∣f∣∣Lip​ ,称 f(x)f(x)f(x) 满足 Lipschitz连续条件

《Wasserstein GAN》 给出了衡量真实分布PrP_rPr​和生成分布PgP_gPg​的 Earth-Mover(EM) 距离:

W(Pr,Pg)=inf⁡γ∈(Pr,Pg)E(x,y)∼γ[∣∣x−y∣∣]W(P_r, P_g)=\inf_{\gamma \in (P_r, P_g)}E_{(x,y)\sim \gamma}[||x-y||]W(Pr​,Pg​)=γ∈(Pr​,Pg​)inf​E(x,y)∼γ​[∣∣x−y∣∣]

Kantorovich-Rubinstein duality 可得EM距离的另一个形式,这里有用到 ∣∣f∣∣Lip||f||_{Lip}∣∣f∣∣Lip​:

W(Pr,Pg)=1Ksup⁡∣∣f∣∣Lip≤KEx∼Pr[f(x)]−Ex∼Pg[f(x)]W(P_r, P_g)=\frac{1}{K} \sup_{||f||_{Lip}\leq K}E_{x\sim P_r}[f(x)]-E_{x\sim P_g}[f(x)]W(Pr​,Pg​)=K1​∣∣f∣∣Lip​≤Ksup​Ex∼Pr​​[f(x)]−Ex∼Pg​​[f(x)]

在∣∣f∣∣Lip≤1||f||_{Lip}\leq 1∣∣f∣∣Lip​≤1的条件下,求两个期望之差的上确界,就是EM距离,推导过程可参考这里。


矩阵范数

根据Wikipedia,矩阵的范数如下,它是右边的分式的上确界。

∣∣A∣∣p=sup⁡x≠0∣∣Ax∣∣p∣∣x∣∣p||A||_p=\sup_{x\neq 0}\frac{||Ax||_p}{||x||_p}∣∣A∣∣p​=x̸​=0sup​∣∣x∣∣p​∣∣Ax∣∣p​​

其中,p=2p=2p=2时被称为 Euclidean NormL2-NormSpectral Norm,它是矩阵的最大奇异值(或者最大特征值的开方),下面 ATA^TAT 表示共轭转置。

σ(A)=∣∣A∣∣2=λmax(ATA)\sigma (A)=||A||_2=\sqrt{\lambda_{max}(A^TA)}σ(A)=∣∣A∣∣2​=λmax​(ATA)​

矩阵范数的性质:

∣∣AB∣∣p≤∣∣A∣∣p∣∣B∣∣p||AB||_p\leq ||A||_p||B||_p∣∣AB∣∣p​≤∣∣A∣∣p​∣∣B∣∣p​
σ(AB)≤σ(A)σ(B)\sigma (AB)\leq \sigma (A)\sigma (B)σ(AB)≤σ(A)σ(B)


Spectral Normalization

考虑采用了非线性激活函数的MLP网络构成的判别器D,xl=al(Wlxl−1+bl)x^l=a^l(W^lx^{l-1}+b^l)xl=al(Wlxl−1+bl),把它的参数写成 θ={Wl,bl}l=1L\theta=\{W^l, b^l\}_{l=1}^{L}θ={Wl,bl}l=1L​,使的判别器可以写成 fθ(x0)=xLf_{\theta}(x^0)=x^Lfθ​(x0)=xL。只考虑xxx的一个很小的邻域,可以将判别器看作一个线性函数, fθ(x)=Wθ,xx+bθ,xf_{\theta}(x)=W_{\theta,x}x+b_{\theta,x}fθ​(x)=Wθ,x​x+bθ,x​

∣∣fθ(x+δ)−fθ(x)∣∣2∣∣δ∣∣2=∣∣Wθ,xδ∣∣2∣∣δ∣∣2≤σ(Wθ,x)=sup⁡δ≠0∣∣Wθ,xδ∣∣2∣∣δ∣∣2\frac{||f_{\theta}(x+\delta)-f_{\theta}(x)||_2}{||\delta||_2}=\frac{||W_{\theta,x}\delta||_2}{||\delta||_2}\leq \sigma(W_{\theta,x})=\sup_{\delta\neq 0}\frac{||W_{\theta,x}\delta||_2}{||\delta||_2}∣∣δ∣∣2​∣∣fθ​(x+δ)−fθ​(x)∣∣2​​=∣∣δ∣∣2​∣∣Wθ,x​δ∣∣2​​≤σ(Wθ,x​)=δ̸​=0sup​∣∣δ∣∣2​∣∣Wθ,x​δ∣∣2​​

若激活函数使用 ReLU,可以把 al(xl−1)a^l(x^{l-1})al(xl−1) 看作 Dθ,xlxD^l_{\theta, x}xDθ,xl​x,对角阵 Dθ,xlD^l_{\theta, x}Dθ,xl​ 在 xl−1x^{l-1}xl−1 非负的对应位置上为1,其他地方是0,这样 σ(Dθ,xl)=∣∣Dθ,xl∣∣2≤1\sigma(D^l_{\theta, x})=||D^l_{\theta, x}||_2\leq 1σ(Dθ,xl​)=∣∣Dθ,xl​∣∣2​≤1。

Wθ,x=Dθ,xLWθ,xL⋯Dθ,x1W1W_{\theta,x}=D^L_{\theta,x}W^{L}_{\theta,x}\cdots D^1_{\theta,x}W^1Wθ,x​=Dθ,xL​Wθ,xL​⋯Dθ,x1​W1

σ(Wθ,x)≤σ(Dθ,xL)σ(WL)⋯σ(Dθ,x1)σ(W1)≤∏l=1Lσ(Wl)\sigma(W_{\theta,x})\leq \sigma(D^L_{\theta,x})\sigma(W^{L})\cdots\sigma(D^1_{\theta,x})\sigma(W^1)\leq\prod_{l=1}^L\sigma(W^l)σ(Wθ,x​)≤σ(Dθ,xL​)σ(WL)⋯σ(Dθ,x1​)σ(W1)≤l=1∏L​σ(Wl)

若对判别器D的每一层都做 Spectral Normalization:

W^SNl=Wlσ(Wl)\hat W^l_{SN}=\frac{W^l}{\sigma(W^l)}W^SNl​=σ(Wl)Wl​

σ(W^SNl)=σ(W)σ(W)=1\sigma(\hat W^l_{SN})=\frac{\sigma(W)}{\sigma(W)}=1σ(W^SNl​)=σ(W)σ(W)​=1

∣∣fθ(x+δ)−fθ(x)∣∣2∣∣δ∣∣2≤σ(Wθ,x)≤1\frac{||f_{\theta}(x+\delta)-f_{\theta}(x)||_2}{||\delta||_2}\leq\sigma(W_{\theta,x})\leq1∣∣δ∣∣2​∣∣fθ​(x+δ)−fθ​(x)∣∣2​​≤σ(Wθ,x​)≤1


Spectral Normalization 实现

Spectral Normalization实际上在做的事情,是将每层的参数矩阵除以自身的最大奇异值,本质上是一个逐层SVD的过程,但是真的去做SVD就太耗时了,所以采用幂迭代的方法求(参见Wikipedia)。

def spectral_norm(w, iteration=10, name="sn"):'''Ref: https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/65218e8cc6916d24b49504c337981548685e1be1/spectral_norm.py'''w_shape = w.shape.as_list() # [KH, KW, Cin, Cout] or [H, W]w = tf.reshape(w, [-1, w_shape[-1]]) # [KH*KW*Cin, Cout] or [H, W]u = tf.get_variable(name+"_u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)s = tf.get_variable(name+"_sigma", [1, ], initializer=tf.random_normal_initializer(), trainable=False)u_hat = u # [1, Cout] or [1, W]v_hat = None for _ in range(iteration):v_hat = tf.nn.l2_normalize(tf.matmul(u_hat, tf.transpose(w))) # [1, KH*KW*Cin] or [1, H]u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w)) # [1, Cout] or [1, W]u_hat = tf.stop_gradient(u_hat)v_hat = tf.stop_gradient(v_hat)sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) # [1,1]sigma = tf.reshape(sigma, (1,))with tf.control_dependencies([u.assign(u_hat), s.assign(sigma)]):# ops here run after u.assign(u_hat)w_norm = w / sigma w_norm = tf.reshape(w_norm, w_shape)return w_norm

若输入矩阵是全连接层的参数,尺寸为[H,W][H,W][H,W],则spectral_norm在效果上会直接对该二维矩阵求最大奇异值,但如果输入矩阵为卷积层的卷积核,其尺寸应该是[KH,KW,Cin,Cout][K_H, K_W, C_{in}, C_{out}][KH​,KW​,Cin​,Cout​],spectral_norm会先将该矩阵reshape成一个大小为[KHKWCin,Cout][K_HK_W C_{in}, C_{out}][KH​KW​Cin​,Cout​]的矩阵,再用迭代法对该二维矩阵求最大奇异值。不论哪种情况,求得的奇异值都是一个单值scalar


带SN的卷积层

def conv2d(x, channel, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name='conv2d'):with tf.variable_scope(name):w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], channel], initializer=tf.truncated_normal_initializer(stddev=stddev))w_sn = spectral_norm(w, iteration=3)conv = tf.nn.conv2d(x, filter=w_sn, strides=[1, d_h, d_w, 1], padding='VALID')biases = tf.get_variable('biases', [channel], initializer=tf.constant_initializer(0.0))conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())return conv

带SN的FC层

def dense(x, output_size, stddev=0.02, bias_start=0.0, activation=None, sn=False, reuse=False, name='dense'):shape = x.get_shape().as_list()with tf.variable_scope(name, reuse=reuse):W = tf.get_variable('weights', [shape[1], output_size], tf.float32, tf.random_normal_initializer(stddev=stddev))bias = tf.get_variable('biases', [output_size], initializer=tf.constant_initializer(bias_start))if sn:W = spectral_norm(W, 20, name="sn")out = tf.matmul(x, W) + bias if activation is not None:out = activation(out)return out

检验SpectralNorm

  用dense层检验,首先将W的初始值改为tf.ones_initializer()变成全1矩阵,然后用大小(3,3)的全1输入,在sn=Truesn=False的条件下分别输出结果。

x = tf.ones(shape=(3,3), dtype=tf.float32)
y1 = dense_ones(x, output_size=4, sn=False)
y2 = dense_ones(x, output_size=4, sn=True)
with tf.Session() as sess:sess.run(tf.global_variables_initializer())print sess.run(y1)print sess.run(y2)

  dense层的W是一个大小为(3,4)的全1矩阵,用matlabnp.linalg.svd可以算出它的最大奇异值为3.464,而3.0/3.464=0.8660,从下面的结果可见,在sn=True时,FC层输出的结果等于在sn=False时的结果除以FC层权值矩阵W的最大奇异值,该结果本质上是由于权值矩阵W自己除以最大奇异值导致的。
  因此,Spectral Norm的实现是正确的。

# y1
[[3. 3. 3. 3.][3. 3. 3. 3.][3. 3. 3. 3.]]
# y2
[[0.86602545 0.86602545 0.86602545 0.86602545][0.86602545 0.86602545 0.86602545 0.86602545][0.86602545 0.86602545 0.86602545 0.86602545]]

如何查看sigma的值

  有时候希望查看迭代法求得的最大奇异值是否正确,或者用它作为某种指示指标(比如用作GAN里判别器是否放飞自我的标准),希望以Tensor的形式获取它。只需在spectral_norm中专门声明一个variable用于存储它的值,在迭代法计算完之后把sigma赋给它即可。

def spectral_norm(...):...s = tf.get_variable("sn_sigma", [1, 1], initializer=tf.random_normal_initializer(), trainable=False)...with tf.control_dependencies([u.assign(u_hat), s.assign(sigma)]):...

  为了用sess.run获取它实时的值,需要先得到这个Tensor,一种简单暴力的方法是首先用tf.global_variables()查看所有variable的名字,找到该Tensor的名字,用tf.get_collection获取。要注意Tensor的名字往往前面带有很多variable_scope,就像文件路径一样,需要完整地传入tf.get_collection

with tf.Session() as sess:sess.run(tf.global_variables_initializer())v = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, '你的前缀/sn_sigma:0')[0]sigma = sess.run(v, feed_dict={...})print sigma[0, 0]

control_dependencies的作用

  在我找到的所有Spectral Normalization的实现代码中,都有control_dependencies存在,其作用在于每次调用spectral_norm时用上一次调用中进行迭代得到的u来初始化本次调用的u,暂时不明白其必要性,如果为了更加精确而多次调用输入同一个权值矩阵,倒不如将iteration设的大一点(20以上就已经非常精确了),如果是不同的输入,则这样初始化没有意义,并不会变得更精确。
  这里演示一下control_dependencies、assign和identity的配合使用。
  在下面的代码中,control_dependencies保证了在每次out=tf.identity(u)前都会先将u_var更新为u的值,只有在最开始u_var的初始值为1.0。

u_var = tf.Variable(1.0)
u = u_var u = u * 3with tf.control_dependencies([tf.assign(u_var, u, name='update_u')]):out = tf.identity(u)with tf.Session() as sess:sess.run(tf.global_variables_initializer())for _ in range(3):print out.eval()
'''
3.0
9.0
27.0
'''

关于迭代次数

  迭代法计算最大奇异值,迭代次数达到25即可非常精确,但为了节省时间,一般不需要这么多。


Reference

Spectral Norm Regularization for Improving the Generalizability of Deep Learning
Spectral Normalization for Generative Adversarial Networks
BatchNormalization LSTMCell Tensorflow
spectral normalization Tensorflow
compare_gan spectral normalization Tensorflow
Wikipedia Matrix norm

GAN中的Spectral Normalization相关推荐

  1. Self-Attention GAN 中的 self-attention 机制

    作者丨尹相楠 学校丨里昂中央理工博士在读 研究方向丨人脸识别.对抗生成网络 Self Attention GAN 用到了很多新的技术.最大的亮点当然是 self-attention 机制,该机制是 N ...

  2. 详解GAN的谱归一化(Spectral Normalization)

    作者丨尹相楠 学校丨里昂中央理工博士在读 研究方向丨人脸识别.对抗生成网络 本文主要介绍谱归一化这项技术,详细论文参考 Spectral Normalization for Generative Ad ...

  3. 谱归一化(Spectral Normalization)的理解

    <Spectral Normalization for Generative Adversarial Networks>[1]是Takeru Miyato在2018年2月发表的一篇将谱理论 ...

  4. Spectral normalization及torch实现

    Spectral normalization及torch实现 Spectral normalization含义 torch实现一个带有spectral nomalization的2D卷积 Spectr ...

  5. 【GAN优化】详解GAN中的一致优化问题

    GAN的训练是一个很难解决的问题,上期其实只介绍了一些基本的动力学概念以及与GAN的结合,并没有进行过多的深入.动力学是一门比较成熟的学科,有很多非常有用的结论,我们将尝试将其用在GAN上,来得到一些 ...

  6. 国内外对于GaN中Fe相关点缺陷结构的局域特性的研究进展

    国内外研究GaN中Fe相关点缺陷结构的局域特性的研究情况相对较少,但近年来发表的文献表明,通过改变Fe掺杂量,可以改变GaN中Fe相关点缺陷结构的局域特性.此外,还有一些研究表明,Fe可以提高GaN材 ...

  7. 对抗生成网络GAN系列——Spectral Normalization原理详解及源码解析

  8. 深度学习中眼花缭乱的Normalization学习总结

    点击下方标题,迅速定位到你感兴趣的内容 前言 相关知识 Batch Normalization(BN) Layer Normalization(LN) Weight Normalization(WN) ...

  9. 深度神经网络中的Batch Normalization介绍及实现

    之前在经典网络DenseNet介绍_fengbingchun的博客-CSDN博客_densenet中介绍DenseNet时,网络中会有BN层,即Batch Normalization,在每个Dense ...

最新文章

  1. day18-Map和Collection应用
  2. 皮一皮:叫车就要叫这样的,霸气...
  3. hdu 5019 第k大公约数
  4. Python量化(八)下影线选股法
  5. free text search - enterprise search
  6. $(“#addLowForm“).serialize()同时提交其它参数的写法
  7. cad常用字体包_水利设计CAD基础篇(一)
  8. mysql时长用什么类型_MySQL 日期时间类型怎么选?千万不要乱用!
  9. 高并发下如何保证数据的一致性
  10. matlab_多目标遗传算法
  11. 云初起微方案中下单人、联系人、下载者三者之间是什么关系?
  12. html保存快捷,保存文件的快捷键 PS里面的保存快捷键是哪个?
  13. 《黑马程序员》C++基础入门(一)
  14. word交叉引用后,移动文章结构,修改引用顺序到符合引用先后
  15. WiFI Display介绍
  16. Transformer课程 第8课NER案例代码笔记-部署简介
  17. POSIX Timer
  18. 西北乱跑娃 --- requests爬虫五大反反爬机制
  19. 验证码过期(小功能)
  20. 在WIN2003安装TD8,问题汇总

热门文章

  1. 云虚拟机和普通虚拟机有什么区别
  2. 大多数人不知道淘宝天猫有内部优惠卷,能省钱的公众号,购物省钱妙招
  3. 数据分析离不开商业分析
  4. HDU 1847 Good Luck in CET-4 Everybody! 尼姆博弈
  5. Mixamo使用笔记
  6. Linux——进程管理(crontab实例傻瓜教程)
  7. php拼车网源码,PHP拼车网源码 微信拼车源码 手机拼车源码 PC+微信双终端
  8. 牛逼!这届WWDC依旧展现了那个让你无法复制的苹果!
  9. 2021-02-23
  10. 克服弱点,愈发完美-自我篇——《人性的弱点》读后感