变分编码器(Variational AutoEncoder)是自动编码器的升级版本, 其结构跟自动编码器是类似的, 也由编码器和解码器构成。

回忆一下, 自动编码器有个问题, 就是并不能任意生成图片, 因为我们没有办法自己去构造隐藏向量, 需要通过一张图片输入编码我们才知道得到的隐含向量是什么, 这时我们就可以通过变分自动编码器来解决这个问题。

其实原理特别简单, 只需要在编码过程给它增加一些限制, 迫使其生成的隐含向量能够粗略的遵循一个标准正态分布, 这就是其与一般的自动编码器最大的不同。这样我们生成一张新图片就很简单了, 我们只需要给它一个标准正态分布的随机隐含向量, 这样通过解码器就能够生成我们想要的图片, 而不需要给它一张原始图片先编码。

一般来讲, 我们通过 encoder 得到的隐含向量并不是一个标准的正态分布, 为了衡量两种分布的相似程度, 我们使用 KL divergence, 这是用来衡量两种分布相似程度的统计量,它越小,表示两种概率分布越接近。

在实际情况中,需要在模型的准确率和encoder得到的隐含向量服从标准正态分布之间做一个权衡,所谓模型的准确率就是指解码器生成的图片与原始图片的相似程度。可以让神经网络自己做这个决定,只需要将两者都做一个loss,然后求和作为总的loss,这样网络就能够自己选择如何做才能使这个总的loss下降。

为了避免计算 KL divergence 中的积分, 我们使用重参数的技巧, 不是每次产生一个隐含向量, 而是生成两个向量, 一个表示均值, 一个表示标准差, 这里我们默认编码之后的隐含向量服从一个正态分布的之后, 就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布, 最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布, 也就是希望均值为 0, 方差为 1

所以标准的变分自动编码器VAE如下

import os
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import save_image
from visdom import Visdomclass VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(784, 400)self.fc21 = nn.Linear(400, 20) # mean 均值self.fc22 = nn.Linear(400, 20) # var  标准差self.fc3 = nn.Linear(20, 400)self.fc4 = nn.Linear(400, 784)def encode(self, x):x = self.fc1(x)h1 = F.relu(x)mean = self.fc21(h1)var = self.fc22(h1)return mean, var#重参数化def reparametrize(self, mean, logvar):std = logvar.mul(0.5).exp_()normal = torch.FloatTensor(std.size()).normal_() #生成标准正态分布if torch.cuda.is_available():normal = torch.tensor(normal.cuda())else:normal = torch.tensor(normal)return normal.mul(std).add_(mean)  #标准正态分布乘上标准差再加上均值#这里返回的结果就是我们encoder得到的编码,也就是我们decoder要decode的编码def decode(self, z):z = self.fc3(z)z = F.relu(z)z = self.fc4(z)z = torch.tanh(z)return zdef forward(self, x):mean, logvar = self.encode(x) # 编码z = self.reparametrize(mean, logvar) # 重新参数化成正态分布return self.decode(z), mean, logvar # 解码, 同时输出均值方差def loss_function(recon_image, image, mean, logvar):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""reconstruction_function = nn.MSELoss(reduction='sum')MSE = reconstruction_function(recon_image, image)# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mean.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn MSE + KLDdef to_img(x):'''定义一个函数将最后的结果转换回图片'''x = 0.5 * (x + 1.)x = x.clamp(0, 1)x = x.view(x.shape[0], 1, 28, 28)return ximg_transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]) # 标准化
])train_set = MNIST(root='dataset/', transform=img_transforms
)
train_data = DataLoader(dataset=train_set, batch_size=128, shuffle=True
)net = VAE() # 实例化网络
if torch.cuda.is_available():net = net.cuda()optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
viz = Visdom()
viz.line([0.], [0.], win='loss', opts=dict(title='loss'))for epoch in range(100):for image, _ in train_data:image = image.view(image.shape[0], -1)image = torch.tensor(image)if torch.cuda.is_available():image = image.cuda()recon_image, mean, logvar = net(image)loss = loss_function(recon_image, image, mean, logvar) / image.shape[0] # 将 loss 平均optimizer.zero_grad()loss.backward()optimizer.step()print('epoch: {}, Loss: {:.4f}'.format(epoch, loss.item()))save = to_img(recon_image.cpu().data)if not os.path.exists('./vae_img'):os.mkdir('./vae_img')save_image(save, './vae_img/image_{}.png'.format(epoch))viz.line([loss.item()], [epoch], win='loss', update='append')

运行100个eopch之后,可以看出来结果比自动编码器清晰一点,本质上VAE就是在encoder的结果添加了高斯噪声,通过训练要使得decoder对噪声有一定的鲁棒性,这样的话我们生成一张图片就没有必须用一张图片先做编码了,可以想象,我们只需要利用训练好的encoder对一张图片编码得到其分布后,符合这个分布的隐含向量理论上都可以通过decoder得到类似这张图片的图片。

KL越小,噪声越大(可以这麽理解,我们强行让z的分布符合正态分布,其和N(0,1)越接近,KL越小,相当于我们添加的噪声越大),所以直觉上来想loss合并后的训练过程:

  • 当 decoder 还没有训练好时(重构误差远大于 KL loss),就会适当降低噪声(KL loss 增加),使得拟合起来容易一些(重构误差开始下降);
  • 反之,如果 decoder 训练得还不错时(重构误差小于 KL loss),这时候噪声就会增加(KL loss 减少),使得拟合更加困难了(重构误差又开始增加),这时候 decoder 就要想办法提高它的生成能力了。

