Pytorch GAN实战 MINIST手写数字识别分布解析

前言、准备工作

本案例需要导入的包, 没有下载的通过pip install来下载

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import numpy as np
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

GAN网络模型结果

一、数据集导入

MINIST数据集就不过多介绍了, 这里主要是做了个Normalize的操作,将0~1之间的数据转换为-1~1之间的

这里的transform中的操作是为了将输入的图像转为-1~1之间的数值, 是因为在GAN中生成器的最后一层往往会使用nn.Tanh()效果比较好, 而nn.Tanh()返回的数据就是-1~1之间, 为了使得生成器生成的图像数据能和真是的图像数据都在同一个范围内,便于比较,因此在加载数据集的transform操作中将其归一化到-1~1之间

transform = transforms.Compose([transforms.ToTensor(),  # 0~1, [channel,h,w]transforms.Normalize(0.5, 0.5)  # 0~1 => -1~1
])
train_ds = torchvision.datasets.MNIST(r'D:\Source\Datasets',train=True,transform=transform,download=False)
loader = DataLoader(train_ds, batch_size=64, shuffle=True)
x, _ = iter(loader).__next__()
print(x.shape)


这里可以随便打印一下MINIST中的数字

二、生成器构建

本案例中生成器使用长度为 100的噪声(正态分布随机数)作为输入
(1, 28, 28)的图片作为输出

其中各个层的维度变换作用如下所示, 并在期间交替使用激活函数, 这里要注意的是最后一层的激活函数要使用tanh()

  • linear 1: 100->256
  • linear 2: 256->512
  • linear 3: 512->28*28
  • reshape: 28-28->(1, 28, 28)
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 28 * 28),nn.Tanh(),  # 生成器最后一层的激活要用tanh (-1~1))def forward(self, x):# x为长度为100的噪声输入img = self.main(x)img = img.view(-1, 1, 28, 28)return img

‘’‘判别器’’’

三、判别器构建

判别器输入为图片(1, 28, 28)

输出为二分类的概率值, 使用sigmoid激活0~1

用BCEloss计算交叉熵损失

nn.LeakyReLU: x<0时返回α*x x>0时返回x

在判别器中推荐用LeakyReLU的原因是, 让负值产生梯度, 利于更新

这里判别器其实与之前的图像识别网络类似, 不过多解释了

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(28 * 28, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, x):# x为28*28的图片x = x.view(-1, 28 * 28)x = self.main(x)return x

四、初始化工作

'''初始化模型,优化器,损失函数'''
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optim = optim.Adam(dis.parameters(), lr=1e-4)
g_optim = optim.Adam(gen.parameters(), lr=1e-4)
loss_fn = nn.BCELoss()'''绘图函数'''
def gen_img_plot(model, test_input):predict = np.squeeze(model(test_input).detach().cpu().numpy())plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow((predict[i] + 1) / 2)plt.axis('off')plt.show()
# 随机生成16batch,100长的噪声作为Generator的输入
test_input = torch.randn(16, 100).to(device)

五、网络训练

GAN的网络训练还是比较特殊的, 我们可以将其分为两个阶段


  • 训练判别器
        d_optim.zero_grad()# 在判别器训练真实图片real_output = dis(img)  # 判别器输入真实的图片, 希望real_output是1d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 得到判别器在真实数据上的损失d_real_loss.backward()# 在判别器训练生成的假图片gen_img = gen(random_noise)# 此时优化的对象是判别器, 要把生成器的梯度截断fake_output = dis(gen_img.detach())  # 判别器输入生成的图片 fake_output对生成的图片预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 得到判别器在生成数据上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optim.step()

这边有两个步骤

  • 判别器训练真实图片

    • 对于输入数据为真实图片, 判别器需要识别的结尾是1, 因此我们的loss是与torch.ones_like相比较
  • 判别器训练生成的假图片
    • 这里需要注意的是在训练生成器生成的图片过程中我们要截断生成器的反向传播过程, 使用detach来达到目的
    • 对于输入的数据为Generator生成的图片, 判别器要做的是尽量将它判断为0, 也就是与torch.zeros_like相比较

最终两个loss加起来便是判别器总的loss


  • 训练生成器
        g_optim.zero_grad()fake_output = dis(gen_img)  # 希望其骗过disg_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 得到生成器的损失g_loss.backward()g_optim.step()

生成器的目的就是骗过判别器, 也就是让自己生成的图片在输入判别器之后尽可能的接近1, 因此与torch.ones_like相比较

总体训练网络代码如下

'''GAN的训练'''
D_loss = []
G_loss = []
# 训练循环
for epoch in range(50):d_epoch_loss = 0g_epoch_loss = 0count = len(loader)for step, (img, _) in enumerate(loader):img = img.to(device)size = img.size(0)random_noise = torch.randn(size, 100).to(device)d_optim.zero_grad()# 在判别器训练真实图片real_output = dis(img)  # 判别器输入真实的图片, 希望real_output是1d_real_loss = loss_fn(real_output, torch.ones_like(real_output))  # 得到判别器在真实数据上的损失d_real_loss.backward()# 在判别器训练生成的假图片gen_img = gen(random_noise)# 此时优化的对象是判别器, 要把生成器的梯度截断fake_output = dis(gen_img.detach())  # 判别器输入生成的图片 fake_output对生成的图片预测d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))  # 得到判别器在生成数据上的损失d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optim.step()# 对生成器进行优化g_optim.zero_grad()fake_output = dis(gen_img)  # 希望其骗过disg_loss = loss_fn(fake_output, torch.ones_like(fake_output))  # 得到生成器的损失g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_losswith torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print(f'Epoch: {epoch}, d_epoch_loss: {d_epoch_loss}, g_epoch_loss:{g_epoch_loss}')gen_img_plot(gen, test_input)

