自动编码器和变分自动编码器,不管是哪一个,都是通过计算生成图像和输入图像在每个像素点的误差来生成 loss,这一点是特别不好的,因为不同的像素点可能造成不同的视觉结果,但是可能他们的 loss 是相同的,所以通过单个像素点来得到 loss 是不准确的,这个时候我们需要一种全新的 loss 定义方式,就是通过对抗进行学习。

这个网络是由两部分组成的,第一部分是生成,第二部分是对抗。简单来说,就是有一个生成网络和一个判别网络,通过训练让两个网络相互竞争,生成网络来生成假的数据,对抗网络通过判别器去判别真伪,最后希望生成器生成的数据能够以假乱真。

Discriminator Network¶

首先我们来讲一下对抗过程,因为这个过程更加简单。

对抗过程简单来说就是一个判断真假的判别器,相当于一个二分类问题,我们输入一张真的图片希望判别器输出的结果是1,输入一张假的图片希望判别器输出的结果是0。这其实已经和原图片的 label 没有关系了,不管原图片到底是一个多少类别的图片,他们都统一称为真的图片,label 是 1 表示真实的;而生成的假的图片的 label 是 0 表示假的。

我们训练的过程就是希望这个判别器能够正确的判出真的图片和假的图片,这其实就是一个简单的二分类问题,对于这个问题可以用我们前面讲过的很多方法去处理,比如 logistic 回归,深层网络,卷积神经网络,循环神经网络都可以。

Generator Network¶

接着我们看看生成网络如何生成一张假的图片。首先给出一个简单的高维的正态分布的噪声向量,这个时候我们可以通过仿射变换,也就是 xw+b 将其映射到一个更高的维度,然后将他重新排列成一个矩形,这样看着更像一张图片,接着进行一些卷积、转置卷积、池化、激活函数等进行处理,最后得到了一个与我们输入图片大小一模一样的噪音矩阵,这就是我们所说的假的图片。

这个时候我们如何去训练这个生成器呢?这就需要通过对抗学习,增大判别器判别这个结果为真的概率,通过这个步骤不断调整生成器的参数,希望生成的图片越来越像真的,而在这一步中我们不会更新判别器的参数,因为如果判别器不断被优化,可能生成器无论生成什么样的图片都无法骗过判别器。

import torch
from torch import nn
from torch.autograd import Variableimport torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNISTimport numpy as npimport matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'def show_images(images): # 定义画图工具images = np.reshape(images, [images.shape[0], -1])sqrtn = int(np.ceil(np.sqrt(images.shape[0])))sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))fig = plt.figure(figsize=(sqrtn, sqrtn))gs = gridspec.GridSpec(sqrtn, sqrtn)gs.update(wspace=0.05, hspace=0.05)for i, img in enumerate(images):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(img.reshape([sqrtimg,sqrtimg]))return def preprocess_img(x):x = tfs.ToTensor()(x)return (x - 0.5) / 0.5def deprocess_img(x):return (x + 1.0) / 2.0class ChunkSampler(sampler.Sampler): # 定义一个取样的函数"""Samples elements sequentially from some offset. Arguments:num_samples: # of desired datapointsstart: offset where we should start selecting from"""def __init__(self, num_samples, start=0):self.num_samples = num_samplesself.start = startdef __iter__(self):return iter(range(self.start, self.start + self.num_samples))def __len__(self):return self.num_samplesNUM_TRAIN = 50000
NUM_VAL = 5000NOISE_DIM = 96
batch_size = 128train_set = MNIST('./mnist', train=True, download=True, transform=preprocess_img)train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))val_set = MNIST('./mnist', train=True, download=True, transform=preprocess_img)val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)

简单版本的生成对抗网络¶

通过前面我们知道生成对抗网络有两个部分构成,一个是生成网络,一个是对抗网络,我们首先写一个简单版本的网络结构,生成网络和对抗网络都是简单的多层神经网络

判别网络¶

判别网络的结构非常简单,就是一个二分类器,结构如下:

  • 全连接(784 -> 256)
  • leakyrelu, α 是 0.2
  • 全连接(256 -> 256)
  • leakyrelu, α 是 0.2
  • 全连接(256 -> 1)

其中 leakyrelu 是指 f(x) = max(α x, x)

def discriminator():net = nn.Sequential(        nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1))return net

生成网络¶

接下来我们看看生成网络,生成网络的结构也很简单,就是根据一个随机噪声生成一个和数据维度一样的张量,结构如下:

  • 全连接(噪音维度 -> 1024)
  • relu
  • 全连接(1024 -> 1024)
  • relu
  • 全连接(1024 -> 784)
  • tanh 将数据裁剪到 -1 ~ 1 之间
def generator(noise_dim=NOISE_DIM):   net = nn.Sequential(nn.Linear(noise_dim, 1024),nn.ReLU(True),nn.Linear(1024, 1024),nn.ReLU(True),nn.Linear(1024, 784),nn.Tanh())return net

如果我们把 D(x) 看成真实数据的分类得分,那么 D(G(z)) 就是假数据的分类得分,所以上面判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1 下面我们来实现一下

