GAN 是一个近几年比较流行的生成网络形式. 对比起传统的生成模型, 他减少了模型限制和生成器限制, 他具有有更好的生成能力. 人们常用假钞鉴定者和假钞制造者来打比喻, 但是我不喜欢这个比喻, 觉得没有真实反映出 GAN 里面的机理.

所以我的一句话介绍 GAN 就是: Generator 是新手画家, Discriminator 是新手鉴赏家, 你是高级鉴赏家. 你将著名画家的品和新手画家的作品都给新手鉴赏家评定, 并告诉新手鉴赏家哪些是新手画家画的, 哪些是著名画家画的, 新手鉴赏家就慢慢学习怎么区分新手画家和著名画家的画, 但是新手画家和新手鉴赏家是好朋友, 新手鉴赏家会告诉新手画家要怎么样画得更像著名画家, 新手画家就能将自己的突然来的灵感 (random noise) 画得更像著名画家. 我用一个短动画形式来诠释了整个过程 (GAN 动画简介) (如下).

下面是本节内容的效果, 绿线的变化是新手画家慢慢学习如何踏上画家之路的过程. 而能被认定为著名的画作在 upper bound  和 lower bound  之间.

超参数设置

新手画家 (Generator) 在作画的时候需要有一些灵感 (random noise), 我们这些灵感的个数定义为 N_IDEAS . 而一幅画需要有一些规格, 我们将这幅画的画笔数定义一下, N_COMPONENTS  就是一条一元二次曲线(这幅画画)上的点个数. 为了进行批训练, 我们将一整批话的点都规定一下( PAINT_POINTS ).

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plttorch.manual_seed(1)    # reproducible
np.random.seed(1)# 超参数
BATCH_SIZE = 64
LR_G = 0.0001           # learning rate for generator
LR_D = 0.0001           # learning rate for discriminator
N_IDEAS = 5             # think of this as number of ideas for generating an art work (Generator)
ART_COMPONENTS = 15     # it could be total point G can draw in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])

著名画家的画

我们需要有很多画是来自著名画家的(real data), 将这些著名画家的画, 和新手画家的画都传给新手鉴赏家, 让鉴赏家来区分哪些是著名画家, 哪些是新手画家的画. 如何区分我们在后面呈现. 这里我们生成一些著名画家的画 (batch 条不同的一元二次方程曲线).

def artist_works():     # painting from the famous artist (real target)a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]paintings = a * np.power(PAINT_POINTS, 2)   (a-1)paintings = torch.from_numpy(paintings).float()return Variable(paintings)

下面就是会产生曲线的一个上限和下限.

神经网络

这里会创建两个神经网络, 分别是 Generator (新手画家), Discriminator(新手鉴赏家). G 会拿着自己的一些灵感当做输入, 输出一元二次曲线上的点 (G 的画).

D 会接收一幅画作 (一元二次曲线), 输出这幅画作到底是不是著名画家的画(是著名画家的画的概率).

