这里的代码都是,参考网上其他的博文学习的,今天是我第一次学习GAN,心情难免有些激动,想着赶快跑一个生成MNIST数字图像的来瞅瞅效果,看看GAN的神奇。
参考博文是如下三个:
https://www.jb51.net/article/178171.htm
https://blog.csdn.net/happyday_d/article/details/84961175
https://blog.csdn.net/weixin_41278720/article/details/80861284

代码不是原创,只是学习和看明白了。能让我们很直观看到GAN是如何训练的,以及产生的效果。

一:实例一
导入必要的包,以及定义一些图像处理的函数,比如展示图像的函数,加载MNIST数据集,并且将数据集转变成成128批量大小的批次,这个加载数据集和转换批次的操作是之前我做其他BP,CNN网络练习的时候见到过的,再次强调一下:MNIST数据加再进来后默认就是[1, 28, 28]的维度,需要变成784维度向量的话得后续自己view函数处理。

import torch
from torch import nn
from torch.autograd import Variableimport torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNISTimport numpy as npimport matplotlib.pyplot as plt
import matplotlib.gridspec as gridspecplt.rcParams['figure.figsize'] = (10.0, 8.0)  # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'def show_images(images):  # 定义画图工具images = np.reshape(images, [images.shape[0], -1])sqrtn = int(np.ceil(np.sqrt(images.shape[0])))sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))fig = plt.figure(figsize=(sqrtn, sqrtn))gs = gridspec.GridSpec(sqrtn, sqrtn)gs.update(wspace=0.05, hspace=0.05)for i, img in enumerate(images):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(img.reshape([sqrtimg, sqrtimg]))returndef preprocess_img(x):x = tfs.ToTensor()(x)return (x - 0.5) / 0.5def deprocess_img(x):return (x + 1.0) / 2.0NUM_TRAIN = 60000NOISE_DIM = 100
batch_size = 128train_set = MNIST('./data', train=True, transform=preprocess_img)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze()  # 可视化图片效果
# 这里可以先看到128 batch_size 的一部分图片
print(imgs.shape)
show_images(imgs)

定义判别网络,这一步其实就是构造一个数字识别网络,只不过略微有些区别,这里不是识别具体的数字,而是识别是不是真实的图片,输出只有两个(0或者1),1代表是真实的图片,0代表的是构造的虚假图片。输出其实是个概率值。

# 判别网络
class discriminator(torch.nn.Module):def __init__(self, noise_dim=NOISE_DIM):# 调用父类的初始化函数,必须要的super(discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1))def forward(self, img):img = self.net(img)return img

构造生成网络。看似是跟判别网络很类似,其实这里的结构可以任意自行变换,输入是一个100维度的向量,向量值都是随机产生的随机数。最后生了一个784维度的图像数据,这个理的数据将会别送到判别网络中去做判别。

# 生成网络
class generator(torch.nn.Module):def __init__(self, noise_dim=NOISE_DIM):# 调用父类的初始化函数,必须要的super(generator, self).__init__()self.net = nn.Sequential(nn.Linear(noise_dim, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 784),nn.Tanh())def forward(self, img):img = self.net(img)return img

定义损失函数和优化器,这里优化器采用了Adam优化器,损失函数采用了二分类的交叉熵损失函数

# 二分类的交叉熵损失函数
bce_loss = nn.BCEWithLogitsLoss()# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))return optimizer

定义两个函数,分别计算判别网络和生成网络的代价估算,对于判别网络来说,希望真实的图片预测都是输出1,期望标签是1,对于假的图片希望都是模型输出0,期望标签是0。
而对于生成网络来说,希望模型输出是1,因此期望标签是1。

def discriminator_loss(logits_real, logits_fake):  # 判别器的 losssize = logits_real.shape[0]true_labels = Variable(torch.ones(size, 1)).float()size = logits_fake.shape[0]false_labels = Variable(torch.zeros(size, 1)).float()loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)return lossdef generator_loss(logits_fake):  # 生成器的 losssize = logits_fake.shape[0]true_labels = Variable(torch.ones(size, 1)).float()loss = bce_loss(logits_fake, true_labels)return loss

