2014年,蒙特利尔大学的伊恩·古德费洛(Ian Goodfellow)和他的同事发表了一篇惊人的论文,向世界介绍了GAN(即生成性对抗网络)。通过将计算图和博弈论进行创新的组合,他们表明,只要具有足够的建模能力,两个相互竞争的模型就可以通过简单的反向传播进行协同训练。

这些模型扮演两个不同的角色(从字面上看是对抗角色)。给定一些真实数据集R,G是生成器,试图创建看起来像真实数据的伪数据,而D是鉴别器,从真实集或G中获取数据并标记差异。古德费勒的比喻(也是一个很好的比喻)是:G就像是一群伪造者,他们试图将真实的绘画与其输出进行匹配,而D则是一群侦探,试图说出区别。(除非在这种情况下,伪造者G永远不会看到原始数据——只有d的判断。他们就像盲目的伪造者。)

在理想情况下,随着时间的推移,D和G都会变得更好,直到G基本上成为真品的“伪造大师”,而D则不知所措,“无法区分这两种分布”。

在实践中,古德费洛展示的是G可以在原始数据集上执行某种形式的无监督学习,找到某种方式(可能)以较低维度表示该数据。正如Yann LeCun所说的那样,无监督学习是真正AI的“蛋糕”。

这项功能强大的技术似乎仅需要很多代码才能入门,对吗?不。使用PyTorch,我们实际上可以在50行以下的代码中创建一个非常简单的GAN。实际上只有5个组件需要考虑:

  • R:原始的真实数据集
  • I:随机噪声作为噪声源进入发生器
  • G:试图复制/模仿原始数据集的生成器
  • D:鉴别器试图区分R的G输出
  • .在实际的“训练”循环中,我们教G欺骗D,而D要小心G。

1.)R:在我们的例子中,我们将从最简单的R(钟形曲线)开始。该函数取一个平均值和一个标准差,然后返回一个函数,该函数可提供具有这些参数的高斯样本数据的正确形状。在示例代码中,我们将使用平均值4.0和标准偏差1.25。

def get_distribution_sampler(mu, sigma):return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

2.)I:生成器的输入也是随机的,但是为了使我们的工作更加困难,让我们使用统一分布而不是正态分布。这意味着我们的模型G不能简单地移动/缩放输入以复制R,而是必须以非线性方式重塑数据。

def get_generator_input_sampler():return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

3.)G:生成器是标准的前馈图-两个隐藏层,三个线性映射。我们使用的是双曲正切激活函数,因为我们像这样老派。摹会得到均匀的分布数据样本我,不知怎么模仿的通常从分布的样本[R -而没有看到[R 。

class Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size, f):super(Generator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)self.f = fdef forward(self, x):x = self.map1(x)x = self.f(x)x = self.map2(x)x = self.f(x)x = self.map3(x)return x

4.)D:鉴别器代码与G的生成器代码非常相似;具有两个隐藏层和三个线性映射的前馈图。这里的激活函数是一个sigmoid——没什么特别的,各位。它将从R或G中获取样本,并将输出介于0和1之间的单个标量,解释为“假”与“真实”。换句话说,这是神经网络所能做的最简单的事情。

class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size, f):super(Discriminator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)self.f = fdef forward(self, x):x = self.f(self.map1(x))x = self.f(self.map2(x))return self.f(self.map3(x))

5.)最后,训练循环在两种模式之间交替进行:第一种是真实数据训练,另一种是带有精确标签的虚假数据训练(可以将其视为警察学院);然后用不准确的标签训练G去愚弄D(这更像是《十一罗汉》里的准备蒙太奇)。这是一场正义与邪恶的战争。