G = nn.Sequential(                      # Generatornn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)D = nn.Sequential(                      # Discriminatornn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)

训练

接着我们来同时训练 D 和 G. 训练之前, 我们来看看G作画的原理. G 首先会有些灵感, G_ideas 就会拿到这些随机灵感 (可以是正态分布的随机数), 然后 G 会根据这些灵感画画. 接着我们拿着著名画家的画和 G 的画, 让 D 来判定这两批画作是著名画家画的概率.

for step in range(10000):artist_paintings = artist_works()           # real painting from artistG_ideas = Variable(torch.randn(BATCH_SIZE, N_IDEAS))    # random ideasG_paintings = G(G_ideas())                  # fake painting from G (random ideas)prob_artist0 = D(artist_paintings)          # D try to increase this probprob_artist1 = D(G_paintings)               # D try to reduce this prob

然后计算有多少来之画家的画猜对了, 有多少来自 G 的画猜对了, 我们想最大化这些猜对的次数. 这也就是 log(D(x)) log(1-D(G(z))  在论文中的形式. 而因为 torch 中提升参数的形式是最小化误差, 那我们把最大化 score 转换成最小化 loss, 在两个 score 的合的地方加一个符号就好. 而 G 的提升就是要减小 D 猜测 G 生成数据的正确率, 也就是减小 D_score1.

D_loss = - torch.mean(torch.log(prob_artist0)   torch.log(1\. - prob_artist1))
G_loss = torch.mean(torch.log(1\. - prob_artist1))

最后我们在根据 loss  提升神经网络就好了.

opt_D.zero_grad()
D_loss.backward(retain_variables=True)      # retain_variables 这个参数是为了再次使用计算图纸
opt_D.step()opt_G.zero_grad()
G_loss.backward()
opt_G.step()

上面的全部代码内容在我的 github.

可视化训练过程

可视化的代码很简单, 在这里就不会意义叙说了, 大家直接看代码 吧. 在本节的最上面就是这次的动图效果, 最后达到收敛时, 下过如下, G 能成功的根据自己的”灵感”, 产生出一条很像 artist画出的曲线, 而 D 再也没有能力猜出这到底是 G 的画作还是 artist 的画作, 他只能一半时间猜是 G 的, 一半时间猜是 artist的.

GAN (Generative Adversarial Nets 生成对抗网络)相关推荐

  1. Generative Adversarial Nets 生成对抗网络

    Generative Adversarial Nets 生成对抗网络 论文作者 Yan 跟随论文精读 (bilibili李沐) 同时会训练模型 G,生成模型要对整个数据的分布进行建模,就是想生成 尽量 ...

  2. 【GAN ZOO阅读】Generative Adversarial Nets 生成对抗网络 原文翻译 by zk

    Ian J. Goodfellow, Jean Pouget-Abadie ∗ , Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair † ...

  3. GAN(Generative Adversarial Nets (生成对抗网络))

    一.GAN 1.应用 GAN的应用十分广泛,如图像生成.图像转换.风格迁移.图像修复等等. 2.简介 生成式对抗网络是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成 ...

  4. Generative Adversarial Networks 生成对抗网络的简单理解

    1. 引言 在对抗网络中,生成模型与判别相竞争,判别模型通过学习确定样本是来自生成模型分布还是原始数据分布.生成模型可以被认为是类似于一组伪造者,试图产生假币并在没有检测的情况下使用它,而判别模型类似 ...

  5. 简述一下生成对抗网络GAN(Generative adversarial nets)模型?

    简述一下生成对抗网络GAN(Generative adversarial nets)模型? 生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构. 要全面理解生 ...

  6. 【深度学习】Generative Adversarial Network 生成式对抗网络(GAN)

    文章目录 一.神经网络作为生成器 1.1 什么是生成器? 1.2 为什么需要输出一个分布? 1.3 什么时候需要生成器? 二.Generative Adversarial Network 生成式对抗网 ...

  7. 深度学习 | GAN,什么是生成对抗网络

    文章目录 GAN学习笔记 前言 1. GAN原理 2. GAN实例 3. DCGAN原理 4. DCGAN实例 5. WGAN原理 GAN学习笔记 前言 2014年,arXiv上面刊载了一篇关于生成对 ...

  8. 学习笔记:Controllable Artistic Text Style Transfer via Shape-Matching GAN 基于形状匹配生成对抗网络的可控艺术文本风格迁移

    [ICCV-2019] Controllable Artistic Text Style Transfer via Shape-Matching GAN 基于形状匹配生成对抗网络的可控艺术文本风格迁移 ...

  9. GAN Zoo:千奇百怪的生成对抗网络,都在这里了

    自从Goodfellow2014年提出这个想法之后,生成对抗网络(GAN)就成了深度学习领域内最火的一个概念,包括LeCun在内的许多学者都认为,GAN的出现将会大大推进AI向无监督学习发展的进程. ...

  10. The first GAN——Generative Adversarial Nets

    目录 (一)预备知识 (二)总体介绍 (三)相关工作 (四)Adversarial nets && Theoretical Results (五)实验结果 (六)新框架的优缺点分析: ...

最新文章

  1. 里签名boot有什么用_面膜里的塑料纸有什么用?原来这里大有学问
  2. buffsize 缓冲区的大小多少合适_6人餐桌尺寸规格一般是多少
  3. C++描述杭电OJ 2012.素数判定 ||
  4. 对比关系生成模型(Comparative Relation Generative Model)
  5. Job for slapd.service failed because the control process exited with error code. See systemctl stat
  6. Linux Kernel中irq handler, softirq handler 和 tasklet
  7. 传智燕青学成在线项目视频分享
  8. phpwind安装空白问题解决
  9. 从一个html页面传值到另一个页面,两个html之间的值传递(js location.search用法)
  10. 服务器处理蜘蛛抓取网页的过程,搜索引擎蜘蛛抓取页面过程图解
  11. 慢节奏的和府,能否掌握资本带来的“加速度”
  12. 单元测试框架NUnit 之 Attributes特性(一)
  13. windows下mysql-5.7.30-winx64解压安装步骤
  14. 字、字节、字长、存储单元、bit、byte的关系
  15. touchgfx程序_基于TouchGFX和FreeRTOS的智能家居解决方案
  16. oracle 11g ora31626,expdp时候出错:ORA-31626,ORA-31637,ORA-39062,ORA-31613
  17. CRMEB标准版v4.7 新增的通联支付你了解吗?
  18. Windows主机信息搜集
  19. Verilog实现---时钟信号的90°相移
  20. CPU参数中的TPD(热设计功耗)的含义

热门文章

  1. java设计计算器_Java复数计算器的设计
  2. 微信公众号开发之网页授权获取用户基本信息
  3. 剖析Apple Pay 它与支付宝究竟暧昧什么?
  4. 一个IT工薪族的7年奋斗成果:天鸟之路,天鸟有财,天鸟有度
  5. 深入理解操作系统实验——bomb lab(作弊方法2)
  6. VR全景,带您“飞临”探秘北京2022年冬奥会
  7. 《关键对话》读书笔记
  8. Mobile - 小米手机如何开通应用分身?应用多开?
  9. 数据分析(2)——数据分析的流程 数据类型及数据收集和整理方法
  10. android 百度绑定身份证