pytorch学习之GAN生成MNIST手写数字
0.简单介绍:
学深度学习的人必然知道,最基本的GAN模型由一个生成器 G 和判别器 D 组成。生成器用于生成假样本,判别器用于判断样本是真实的还是假的。
在整个训练过程中,生成器努力地让生成的图像更加真实,而判别器则努力地去识别出图像的真假,这个过程相当于一个二人博弈,随着时间的推移,生成器和判别器在不断地进行对抗,最终期望两个网络达到一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5(相当于随机猜测类别)。
以下是作为初学者的我 了解GAN的结构和运作机制的代码:
1.必要的库函数:
import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True) #生成的数字最后会放到image文件夹,没有则创建
2.参数设置
使用argparse模块主要用来为脚本传递命令参数功能**,代码更加灵活:**
parser = argparse.ArgumentParser() #创建一个参数对象
#调用 add_argument() 方法给 ArgumentParser对象添加程序所需的参数信息
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args() # parse_args()返回我们定义的参数字典
print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size) #(1,28,28)cuda = True if torch.cuda.is_available() else False #是否使用“cuda”
3.定义生成器:
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)] #该例只是用全连接层,未卷积if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8)) #BatchNorm:在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布,momentum=0.8layers.append(nn.LeakyReLU(0.2, inplace=True)) #inplace = True ,直接覆盖原输入数据的值return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), #opt.latent_dim,100维的随机噪声*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))), #np.prod(img_shape),返回1*28*28nn.Tanh() #使用Tanh()激活函数)def forward(self, z): #前向传播img = self.model(z) #高斯噪声信号z调用model(),model()调用block(),完成生成图像操作img = img.view(img.size(0), *img_shape) #img.size(0)为784的像素,转换为(1,28,28)的图像return img
3.1通过输入噪声图片,generator输出一个与真实图片一样大小的图像。
3.2.生成器生成图像只用了全连接层哦,没有进行复杂的卷积操作
3.3隐层激活函数采用的是Leaky ReLU,了解各种激活函数,可参考:https://zhuanlan.zhihu.com/p/88429934?from_voters_page=true
3.4在输出层我们使用tanh函数,这是因为tanh在这里相比sigmoid的结果会更好一点
4.定义判别器:
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1) validity = self.model(img_flat)return validity
4.1判别器接收一张图片,输出层为1个结点,输出为1的概率
4.2同样隐层使用了Leaky ReLU
5.定义损失函数,初始化一个生成器和一个判别器对象,加载数据,定义优化器
# Loss function
adversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loader
os.makedirs("./data/mnist", exist_ok=True) #不存在则下载
dataloader = torch.utils.data.DataLoader(datasets.MNIST("./data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] #transforms.Resize重置图像分辨率),),batch_size=opt.batch_size, #一个batch:128shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) #Betas是动量梯度的下降Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
6.训练网络
对于生成器来说,传给辨别器的生成图片,生成器希望辨别器打上标签1。因为它要不断训练减小损失,以期望骗过判别器。
对于判别器来说,给定的真实图片,辨别器要为其打上标签1;给定的生成图片,辨别器要为其打上标签0;它要能够识别真假。
for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) #fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) ## Configure inputreal_imgs = Variable(imgs.type(Tensor))# -----------------# Train Generator# -----------------optimizer_G.zero_grad() #梯度置0# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) #输入从0到1之间,形状为imgs.shape[0], opt.latent_dim的随机高斯数据。# Generate a batch of imagesgen_imgs = generator(z) #生成图像# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid) #计算生成图像的损失,一开始很大g_loss.backward() #算梯度optimizer_G.step() #更新权重# ---------------------# Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid) #计算真实图像的损失fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) #计算生成图像的损失#noise 从 generator 输入,到discriminator 输出,计算 generator 损失,回传,这一步更新了 generator 的参数,并释放了计算图。# 下一步更新 discriminator 的参数时,generator 的输出经过 detach 后,又通过了一遍 discriminator,相当于,generator 的输出前后两次通过了 discriminator ,得到相同的输出d_loss = (real_loss + fake_loss) / 2 #计算判别器的损失d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0: #已完成的batch是400的倍数save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) #将生成的图片的25张保存下来
6.1生成器端,g_loss表示它希望让判别器对自己生成的图片尽可能输出为1,相当于它在于判别器进行对抗。
6.2判别器端,real_loss对应着真实图片的loss,它尽可能让判别器的输出接近于1,real_loss与 fake_loss加起来就是整个判别器的损失。
6.3 我的例子是先训练生成器,再训练判别器。有的方案是反过来的,可以自己找来参考。
看看我们生成的数据:
马马虎虎。代码清楚了。
深部学习小白,欢迎交流。
pytorch学习之GAN生成MNIST手写数字相关推荐
- GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字
有关条件GAN(cgan)的相关原理,可以参考: GAN系列之CGAN原理简介以及pytorch项目代码实现 其他类型的GAN原理介绍以及应用,可以查看我的GANs专栏 一.数据集介绍,加载数据 依旧 ...
- Pytorch入门——MNIST手写数字识别代码
MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...
- 全连接神经网络实现MNIST手写数字识别
在对全连接神经网络的基本知识(全连接神经网络详解)学习之后,通过MNIST手写数字识别这个小项目来学习如何实现全连接神经网络. MNIST数据集 对于深度学习的任何项目来说,数据集是其中最为关键的部分 ...
- 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别
一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...
- pytorch 预测手写体数字_深度学习之PyTorch实战(3)——实战手写数字识别
如果需要小编其他论文翻译,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 上一节,我们已经 ...
- 用PyTorch实现MNIST手写数字识别(非常详细)
Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...
- PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类
文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...
- 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类
多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...
- 使用PYTORCH复现ALEXNET实现MNIST手写数字识别
网络介绍: Alexnet网络是CV领域最经典的网络结构之一了,在2012年横空出世,并在当年夺下了不少比赛的冠军,下面是Alexnet的网络结构: 网络结构较为简单,共有五个卷积层和三个全连接层,原 ...
最新文章
- 如何在github上fork一个项目来贡献代码以及同步原作者的修改
- boost::process::async_system相关的测试程序
- 我的技术回顾那些与ABP框架有关的故事-2018年
- [数据结构-严蔚敏版]P37定义一个带头结点的线性链表
- Flowable 数据库表结构 ACT_RE_MODEL
- js 获取URL参数乱码解决
- java冒泡排序代码_JAVA
- 看老外程序员如何向妻子解释OOD (转载)
- 小米关联公司被列入经营异常
- mysql之查询前几条或者中间某几行数据
- 【高斯消元】[JSOI2008]球形空间产生器sphere
- 电脑没声音解决方法(重启/声卡设置/升级声卡驱动)
- Ubuntu 16.04 快捷键截图
- 【转】 计算机视觉、图像处理学习资料汇总
- HDOJ 2010 水仙花数
- 实时操作系统与分时操作系统(或称非实时操作系统)的区别
- 用通俗易懂的语言去解释JDK的动态代理
- 适合大学生浏览的网站
- 即刻打造数字化工厂2020
- iPhone4平台上实时音视频对话(经验)