本文来源于PyTorch中文网。

一直想了解GAN到底是个什么东西,却一直没能腾出时间来认真研究,前几日正好搜到一篇关于PyTorch实现GAN训练的文章,特将学习记录如下,本文主要包含两个部分:GAN原理介绍和技术层面实现。

一、什么是GAN

2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文。没错,我说的就是《Generative Adversarial Nets》,这标志着生成对抗网络(GAN)的诞生,而这是通过对计算图和博弈论的创新性结合。他们的研究展示,给定充分的建模能力,两个博弈模型能够通过简单的反向传播(backpropagation)来协同训练。

这两个模型的角色定位十分鲜明。给定真实数据集 R,G 是生成器(generator),它的任务是生成能以假乱真的假数据;而 D 是判别器 (discriminator),它从真实数据集或者 G 那里获取数据, 然后做出判别真假的标记。Ian Goodfellow 的比喻是,G 就像一个赝品作坊,想要让做出来的东西尽可能接近真品,蒙混过关。而 D 就是文物鉴定专家,要能区分出真品和高仿(但在这个例子中,造假者 G 看不到原始数据,而只有 D 的鉴定结果——前者是在盲干)。

理想情况下,D 和 G 都会随着不断训练,做得越来越好——直到 G 基本上成为了一个“赝品制造大师”,而 D 因无法正确区分两种数据分布输给 G。

实践中,Ian Goodfellow 展示的这项技术在本质上是:G 能够对原始数据集进行一种无监督学习,找到以更低维度的方式(lower-dimensional manner)来表示数据的某种方法。而无监督学习之所以重要,就好像 Yann LeCun 的那句话:“无监督学习是蛋糕的糕体”。这句话中的蛋糕,指的是无数学者、开发者苦苦追寻的“真正的 AI”。

二、用PyTorch训练GAN

Dev Nag:在表面上,GAN 这门如此强大、复杂的技术,看起来需要编写天量的代码来执行,但事实未必如此。我们使用 PyTorch,能够在 50 行代码以内创建出简单的 GAN 模型。这之中,其实只有五个部分需要考虑:

  • R:原始、真实数据集
  • I:作为熵的一项来源,进入生成器的随机噪音
  • G:生成器,试图模仿原始数据
  • D:判别器,试图区别 G 的生成数据和 R
  • 我们教 G 糊弄 D、教 D 当心 G 的“训练”环。

R:在我们的例子里,从最简单的 R 着手——贝尔曲线(bell curve)。它把平均数(mean)和标准差(standard deviation)作为输入,然后输出能提供样本数据正确图形(从 Gaussian 用这些参数获得 )的函数。在我们的代码例子中,我们使用 4 的平均数和 1.25 的标准差。

def get_distribution_sampler(mu, sigma):return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

I:生成器的输入是随机的,为提高点难度,我们使用均匀分布(uniform distribution )而非标准分布。这意味着,我们的 Model G 不能简单地改变输入(放大/缩小、平移)来复制 R,而需要用非线性的方式来改造数据。

def get_generator_input_sampler():return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

G: 该生成器是个标准的前馈图(feedforward graph)——两层隐层,三个线性映射(linear maps)。我们使用了 ELU (exponential linear unit)。G 将从 I 获得平均分布的数据样本,然后找到某种方式来模仿 R 中标准分布的样本。

class Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Generator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)def forward(self, x):x = F.elu(self.map1(x))x = F.sigmoid(self.map2(x))return self.map3(x)

D: 判别器的代码和 G 的生成器代码很接近。一个有两层隐层和三个线性映射的前馈图。它会从 R 或 G 那里获得样本,然后输出 0 或 1 的判别值,对应反例和正例。这几乎是神经网络的最弱版本了。

class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Discriminator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)def forward(self, x):x = F.elu(self.map1(x))x = F.elu(self.map2(x))return F.sigmoid(self.map3(x))

最后,训练环在两个模式中变幻:第一步,用被准确标记的真实数据 vs. 假数据训练 D;随后,训练 G 来骗过 D,这里是用的不准确标记。道友们,这是正邪之间的较量。

for epoch in range(num_epochs):for d_index in range(d_steps):# 1. Train D on real+fakeD.zero_grad()#  1A: Train D on reald_real_data = Variable(d_sampler(d_input_size))d_real_decision = D(preprocess(d_real_data))d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = trued_real_error.backward() # compute/store gradients, but don't change params#  1B: Train D on faked_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labelsd_fake_decision = D(preprocess(d_fake_data.t()))d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = faked_fake_error.backward()d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()for g_index in range(g_steps):# 2. Train G on D's response (but DO NOT train D on these labels)G.zero_grad()gen_input = Variable(gi_sampler(minibatch_size, g_input_size))g_fake_data = G(gen_input)dg_fake_decision = D(preprocess(g_fake_data.t()))g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuineg_error.backward()g_optimizer.step()  # Only optimizes G's parameters