for epoch in range(num_epochs):for d_index in range(d_steps):# 1. Train D on real+fakeD.zero_grad()#  1A: Train D on reald_real_data = Variable(d_sampler(d_input_size))d_real_decision = D(preprocess(d_real_data))d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1])))  # ones = trued_real_error.backward() # compute/store gradients, but don't change params#  1B: Train D on faked_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labelsd_fake_decision = D(preprocess(d_fake_data.t()))d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = faked_fake_error.backward()d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]for g_index in range(g_steps):# 2. Train G on D's response (but DO NOT train D on these labels)G.zero_grad()gen_input = Variable(gi_sampler(minibatch_size, g_input_size))g_fake_data = G(gen_input)dg_fake_decision = D(preprocess(g_fake_data.t()))g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1])))  # Train G to pretend it's genuineg_error.backward()g_optimizer.step()  # Only optimizes G's parametersge = extract(g_error)[0]

即使您以前从未看过PyTorch,也可以判断发生了什么。在第一(d_index)部分中,我们将两种类型的数据都通过D,然后对D的猜测与实际标签应用可区分的标准。推动是“前进”的一步;然后,我们明确调用“ backward()”以计算梯度,然后将其用于在d_optimizer step()调用中更新D的参数。使用G,但此处未进行训练。

然后,在最后一个(g_index)部分中,我们对G执行相同的操作-请注意,我们也通过D运行了G的输出(本质上是给伪造者提供了一个可以进行练习的侦探),但我们并未优化或更改D在这一步。我们不希望侦探D学习错误的标签。因此,我们仅调用g_optimizer.step()。

而且…仅此而已。还有其他一些样板代码,但GAN特定的东西只是这5个组件,仅此而已。

在D和G之间这种禁止的舞蹈进行了数千回合之后,我们会得到什么?判别器D很快变得很好(而G缓慢上升),但是一旦达到一定程度的力量,G便成为了一个有价值的对手并开始进步。真正提高。

超过5,000轮训练,每轮训练D 20次,然后训练G 20次,G的输出平均值超过4.0,但随后回到相当稳定的正确范围内(左)。同样,标准偏差最初会朝错误的方向下降,然后上升到所需的1.25范围(右),与R匹配。

好的,因此基本属性最终匹配R。更高的时刻怎么样?分布的形状看起来正确吗?毕竟,您当然可以具有均值为4.0且标准偏差为1.25的均匀分布,但这与R并不完全匹配。让我们看一下G发出的最终分布:

不错。右尾比左尾稍微胖一点,但是偏度和峰度是原始高斯分布的再现。

G几乎完美地恢复了原始分布R,而D则退缩在角落,喃喃自语,无法从小说中分辨出事实。这正是我们想要的行为(请参见古德费洛中的图1)。从少于50行的代码开始。

现在,警告一下:GAN可能会很挑剔。而且脆弱。而且当他们进入怪异状态时,他们常常会在没有一点哄骗的情况下出来。运行我的示例代码十次(每次超过5,000发)显示了以下十个分布:

10次测试中有8次获得了非常好的最终分布——近似于高斯分布,均值为4,标准差在正确的范围内。但是两次运行没有——在第5次运行中,有一个凹分布,平均值在6.0左右,在最后一次运行(第10次),在11处有一个狭窄的峰值!当你开始在几乎任何环境中应用GANs时,你会发现这个现象——GANs并不像一般的监督学习工作流那样稳定。但当它们发挥作用时,它们看起来几乎是神奇的。
去看看代码

