前言

在训练VAE模型时,当我们使用过于过于强大的decoder时,尤其是自回归式的decoder比如LSTM时,存在一个非常大的问题就是,decoder倾向于不从latent variable z中学习,而是独立地重构数据,这个时候,同时伴随着的就是KL(p(z|x)‖q(z))倾向于0。先不考虑说,从损失函数的角度来说,这种情况不是模型的全局最优解这个问题(可以将其理解为一个局部最优解)。单单从VAE模型的意义上来说,这种情况也是我们不愿意看到的。VAE模型的最重要的一点就是其通过无监督的方法构建数据的编码向量(即隐变量z)的能力。而如果出现posterior collapse的情况,就意味着后验概率退化为与先验概率一致,即N(0,1)。此时encoder的输出近似于一个常数向量,不再能充分利用输入x的信息。decoder则变为了一个普通的language model,尽管它依然很强。

因此,不管从哪个方面来看,解决它都是必须要面临的课题,事实上,从2016年开始就有很多文章提出了不同的解决方案,这里重点介绍一下使用Batch Normalization来解决这个问题的思路。这篇文章全名A Batch Normalized Inference Network Keeps the KL Vanishing Away发表于2020年,还算是一篇比较新的文章。下面我们开始。

方法介绍

Expectation of the KL’s Distribution

首先基于隐变量空间为高维高斯分布的假设,对于一个mini-batch的数据来说,我们可以计算KL divergence的表达式如下:

其中b代表的是mini-batch的样本个数,n代表的隐变量z的维度。同时作者还假设对于每个不同的维度,其都遵循某个特定的分布,各个维度可以不同。

假设我们认为上述的样本均值可以近似等于总体期望,那么我们可以将上述的样本均值用期望来代替,又因为我们有如下基本等式

最终我们可以得到KL divergence的期望表达式如下。

上述不等式是因为e^x-x>=1恒成立。那么这么一来,我们就有关于KL divergence的一个lower bound。这个lower bound只与隐变量的维度n和μi的分布有关。

Normalizing Parameters of the Posterior

接下来,我们要考虑的问题就是如何来构建每个μi的分布,使其保证这个lower bound的值恒为正,也就间接保证了KL divergence不会变为0。这里用到的方法就是Batch Normalization。

我们熟知的Batch Normalization往往用在神经网络模型中,通过控制每个隐藏层的数据的分布使得训练更加平稳。

但是在这里我们使用它来转换μi的分布,将其控制在一个合理的范围内,从而保证lower bound的值为正。具体如下

其中μBi 和 σBi 分别表示通过mini-batch计算的 μi的均值和标准差。γ 和 β分别是scale和shift参数。通过合理地控制这两个参数,我们可以将lower bound近似地转换为如下式子。

下面是完整的算法流程。

在原文中还有涉及到对参数设置的进一步拓展,大家可以参考苏剑林老师的这篇博客

Torch 实现

在苏剑林老师的博客中,他用keras实现了文章中的关键内容,在这里,我用torch实现了一下,供大家参考。

import torch
import torch.nn as nn# reference paper:https://arxiv.org/abs/2004.12585
class BN_Layer(nn.Module):def __init__(self,dim_z,tau,mu=True):super(BN_Layer,self).__init__()self.dim_z=dim_zself.tau=torch.tensor(tau) # tau : float in range (0,1)self.theta=torch.tensor(0.5,requires_grad=True)self.gamma1=torch.sqrt(self.tau+(1-self.tau)*torch.sigmoid(self.theta)) # for muself.gamma2=torch.sqrt((1-self.tau)*torch.sigmoid((-1)*self.theta)) # for varself.bn=nn.BatchNorm1d(dim_z)self.bn.bias.requires_grad=Falseself.bn.weight.requires_grad=Trueif mu:with torch.no_grad():self.bn.weight.fill_(self.gamma1)else:with torch.no_grad():self.bn.weight.fill_(self.gamma2)def forward(self,x): # x:(batch_size,dim_z)x=self.bn(x)return x

参考

A Batch Normalized Inference Network Keeps the KL Vanishing Away
变分自编码器(五):VAE + BN = 更好的VAE

