变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。
回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编码我们才知道得到的隐含向量是什么,这时我们就可以通过变分自动编码器来解决这个问题。
其实原理特别简单,只需要在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。
这样我们生成一张新图片就很简单了,我们只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成我们想要的图片,而不需要给它一张原始图片先编码。
一般来讲,我们通过 encoder 得到的隐含向量并不是一个标准的正态分布,为了衡量两种分布的相似程度,我们使用 KL divergence,利用其来表示隐含向量与标准正态分布之间差异的 loss,另外一个 loss 仍然使用生成图片与原图片的均方误差来表示。
KL divergence 的公式如下

重参数 为了避免计算 KL divergence 中的积分,我们使用重参数的技巧,不是每次产生一个隐含向量,而是生成两个向量,一个表示均值,一个表示标准差,这里我们默认编码之后的隐含向量服从一个正态分布的之后,就可以用一个标准正态分布先乘上标准差再加上均值来合成这个正态分布,最后 loss 就是希望这个生成的正态分布能够符合一个标准正态分布,也就是希望均值为 0,方差为 1

所以最后我们可以将我们的 loss 定义为下面的函数,由均方误差和 KL divergence 求和得到一个总的 loss

def loss_function(recon_x, x, mu, logvar):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""MSE = reconstruction_function(recon_x, x)# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn MSE + KLD

用 mnist 数据集来简单说明一下变分自动编码器

import osimport torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoaderfrom torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_imageim_tfs = tfs.Compose([tfs.ToTensor(),tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])train_set = MNIST('./mnist', transform=im_tfs)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
class VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(784, 400)self.fc21 = nn.Linear(400, 20) # meanself.fc22 = nn.Linear(400, 20) # varself.fc3 = nn.Linear(20, 400)self.fc4 = nn.Linear(400, 784)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparametrize(self, mu, logvar):std = logvar.mul(0.5).exp_()eps = torch.FloatTensor(std.size()).normal_()if torch.cuda.is_available():eps = Variable(eps.cuda())else:eps = Variable(eps)return eps.mul(std).add_(mu)def decode(self, z):h3 = F.relu(self.fc3(z))return F.tanh(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x) # 编码z = self.reparametrize(mu, logvar) # 重新参数化成正态分布return self.decode(z), mu, logvar # 解码,同时输出均值方差net = VAE() # 实例化网络
if torch.cuda.is_available():net = net.cuda()
x, _ = train_set[0]
x = x.view(x.shape[0], -1)
if torch.cuda.is_available():x = x.cuda()
x = Variable(x)
_, mu, var = net(x)print(mu)Variable containing:  Columns 0 to 9  -0.0307 -0.1439 -0.0435  0.3472  0.0368 -0.0339  0.0274 -0.5608  0.0280  0.2742  Columns 10 to 19  -0.6221 -0.0894 -0.0933  0.4241  0.1611  0.3267  0.5755 -0.0237  0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]

可以看到,对于输入,网络可以输出隐含变量的均值和方差,这里的均值方差还没有训练 下面开始训练

reconstruction_function = nn.MSELoss(size_average=False)def loss_function(recon_x, x, mu, logvar):"""recon_x: generating imagesx: origin imagesmu: latent meanlogvar: latent log variance"""MSE = reconstruction_function(recon_x, x)# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)KLD = torch.sum(KLD_element).mul_(-0.5)# KL divergencereturn MSE + KLDoptimizer = torch.optim.Adam(net.parameters(), lr=1e-3)def to_img(x):'''定义一个函数将最后的结果转换回图片'''x = 0.5 * (x + 1.)x = x.clamp(0, 1)x = x.view(x.shape[0], 1, 28, 28)return xfor e in range(100):for im, _ in train_data:im = im.view(im.shape[0], -1)im = Variable(im)if torch.cuda.is_available():im = im.cuda()recon_im, mu, logvar = net(im)loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均optimizer.zero_grad()loss.backward()optimizer.step()if (e + 1) % 20 == 0:print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0]))save = to_img(recon_im.cpu().data)if not os.path.exists('./vae_img'):os.mkdir('./vae_img')save_image(save, './vae_img/image_{}.png'.format(e + 1))epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343

变分自动编码器虽然比一般的自动编码器效果要好,而且也限制了其输出的编码 (code) 的概率分布,但是它仍然是通过直接计算生成图片和原始图片的均方误差来生成 loss,这个方式并不好,生成对抗网络中,我们会讲一讲这种方式计算 loss 的局限性,然后会介绍一种新的训练办法,就是通过生成对抗的训练方式来训练网络而不是直接比较两张图片的每个像素点的均方误差

