《Generative Adversarial Nets》

  • 生成式对抗网络;
  • 作者:lan Goodfellow;
  • 单位:加拿大蒙特利尔大学;
  • 发表会议及时间:NeurlPS(NIPS) 2014;

核心要点

  1. 提出了一个基于对抗的 新生成式模型,由一个生成器和一个判别器组成;
  2. 生成器的目标是学习到样本的数据分布,从而能生成样本欺骗判别器;判别器的目标是判断输入样本时生成/真实的概率;
  3. GAN模型等同于博弈论中的二人零和博弈;
  4. 对于任意的生成器和判别器,都存在一个独特的全局最优解;
  5. 在本文中,生成器和判别器都是由多层感知机实现,整个网络可以用反向传播算法来训练;
  6. 通过实验的定性与定量分析显示,GAN具备很大的潜力;

研究背景

1、零和博弈

  • 一方的收益必然意味着另一方的损失,博弈各方的收益和损失相加总和永远为“零”,双方不存在合作的可能;
  • 在零和博弈中,为了使己方达到最优解,所以把目标设为让对方的最大化收益最小化;

2、使用数据集

  • MNIST:手写数据集,源自NIST;28*28的灰度图,训练集60000张,测试集10000张;

  • TFD:The Toronro face dataset,人脸数据集;

  • CIFAR-10:32*32彩图,10个类别,每类6000张图,训练集50000张,测试集10000张;

3、GAN价值函数

价值函数
minGmaxDV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]min_G max_D V(D,G)=E_{x\sim p_{data}(x)}[log D(x)]+E_{z\sim p_z(z)}[log(1-D(G(z)))]minG​maxD​V(D,G)=Ex∼pdata​(x)​[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))]

  • datadatadata:真实数据;
  • DDD:判别器,输出值为[0,1],代表输入来自真实数据的概率;
  • zzz:随机噪声;
  • GGG:生成器,输出为合成数据;

判别器DDD的目的是最大化价值函数VVV,对数函数log在底数大于1时为单调递增函数,最大化VVV就是最大化D(x)D(x)D(x)和1−D(G(z))1-D(G(z))1−D(G(z)),对于任意的x,都有D(x)=1D(x)=1D(x)=1,对于任意的zzz都有D(G(z))=0D(G(z))=0D(G(z))=0。

生成器GGG的目的是针对特定的DDD,去最小化价值函数VVV;最小化价值函数VVV,就是最小化D(x)D(x)D(x)和1−D(G(z))1-D(G(z))1−D(G(z));对于任意的zzz,都有D(G(z))=1D(G(z))=1D(G(z))=1。

训练小trick

  • 在开始训练的时候,生成器GGG的性能较差,D(G(z))D(G(z))D(G(z))接近于0,此时价值函数中的log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z)))的梯度值较小,而log(D(G(z)))log(D(G(z)))log(D(G(z)))的梯度值较大,所以可以把生成器GGG的目标改为最大化logD(G(z))logD(G(z))logD(G(z)),这样可以在早期学习中提供更强的梯度。

4、训练流程

  • 使用mini-batch梯度下降(带momentum);
  • 训练k次判别器(本论文实验中k=1);
  • 训练1次生成器;


根据伪代码可以知道,对应两个神经网络模型——生成器GGG和判别器DDD,首先会固定生成器GGG的参数,使用生成器GGG生成的数据和真实的数据训练判别器DDD,训练k次判别器DDD后,固定判别器DDD的参数,训练生成器GGG。

理想情况下,判别器的最优解为:DG∗(x)=Pdata(x)Pdata(x)+Pg(x)D^*_{G}(x)=\frac{P_{data}(x)}{P_{data}(x)+P_g(x)}DG∗​(x)=Pdata​(x)+Pg​(x)Pdata​(x)​判别器取得最优解时,生成器的最优解为:Pg=PdataP_g=P_{data}Pg​=Pdata​此时价值函数的值为C∗=−log(4)C^*=-log(4)C∗=−log(4)

模型优劣势

缺点:

  • 没有显式表示的Pg(x)P_g(x)Pg​(x);
  • 必须同步训练G和D,可能会发生模式崩溃;

优点:

  • 不使用马尔科夫链,在学习过程中不需要推理;
  • 可以将多种函数合并到模型中;
  • 可以表示非常尖锐、甚至退化的分布;
  • 不是直接使用数据来计算loss更新生成器,而是使用判别器的梯度,所以数据不会直接复制到生成器的参数中;

Pytorch代码

