「@Author:Runsen」

GAN 是使用两个神经网络模型训练的生成模型。一种模型称为生成网络模型,它学习生成新的似是而非的样本。另一个模型被称为判别网络,它学习区分生成的例子和真实的例子。

生成性对抗网络

2014,蒙特利尔大学的Ian Goodfellow和他的朋友发明了生成性对抗网络(GAN)。自它出版以来,有许多它的变体和客观功能来解决它的问题

论文在这里找到.

论文提出了两种模型:生成模型和判别模型。两个模型竞争,以产生真实和假的样本。2016年,Yann LeCun将GANs描述为“过去二十年机器学习中最酷的想法”。

GAN 的大部分研究和应用都集中在计算机视觉领域。

其原因是卷积神经网络 (CNN) 等深度学习模型在过去 5 到 7 年中在计算机视觉领域取得了巨大成功,例如在具有挑战性的任务(如对象检测和人脸识别。

GAN 的典型例子是生成新的逼真的照片,最令人吃惊的是生成照片般逼真的人脸的例子。

在本教程中,我们将实现一个简单的GAN生成假的MNIST样本。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoaderimport torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utilsimport numpy as np
import matplotlib.pyplot as plt
# CPU / GPU Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)  #cuda

使用MNIST数据集,具有最小大小的数据集。

它由60000个训练图像和10000个测试图像组成,每个图像有28*28的大小和一个彩色通道。

# Define a transform
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5, ), std = (0.5, ))
])# batch_size是一个前向和后向传播过程中的图像数。
batch_size = 100mnist = datasets.MNIST('./data/MNIST', download = True, train = True, transform = transform)mnist_loader = DataLoader(dataset = mnist, batch_size = batch_size, shuffle = True)
# CPU
def imshow(img, title):img = utils.make_grid(img.cpu().detach())img = (img+1)/2npimg = img.detach().numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.title(title)plt.show()
#GPU
def imshow(img, title):npimg = img.detach().numpy()fig = plt.figure(figsize = (10, 10))plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.title(title)plt.show()images, labels = iter(mnist_loader).next()
imshow(images[0:16, :, :], "MNIST Images")

建立一个GANs模型。一个Generator和Discriminator

GANs由完全连接的层组成。它将从100维高斯分布采样的噪声转换为MNIST图像。鉴别器网络也由完全连接的层组成,用于区分输入数据是真是假。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()latent_size = 100output = 28*28self.main = nn.Sequential(nn.Linear(latent_size, 128),nn.ReLU(inplace=True),nn.Linear(128, 256),nn.ReLU(inplace=True),nn.Linear(256, 512),nn.ReLU(inplace=True),nn.Linear(512, output),nn.Tanh())def forward(self, x):out = self.main(x)out = out.view(-1, 1, 28, 28)return outclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()n_features = 28 * 28n_out = 1self.main = nn.Sequential(nn.Linear(n_features, 512),nn.ReLU(inplace=True),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Linear(256, 128),nn.ReLU(inplace=True),nn.Linear(128, 64),nn.ReLU(inplace=True),nn.Linear(64, n_out),nn.Sigmoid()        )def forward(self, x):x = x.view(-1, 28*28)out = self.main(x)return outG = Generator().to(device)
D = Discriminator().to(device)

生成性对抗网络训练过程的损失函数是二进制交叉熵损失,由torch.nn.BCELoss实现。

这两种模型都使用torch.optim.Adam作为优化工具,学习率设置为0.002。

# Objective Function
criterion = nn.BCELoss()# Optimizer
G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)# Constants
noise_dim = 100
num_epochs = 50
total_batch = len(mnist_loader)# Lists
G_losses = []
D_losses = []# Noise
sample_size = 16
fixed_noise = torch.randn(sample_size, noise_dim).to(device)# Train
for epoch in range(num_epochs):for i, (images, labels) in enumerate(mnist_loader):# Images #images = images.reshape(batch_size, -1).float().to(device)# Labels #ones = torch.ones(batch_size, 1).to(device)zeros = torch.zeros(batch_size, 1).to(device)# Noise #noise = torch.randn(batch_size, noise_dim).to(device)# Initialize OptimizersD_optimizer.zero_grad()G_optimizer.zero_grad()######################## Train Discriminator ######################### Forward Images #prob_real = D(images)D_real_loss = criterion(prob_real, ones)# Generate Samples #fake_images = G(noise)prob_fake = D(fake_images)# Forward Fake Samples and Calculate Discriminator Loss #D_fake_loss = criterion(prob_fake, zeros)D_loss = (D_real_loss + D_fake_loss).mean()# Back Propagation and UpdateD_loss.backward()D_optimizer.step()#################### Train Generator ####################fake_images = G(noise)prob_fake = D(fake_images)# According to the p 3 in paper,# early in learning, when G is very poor, D can reject samples from G.# In this case, log(1-D(G(z))) saturates. # thus, train G to maximiaze log(D(G(z))) instead of minimizing log(1-D(G(z)))G_loss = criterion(prob_fake, ones)# Back Propagation and UpdateG_loss.backward()G_optimizer.step()# Save Losses for Plotting LaterG_losses.append(G_loss.item())D_losses.append(D_loss.item())# Print Statistics #if (i + 1) % 100 == 0:print("Epoch [%d/%d] Iter [%d/%d], D_Loss: %.4f G_Loss: %.4f"%(epoch+1, num_epochs, i+1, total_batch, D_loss.item(), G_loss.item()))# Generate Samples #if epoch % 1 == 0:fake_samples = G(fixed_noise)imshow(fake_samples, "Generated MNIST Images")# Save Model Weights for Digit Generation
torch.save(G.state_dict(), './data/GAN.pkl')