bce_loss = nn.BCEWithLogitsLoss()def discriminator_loss(logits_real, logits_fake): # 判别器的 losssize = logits_real.shape[0]true_labels = Variable(torch.ones(size, 1)).float().cuda()false_labels = Variable(torch.zeros(size, 1)).float().cuda()loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)return lossdef generator_loss(logits_fake): # 生成器的 loss  size = logits_fake.shape[0]true_labels = Variable(torch.ones(size, 1)).float().cuda()loss = bce_loss(logits_fake, true_labels)return loss# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))return optimizer

下面我们开始训练一个这个简单的生成对抗网络

def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, noise_size=96, num_epochs=10):iter_count = 0for epoch in range(num_epochs):for x, _ in train_data:bs = x.shape[0]# 判别网络real_data = Variable(x).view(bs, -1).cuda() # 真实数据logits_real = D_net(real_data) # 判别网络得分sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布g_fake_seed = Variable(sample_noise).cuda()fake_images = G_net(g_fake_seed) # 生成的假的数据logits_fake = D_net(fake_images) # 判别网络得分d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 lossD_optimizer.zero_grad()d_total_error.backward()D_optimizer.step() # 优化判别网络# 生成网络g_fake_seed = Variable(sample_noise).cuda()fake_images = G_net(g_fake_seed) # 生成的假的数据gen_logits_fake = D_net(fake_images)g_error = generator_loss(gen_logits_fake) # 生成网络的 lossG_optimizer.zero_grad()g_error.backward()G_optimizer.step() # 优化生成网络if (iter_count % show_every == 0):print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.data[0], g_error.data[0]))imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())show_images(imgs_numpy[0:16])plt.show()print()iter_count += 1D = discriminator().cuda()
G = generator().cuda()D_optim = get_optimizer(D)
G_optim = get_optimizer(G)train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

Deep Convolutional GANs 深度卷积生成对抗网络特别简单,就是将生成网络和对抗网络都改成了卷积网络的形式,下面我们来实现一下

卷积判别网络¶

卷积判别网络就是一个一般的卷积网络,结构如下

  • 32 Filters, 5x5, Stride 1, Leaky ReLU(alpha=0.01)
  • Max Pool 2x2, Stride 2
  • 64 Filters, 5x5, Stride 1, Leaky ReLU(alpha=0.01)
  • Max Pool 2x2, Stride 2
  • Fully Connected size 4 x 4 x 64, Leaky ReLU(alpha=0.01)
  • Fully Connected size 1
class build_dc_classifier(nn.Module):def __init__(self):super(build_dc_classifier, self).__init__()self.conv = nn.Sequential(nn.Conv2d(1, 32, 5, 1),nn.LeakyReLU(0.01),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, 5, 1),nn.LeakyReLU(0.01),nn.MaxPool2d(2, 2))self.fc = nn.Sequential(nn.Linear(1024, 1024),nn.LeakyReLU(0.01),nn.Linear(1024, 1))def forward(self, x):x = self.conv(x)x = x.view(x.shape[0], -1)x = self.fc(x)return x

卷积生成网络¶

卷积生成网络需要将一个低维的噪声向量变成一个图片数据,结构如下

  • Fully connected of size 1024, ReLU
  • BatchNorm
  • Fully connected of size 7 x 7 x 128, ReLU
  • BatchNorm
  • Reshape into Image Tensor
  • 64 conv2d^T filters of 4x4, stride 2, padding 1, ReLU
  • BatchNorm
  • 1 conv2d^T filter of 4x4, stride 2, padding 1, TanH
class build_dc_generator(nn.Module): def __init__(self, noise_dim=NOISE_DIM):super(build_dc_generator, self).__init__()self.fc = nn.Sequential(nn.Linear(noise_dim, 1024),nn.ReLU(True),nn.BatchNorm1d(1024),nn.Linear(1024, 7 * 7 * 128),nn.ReLU(True),nn.BatchNorm1d(7 * 7 * 128))self.conv = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, padding=1),nn.ReLU(True),nn.BatchNorm2d(64),nn.ConvTranspose2d(64, 1, 4, 2, padding=1),nn.Tanh())def forward(self, x):x = self.fc(x)x = x.view(x.shape[0], 128, 7, 7) # reshape 通道是 128,大小是 7x7x = self.conv(x)return xdef train_dc_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, noise_size=96, num_epochs=10):iter_count = 0for epoch in range(num_epochs):for x, _ in train_data:bs = x.shape[0]# 判别网络real_data = Variable(x).cuda() # 真实数据logits_real = D_net(real_data) # 判别网络得分sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布g_fake_seed = Variable(sample_noise).cuda()fake_images = G_net(g_fake_seed) # 生成的假的数据logits_fake = D_net(fake_images) # 判别网络得分d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 lossD_optimizer.zero_grad()d_total_error.backward()D_optimizer.step() # 优化判别网络# 生成网络g_fake_seed = Variable(sample_noise).cuda()fake_images = G_net(g_fake_seed) # 生成的假的数据gen_logits_fake = D_net(fake_images)g_error = generator_loss(gen_logits_fake) # 生成网络的 lossG_optimizer.zero_grad()g_error.backward()G_optimizer.step() # 优化生成网络if (iter_count % show_every == 0):print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.data[0], g_error.data[0]))imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())show_images(imgs_numpy[0:16])plt.show()print()iter_count += 1D_DC = build_dc_classifier().cuda()
G_DC = build_dc_generator().cuda()D_DC_optim = get_optimizer(D_DC)
G_DC_optim = get_optimizer(G_DC)train_dc_gan(D_DC, G_DC, D_DC_optim, G_DC_optim, discriminator_loss, generator_loss, num_epochs=5)