即便你从没接触过 PyTorch,大概也能明白发生了什么。在第一部分,我们让两种类型的数据经过 D,并对 D 的猜测 vs. 真实标记执行不同的评判标准。这是 “forward” 那一步;随后我们需要 “backward()” 来计算梯度,然后把这用来在 d_optimizer step() 中更新 D 的参数。这里,G 被使用但尚未被训练。

在最后的部分,我们对 G 执行同样的操作——注意我们要让 G 的输出穿过 D (这其实是送给造假者一个鉴定专家来练手)。但在这一步,我们并不优化、或者改变 D。我们不想让鉴定者 D 学习到错误的标记。因此,我们只执行 g_optimizer.step()。

这就是全部了。还有一些其他样板代码,但GAN特定的东西只是那5个组件,没有别的了。

在 D 和 G 之间几千轮交手之后,我们会得到什么?判别器 D 会快速改进,而 G 的进展要缓慢许多。但当模型达到一定性能之后,G 才有了个配得上的对手,并开始提升,巨幅提升。

两万轮训练之后,G 的输入平均值超过 4,但会返回到相当平稳、合理的范围(左图)。同样的,标准差一开始在错误的方向降低,但随后攀升至理想中的 1.25 区间(右图),达到 R 的层次。

所以,基础数据最终会与 R 吻合。那么,那些比 R 更高的时候呢?数据分布的形状看起来合理吗?毕竟,你一定可以得到有 4.0 的平均值和 1.25 标准差值的均匀分布,但那不会真的符合 R。我们一起来看看 G 生成的最终分布。

结果是不错的。左侧的尾巴比右侧长一些,但偏离程度和峰值与原始 Gaussian 十分相近。G 接近完美地再现了原始分布 R——D 落于下风,无法分辨真相和假相。而这就是我们想要得到的结果——使用不到 50 行代码。

该说的都说完了,老司机请上 GitHub 把玩全套代码。

地址:https://github.com/devnag/pytorch-generative-adversarial-networks

附所有代码供参考:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable# Data params
data_mean = 4
data_stddev = 1.25# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_sized_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)print("Using data [%s]" % (name))# ##### DATA: Target data and generator input datadef get_distribution_sampler(mu, sigma):return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussiandef get_generator_input_sampler():return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian# ##### MODELS: Generator model and discriminator modelclass Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Generator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)def forward(self, x):x = F.elu(self.map1(x))x = F.sigmoid(self.map2(x))return self.map3(x)class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Discriminator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)def forward(self, x):x = F.elu(self.map1(x))x = F.elu(self.map2(x))return F.sigmoid(self.map3(x))def extract(v):return v.data.storage().tolist()def stats(d):return [np.mean(d), np.std(d)]def decorate_with_diffs(data, exponent):mean = torch.mean(data.data, 1)mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])diffs = torch.pow(data - Variable(mean_broadcast), exponent)return torch.cat([data, diffs], 1)d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)for epoch in range(num_epochs):for d_index in range(d_steps):# 1. Train D on real+fakeD.zero_grad()#  1A: Train D on reald_real_data = Variable(d_sampler(d_input_size))d_real_decision = D(preprocess(d_real_data))d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = trued_real_error.backward() # compute/store gradients, but don't change params#  1B: Train D on faked_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labelsd_fake_decision = D(preprocess(d_fake_data.t()))d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = faked_fake_error.backward()d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()for g_index in range(g_steps):# 2. Train G on D's response (but DO NOT train D on these labels)G.zero_grad()gen_input = Variable(gi_sampler(minibatch_size, g_input_size))g_fake_data = G(gen_input)dg_fake_decision = D(preprocess(g_fake_data.t()))g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuineg_error.backward()g_optimizer.step()  # Only optimizes G's parametersif epoch % print_interval == 0:print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,extract(d_real_error)[0],extract(d_fake_error)[0],extract(g_error)[0],stats(extract(d_real_data)),stats(extract(d_fake_data))))