# 代码来源:https://github.com/eriklindernoren/PyTorch-GAN
import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")  # 迭代次数
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")  # 批量大小
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")  # adam的学习率
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")  # 动量法
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")  # 生成器输入维度
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")  # 照片的尺寸
parser.add_argument("--channels", type=int, default=1, help="number of image channels")  # 通道数,1表示灰度图
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")  # 采样照片频率
opt = parser.parse_args()
print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module):  # 生成器def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):  # 判别器def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# Loss function
adversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,  # 训练模式download=True,  # 如果MNIST没有下载则直接下载transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),  # 照片处理方式),  # 数据集batch_size=opt.batch_size,  # 训练数据批量大小shuffle=True,  # 是否打乱
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  # 生成器的优化器
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  # 判别器的优化器Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  # 真实数据的labelfake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)  # 生成数据的label# Configure inputreal_imgs = Variable(imgs.type(Tensor))  # 真实照片# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()  ## Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))  # 生成随机分布的数据# Generate a batch of imagesgen_imgs = generator(z)  # 生成器生成伪照片# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)  # 生成器的目的是骗过判别器,所以希望生成器生成的照片被预测为1g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)  # 判别器希望真实的照片预测为1fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)  # 判别器希望伪造的照片预测为0d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)os.makedirs("model", exist_ok=True)torch.save(generator, 'model/generator.pkl')torch.save(discriminator, 'model/discriminator.pkl')

GAN —— 《Generative Adversarial Nets》相关推荐

  1. 《generative adversarial nets》的解读以及实现

    一 论文解读 1.1论文基本信息 <Generative Adversarial Nets>是Ian J.Goodfellow发表在NIPS 2014上的一篇论文,也是GANs的开山之作. ...

  2. 重读经典:《Generative Adversarial Nets》

    GAN论文逐段精读[论文精读] 这是李沐博士论文精读的第五篇论文,这次精读的论文是 GAN.目前谷歌学术显示其被引用数已经达到了37000+.GAN 应该是机器学习过去五年上头条次数最多的工作,例如抖 ...

  3. GAN(Generative Adversarial Nets)研究进展

    想与大家分享的是图像生成中一些工作. 这些工作都基于一大类模型,Generative Adversarial Networks(GAN).从模型名称上甚至都可以看出一些发展轨迹:GAN->CGA ...

  4. 生成式对抗网络GAN(Generative Adversarial Nets)论文笔记

    1.介绍 本文基本从2014年<Generative Adversarial Nets>翻译总结的. GAN(Generative Adversarial Nets),生成式对抗网络.包含 ...

  5. 《Generative Adversarial Networks》

    目录 参考资料 模型和算法 理论证明 参考资料 Generative Adversarial Nets (nips.cc) goodfeli/adversarial: Code and hyperpa ...

  6. GAN(Generative Adversarial Nets (生成对抗网络))

    一.GAN 1.应用 GAN的应用十分广泛,如图像生成.图像转换.风格迁移.图像修复等等. 2.简介 生成式对抗网络是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成 ...

  7. 【论文学习】《Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks》

    <Parallel WaveGAN : A fast waveform generation model based on generative adversarial networks wit ...

  8. (转)ICLR 2019 《GAN DISSECTION: VISUALIZING AND UNDERSTANDING GENERATIVE ADVERSARIAL NETWORKS》

    ICLR 2019 <GAN DISSECTION: VISUALIZING AND UNDERSTANDING GENERATIVE ADVERSARIAL NETWORKS> ICLR ...

  9. 深度学习之生成式对抗网络 GAN(Generative Adversarial Networks)

    一.GAN介绍 生成式对抗网络GAN(Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.它源于2014年发表的论文:& ...

最新文章

  1. body click js 委托_自动化测试之selenium调用JS语句
  2. layui的表格可以动态添加行吗_答疑分享052:插入表格,数据分析更方便
  3. ES6新特性之字符串扩展
  4. 史上最全的stm32资料库4---常见问题及编译工具篇
  5. java弹窗 触发事件_关于ElementUI中MessageBox弹框的取消键盘触发事件(enter,esc)关闭弹窗(执行事件)的解决方法...
  6. MOTOMAN-SV3X运动学建模验证图
  7. 【Pytorch神经网络实战案例】15 WGAN-gp模型生成Fashon-MNST模拟数据
  8. 【转载】RocketMQ优秀文章
  9. Unity OnTriggerEnter不调用
  10. 04-----赋值运算符
  11. 设计模式 创建型 工厂方法模式
  12. 【认知femto】femtocell的认知无线电频谱感知算法性能仿真
  13. 安卓开发中关于软键盘处理的一些问题
  14. 计算机设备没有音频,电脑没有音频设备怎么办
  15. C++ Toolkit zz
  16. linux skb机制,skb 的分配细节
  17. Linux:安装Debian最新10.x操作系统(超详细)
  18. 【演讲实录+视频】走近40+世界级AI专家!第三届中国人工智能大会资料分享(直播进行中_不断更新)
  19. 启动电脑时出现0xc000000f错误的解决办法
  20. BZOJ4735 你的生命已如风中残烛(组合数学)

热门文章

  1. nusoap传递数组对象
  2. Linux中文件查找技术大全
  3. MySQL迁移到ClickHouse方案
  4. 面试官:怎么改进哈希算法实现负载均衡的扩展性和容错性?我:...
  5. HTTP 请求之URLs 与 URNs
  6. Docker for windows 容器内网通过独立IP直接访问的方法
  7. 结构设计模式 - 代理设计模式
  8. 高性能异步批量ping的golang实现
  9. @Select的使用说明
  10. js this指向问题,同级this指向同级,非同级this指向全局