六、训练结果

可以看到随着epoch的增加, 生成的效果越来越好


Pytorch GAN实战 MINIST手写数字识别分布解析相关推荐

  1. 基于CNN的MINIST手写数字识别项目代码以及原理详解

    文章目录 项目简介 项目下载地址 项目开发软件环境 项目开发硬件环境 前言 一.数据加载的作用 二.Pytorch进行数据加载所需工具 2.1 Dataset 2.2 Dataloader 2.3 T ...

  2. tensorflow入门之MINIST手写数字识别

    最近在学tensorflow,看了很多资料以及相关视频,有没有大佬推荐一下比较好的教程之类的,谢谢.最后还是到了官方网站去,还好有官方文档中文版,今天就结合官方文档以及之前看的教程写一篇关于MINIS ...

  3. tensorflow2 minist手写数字识别数据训练

    ✨ 博客主页:小小马车夫的主页 ✨ 所属专栏:Tensorflow 文章目录 前言 一.tenosrflow minist手写数字识别代码 二.输出 三.参考资料 总结 前言 刚开始学习tensorf ...

  4. matlab 对mnist手写数字数据集进行判决分析_人工智能TensorFlow(十四)MINIST手写数字识别...

    MNIST是一个简单的视觉计算数据集,它是像下面这样手写的数字图片: MNIST 每张图片还额外有一个标签记录了图片上数字是几,例如上面几张图的标签就是:5.0.4.1. MINIST数据 MINIS ...

  5. 基于Paddle的计算机视觉入门教程——第7讲 实战:手写数字识别

    B站教程地址 https://www.bilibili.com/video/BV18b4y1J7a6/ 任务介绍 手写数字识别是计算机视觉的一个经典项目,因为手写数字的随机性,使用传统的计算机视觉技术 ...

  6. python手写数字识别实验报告_机器学习python实战之手写数字识别

    看了上一篇内容之后,相信对K近邻算法有了一个清晰的认识,今天的内容--手写数字识别是对上一篇内容的延续,这里也是为了自己能更熟练的掌握k-NN算法. 我们有大约2000个训练样本和1000个左右测试样 ...

  7. TensorFlow2 入门指南 | 04 分类问题实战之手写数字识别

    前言: 本专栏在保证内容完整性的基础上,力求简洁,旨在让初学者能够更快地.高效地入门TensorFlow2 深度学习框架.如果觉得本专栏对您有帮助的话,可以给一个小小的三连,各位的支持将是我创作的最大 ...

  8. Pytorch 学习 (一)Minst手写数字识别(含特定函数解析)

    目录 本人目前在跟随csdn博主 "K同学啊"进行365天深度学习训练营进行学习,这是打卡内容 也作为本人学习的记录. 一.准备部分 三.训练模型 四.正式训练 五.输出 MNIS ...

  9. PyTorch:MNIST数据集手写数字识别

    MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它"下手"几乎成为一个 "典范",可以说它就是计算机视觉里面的Hello World. ...

  10. C++元编程——CNN进行Minist手写数字识别

    Minist数据来源: MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges 数据的格式如下: C ...

最新文章

  1. 【STL】string的增删改查
  2. Android2.2 r1 API 中文文档系列(10) —— CheckBox
  3. Struts-config.xml配置文件《action-mappings》元素的详解
  4. 医生的小助手,医疗AI赋能诊断新冠肺炎新方案!
  5. python条件判断true_Python中的True,False条件判断实例分析
  6. jquery和php上传文件进度条,jQuery实现文件上传进度条特效_jquery
  7. 面试AI岗,为什么我在100人中拿到了唯一年薪70万的offer?
  8. 研究鸟类迁徙的目的和意义
  9. arcgis去除遥感影像黑边
  10. 年记 2018,新年快乐
  11. 巨型机是一种什么的超级计算机,把计算机分为巨型机、大中型机按照什么分的...
  12. 解决js newDate()苹果手机日期格式显示NaN
  13. 软件测试用例——三角形
  14. 爬虫2_2019年549所中国大学排名
  15. laradock 安装使用 kafka
  16. 如何使Android录音实现内录功能
  17. Unity使用FairyGUI切换Spine
  18. 蓝海彤翔董事长鲁永泉荣获太湖科学城功能片区2022年度表彰
  19. java屠龙_Java中的屠龙之术——如何修改语法树
  20. Atitit r2017 r3 doc list on home ntpc.docx

热门文章

  1. jeb java_jeb2 java 脚本插件
  2. GD32创建工程与启动文件选择
  3. 微管理——给你一个技术团队,你该怎么管
  4. python怎么爬取新浪微博数据中心_新浪微博数据爬取研究
  5. 单片机破解方法(摘录)
  6. java gps_用Java解析GPS经纬度
  7. sftp非交互式每日定时拉取增量数据文件至本地合并至存量
  8. linux修改sybase数据库密码,巧用Sybase数据库的超级用户密码
  9. 从零开始编译LEDE固件 默认中文material主题
  10. SWMM模型及自主开发城市内涝一维二维耦合软件的复杂城市排水系统建模技术及在城市排涝、海绵城市等领域实践应用