#!/usr/bin/env python# Generative Adversarial Networks (GAN) example in PyTorch. Tested with PyTorch 0.4.1, Python 3.6.7 (Nov 2018)
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variablematplotlib_is_available = True
try:from matplotlib import pyplot as plt
except ImportError:print("Will skip plotting; matplotlib is not available.")matplotlib_is_available = False# Data params
data_mean = 4
data_stddev = 1.25# ### Uncomment only one of these to define what data is actually sent to the Discriminator
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)print("Using data [%s]" % (name))# ##### DATA: Target data and generator input datadef get_distribution_sampler(mu, sigma):return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussiandef get_generator_input_sampler():return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian# ##### MODELS: Generator model and discriminator modelclass Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size, f):super(Generator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)self.f = fdef forward(self, x):x = self.map1(x)x = self.f(x)x = self.map2(x)x = self.f(x)x = self.map3(x)return xclass Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size, f):super(Discriminator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)self.f = fdef forward(self, x):x = self.f(self.map1(x))x = self.f(self.map2(x))return self.f(self.map3(x))def extract(v):return v.data.storage().tolist()def stats(d):return [np.mean(d), np.std(d)]def get_moments(d):# Return the first 4 moments of the data providedmean = torch.mean(d)diffs = d - meanvar = torch.mean(torch.pow(diffs, 2.0))std = torch.pow(var, 0.5)zscores = diffs / stdskews = torch.mean(torch.pow(zscores, 3.0))kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0  # excess kurtosis, should be 0 for Gaussianfinal = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))return finaldef decorate_with_diffs(data, exponent, remove_raw_data=False):mean = torch.mean(data.data, 1, keepdim=True)mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])diffs = torch.pow(data - Variable(mean_broadcast), exponent)if remove_raw_data:return torch.cat([diffs], 1)else:return torch.cat([data, diffs], 1)def train():# Model parametersg_input_size = 1      # Random noise dimension coming into generator, per output vectorg_hidden_size = 5     # Generator complexityg_output_size = 1     # Size of generated output vectord_input_size = 500    # Minibatch size - cardinality of distributionsd_hidden_size = 10    # Discriminator complexityd_output_size = 1     # Single dimension for 'real' vs. 'fake' classificationminibatch_size = d_input_sized_learning_rate = 1e-3g_learning_rate = 1e-3sgd_momentum = 0.9num_epochs = 5000print_interval = 100d_steps = 20g_steps = 20dfe, dre, ge = 0, 0, 0d_real_data, d_fake_data, g_fake_data = None, None, Nonediscriminator_activation_function = torch.sigmoidgenerator_activation_function = torch.tanhd_sampler = get_distribution_sampler(data_mean, data_stddev)gi_sampler = get_generator_input_sampler()G = Generator(input_size=g_input_size,hidden_size=g_hidden_size,output_size=g_output_size,f=generator_activation_function)D = Discriminator(input_size=d_input_func(d_input_size),hidden_size=d_hidden_size,output_size=d_output_size,f=discriminator_activation_function)criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bcelossd_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)for epoch in range(num_epochs):for d_index in range(d_steps):# 1. Train D on real+fakeD.zero_grad()#  1A: Train D on reald_real_data = Variable(d_sampler(d_input_size))d_real_decision = D(preprocess(d_real_data))d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1])))  # ones = trued_real_error.backward() # compute/store gradients, but don't change params#  1B: Train D on faked_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labelsd_fake_decision = D(preprocess(d_fake_data.t()))d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = faked_fake_error.backward()d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]for g_index in range(g_steps):# 2. Train G on D's response (but DO NOT train D on these labels)G.zero_grad()gen_input = Variable(gi_sampler(minibatch_size, g_input_size))g_fake_data = G(gen_input)dg_fake_decision = D(preprocess(g_fake_data.t()))g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1])))  # Train G to pretend it's genuineg_error.backward()g_optimizer.step()  # Only optimizes G's parametersge = extract(g_error)[0]if epoch % print_interval == 0:print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %(epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))if matplotlib_is_available:print("Plotting the generated distribution...")values = extract(g_fake_data)print(" Values: %s" % (str(values)))plt.hist(values, bins=50)plt.xlabel('Value')plt.ylabel('Count')plt.title('Histogram of Generated Distribution')plt.grid(True)plt.show()train()