【PyTorch】50行代码实现GAN——PyTorch相关推荐

  1. python爬虫实战:利用scrapy,短短50行代码下载整站短视频

    近日,有朋友向我求助一件小事儿,他在一个短视频app上看到一个好玩儿的段子,想下载下来,可死活找不到下载的方法.这忙我得帮,少不得就抓包分析了一下这个app,找到了视频的下载链接,帮他解决了这个小问题 ...

  2. python实现50行代码_50行代码实现python计算器主要功能

    实现功能:计算带有括号和四则运算的式子 3*( 4+ 50 )-(( 100 + 40 )*5/2- 3*2* 2/4+9)*((( 3 + 4)-4)-4) 基本思路:使用正则表达式提取出每一层小括 ...

  3. 如何用50行代码构建情感分类器

    选自Toward Data Science,作者:Rohith Gandhi,机器之心编译. 本文介绍了如何构建情感分类器,从介绍自然语言处理开始,一步一步讲述构建过程. 自然语言处理简介 语言把人类 ...

  4. pyquery获取不到网页完整源代码_爬虫神器之PyQuery实用教程(二),50行代码爬取穷游网...

    爬虫神器之PyQuery实用教程(二),50行代码爬取穷游网 前言 上篇文章 PyQuery (一) 回顾.今天来介绍具体 PyQuery 的使用方法. 穷游网目标与分析 开始之前,按照之前的套路一步 ...

  5. python跑酷游戏源码_HTML5游戏实战(1):50行代码实现正面跑酷游戏

    前段时间看到一个"熊来了"的HTML5跑酷游戏,它是一个典型的正面2D跑酷游戏,这里借用它来介绍一下用Gamebuilder+CanTK开发正面跑酷游戏的基本方法. CanTK(C ...

  6. 转:目标50行代码之内完成3d编辑器功能

    1024程序员节刚过,手痒想实现一个html的3d编辑器,看了three.js  同时还看了网上流传已久的<<基于 HTML5 Canvas 的简易 2D 3D 编辑器>>,都 ...

  7. 利用scrapy,短短50行代码下载整站短视频

    一.撕开爬虫的面纱--爬虫是什么,它能做什么 爬虫是什么 爬虫就是一段能够从互联网上高效获取数据的程序. 我们每天都在从互联网上获取数据.当打开浏览器访问百度的时候,我们就从百度的服务器获取数据,当拿 ...

  8. 50行代码实现的艺术签名设计微信小程序,轻松对接公众号,涨粉神器,学习赚钱两不误.微信公众号引流工具.html,python学习小项目.艺术签名设计微信小程序,前端学习小项目有趣的项目

    50行代码实现的艺术签名设计微信小程序,轻松对接公众号,涨粉神器,学习赚钱两不误 先看效果 这个小程序实现艺术签名设计的功能 对接到公众号之后,相当于给你的公众号添加了一个功能,别人关注公众号后,可以 ...

  9. python pyquery不规则数据的抓取_爬虫神器之PyQuery实用教程(二),50行代码爬取穷游网...

    爬虫神器之PyQuery实用教程(二),50行代码爬取穷游网 前言 上篇文章 PyQuery (一) 回顾.今天来介绍具体 PyQuery 的使用方法. 穷游网目标与分析 开始之前,按照之前的套路一步 ...

最新文章

  1. a,b互换,不使用中间变量
  2. 022_Jedis的事物
  3. DB2 SQLCODE=-1585的问题解决
  4. apk的签名文件(两次Hash+加密)
  5. oracle存储过程关键字有哪些,ORACLESTREAMS存储过程中的一些参数有哪些?
  6. React开发(165):ant design validateFields
  7. 如何用notepad写php,notepad新手怎么使用
  8. tidb时间转字符串_如何使用TiDB节省时间
  9. managed code和unmanaged code混合debug
  10. sql2018 ssas_如何使用SQL Server Analysis Services(SSAS)从头开始构建多维数据集
  11. 单例模式访问mysql设计类图_利用单例模式设计数据库连接Model类
  12. AcWing 866. 试除法判定质数(素数判定)
  13. toastr 在js中的用法
  14. html图片轮播代码 贴吧,JS实现简易图片轮播效果的方法
  15. CDR排钻教程-CorelDRAW服装设计中的排钻技术
  16. 激光雷达相关技术方案介绍
  17. linux 服务器运维常用命令
  18. JAVA水晶报表从环境搭建到创建动态水晶报表
  19. Loaded plugins: fastestmirror, refresh-packagekit, security Loading mirror speeds from cached
  20. 使用Android Studio写一个发短信的小案例

热门文章

  1. python学习 day7_字符串、列表的相关操作
  2. Fedora 34 Workstation安装后的配置
  3. 俄勒冈大学计算机科学专业,俄勒冈大学计算机与信息科学详解 热门专业还等什么...
  4. sim7600ce拨号上网
  5. 【毕业设计_课程设计】在线免费小说微信小程序的设计与实现(源码+论文)
  6. GAMES101作业1-VS2019
  7. KindEditor 图片上传功能实现
  8. 用友U8 cloud,信创云ERP的数智先锋
  9. cavium CN71XX芯片 GSER Interface总结
  10. Pytorch的rand、randn和normal的用法及区别