GAN详解与MINIST手写数字生成实战

  • GAN简介
  • GAN论文原理
  • MINIST手写数字生成实战
    • 1、导入MINIST数据集。
    • 2、构建辨别器和生成器
    • 3、训练模型

GAN简介

GAN(Generative Adversarial Nets) 用中文来说就是 生成对抗网络,它是Ian J. Goodfellow在2014年提出的一种深度学习网络模型。它包含两个模型:生成模型和辨别模型。生成模型是用来捕捉真实数据分布来生成符合原始数据分布的新的数据,辨别模型是用来辨别真实数据和生成模型生成的数据。生成模型的目的是为了来让辨别模型犯错,辨别模型的目的是为了区分生成数据和真实数据。这就好像是两个模型在互相对抗,在对抗中不断吸取经验从而来让自身得到提升,可以类比博弈论中的两人对抗游戏。GAN 可以通过**MLP(多层感知机)**来进行误差反传进行训练。这样就比使用马尔科夫链或者对近似推理过程展开更加简单。由于GAN在生成图像方面有很好的效果,所以得到了很广泛的应用,比如生成名人小时候的照片,将真实人物变成卡通形式,甚至还可以生成世界上不存在的人的脸部照片。

GAN论文原理

    对于生成器,以生成图片为例,我们需要输入一个噪声zzz,就类似一个一百维的变量吧,然后zzz通过从真实数据xxx中学习到的分布pgp_gpg​去进行映射,就可以生成一张图片G(z)G(z)G(z)
    对于辨别器D,我们就是去对生成数据和真实数据进行分类,类似一个两类的分类器。设D(x)D(x)D(x)表示xxx是来自真实数据而不是pgp_gpg​的概率。
    根据GAN的需求,我们需要尽可能让辨别器能够辨别真实数据和生成数据并且让生成器生成数据让辨别器尽可能犯错。简而言之就是最大化log(D(x))log(D(x))log(D(x)),最小化log(1−D(G(x)))log(1-D(G(x)))log(1−D(G(x))),所以我们可以得到如下公式:

这样DDD和GGG就在好像进行两人对抗游戏。

    这是GAN的训练过程,其中绿色的线为生成器生成的数据,黑色的点为真实数据,蓝色的点为辨别器的结果。从a-b-c-d可以看出,生成器生成的数据在不断向真实数据拟合,辨别的结果也在不断改变,最后黑色的点和绿色的线完全拟合时辨别器无法辨别真实数据和生成数据时,此时辨别器的曲线值为0.5(0表示生成数据,1表示真实数据),就无法通过辨别器的值来辨别数据来源。

下面介绍GAN的算法:

    这里比较重要的就是kkk的取值,它会关系到我们模型训练的好坏。kkk的取值不能太小,也不能太大。如果kkk的取值太小,这样每次更新生成器后辨别器得不到充分的更新,无法很好辨别真实数据和生成数据,这时就算不更新生成器也能糊弄辨别器,此时更新生成器的意义不大;如果kkk的取值太大,意味着生成器更新后辨别器会被更新得很好,此时上述生成器梯度公式中log(1−D(G(z(i))))log(1-D(G(z^{(i)})))log(1−D(G(z(i))))就是0,这是就对0求梯度,这样在生成模型的更新上会有困难。这里我们类比一个例子更好理解:假设辨别器就是警察,生成器就是造假者。如果警察太厉害,那么造假者生产一点假钞就被一锅端了,那么造假者就没法赚到钱,不能去进一步改进工艺;如果警察太无力,无法比较好分辨真钞和假钞,那么造假者随便生产点东西都能赚到钱,这样生产者就不会想着去改进工艺。所以两方面都不行,最好的就是两方实力相当,这样都能互相促进进步。

MINIST手写数字生成实战

这是一个利用手写数据集进行训练得到的GAN,生成器接收随机噪声作为输入,然后输出一张手写数字图像;判别器的输入则是两幅图像,分别是真的手写数字图像和生成器生成的假图像,然后输出对这两幅图像的判别结果。

1、导入MINIST数据集。

train_data = dataloader.DataLoader(datasets.MNIST(root='data/', train=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
]), download=True), shuffle=True, batch_size=batch_sz)

2、构建辨别器和生成器

辨别器:

class discrimination(nn.Module):def __init__(self):super(discrimination, self).__init__()self.hidden0 = nn.Sequential(nn.Linear(784, 1024),nn.LeakyReLU(0.2),)self.hidden1 = nn.Sequential(nn.Linear(1024, 512),nn.LeakyReLU(0.2),)self.hidden2 = nn.Sequential(nn.Linear(512, 256),nn.LeakyReLU(0.2),)self.out = nn.Sequential(nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = self.hidden0(x)x = self.hidden1(x)x = self.hidden2(x)x = self.out(x)return x

这里我们对辨别器网络构造采用4层,前三层用LeakyReLu,最后一层用sigmoid。使用LeakyReLu是因为不会将零以下的数全部置为零,所以使用LeakyReLU 激活函数相比使用ReLU 能够更好地使梯度流过网络。使用sigmoid是因为能够将输出值约束在区间[0, 1]

生成器:

class generate(nn.Module):def __init__(self):super(generate, self).__init__()self.hidden0 = nn.Sequential(nn.Linear(100, 256),nn.LeakyReLU(0.2))self.hidden1 = nn.Sequential(nn.Linear(256, 512),nn.LeakyReLU(0.2))self.hidden2 = nn.Sequential(nn.Linear(512, 1024),nn.LeakyReLU(0.2))self.out = nn.Sequential(nn.Linear(1024, 784),nn.Tanh())def forward(self, x):x = self.hidden0(x)x = self.hidden1(x)x = self.hidden2(x)x = self.out(x)return x

生成器前面三层和辨别器一样,最后一层采用tanh激活函数,是为了与对MNIST 数据进行的归一化同步,以将其值转换到[-1, 1] 中,以便判别器始终获取数据点处于相同值域的数据集。

3、训练模型

def train_discriminator(optimizer, loss_fn, real_data, fake_data):optimizer.zero_grad()discriminator_real_data = discriminator(real_data)loss_real = loss_fn(discriminator_real_data, torch.ones(real_data.size(0), 1).to(device))loss_real.backward()discriminator_fake_data = discriminator(fake_data)loss_fake = loss_fn(discriminator_fake_data, torch.zeros(fake_data.size(0), 1).to(device))loss_fake.backward()optimizer.step()return loss_real + loss_fake, discriminator_real_data, discriminator_fake_datadef train_generator(optimizer, loss_fn, fake_data):optimizer.zero_grad()output_discriminator = discriminator(fake_data)loss = loss_fn(output_discriminator, torch.ones(output_discriminator.size(0), 1).to(device))loss.backward()optimizer.step()return lossfor epoch in range(num_epoch):for train_idx, (input_real_batch, _) in enumerate(train_data):real_data = images2vectors(input_real_batch).to(device)generated_fake_data = generator(noise(real_data.size(0))).detach()d_loss, discriminated_real, discriminated_fake = train_discriminator(d_optimizer, loss_fn, real_data,generated_fake_data)generated_fake_data = generator(noise(real_data.size(0)))g_loss = train_generator(g_optimizer, loss_fn, generated_fake_data)if train_idx == len(train_data) - 1:print(epoch, 'd_loss: ', d_loss.item(), 'g_loss: ', g_loss.item())

train_discriminator和train_generator是为了分别对辨别器和生成器求loss,并进行反向传播、参数优化。辨别器涉及到真实数据和生成数据俩方面的误差(上面图中有提到),所以将他们相加起来。
源代码见 MINIST手写数字生成

小伙伴喜欢文章的话记得 点赞加关注 哦,后面会更新其他深度学习的文章。
如果有什么写得有问题的地方希望大家能值出,谢谢。

GAN详解与PyTorch MINIST手写数字生成实战相关推荐

  1. pytorch 之手写数字生成网络

    EPOCH = 10 BATCH_SIZE = 64 LR = 0.005 # learning rate DOWNLOAD_MNIST = False N_TEST_IMG = 5# Mnist d ...

  2. 基于CNN的MINIST手写数字识别项目代码以及原理详解

    文章目录 项目简介 项目下载地址 项目开发软件环境 项目开发硬件环境 前言 一.数据加载的作用 二.Pytorch进行数据加载所需工具 2.1 Dataset 2.2 Dataloader 2.3 T ...

  3. tensorflow入门之MINIST手写数字识别

    最近在学tensorflow,看了很多资料以及相关视频,有没有大佬推荐一下比较好的教程之类的,谢谢.最后还是到了官方网站去,还好有官方文档中文版,今天就结合官方文档以及之前看的教程写一篇关于MINIS ...

  4. 使用Pytorch实现手写数字识别

    使用Pytorch实现手写数字识别 1. 思路和流程分析 流程: 准备数据,这些需要准备DataLoader 构建模型,这里可以使用torch构造一个深层的神经网络 模型的训练 模型的保存,保存模型, ...

  5. 用PyTorch进行手写数字识别

    目录 数据准备 网络模型 完整实现 数据准备 torch.utils.data.Datasets是PyTorch用来表示数据集的类,它是用PyTorch进行手写数字识别的关键. 下面是加载mnist数 ...

  6. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

  7. tensorflow2 minist手写数字识别数据训练

    ✨ 博客主页:小小马车夫的主页 ✨ 所属专栏:Tensorflow 文章目录 前言 一.tenosrflow minist手写数字识别代码 二.输出 三.参考资料 总结 前言 刚开始学习tensorf ...

  8. 深度学习(4)手写数字识别实战

    深度学习(4)手写数字识别实战 Step0. 数据及模型准备 1. X and Y(数据准备) 2. out=relu{relu{relu[X@W1+b1]@W2+b2}@W3+b3}out=relu ...

  9. 生成式对抗网络实战(一)——手写数字生成(CPU本地版)完整代码加详解

    [注1]代码的原文来自以下网址,修改部分及增添注释(基本上都注释了).修改版整体见最后,原版下方链接,均可以跑通,有问题欢迎交流.生成对抗网络GAN---生成mnist手写数字图像示例( 附代码)_陶 ...

  10. 使用PyTorch进行手写数字识别,在20 k参数中获得99.5%的精度。

    In this article we'll build a simple convolutional neural network in PyTorch and train it to recogni ...

最新文章

  1. python教学视频-Python入门视频课程
  2. source code compiled install mongodb
  3. [Leetcode][第39题][JAVA][组合总和][回溯][dfs][剪枝]
  4. [转载] 全本张广泰——第九回 出世见师兄 广泰走江湖
  5. 大学生计算机知识竞赛,大学生计算机基础知识竞赛题库(试题附答案).docx
  6. repost ACM算法学习三境界---王国维人间词话
  7. 计算机硬盘能否做u盘用怎么用,教你怎么用移动硬盘做原系统的启动硬盘图文教程...
  8. Keras或者Tensorflow出现:Optimization loop failed: Cancelled: Operation was cancelled
  9. 【Writeup】BUUCTF_Web_高明的黑客
  10. Python实现主播人气排行榜,带你发现人气王
  11. 营销型网站文案写作的8个技巧
  12. 海信Vidda S65 2023款和2020款有什么区别?哪个更好
  13. GPS基础知识(五)、GPS导航电文
  14. Win7系统中,如何关闭Windows默认的防火墙? win7如何关闭防火墙
  15. Springboot+阿里云kafka踩坑实录
  16. 360浏览器 | 如何从360浏览器中恢复你的密码
  17. android11铃声pixel,Android 11 Beta版1发布,谷歌Pixel系列尝鲜,到底是亲儿子
  18. iOS15-16绕过激活锁,屏幕锁完美隐藏工具老虎V4.5,支持最新iOS16.1.1系统
  19. 当前端连后端主机连不上时的可能原因
  20. Ruby 简单入门(一)

热门文章

  1. 金融交易学——一个专业交易者…
  2. tomcat6到tomcat9解压版(64位)随意下载
  3. 安装虚拟机不支持i686 cpu的解决办法
  4. 丁磊推荐《你的灯亮着吗》为三大管理必读书
  5. 软件测试报告模板怎么写,这篇文章告诉你
  6. UE4 蓝图 循环调用
  7. 2017-06-11 Padavan 完美适配newifi mini【adbyby+SS+KP ...】youku L1 /小米mini
  8. 几何画板如何绘制动态正切函数图像
  9. 研发项目wbs分解简单案例_做项目WBS(工作分解结构)
  10. 商场平面 html5,收集50张商场平面图,看购物中心动线规划