Generative Adversarial Nets


论文

Abstract

提出了新的对抗生成网络,a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G.(一个G和一个D,具体是什么?)G让D尽可能犯错误(犯错误?)G是用来恢复图片数据分布的,而D最后会等于1/2(等于1/2?)最后G和D都是通过multilayer perceptron来实现的。

1 Introduction

这里不写,你们自己看

2 Related work

同上

3 Adversarial nets

为了学习到图像(也就是数据集,真实的数据)的数据分布(数据分布的定义), 这里先定义一个随机分布的噪声变量,然后通过这个噪声变量输入到G当中产生一个映射(这里我们可以得知G是一个generator,其作用就是通过输入一个噪声变量,映射到所要产生的图像分布上)。

所图所示,输入一个随机分布,输出所需图像的数据分布,这里只不过是将数据分布变成了图片

与此同时还定义了一个,D代表这个数据x来源于的可能性。

所图所示,D接受一张图片之后,如果这张图片是来源于数据集的,那么它输出的可能性会高一些,而来源于G生成的假图片的话,所输出的可能性只有0.1

综上所述,作者因此要训练D以确保D能够正确的分辨出所输入的图像是来源于真实数据库的还是G所产生的,而同时G要最小化。这里作者给出了一条公式,意思就是我们从D中最大化后面那条公式,然后选出D最小的G:

这条公式怎么来的呢?

1、我们已经知道,G是尽可能让输出的图片分布和真实的图片分布一样,因此可以得到一下G*的公式,找到一个G,能让PG和Pdata的偏差(DIV)尽可能小。那么DIV要怎样算?

2、通过D,也就是我们的判别器来计算DIV。我们假设G目前是固定的,因此输入到D当中有两种情况,一种是图像是来源于数据库的,我们要求输入到D中的话可能性要尽可能地高,所以要尽可能高;而加上个log不影响吧,也方便求导,所以也变成尽可能地高;最后要用到所有真实数据,求个均值也不影响,所以也尽可能地高。

反之,第二种图像是来源于G产生的图像,所以要尽可能地小,但是为了统一公式所以尽可能地高就等价于前面要尽可能地小了,其他是一样的

接下来作者给出了一张图来解释对抗生成网络的过程

这张图怎样理解呢?

黑线代表真实数据集的数据分布,绿线代表G所产生图像的数据分布,蓝线代表D所判断的概率。

一开始a是代表真实的数据分布和绿线的数据分布先固定好,然后D也没有经过训练,它所判断的效果并不大行。

a->b:代表我们通过已知的两个分布去训练D,使得D在假的分布上概率较低,而在真实的分布上概率较高

b->c:代表训练好D之后,我们接着训练G,让虚假分布接近真实分布,让G能更好地骗过D

c->d:代表经过多次迭代之后,我们G所产生的分布接近于真实的分布,这时候D已经不能区别出这两个分布了,所以等价于1/2(因为乱猜有50%可能性猜对),而G所产生的图像也接近于真实的图像。

同时作者还给出它的训练过程是优化几次D的同时优化G,而不能直接优化好D再优化G,其一是在WGAN上证明了它会导致梯度消失,其二是本文所说的在数据集上会导致过拟合(过拟合就是G只产生同一张照片就可以骗过G了)。

同时优化G时,一开始D是很容易判别出哪个是真实数据哪个是G产生的虚假数据,所以G的最小化会导致梯度消失,作者用了最大化来代替它。

4 Theoretical Results

作者先给出了整个算法的流程图,具体过程请配合代码观看,这样方便理解。

4.1 Global Optimality of pg= pdata

1、假设G是不变的,那么D的最优解是

证明过程:

我们已知D就是为了最大化V(D,G),那我们怎样计算V(D,G)呢?

这里直接用到积分公式求均值,这时候最大化V就变成了最大化V中里面积分的公式

最后带入a,b,D进入到积分公式,就能得到D的最优值

2、G的最小值为-log4当且仅当pg= pdata。

证明过程

我们已知D的最优值,那么只需要将D的最优质带入回V当中化简即可,这里我们可以看到V的所有取值为-2log2+KL。也就是G要找V的最小值的话,需要pg= pdata使得KL为0,则G的最小值为-2log2=-log4

4.2 Convergence of Algorithm

这部分证明的Pg能够收敛于Pdata,具体详情请看论文,俺也看不懂。

5 Experiments

自个看吧,感觉已经用不上这些标准了

6 Advantages and disadvantages

自己看hhhh


创新点

1、提出了新的对抗生成网络


代码解读

Generator

传入的是z的通道数[b, latent_dim],最后输出的是图片[b, h, w],注意这里的代码只有一个通道,是mnist灰度图像数据集,所以没有设置三通道。

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return img

Discriminator

输入的是图像[b, h, w], 输出一个分数,代表可能性

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity

训练过程

具体解释在注释中