PyTorch中的生成对抗网络(GAN)相关推荐

  1. 深度学习中的生成对抗网络GAN

    转载:一文看尽深度学习中的生成对抗网络 | CVHub带你看一看GANs架构发展的8年 (qq.com) 导读 生成对抗网络 (Generative Adversarial Networks, GAN ...

  2. pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)

    1.简介 这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码.数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G.生成器与判 ...

  3. [Pytorch系列-72]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型训练CycleGAN模型

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  4. [Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleG ...

  5. [Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  6. 一文看懂「生成对抗网络 - GAN」基本原理+10种典型算法+13种应用

    生成对抗网络 – Generative Adversarial Networks | GAN 文章目录 GAN的设计初衷 生成对抗网络 GAN 的基本原理 GAN的优缺点 10大典型的GAN算法 GA ...

  7. [人工智能-深度学习-59]:生成对抗网络GAN - 基本原理(图解、详解、通俗易懂)

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  8. 【CV秋季划】生成对抗网络GAN有哪些研究和应用,如何循序渐进地学习好(2022年言有三一对一辅导)?...

    GAN自从被提出来后,技术发展就非常迅猛,已经被落地于众多的方向,其应用涉及图像与视频生成,数据仿真与增强,各种各样的图像风格化任务,人脸与人体图像编辑,图像质量提升. 那我们究竟如何去长期学好相关的 ...

  9. 简述一下生成对抗网络GAN(Generative adversarial nets)模型?

    简述一下生成对抗网络GAN(Generative adversarial nets)模型? 生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构. 要全面理解生 ...

  10. 生成对抗网络gan原理_中国首个“芯片大学”即将落地;生成对抗网络(GAN)的数学原理全解...

    开发者社区技术周刊又和大家见面了,萌妹子主播为您带来第三期"开发者技术联播".让我们一起听听,过去一周有哪些值得我们开发者关注的重要新闻吧. 中国首个芯片大学,南京集成电路大学即将 ...

最新文章

  1. 关于STM32驱动DS1302实时时钟的一点思考
  2. 数据库事务原理详解-Spring 事务的传播属性
  3. linux 信号量锁 内核,Linux内核信号量互斥锁应用
  4. elasticsearch存储空间不足导致索引只读,不能创建
  5. 微信公众帐号开发教程第16篇-应用实例之历史上的今天
  6. 12_python基础—函数基础(参数、返回值、调用)
  7. 《linux核心应用命令速查》连载二:lastcomm:显示以前使用过的命令的信息
  8. 去面试字节跳动,你最好有点心理准备!
  9. 拓端tecdat|excel数据分析——贝叶斯分析预测
  10. ffmpeg对H.264进行rtp打包
  11. 大数据技术在银行业中的应用场景,主要有哪些?
  12. matlab 图像范围,Matlab对数范围colorbar图像c
  13. Linux 测试IP和端口是否能访问
  14. 说我菜?那好,我用Python制作电脑与手机游戏脚本来赢你
  15. ThingsBoard教程(十):前端初级定制化
  16. bayes什么意思_Bayes是什么意思
  17. 【VivadoHLS 仿真csim 报错bug】hls video库和math库 的hls::sqrt重定义问题解决
  18. 天乐文本文件按行分割器_v1.0正式版【专业制作极速分割】
  19. python读取电脑识别码
  20. 母带混音插件套装-Acon Digital Mastering Suite 1.2.1 WiN-MAC

热门文章

  1. 用PS设计等高线效果的背景图片
  2. Lync常识之Lync Server有哪些角色
  3. Total Commander工具栏图标 备份
  4. java operator 重载 ==_运算符重载
  5. [二进制拆分]Luogu1833 樱花
  6. Express框架学习笔记-静态资源的处理
  7. 华为手机投屏电脑_手机投屏干货分享:华为如何投屏到电视机?
  8. win10 C语言qt调试,如何在Windows中调试Qt(MSVC)应用程序
  9. html5用户输入后自动显示用户名已重复_IT兄弟连 HTML5教程 HTML5表单 HTML5新增表单元素...
  10. SpringBoot系列(8):SpringBoot中的MVC支持【组件型注解、请求和参数型注解】详解