PyTorch 深度学习:35分钟快速入门——变分自动编码器相关推荐

  1. PyTorch 深度学习:36分钟快速入门——GAN

    自动编码器和变分自动编码器,不管是哪一个,都是通过计算生成图像和输入图像在每个像素点的误差来生成 loss,这一点是特别不好的,因为不同的像素点可能造成不同的视觉结果,但是可能他们的 loss 是相同 ...

  2. PyTorch 深度学习:34分钟快速入门——自动编码器

    自动编码器最开始是作为一种数据压缩方法,同时还可以在卷积网络中进行逐层预训练,但是随后更多结构复杂的网络,比如 resnet 的出现使得我们能够训练任意深度的网络,自动编码器就不再使用在这个方面,下面 ...

  3. PyTorch 深度学习:32分钟快速入门——ResNet

    ResNet 当大家还在惊叹 GoogLeNet 的 inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络子在 2015 年 Ima ...

  4. PyTorch 深度学习:37分钟快速入门——FCN 做语义分割

    语义分割是一种像素级别的处理图像方式,对比于目标检测其更加精确,能够自动从图像中划分出对象区域并识别对象区域中的类别 在 2015 年 CVPR 的一篇论文 Fully Convolutional N ...

  5. PyTorch 深度学习:33分钟快速入门——VGG

    CIFAR 10¶ cifar 10 这个数据集一共有 50000 张训练集,10000 张测试集,两个数据集里面的图片都是 png 彩色图片,图片大小是 32 x 32 x 3,一共是 10 分类问 ...

  6. PyTorch 深度学习:32分钟快速入门——DenseNet

    DenseNet¶ 因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 cvpr 2017 的 best paper,DenseNet. DenseNet ...

  7. PyTorch 深度学习:30分钟快速入门

    卷积¶ 卷积在 pytorch 中有两种方式,一种是 torch.nn.Conv2d(),一种是 torch.nn.functional.conv2d(),这两种形式本质都是使用一个卷积操作 这两种形 ...

  8. PyTorch 深度学习:38分钟快速入门——RNN 做图像分类

    RNN 特别适合做序列类型的数据,那么 RNN 能不能想 CNN 一样用来做图像分类呢?下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类,但是这种方法并不是主流,这里我们只是 ...

  9. PyTorch 深度学习:31分钟快速入门——Batch Normalization

    Batch Normalization¶ 前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好.但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相 ...

最新文章

  1. 数据流模式、转换、格式与操作
  2. python 中文转unicode编码_Python 解码 Unicode 转义字符串
  3. 关联关系、依赖关系总结
  4. 安卓USB开发教程 五 安卓 AOA 1.0
  5. Flex布局实现筛子3
  6. MySQL导入与导出备份详解
  7. 【图像加密】基于matlab RSA图像加密解密【含Matlab源码 1442期】
  8. 4.计蒜客ACM题库.A1947 An Olympian Math Problem
  9. Unity3D加密工具:Virbox Protector Unity3D版
  10. FlashFXP设置文件传输速度
  11. P2P直播软件设计的技术原理和改进
  12. 街头篮球服务器未响应,我的生涯我做主 《街头篮球》生涯联赛FAQ
  13. (原创)[联觉][类比推理的应用]震惊!声音也有温度和冷暖?什么是冷声和暖声?无处不在的联觉,色彩、声音的频率与温度之间的通感,色彩和声音的冷暖(类比冷色和暖色)
  14. 2021年电子竞赛四天三夜征程—-信号失真度测量装置(A题)
  15. 网络攻击机制和技术发展综述
  16. logo免费在线设计,给自己的logo寻找灵感
  17. 六,iOS中的金额格式化和金额大小写转换
  18. 阿里 M8 级大神整理出 SQL 手册:收获不止 SQL 优化,抓住 SQL 的本质
  19. Navicate无法连接,提示is not allowed to connect to this mysql server
  20. 使用javascript制作 滚动字幕及时钟

热门文章

  1. Velo 实验室集成 Chainlink 预言机喂价
  2. Zcash已发布ZIP 313提案
  3. 以太坊2.0质押地址余额超过170万枚
  4. SAP License:SAP生产订单中的统计指标运用
  5. 如何浅显得理解风控模型中的特征筛选|附实操细节(全)
  6. sqoop-import 并行抽数及数据倾斜解决
  7. DataFrame的级联合并操作
  8. 关闭ArcGIS9.3时 .NET Framework出现尝试读取或写入受保护的内存问题
  9. 背景透明,文字不透明效果
  10. 2018.10.26多校