定义训练流程函数

def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,noise_size=NOISE_DIM, num_epochs=10):iter_count = 0for epoch in range(num_epochs):for x, _ in train_data:bs = x.shape[0]# 判别网络real_data = Variable(x).view(bs, -1)  # 真实数据logits_real = D_net(real_data)  # 判别网络得分sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布g_fake_seed = Variable(sample_noise)fake_images = G_net(g_fake_seed)  # 生成的假的数据logits_fake = D_net(fake_images)  # 判别网络得分d_total_error = discriminator_loss(logits_real, logits_fake)  # 判别器的 lossD_optimizer.zero_grad()d_total_error.backward()D_optimizer.step()  # 优化判别网络# 生成网络g_fake_seed = Variable(sample_noise)fake_images = G_net(g_fake_seed)  # 生成的假的数据gen_logits_fake = D_net(fake_images)g_error = generator_loss(gen_logits_fake)  # 生成网络的 lossG_optimizer.zero_grad()g_error.backward()G_optimizer.step()  # 优化生成网络if (iter_count % show_every == 0):print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())show_images(imgs_numpy[0:16])plt.show()print()iter_count += 1print('iter_count: ', iter_count)

开始进行训练

D = discriminator()
G = generator()D_optim = get_optimizer(D)
G_optim = get_optimizer(G)train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

代码清晰明了,对于初学者跑出一个GAN很有直观上的印象,以及怎么训练GAN也有很清晰的认识。

看看几个效果图:




总体趋势是随着迭代次数的增加,图像会变得稍微清晰一点点,数字的轮廓也明显一些。

图像十分不清晰,只能看到大概的样子,但是起码也有了数字的大致轮廓了,如果加上去雾处理的话可能效果会再好一些。

二:实例二
实例一用的是BP全连接网络结构,其他的都不动,我们把判别网络和生成网络的模型改成CNN卷积的模型,如下:

class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.conv = nn.Sequential(nn.Conv2d(1, 32, 5, 1),nn.LeakyReLU(0.01),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, 5, 1),nn.LeakyReLU(0.01),nn.MaxPool2d(2, 2))self.fc = nn.Sequential(nn.Linear(1024, 1024),nn.LeakyReLU(0.01),nn.Linear(1024, 1))def forward(self, x):x = self.conv(x)x = x.view(x.shape[0], -1)x = self.fc(x)return xclass generator(nn.Module):def __init__(self, noise_dim=NOISE_DIM):super(generator, self).__init__()self.fc = nn.Sequential(nn.Linear(noise_dim, 1024),nn.ReLU(True),nn.BatchNorm1d(1024),nn.Linear(1024, 7 * 7 * 128),nn.ReLU(True),nn.BatchNorm1d(7 * 7 * 128))self.conv = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, padding=1),nn.ReLU(True),nn.BatchNorm2d(64),nn.ConvTranspose2d(64, 1, 4, 2, padding=1),nn.Tanh())def forward(self, x):x = self.fc(x)x = x.view(x.shape[0], 128, 7, 7)  # reshape 通道是 128,大小是 7x7x = self.conv(x)return x

效果确实比BP网络的要好多了,生成的图像更加清晰。
来看下效果变化:




总体上看,图像更加清晰,对着迭代次数的增加,图像越清晰。

