导入必要的包

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_imageimport matplotlib.pyplot as plt
import matplotlib.image as mpimg
# 设备配置
# torch.cuda.set_device(1)
# 这句用来设置pytorch在哪块GPU上运行,这里假设使用序号为1的这块GPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 在当前目录,创建不存在的目录ave_samples
sample_dir = 'ave_samples'
if not os.path.exists(sample_dir):os.makedirs(sample_dir)

定义一些超参数

image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 30
batch_size = 128
learning_rate = 0.001

下载MNIST训练集

#这里因已下载,故download=False
dataset = torchvision.datasets.MNIST(root='data',train=True,transform=transforms.ToTensor(),download=False)#数据加载
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)

定义AVE模型

class VAE(nn.Module):def __init__(self, image_size=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h_dim)self.fc2 = nn.Linear(h_dim, z_dim)self.fc3 = nn.Linear(h_dim, z_dim)self.fc4 = nn.Linear(z_dim, h_dim)self.fc5 = nn.Linear(h_dim, image_size)def encode(self, x):h = F.relu(self.fc1(x))return self.fc2(h), self.fc3(h)# 用mu,log_var生成一个潜在空间点z,mu,log_var为两个统计参数,并假设这个假设分布能生成图像def reparameterize(self, mu, log_var):std = torch.exp(log_var / 2)eps = torch.randn_like(std) # std为随机采样的return mu + eps * stddef decode(self, z):h = F.relu(self.fc4(z))return F.sigmoid(self.fc5(h))def forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)x_reconst = self.decode(z)return x_reconst, mu, log_varmodel = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

开始训练模型

for epoch in range(num_epochs):model.train()for i, (x, _) in enumerate(data_loader):# 前向传播model.zero_grad()x = x.to(device).view(-1, image_size)x_reconst, mu, log_var = model(x)# Compute reconstruction loss and kl divergence# For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())# 反向传播及优化器loss = reconst_loss + kl_div    # 两者相加得总损失optimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}".format(epoch + 1, num_epochs, i + 1, len(data_loader), reconst_loss.item(), kl_div.item()))with torch.no_grad():# 保存采样图像,即潜在向量Z通过解码器生成的新图像z = torch.randn(batch_size, z_dim).to(device)out = model.decode(z).view(-1, 1, 28, 28)save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch + 1)))# 保存重构图像,即原图像通过解码器生成的图像out, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch + 1)))

展示原图像及重构图像

reconsPath = './ave_samples/reconst-30.png'
Image = mpimg.imread(reconsPath)
plt.imshow(Image) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

显示由潜在空间点Z生成的新图像

genPath = './ave_samples/sampled-30.png'
Image = mpimg.imread(genPath)
plt.imshow(Image) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

变分自编码AVE器生成图像(Pytorch)相关推荐

  1. 【Pytorch神经网络实战案例】13 构建变分自编码神经网络模型生成Fashon-MNST模拟数据

    1 变分自编码神经网络生成模拟数据案例说明 变分自编码里面真正的公式只有一个KL散度. 1.1 变分自编码神经网络模型介绍 主要由以下三个部分构成: 1.1.1 编码器 由两层全连接神经网络组成,第一 ...

  2. 【Pytorch神经网络实战案例】14 构建条件变分自编码神经网络模型生成可控Fashon-MNST模拟数据

    1 条件变分自编码神经网络生成模拟数据案例说明 在实际应用中,条件变分自编码神经网络的应用会更为广泛一些,因为它使得模型输出的模拟数据可控,即可以指定模型输出鞋子或者上衣. 1.1 案例描述 在变分自 ...

  3. 从零开始学keras之变分自编码器生成图像

    自编码器由 Kingma 和 Welling 于 2013 年 12 月 a 与 Rezende.Mohamed 和 Wierstra 于 2014 年 1 月 同时发现,它是一种生成式模型,特别适用 ...

  4. 深入理解自编码器(用变分自编码器生成图像)

    文章目录 自编码器 欠完备自编码器 正则自编码器 稀疏自编码器 去噪自编码器 收缩自编码器 变分自编码器 References 内容总结自花书<Deep Learning>以及<Py ...

  5. 【Pytorch神经网络理论篇】 22 自编码神经网络:概述+变分+条件变分自编码神经网络

    1 无监督学习模型的概述 在监督训练中,模型能根据预测结果与标签差值来计算损失,并向损失最小的方向进行收敛. 在无监督训练中,无法通过样本标签为模型权重指定收敛方向,这就要求模型必须有自我监督的功能. ...

  6. 华人一作统一「视觉-语言」理解与生成:一键生成图像标注,完成视觉问答,Demo可玩...

    来源:机器学习研究组订阅 这个 BLIP 模型可以「看图说话」,提取图像的主要内容,不仅如此,它还能回答你提出的关于图像的问题. 视觉 - 语言预训练 (Vision-Language Pre-tra ...

  7. 二元函数图像生成器_GAN生成图像综述

    点击上方"CVer",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者:YTimo(PKU EECS)   研究方向:深度学习,计算机 ...

  8. TensorFlow从1到2(十一)变分自动编码器和图片自动生成

    基本概念 "变分自动编码器"(Variational Autoencoders,缩写:VAE)的概念来自Diederik P Kingma和Max Welling的论文<Au ...

  9. 文本生成图像简述4——扩散模型、自回归模型、生成对抗网络的对比调研

    基于近年来图像处理和语言理解方面的技术突破,融合图像和文本处理的多模态任务获得了广泛的关注并取得了显著成功. 文本生成图像(text-to-image)是图像和文本处理的多模态任务的一项子任务,其根据 ...

最新文章

  1. 辩证看待 iostat
  2. 碎片化学习不是学习碎片,看这篇了解碎片化学习的真相
  3. H - Maximal submatrix HDU - 6957
  4. dmsetup remove_all 这命令干啥的_分一个小知识,服务器上的一个解压与压缩文件的命令....
  5. Hadoop精华问答 | 非大数据的项目能否用Hadoop?
  6. android开源2016_出版商的选择:2016年顶级开源书籍
  7. 【英语学习】【医学】Unit 02 The Brain and Its Functions
  8. 【java】Java 最坑爹的 10 大功能点
  9. 结组开发项目(TD学生助手)
  10. azure api 管理_Azure Cosmos DB和MongoDB API入门
  11. 浅析Mysql的隔离级别及MVCC
  12. 微信开发者工具命令行_微信开发者工具 Linux版
  13. web前端开发是什么?
  14. 建筑行业必看,一招学会工地管理诀窍
  15. Bootstrap 之Table样式
  16. python中的取整
  17. 【消息队列笔记】chp2-如何选择消息队列
  18. 2021-06-28 什么是一清机跟二清机、费率、分润、MCC码_POS机
  19. 奥塔在线:vsftpd服务如何开启访问日志
  20. 美团饿了吗CPS红包,别人领红包下单,你拿推广佣金(送源码)

热门文章

  1. win10电脑用蓝牙实现文件传输,安卓手机通过蓝牙将文件传送到电脑
  2. 【学习挑战赛】经典算法之折半查找
  3. pcb只开窗不镀锡_关于pads中 PCB铺铜开窗镀锡 的操作
  4. Comparable Comparator
  5. 网络舆情数据分析系统技术方案
  6. 【Linux环境搭建】十二、Linux(CentOS7) 时序数据库InfluxDB及Influx-proxy安装配置
  7. 最强大的PDF编辑器Adobe Acrobat DC Pro
  8. NSG44273低侧驱动IC
  9. ISO14067产品碳足迹认证流程是怎么样的?
  10. 3-1存储系统-存储器概述主存储器