细说VAE的来龙去脉 (Variational Autoencoder)
目标是啥? 假如有下面这样手写数字图片集,我们想构造一个model可以无限生成类似的(同分布)的手写数字图片.
第一反映是做一个d纬度的向量Z,其中z1代表数字,z2代表粗细,z3代表倾斜度,....
完后把Z,输入一个带有参数的model,取训练集合里面符合Z的图片,作为目标,通过Supervised learning学习model的参数就好了.
But,一个大问题是人为手工的给Z的每个纬度构造含义,非常费时费力,容易漏纬度,而且构造出来的纬度之间很难确保不是相互纠缠的.
怎么办?答案是 ,不要手工人为的构造Z人为的赋予每个纬度的含义,而是固定Z的长度为d,通过网络学习出这个Z里面每个纬度的含义.
也就是说Z的前面会有若干层的网络,那么最开始层的输入是什么?输入z是固定的一个常向量?不行,那么只能输出一张稳定的图片了.输入z是一个d维的随机向量?
也不行,因为我们要输出某一个特定分布的手写数字图片,所以z要服从某一个特定分布,那么z选择什么分布合适,大量实验发现标准正态分布N(0;I)就非常有效,
并且数学上证明了输入服从N(0;I)经过复杂的function或NN后,可以转换成任何一种特定的分布,就比如我们要生成的手写数字图片的分布.
下面数学化描述我们的问题.
z~P(z) is N(0;I)
X 是图片gt数据集的随机变量
想找到一个模型,使得左边的P(X)最大(根据maximize the likelihood of training data);
这个积分即全概率公式的积分形式,
如果z~P(z) 真的从N(0;I)里面取值,会有大量P(X|z)为零,增大了z的取值范围和计算复杂度。
z从Q(z|X)生成,再把生成的z放入到P(X|z)生成X,增大了P(X|z),缩小了z的取值范围,这时的z也更靠谱了。
那么Q(z|X)是否等于原来的P(z|X)呢?显然不完全等于,Q(z|X)偏小,再者P(z|X)服从于正太分布,一个编码X生成z的网络却可以生成各种不同分布,
所以我们要对Q(z|X)进行约束使其服从与正太分布,且尽可能通过增加参数来扩大它的容量。
所以构造 Q(z|X) =
上面的u是正太分布的均值,是一个带参数的固定网络,是均方差,一个带参数的固定网络。
那么这样的Q(z|X)有多接近原来的P(z|X)呢?嗯,这个距离是我们后面调参时需要缩小的一个目标。
代到KL divergence里面如下,
下面做一系列神奇的数学变换,
好了到了(5)式,左边:第一项就是我们最后要最大化的目标项,第二项是附加项,为了使Q(z|X)尽可能的接近P(z|X)
也就是说左边就是我们要最大化的目标了。
那右边自然就是我们要构造的含参数的function了,里面的参数通过用梯度下降法,loop训练数据集来调节使得左边的目标最大,问题即求解了。
来看看右边的第一项是个什么?嗯,它的输入是Q生成的z,输出P(X|z),这显然是一个解码器。
右边的第二项是个什么?是两个正太分布的KL divergence,展开看看。
这时,看一眼展开的结果,也就是(7)的右边,仅仅是带参数的固定网络和u输出的一种组合而已。
初步搭一个网络看看,
要从下往上看,先给X编码,经过带参数的固定网络和u, 此时通过第一个蓝框KL网络来构造(5)的第二项,嗯
(5)的右边第一项呢?网络和u的输出在红框部分做sample操作,sample分布为u和的正太分布,得到z
把z喂给一个解码器网络P,得到f(z), 让它尽可能贴近X,等价于log P(X|z)越大。网络搭建完毕,计算出两个蓝框的loss ,
通过梯度下降使之变小,调节参数,从而得到我们目标的model function。
可是在梯度下降的过程中,这个网络有一个sample节点,也就是那个红色的方框,是不连续的function,loss传回的梯度在这里断掉了?怎么办?
答案,用参数化的网络节点替代掉那个红框。替代后如下,
这里的从一个标准正太分布随机生成,是一个新的输入,的上面就是参数化后的网络,完美的替换掉了sample u和分布的那个操作,
且这里保证了网络的连续性,loss向后传播在这里也不用受阻了。
训练就简单了,把我们数据集里面的图片X依次喂入网络,向前传播计算loss,向后传播,update参数,一直训练下去,直到loss趋于稳定,
即得到我们想要的model 了。
到这里有人要问了,可可,Encoder(Q) ,u(X)和(X),还有Decoder(P)这些网络到底是什么样子的阿...
其实这些网络模块的具体设计要看你生成的X到底是什么类型的数据拉.
如果是手写数字图片,我们用vanila版本的VAE就可以就决问题,用pytorch表示的网络结构大概如下:
self.encoder
Sequential((0): Sequential((0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(1): Sequential((0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(2): Sequential((0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(3): Sequential((0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(4): Sequential((0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))
)self.decoder
Sequential((0): Sequential((0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(1): Sequential((0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(2): Sequential((0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(3): Sequential((0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))
)
注意,这里的encoder只是从图片提取了非常丰富的features,比如
144张图片从torch.Size([144, 3, 64, 64])到torch.Size([144, 512, 2, 2]),
拉平变成torch.Size([144, 2048]) ,
完后就要送给 u(X)和(X)网络,如下:
self.fc_mu
Linear(in_features=2048, out_features=128, bias=True)
self.fc_var
Linear(in_features=2048, out_features=128, bias=True)
输出为
mu.shape = {Size: 2} torch.Size([144, 128]) # mean of the latent Gaussian [B * D]
log_var.shape = {Size: 2} torch.Size([144, 128]) # deviation of the latent Gaussian [B * D]
用参数化模块实现
sample from N(mu, var) from N(0,1)
这个参数化模块什么样子?
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
这里面的关键函数torch.
randn_like解释如下:
pytorch官方解释原文:
Returns a tensor with the same size as
input
that is filled with random numbers from a normal distribution with mean 0 and variance 1.
还有个地方要注意,这个版本的VAE
std = torch.exp(0.5 * logvar)
编码网络没有l直接让它生成std, 生成的是var取log
不管怎样,经过了参数化模块, 输出的服从N(u(X),(X))的z
shape为torch.Size([144, 128]).
把z喂给decoder P模块之前先作如下变形:
result = Linear(in_features=128, out_features=2048, bias=True)
result = result.view(-1, 512, 2, 2)
此时的tensor为 torch.Size([144, 512, 2, 2])
经过decoder生成图片.
Sequential((0): Sequential((0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(1): Sequential((0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(2): Sequential((0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))(3): Sequential((0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01))
)
Sequential((0): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.01)(3): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): Tanh()
)
完后要做的是计算loss,
f(z) 和 X 通过 F.mse_loss 计算loss, 非常简单不用戏说.
最后剩下是如何计算N(u(X),(X))与N(0,I) 的KLd loss呢?
mu.shape = {Size: 2} torch.Size([144, 128]) # mean of the latent Gaussian [B * D]
log_var.shape = {Size: 2} torch.Size([144, 128]) # deviation of the latent Gaussian [B * D]
引用两个高斯分布的KLd loss公式,如下:
(10)
经过了一系列的转换,精简到了(10)公式,取我们encoder网络生成的mu和log_var 代进去,代码如下:
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
完毕!
细说VAE的来龙去脉 (Variational Autoencoder)相关推荐
- 详解变分自编码器VAE(Variational Auto-Encoder)
前言 过去虽然没有细看,但印象里一直觉得变分自编码器(Variational Auto-Encoder,VAE)是个好东西.趁着最近看概率图模型的三分钟热度,我决定也争取把 VAE 搞懂. 于是乎照样 ...
- 变分自动编码器(VAE variational autoencoder)
文章目录 自动编码器 AutoEncoder 变分推断 Variational Inference 变分自动编码器 Variational AutoEncoder 条件变分自动编码器 Conditio ...
- 变分自编码器(VAE,Variational Auto-Encoder)
变分自编码器(Variational auto-encoder,VAE)是一类重要的生成模型(Generative Model) 除了VAEs,还有一类重要的生成模型GANs VAE 跟 GAN 比较 ...
- Variational AutoEncoder(VAE)变分自编码器
[本文转载自博客]:解析Variational AutoEncoder(VAE): https://www.jianshu.com/p/ffd493e10751 文章目录 1. 模型总览 1.1 Au ...
- MATLAB实现自编码器(四)——变分自编码器实现图像生成Train Variational Autoencoder (VAE) to Generate Images
本文是对Train Variational Autoencoder (VAE) to Generate Images网页的翻译,该网页实现了变分自编码的图像生成,以MNIST手写数字为训练数据,生成了 ...
- 从AE(Auto-encoder)到VAE(Variational Auto-Encoder)
一.Auto-Encoder (AE) 1)Auto-encoder概念 自编码器要做的事:将高维的信息通过encoder压缩到一个低维的code内,然后再使用decoder对其进行重建." ...
- 【论文解读 WWW 2019 | MVAE】Multimodal Variational Autoencoder for Fake News Detection
论文题目:MVAE: Multimodal Variational Autoencoder for Fake News Detection 论文来源:WWW 2019 论文链接:https://doi ...
- VARIATIONAL AUTOENCODER FOR SPEECH ENHANCEMENT WITH A NOISE-AWARE ENCODER
文章目录 0. 摘要 1. Introduction 2. Problem Formulation 2.1 Mixture model 2.2 Speech model 2.3 Noise model ...
- 【DL笔记】Tutorial on Variational AutoEncoder——中英文对照(更新中)
更新时间:2018-09-25 Abstract In just three years, Variational Autoencoders (VAEs) have emerged as one of ...
- 【Donut论文】Unsupervised anomaly detection via variational auto-encoder for seasonal kpis...
简述 本文提出的 Donut,基于 VAE(代表性的深层生成模型)的无监督异常检测算法,伴有理论解释,可以无标签或偶尔提供的标签下学习. 本文贡献 1,Donut 里的三项技术:改进的 ELBO,缺失 ...
最新文章
- 1136 A Delayed Palindrome 需再做
- 智能合约重构社会契约 (2)雅阁项目智能合约
- stm32看门狗_「正点原子NANO STM32开发板资料连载」第十一章 独立看门狗实验
- 从家书到小票!看到海尔智家的转型是真的
- 葡萄城 SpreadJS 表格控件 V11 产品白皮书
- 【Python3网络爬虫开发实战】3.3-正则表达式
- android 本地日历,Android日历提供商:如何删除自己的本地日历?
- springboot 前缀_springboot插件式开发框架
- 将ERF格式转换成PCAP格式
- 计算机辅助制造期末试题答案,西工大《计算机辅助制造》期末试题2006-2007A答案.doc...
- linux ipk,openwrt下ipk生成过程及原理
- ffmpeg 命令转vp9
- Python爬取文件的11种方式
- 关于无线传输功率和距离的问题
- 【定时任务】SpringBoot多线程并发动态执行定时任务
- 图论介绍和PyTorch Geometric(PyG)库基础知识
- java左手画圆右手画方_左手画圆右手画方可以同时进行吗?
- Springboot项目架构设计
- macos最新版本是什么_macOS的最新版本是什么?
- 怎么一心多用高效处理工作琐事?用敬业签同时处理多个任务
热门文章
- 挑战性题目DSCT103:客观指标评价问题
- Luogu4116 Qtree3
- 单引号、双引号、倒引号
- c语言 键盘 屏幕,c语言之键盘输入语屏幕输出.pptx
- c语言编程下雪,C语言怎么 实现 下雪效果
- android 获取service 实例化,在Activity中,如何获取service对象?a.可以通过直接实例化得到。b.可以通过绑定得到。c.通过star - 众答网问答...
- java for 变量赋值_Java 如何引用变量赋值?
- WEBPACK+ES6+REACT入门(4/7)-评论列表DEMO以及CSS样式
- echarts无数据时显示无数据_无服务器数据库竞技,哪家云服务落伍了?
- Angr安装与使用之使用篇(十)