for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):'''设置真的标签和假的标签,真的为1,也就是可能性为1'''# Adversarial ground truths  valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# Configure input  real_imgs = Variable(imgs.type(Tensor))# -----------------#  Train Generator  训练G网络,其实也可以先训练D网络# -----------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of images  输入噪声gen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator  训练D网络# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)  # 传入真实图像和真标签fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)  # 传入生成的假图像和假标签d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()

完整代码

import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# Loss function
adversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(Tensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
© 2021 GitHub, Inc.

备注

1、图片大部分来源于李哥的PPT

2、代码来源于这里

Generative Adversarial Nets论文理解和代码讲解相关推荐

  1. Generative Adversarial Nets 论文翻译

    论文原文: https://dl.acm.org/doi/pdf/10.1145/3422622 代码: GitHub - goodfeli/adversarial: Code and hyperpa ...

  2. Generative Adversarial Nets 论文笔记

    论文地址 Generative Adversarial Nets 摘要 首先,在论文中提出了一个新的框架:生成对抗网络框架,这个框架是为了通过对抗的过程实现评估生成模型. 处理过程中,我们同时训练两个 ...

  3. Generative Adversarial Nets论文翻译

    Abstract We propose a new framework for estimating generative models via an adversarial process, 我们提 ...

  4. Conditional Structure Generation throughGraph Variational Generative Adversarial Nets 论文阅读

    目标 基于语义条件生成图 (1)基于条件生成尽可能相似的图. (2)有条件的生成新的图. 解决的问题 (1)基于语义有条件的生成图 (2)如何处理图在生成过程中的顺序的问题 解决问题的方法 (1)提出 ...

  5. 生成对抗网络Generative Adversarial Nets(译)

    仅供参考,如有翻译不到位的地方敬请指出. 论文地址:Generative Adversarial Nets  论文翻译:XlyPb(http://blog.csdn.net/wspba/article ...

  6. CGAN论文解读:Conditional Generative Adversarial Nets

    论文链接:Conditional Generative Adversarial Nets 代码解读:Keras-CGAN_MNIST 代码解读 目录 一.前言 二.相关工作 三.网络结构 CGAN N ...

  7. GAIN: Missing Data Imputation using Generative Adversarial Nets(基于生成对抗网络的缺失数据填补)论文详解

    目录 一.背景分析 1.1 缺失数据 1.2 填补算法 二.GAIN 2.1 GAIN网络架构 2.2 符号描述(Symbol Description) 2.3 生成器模型 2.4 判别器模型 2.5 ...

  8. GaitGAN: Invariant Gait Feature Extraction Using Generative Adversarial Networks论文翻译以及理解

    GaitGAN: Invariant Gait Feature Extraction Using Generative Adversarial Networks论文翻译以及理解 格式:一段英文,一段中 ...

  9. 生成式对抗网络GAN(Generative Adversarial Nets)论文笔记

    1.介绍 本文基本从2014年<Generative Adversarial Nets>翻译总结的. GAN(Generative Adversarial Nets),生成式对抗网络.包含 ...

最新文章

  1. SQL性能--left join和inner join的运行速度与效率
  2. Go语言程序结构分析初探
  3. 大规模中文多模态评测基准MUGE发布
  4. JavaWeb课程复习资料(八)——添加功能
  5. SQL SERVER 运维日记
  6. 电脑屏幕卡住了按什么都没反应_手机突然“死机”了关机也不行,怎么按都没反应,怎么办?...
  7. jasmine-JavaScript单元测试工具
  8. 新地址 贴吧_建议收藏 | 新媒体人必备5大工具
  9. Oracle往表里插入系统当前时间
  10. 排序数据图-R/python
  11. 访问samba文件 指定网络名不再可用
  12. java js 二级联动下拉列表_最简单js代码实现select二级联动下拉菜单
  13. 安防监控直播中无插件web直播方案中实现快照抓取的功能
  14. 数据库审计方案简介和功能对比
  15. R-squared居然是负数
  16. 朋友圈投票活动-刷票案例实现与分析
  17. Latex引用参考文献的5种方式
  18. 国内免费(开源)CMS系统【大全】
  19. VB--. 和 ! ?
  20. 电动汽车集群并网的分布式鲁棒优化调度matlab

热门文章

  1. SharePoint 2010 图片库缩略图 Web Part Web 部件
  2. 【fly-iot 飞驰物联】(6):通过docker镜像使用gitbook启动ActorCloud项目文档,发现是个IOT功能非常丰富的项目,可以继续研究下去。
  3. 阿里巴巴达摩院Topic推荐-AMiner
  4. Trie树的构建和应用
  5. wx-open-launch-weapp 修改样式
  6. 使用 easyexcel
  7. 干货!XDR产品安全检测体系如何更好的落地?
  8. [Qt C++]对酷狗krc文件进行解码
  9. [Java]NaN与NaN比较是否相等
  10. Seaborn 基本使用