PyTorch 深度学习:36分钟快速入门——GAN相关推荐

  1. PyTorch 深度学习:34分钟快速入门——自动编码器

    自动编码器最开始是作为一种数据压缩方法,同时还可以在卷积网络中进行逐层预训练,但是随后更多结构复杂的网络,比如 resnet 的出现使得我们能够训练任意深度的网络,自动编码器就不再使用在这个方面,下面 ...

  2. PyTorch 深度学习:32分钟快速入门——ResNet

    ResNet 当大家还在惊叹 GoogLeNet 的 inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络子在 2015 年 Ima ...

  3. PyTorch 深度学习:37分钟快速入门——FCN 做语义分割

    语义分割是一种像素级别的处理图像方式,对比于目标检测其更加精确,能够自动从图像中划分出对象区域并识别对象区域中的类别 在 2015 年 CVPR 的一篇论文 Fully Convolutional N ...

  4. PyTorch 深度学习:33分钟快速入门——VGG

    CIFAR 10¶ cifar 10 这个数据集一共有 50000 张训练集,10000 张测试集,两个数据集里面的图片都是 png 彩色图片,图片大小是 32 x 32 x 3,一共是 10 分类问 ...

  5. PyTorch 深度学习:32分钟快速入门——DenseNet

    DenseNet¶ 因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 cvpr 2017 的 best paper,DenseNet. DenseNet ...

  6. PyTorch 深度学习:30分钟快速入门

    卷积¶ 卷积在 pytorch 中有两种方式,一种是 torch.nn.Conv2d(),一种是 torch.nn.functional.conv2d(),这两种形式本质都是使用一个卷积操作 这两种形 ...

  7. PyTorch 深度学习:38分钟快速入门——RNN 做图像分类

    RNN 特别适合做序列类型的数据,那么 RNN 能不能想 CNN 一样用来做图像分类呢?下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类,但是这种方法并不是主流,这里我们只是 ...

  8. PyTorch 深度学习:35分钟快速入门——变分自动编码器

    变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成. 回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编 ...

  9. PyTorch 深度学习:31分钟快速入门——Batch Normalization

    Batch Normalization¶ 前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好.但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相 ...

最新文章

  1. LeetCode 06Z字形变换07整数反转
  2. 故障解决:error while loading shared libraries: libncurses.so.5
  3. C#生成CHM文件(中级篇)
  4. 如何用texstudio下载ctex_公众号素材库视频如何下载,用这种方法就可以哦
  5. 世界上最伟大的十个公式,看看你懂得几个?
  6. 如何理解SVM | 支持向量机之我见
  7. html marquee初始空白_前端开发必会的HTML/CSS硬知识
  8. vb 6 MDI窗体图片自适应源码
  9. HDU-1087 Super Jumping! Jumping! Jumping!
  10. mvc npoi将List实体导出excel的最简单方法
  11. 好用的 windows10 软件推荐
  12. 有效软件测试 - 50条建议 - 需求阶段
  13. ubuntu下安装绿联的AC650网卡驱动
  14. sql重启oracle数据库,oracle重启数据库sql
  15. 【NLP】模型压缩与蒸馏!BERT的忒修斯船
  16. 平安新一贷怎么被拒了及原因是什么?你都知道吗?
  17. 判断一个数为奇偶数的三种方法
  18. 002 Figuring in C/C++
  19. 如何计算EEG信号的香农熵Shannon entropy(附Matlab程序)
  20. Python学习 Day31 DOM

热门文章

  1. 今日恐慌与贪婪指数为66 贪婪程度有所缓解
  2. Axure高保真web端后台管理系统/垃圾回收分类系统/垃圾回收高保真原型设计 /垃圾分类后台管理系统/垃圾回收分类平台//垃圾回收分类智慧管理系统/订单管理/财务管理/系统管理/库存管理/设备管理
  3. Axure高保真移动端智能数据监控+用户画像+饼状图+条形图+折线图数据统计+抖音直播app用户数据统计+智慧移动端主播粉丝、评论、播放量大数据统计+套餐购买、续费套餐prd流程
  4. Samba服务器安装测试
  5. laradock 进入 工作区
  6. CentOS6 安装Sendmail + Dovecot + Squirrelmail
  7. 利用Gulp实现JSDoc 3的文档编写过程中的实时解析和效果预览
  8. jQuery lazyload
  9. 简便无刷新文件上传系统
  10. 是谁在我的心里打了个结(二十一)托付