plt.figure(figsize = (8, 6))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Losses")
plt.legend()
plt.show()

sample_size = 64
noise_dim = 100noise = torch.randn(sample_size, noise_dim).to(device)G.load_state_dict(torch.load('GAN.pkl'))
fake_samples = G(fixed_noise)
imshow(fake_samples, "Generated MNIST Images")

GAN生成性对抗网络的运用

  • 将语义图像翻译成城市景观和建筑物的照片。

  • 将卫星照片翻译成地图。

  • 从白天到晚上的照片翻译。

  • 将黑白照片翻译成彩色。

- 论文在这里找到:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

- 上述代码的论文:https://arxiv.org/abs/1511.06434

- 上述代码:https://github.com/yihui-he/GAN-MNIST


往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》课件合集
本站qq群851320808,加入微信群请扫码:

【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络相关推荐

  1. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  2. 【小白学习PyTorch教程】五、在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据

    「@Author:Runsen」 有时候,在处理大数据集时,一次将整个数据加载到内存中变得非常难. 因此,唯一的方法是将数据分批加载到内存中进行处理,这需要编写额外的代码来执行此操作.对此,PyTor ...

  3. 【小白学习PyTorch教程】十七、 PyTorch 中 数据集torchvision和torchtext

    @Author:Runsen 对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext. 之前使用 torchDataLoader类直接加载图像并将其转换为张量. ...

  4. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...

    「@Author:Runsen」 上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. imp ...

  5. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度

    @Author:Runsen 上次基于CIFAR-10 数据集,使用PyTorch ​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. im ...

  6. 【论文翻译】Auto-painter:基于条件Wasserstein生成性对抗网络的草图卡通形象生成

    Auto-painter: Cartoon image generation from sketch by using conditional Wasserstein generative adver ...

  7. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型...

    「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...

  8. 【小白学习PyTorch教程】七、基于乳腺癌数据集​​构建Logistic 二分类模型

    「@Author:Runsen」 在逻辑回归中预测的目标变量不是连续的,而是离散的.可以应用逻辑回归的一个示例是电子邮件分类:标识为垃圾邮件或非垃圾邮件.图片分类.文字分类都属于这一类. 在这篇博客中 ...

  9. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型

    @Author:Runsen 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先了解 ...

最新文章

  1. 栈溢出原理 小记 基础中的基础~~
  2. vueRouter-命名视图
  3. 华硕笔记本,宽带连上,可以上网, 但收到不无线
  4. python 爬取大乐透开奖结果
  5. 【java】docker容器内使用jstack等命令报错 The VM does not support the attach mechanism
  6. 判断wifi连接是否可用
  7. sqoop遇到的问题
  8. java cas原理_Java中的锁原理、锁优化、CAS、AQS,看这篇就对了!
  9. Nginx为什么会比Apache Httpd高效
  10. 小学steam计算机课程案例,基于STEAM教育的小学信息技术课程案例开发
  11. 【sduoj】前端JSZip库的使用
  12. Excel做数据分析?是真的很强
  13. Java移除出界敌机,java实现飞机大战案例详解
  14. 安卓是属于全人类的还是谷歌的私有产品?
  15. Apache Hadoop
  16. lazada发货_lazada怎么发货?
  17. 51单片机常用波特率设置
  18. 柳维尔定理与代数基本定理
  19. 支招:苹果电脑Mac版如何快速解压缩软件
  20. python网络安全设计_专为渗透测试人员设计的 Python 工具大合集

热门文章

  1. 如何实现DataGridView实时更新数据【Z】
  2. 一些关于“数据挖掘介”技术的有用文档
  3. JavaScript对象和数组
  4. Spring Boot 使用slf4j+logback记录日志配置
  5. LINUX修改文件权限 学习
  6. html 多行多列列表格,HTML跨多行跨多列表格.doc
  7. Python DataFrame删除某一列中包含的特定元素所在的行
  8. 病例对照研究和队列研究详解
  9. 数据结构与算法分析 C++语言描述第四版.Mark Allen Weiss
  10. Vscode Python输出窗口中文乱码的解决办法