使用Batch Normalization解决VAE训练中的后验坍塌(posterior collapse)问题相关推荐

  1. 神经网络中使用Batch Normalization 解决梯度问题

    BN本质上解决的是反向传播过程中的梯度问题. 详细点说,反向传播时经过该层的梯度是要乘以该层的参数的,即前向有: 那么反向传播时便有: 那么考虑从l层传到k层的情况,有: 上面这个 便是问题所在.因为 ...

  2. 解决:jssip中接通后 PC没有声音但是话机有声音

    由于jssip各类解决文章太少的 赶紧记下来.... audioEle.srcObject=session.connection.getRemoteStreams()[0]; audioEle就是页面 ...

  3. Batch Normalization原理与实战

    作者:天雨粟 链接:https://zhuanlan.zhihu.com/p/34879333 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非商业转载请注明出处. 前言 本期专栏主要来从 ...

  4. 【深度学习理论】(4) 权重初始化,Batch Normalization

    各位同学好,最近学习了CS231N斯坦福计算机视觉公开课,讲的太精彩了,和大家分享一下. 1. 权重初始化 1.1 相同的初始化权重 神经网络中的所有权重都能通过梯度下降和反向传播来优化和更新.现在问 ...

  5. 原理解释|直觉与实现:Batch Normalization

    https://www.toutiao.com/a6707566287964340747/ 作者:Harrison Jansma编译:ronghuaiyang 在本文中,我会回顾一下batch nor ...

  6. NLP、CV经典论文:Batch Normalization 笔记

    NLP.CV经典论文:Batch Normalization 笔记 论文 介绍 优点 缺点 模型结构 文章部分翻译 Abstract 1 Introduction 2 Towards Reducing ...

  7. 【文章阅读】BN(2015)理解Batch Normalization批标准化

    Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift Brief 该 ...

  8. Batch Normalization原理

    batch normalization是指在神经网络中激活函数的前面,对每个神经元根据mini bach中统计的m个x=wu+b进行normalization变换,即: ,这种思想源于一种理论:当变量 ...

  9. Batch Normalization原文详细解读

    这篇博客分为两部分, 一部分是[3]中对于BN(Batch Normalization的缩写)的大致描述 一部分是原文[1]中的完整描述 .####################先说下书籍[3]## ...

最新文章

  1. 面试官:为什么需要 Hystrix?
  2. 当 Python 遇到了你的微信好友
  3. win2008win7设置自动登录系统的方法(三步搞定)
  4. Spring Cloud构建微服务架构:服务容错保护(Hystrix服务降级)
  5. 动态加载DLL(C#)
  6. android 开发工具篇之android studio(二)使用篇
  7. adobe reader java_使用PDF框设置的表单字段值在Adobe Reader中不可见
  8. Spring MVC学习总结(5)——SpringMVC项目关于安全的一些配置与实现方式
  9. GitHub一夜爆火的阿里高并发技术小册究竟有什么魅力?
  10. python源码中的学习笔记_第12章_编码格式与文件操作
  11. linux oracle ojdbc,Maven无法下载Oracle驱动ojdbc的解决方式
  12. WinForm界面控件DevExpress入门指南 - Window Service
  13. 乐高机器人编程自学入门
  14. XGBOOST_航班延误预测
  15. 高斯消元(解线性方程组)
  16. Bluetooth HCI介绍
  17. signature=735f4378ec01919f23285d0d2557be19,OPENSSL编程 第二十章 椭圆曲线
  18. 产品设计体会(0013)产品经理应该是管理者么
  19. java webservice测试_搭建Soap webservice api接口测试案例系统
  20. ACM篇:UVA220黑白棋总结

热门文章

  1. 树莓派4B关于cpu降频的查看与修改
  2. 卡塔尔航空为中国留学生提供返校包机服务
  3. Export_Parent父子项目的搭建
  4. 先行者螺旋式水下地形勘探机器人设计
  5. 非递减有序排列C语言,非递减有序顺序表的排序
  6. html dt dd dl英文,dl dt dd是什么的缩写
  7. [附源码]计算机毕业设计springboot学生宿舍管理系统
  8. 如何办理工作居住证续签
  9. 基于Matlab模拟菲涅尔公式
  10. popwindow显示在控件左方