一、前言

本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络。

二、GAN概念

生成对抗网络(Generative Adversarial Networks,GAN)包含生成器(Generator)和鉴别器(Discriminator)两个神经网络。生成器用于生成虚假的数据,经过训练后能够生成以假乱真的数据;鉴别器使用真实数据和虚假数据训练后,能够辨别数据的真假;生成器和鉴别器相互博弈,最终达到鉴别器难以区分生成数据真假的状态。

三、案例实战

我们会创建一个GAN,生成器通过学习训练,来创建符合1010格式规律的值。这个任务比生成图像要简单。通过这个任务,我们可以了解GAN的基本代码框架,观察训练进程,进而帮助我们为接下来生成图像的任务做好准备。

我们先引入依赖库:

import matplotlib.pyplot as plt
import pandas
import torch
import torch.nn as nn

2.1 构造真实数据源

真实数据源可以是一个返回1010格式数据的函数,如下所示:

def generate_real():real_data = torch.FloatTensor([1,0,1,0])return real_data

执行:

generate_real()

结果:

tensor([1., 0., 1., 0.])

但是,在实际生活中,数据往往不是那么精准,我们让其有一定随机性:

def generate_real():real_data = torch.FloatTensor([random.uniform(0.8, 1.0),random.uniform(0.0, 0.2),random.uniform(0.8, 1.0),random.uniform(0.0, 0.2)])return real_data

random.uniform(0.8, 1.0)产生0.8-1.0之间的随机小数。
执行:

generate_real()

结果:

tensor([0.9782, 0.0673, 0.8500, 0.1788])

2.2 构造随机数据

产生4个随机数,可能满足1010格式,也可能不满足,函数如下:

def generate_random(size):random_data = torch.rand(size)return random_data

执行:

generate_random(4)

结果:

tensor([0.4241, 0.0611, 0.7684, 0.2931])

2.3 构造鉴别器

鉴别器是一个神经网络,我们的目的是训练出一个能区分真实数据与随机噪声数据的鉴别器。下面代码定义了一个非常简单的神经网络:输入层有4个节点,用于接受输入的4个值;隐藏层有3个节点;输出层输出0~1的单个值,表示真或假。

class Discriminator(nn.Module):def __init__(self):# 初始化Pytorch父类super().__init__()# 定义神经网络层self.model = nn.Sequential(nn.Linear(4, 3),nn.Sigmoid(),nn.Linear(3, 1),nn.Sigmoid())# 创建损失函数,使用均方误差self.loss_function = nn.MSELoss()# 创建优化器,使用随机梯度下降self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)# 训练次数计数器self.counter = 0# 训练过程中损失值记录self.progress = []# 前向传播函数def forward(self, inputs):return self.model(inputs)# 训练函数def train(self, inputs, targets):# 前向传播,计算网络输出outputs = self.forward(inputs)# 计算损失值loss = self.loss_function(outputs, targets)# 累加训练次数self.counter += 1# 每10次训练记录损失值if (self.counter % 10 == 0):self.progress.append(loss.item())# 每10000次输出训练次数   if (self.counter % 10000 == 0):print("counter = ", self.counter)# 梯度清零, 反向传播, 更新权重self.optimiser.zero_grad()loss.backward()self.optimiser.step()# 绘制损失变化图def plot_progress(self):df = pandas.DataFrame(self.progress, columns=['loss'])df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))

2.4 测试鉴别器

由于还没有创建生成器,所以无法测试能够与其竞争的鉴别器,目前能做的是,检验鉴别器是否能将真实数据与随机数据区分开。

训练

D = Discriminator()
for i in range(10000):# 真实数据D.train(generate_real(), torch.FloatTensor([1.0]))# 随机数据D.train(generate_random(4), torch.FloatTensor([0.0]))

结果:

counter =  10000
counter =  20000

上述代码虽然迭代了10000次,但是在每次迭代中分别对真实数据和随机数据进行了训练,累计训练20000次。

损失值变化

我们来看看训练过程中的损失值变化:

D.plot_progress()


如上图所示,损失值一开始接近0.25,随着训练次数增加,损失值逐渐接近0。

鉴别效果

我们再来测试一下鉴定器的效果,现在分别输入1010格式数据与随机数据,代码和运行结果如下:

print(D.forward(generate_real()).item())
print(D.forward(generate_random(4)).item())

结果:

0.8134430050849915
0.05087679252028465

得出的结果分别接近1和0,这说明鉴别器能够区分真实数据与随机噪声。

2.5 构造生成器

生成器也是一个神经网络,目的是尽量生成满足1010格式的4个值。为了使生成器与鉴别器不相伯仲地相互竞争与提高,生成器与鉴别器的结构正好相反:输入层只有1个节点;隐藏层有3个节点;输出层有4个节点,输出4个值。
代码如下,注意训练函数稍有不同,引入了鉴别器的损失函数进行反向传播,进而更新生成器权重

class Generator(nn.Module):def __init__(self):# 初始化Pytorch父类super().__init__()# 定义神经网络层self.model = nn.Sequential(nn.Linear(1, 3),nn.Sigmoid(),nn.Linear(3, 4),nn.Sigmoid())# 注意这里没有损失函数,在训练时使用鉴别器的损失函数。# 创建优化器,使用随机梯度下降self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)# 训练次数计数器self.counter = 0# 训练过程中损失值记录self.progress = []# 前向传播函数def forward(self, inputs):return self.model(inputs)# 训练函数def train(self, D, inputs, targets):# 前向传播,计算网络输出g_output = self.forward(inputs)# 将生成器输出,传入鉴别器,输出分类结果d_output = D.forward(g_output)# 计算鉴别误差loss = D.loss_function(d_output, targets)# 累加训练次数self.counter += 1# 每10次训练记录损失值if (self.counter % 10 == 0):self.progress.append(loss.item())# 梯度清零, 反向传播, 更新权重。注意这里是对鉴别器的误差进行反向传播,但只更新生成器的权重self.optimiser.zero_grad()loss.backward()self.optimiser.step()# 绘制损失变化图def plot_progress(self):df = pandas.DataFrame(self.progress, columns=['loss'])df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))

2.6 检查生成器输出

同样地,我们也可以单独对生成器进行测试,以检查是否正常工作:

G = Generator()
G.forward(torch.FloatTensor([0.5]))

结果:

tensor([0.6172, 0.5979, 0.5700, 0.6622], grad_fn=<SigmoidBackward0>)

可以看到输出了4个值,但不符合1010格式,因为我们还没有对其进行训练。

2.7 训练GAN

训练

先看代码:

D = Discriminator()
G = Generator()for i in range(10000):# 用真实样本数据训练鉴别器D.train(generate_real(), torch.FloatTensor([1.0]))# 用生成数据训练鉴别器# 此处训练是为了更新鉴别器权重,不需要更新生成器权重,使用detach()以避免计算生成器中的梯度D.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))# 训练生成器,更新生成器权重G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))

在迭代过程中,每次循环都会重复训练GAN的3个步骤:

  1. 用真实样本数据训练鉴别器,更新鉴别器权重
  2. 用生成的数据训练鉴别器,更新鉴别器权重。此处不需要更新生成器权重,detach()的作用是将其从计算图中分离出来
  3. 训练生成器,更新生成器权重

损失值变化

训练完成后,我们来看看鉴别器损失值的变化:

D.plot_progress()


这是一个非常有意思的结果,损失值最终保持在0.25附近。这说明鉴别器无法判断数据是真实的还是伪造的,于是输出0.5,由于我们损失函数使用的是均方误差,所以损失值是0.5的平方,即0.25。

下图是生成器的损失图,与鉴别器损失是互补的:

G.plot_progress()

生成数据

现在我们用训练好的生成器来生成数据:

G.forward(torch.FloatTensor([0.5]))

结果:

tensor([0.9537, 0.0367, 0.9493, 0.0507], grad_fn=<SigmoidBackward0>)

可以看到生成的数据符合1010格式。效果相当不错!

通过上面的训练,相信你已经熟悉GAN的结构了,后面我们将使用GAN来实现手写数字生成等更加酷炫的任务