变分自动编码器虽然比一般的自动编码器效果要好, 而且也限制了其输出的编码(code) 的概率分布, 但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss, 这个方式并不好。

在之后生成对抗网络中, 我们会讲一讲这种方式计算 loss 的局限性, 然后会介绍一种新的训练办法, 就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

变分自编码器VAE:原来是这么一回事 | 附开源代码 - 知乎

Pytorch之经典神经网络Generative Model(二) —— VAE (MNIST)相关推荐

  1. Pytorch之经典神经网络CNN(七) —— GoogLeNet(InceptionV1)(Bottleneck)(全局平均池化GAP)(1*1卷积)(多尺度)(flower花卉数据集)

    2014年 Google提出的 是和VGG同年出现的,在ILSVRC(ImageNet) 2014中获得冠军,vgg屈居第二 GoogLeNet也称Inception V1.之所以叫GoogLeNet ...

  2. Pytorch之经典神经网络CNN(三) —— AlexNet(CIFAR-10) (LRN)

    2012年 多伦多大学Hinton提出的 AlexNet AlexNet是第一个large-scale CNN, 从AlexNet之后CNN开始变得火了起来 贡献是提出了用多层最小卷积叠加来替换单个大 ...

  3. 深度学习入门笔记(二十):经典神经网络(LeNet-5、AlexNet和VGGNet)

    欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...

  4. 【深度学习】基于Pytorch进行深度神经网络计算(二)

    [深度学习]基于Pytorch进行深度神经网络计算(二) 文章目录 1 延后初始化 2 Pytorch自定义层2.1 不带参数的层2.2 带参数的层 3 基于Pytorch存取文件 4 torch.n ...

  5. 【深度学习】基于Pytorch的卷积神经网络概念解析和API妙用(二)

    [深度学习]基于Pytorch的卷积神经网络API妙用(二) 文章目录1 Padding和Stride 2 多输入多输出Channel 3 1*1 Conv(笔者在看教程时,理解为降维和升维) 4 池 ...

  6. 经典神经网络论文超详细解读(二)——VGGNet学习笔记(翻译+精读)

    前言 上一篇我们介绍了经典神经网络的开山力作--AlexNet:经典神经网络论文超详细解读(一)--AlexNet学习笔记(翻译+精读) 在文章最后提及了深度对网络结果很重要.今天我们要读的这篇VGG ...

  7. 经典神经网络论文超详细解读(三)——GoogLeNet InceptionV1学习笔记(翻译+精读+代码复现)

    前言 在上一期中介绍了VGG,VGG在2014年ImageNet 中获得了定位任务第1名和分类任务第2名的好成绩,而今天要介绍的就是同年分类任务的第一名--GoogLeNet . 作为2014年Ima ...

  8. 经典神经网络论文超详细解读(八)——ResNeXt学习笔记(翻译+精读+代码复现)

    前言 今天我们一起来学习何恺明大神的又一经典之作: ResNeXt(<Aggregated Residual Transformations for Deep Neural Networks&g ...

  9. SinGAN: Learning a Generative Model from a Single Natural Image

    Abstract(摘要) We introduce SinGAN, an unconditional generative model that can be learned from a singl ...

最新文章

  1. h5是什么 www.php.cn,20分钟看懂html5 看看H5都有啥新特性
  2. jzoj3888-正确答案【字符串hash,dfs】
  3. 图论 —— 生成树 —— 最小树形图
  4. ACM Robot Motion
  5. 在 Windows 7 中安装和使用Windows XP Mode
  6. 【第四组】用例文档+功能说明书+技术说明书:查看导入的图片,工作序号:001,2017/7/11...
  7. cron表达式解析 3秒执行一次
  8. 常用传感器讲解四--水位传感器(water sensor)
  9. 路由器不开机——维修更换MT7621AT CPU
  10. C语言利用指针实现字符串逆序输出
  11. 2018年 第九届 蓝桥杯省赛 C/C++ B 组
  12. 轻量级的无线抓包(microsoft network monitor)
  13. linux 远程端口号,linux远程端口查看
  14. 全世界所有程序员都会犯的错误-蔡学镛
  15. VMware虚拟机安装win10系统
  16. 20162328WJH实验五网络编程与安全实验报告
  17. 怎么在word和python中输入对号
  18. android服务实现播放器,Android实现简单音乐播放器(MediaPlayer)
  19. java如何设置成中文字体,Java程序中文字体配置
  20. python入门教材带视频_Python全套,从入门到进阶。视频,电子书

热门文章

  1. PYRE 人物剧情 和 游戏技巧小结
  2. 解决Pandoc wasn't found.pdflatex not found on PATH
  3. 报错h is not defind
  4. spring boot 项目中遇到的错误(tomcat 400)
  5. 利用计算机窃听,神奇 | 以色列研究人员实现利用计算机风扇噪音窃听
  6. lce损失与软交叉熵损失函数
  7. #深入解读# 机器学习中的指数函数和对数函数的作用
  8. 河南中医药大学计算机科学与技术,我校信息技术学院成功举办“隐结构及其在中医药研究的应用暨计算机科学与技术学科建设”学术论坛...
  9. 软件技巧--手机软件
  10. SDIO驱动(14)card的CIS读取及解析