Pytorch《GAN模型生成MNIST数字》相关推荐

  1. GAN网络生成手写体数字图片

    Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的. 目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接 ...

  2. GAN 模型生成山水画,骗过半数观察者,普林斯顿大学本科生出品

    作者 | 高卫华 出品 | AI科技大本营 近年来,基于生成对抗网络GAN模型,图像生成领域实现了许多有趣的应用,尤其是在绘画创作方面. 英伟达曾在2019年提出一款名叫GauGAN的神经网络作图工具 ...

  3. 深度学习《GAN模型学习》

    前言:今天我们来一起学习下GAN神经网络,上一篇博文我先用pytorch运行了几个网上的代码例子,用于生成MNIST图片,现在我才反过来写GAN的学习,这次反了过来,效果也是很显而易见的,起码有个直观 ...

  4. GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

    有关条件GAN(cgan)的相关原理,可以参考: GAN系列之CGAN原理简介以及pytorch项目代码实现 其他类型的GAN原理介绍以及应用,可以查看我的GANs专栏 一.数据集介绍,加载数据 依旧 ...

  5. 搭建简单GAN生成MNIST手写体

    Keras搭建GAN生成MNIST手写体 GAN简介 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前 ...

  6. 【Pytorch神经网络理论篇】 24 神经网络中散度的应用:F散度+f-GAN的实现+互信息神经估计+GAN模型训练技巧

    1 散度在无监督学习中的应用 在神经网络的损失计算中,最大化和最小化两个数据分布间散度的方法,已经成为无监督模型中有效的训练方法之一. 在无监督模型训练中,不但可以使用K散度JS散度,而且可以使用其他 ...

  7. Pytorch 使用DCGAN生成动漫人物头像 入门级实战教程

    有关DCGAN实战的小例子之前已经更新过一篇,感兴趣的朋友可以点击查看 Pytorch 使用DCGAN生成MNIST手写数字 入门级教程 有关DCGAN的相关原理:DCGAN论文解读-----DCGA ...

  8. GAN掉人脸识别系统?GAN模型「女扮男装」

    文章来源 新智元 编辑:LRS [新智元导读]人脸识别技术最近又有新的破解方式!一位斯坦福的学生使用GAN模型生成了几张自己的图片,轻松攻破两个约会软件,最离谱的是「女扮男装」都识别不出来. 真的有人 ...

  9. pytorch学习之GAN生成MNIST手写数字

    0.简单介绍: 学深度学习的人必然知道,最基本的GAN模型由一个生成器 G 和判别器 D 组成.生成器用于生成假样本,判别器用于判断样本是真实的还是假的. 在整个训练过程中,生成器努力地让生成的图像更 ...

最新文章

  1. 盘点数据科学20个最好的Python库(附链接)
  2. tomcat启动慢_Hack下mongodb jdbc driver启动慢
  3. 初次使用VS附加到进程功能
  4. PHP保留小数的相关方法
  5. 直接学python行不行_是否可以直接学python或者java而不学c++?
  6. Cockroach DB 1.0发布
  7. 【Android 修炼手册】Gradle 篇 -- Gradle 源码分析
  8. Java Web App体系结构
  9. 用tsmmc.MSC方式在xp和Win7集中管理多台Win2003服务器
  10. 卢伟冰再曝Redmi Note 8:拍照、续航、屏占比、手感都更好
  11. 计算机如何共享手机网络热点,手机也能做热点 教你如何共享手机3G网
  12. Windows上使用Netbeans进行Linux C开发
  13. 卷积神经网络完整总结
  14. oracle导出为dmp文件,oracle导出dmp文件的2种方法
  15. 垃圾邮件识别(一):用机器学习做中文邮件内容分类
  16. BurpSuite工具-HTTP协议详解部分(不懂就查系列)
  17. 导向滤波与opencv python实现
  18. CCF-CSP-2015年9月-题解
  19. 大数据分析应用领域有哪些
  20. 快速查询中通速递物流,查看未签收单号的最后站点

热门文章

  1. 1.4.2.4. SAVING(Core Data 应用程序实践指南)
  2. 给大一师弟师妹的一些建议
  3. Webhook入门教程:Webhook vs API 它们之间有什么区别?
  4. CSDN开发者周刊 TDengine:专为物联网订制的大数据平台 YugaByte DB:高性能的分布式ACID事务数据库
  5. jQuery封装tab选项卡组件(自定义自动功能和延迟显示功能)
  6. 04737 c++ 自学考试2019版 第五章程序设计题 1
  7. 【JAVA 第三章 流程控制语句】课后习题 找零钱
  8. Windows环境下多个tomcat启动方法
  9. 黑猫警长 stl_如何使用当地警长保护您的信息
  10. haproxy 负载_负载测试HAProxy(第1部分)