深度学习 GAN生成对抗网络-1010格式数据生成简单案例相关推荐

  1. 基于生成对抗网络的医学数据域适应研究

    点击上方蓝字关注我们 基于生成对抗网络的医学数据域适应研究 于胡飞, 温景熙, 辛江, 唐艳 中南大学计算机学院,湖南 长沙 410083   摘要:在医疗影像辅助诊断研究中,研究者通常使用不同医院( ...

  2. GAIN: Missing Data Imputation using Generative Adversarial Nets(基于生成对抗网络的缺失数据填补)论文详解

    目录 一.背景分析 1.1 缺失数据 1.2 填补算法 二.GAIN 2.1 GAIN网络架构 2.2 符号描述(Symbol Description) 2.3 生成器模型 2.4 判别器模型 2.5 ...

  3. 深度学习100例-生成对抗网络(DCGAN)生成动漫人物 | 第20天

    文章目录 一.前言 二.什么是生成对抗网络? 1. 设置GPU 2. 加载和准备数据集 三.创建模型 1. 生成器 2. 判别器 四.定义损失函数和优化器 1. 判别器损失 2. 生成器损失 五.保存 ...

  4. 【论文分享】MAD-GAN :基于生成对抗网络的时间序列数据多元异常检测

    2019年ICANN文章 MAD-GAN: Multivariate Anomaly Detection for Time Series Data with Generative Adversaria ...

  5. 深度学习之生成式对抗网络 GAN(Generative Adversarial Networks)

    一.GAN介绍 生成式对抗网络GAN(Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.它源于2014年发表的论文:& ...

  6. 深度学习之生成式对抗网络GAN

    一.GAN介绍 生成式对抗网络GAN(Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块 ...

  7. [李宏毅老师深度学习视频] 生成式对抗网络(GAN)【持续更新】

    从零开始GAN 1.生成式对抗网络 - 基本概念介绍 1.1.引入生成式对抗网络 1.2.Generative Adversarial Network(GAN) 2.GAN理论介绍+WGAN 上上个星 ...

  8. 深度学习:生成式对抗网络,让机器在博弈中实现“自我成长”

    点击阅读原文 深度神经网络在判别模型领域的进步远比在生成模型领域进步快得多,其主要原因就在于相对于生成式模型来说,判别模型目标清晰.逻辑相对简单,实现起来容易. 用通俗的比喻来说,判别模型相当于是来料 ...

  9. 【深度学习】生成式对抗网络的损失函数的理解

    生成式对抗网络即GAN由生成器和判别器组成.原论文中,关于生成器和判别器的损失函数是写成以下形式: 首先,第一个式子我们不看梯度符号的话即为判别器的损失函数,logD(xi)为判别器将真实数据判定为真 ...

最新文章

  1. 独家 | Python处理海量数据集的三种方法
  2. 源代码遭泄露,大疆员工被罚20万,判刑半年。
  3. python高级应用_Python高级编程技巧
  4. [Python人工智能] 三.theano实现分类神经网络及机器学习基础
  5. 每个人都要在自己的“时区”里找到自己的快乐
  6. Docker 存储选型,这些年我们遇到的坑
  7. Node.js 和 Python之间如何进行选择?教你一招搞定
  8. Google DeepMind 团队发布新算法,下一个被 AI 虐哭的是谁?
  9. java动态变量名_Java||第一篇:了解Java并搭建环境
  10. vue slot插槽_Vue之路 | 08vue插槽slot使用
  11. angular.injector()
  12. java 后台接受json参数的几种方式_java后台发送及接收json数据
  13. html (第四本书第九章参考)
  14. 遍历出List<Map>的Key / Value
  15. ECharts 饼图颜色设置教程 - 4 种方式设置饼图颜色
  16. 速读 OSI合作的《2022全球开源趋势报告》
  17. 安卓进阶之android系统架构
  18. ImportError: Couldn‘t import Django
  19. [音视频] wav 格式
  20. Win10 wsl-安装教程

热门文章

  1. 【C语言训练】自由落体问题
  2. 检测xposed框架实现
  3. Android Studio自带图标制作利器 Image Asset Studio
  4. 那个三本的程序员妹子,凉了
  5. python实现猫抓老鼠
  6. Primitives vs Objects
  7. C# Dev利用TreeList设置菜单导航并双击节点打开模块窗体
  8. ORA-01455: converting column overflows integer datatype
  9. COGS 1043. [Clover S2] Freda的迷宫
  10. SpringBoot使用mybatis-autogenerator时,显示Failure to find org.eclipse.m2e:lifecycle-mapping:pom:1.0.0错误