对于GAN的原理,我这里就不多讲了,网上很多。这里主要讲代码,以及调试的踩得坑。
本文参考:
https://blog.csdn.net/qxqsunshine/article/details/84105948
首先导入相关的包。

import torch
import torchvision
import torch.utils.data
import torch.nn
import torch.autograd.variable
from torch.autograd import Variable
from torchvision.utils import save_image

数据的处理

transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])

在网上其他代码中,通常会加入如下的代码:

torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

但是会出现通道不匹配的错误提示:
RuntimeError: output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]。
原因是因为MINST数据集是单通道的图像(灰度图),但是Normalize均值方差都是三个。解决上述错误,一般会加入如下代码:

torchvision.transforms.Lambda(lambda x: x.repeat(3, 1, 1)),#将单通道,转化为多

这样问题可以暂时解决,但后面又会出现元素数量不对的情况:
RuntimeError: size mismatch, m1: [100 x 2352], m2: [784 x 256] at C:/w/1/s/windows/pytorch/aten/src\THC/generic/THCTensorMathBlas.cu:268
很烦。所以索性不用Normalize,这样所有的问题完美解决。。。

加载数据集

#数据集
test_data=torchvision.datasets.MNIST(root='./data/',#路径transform=transform,#数据处理train=False,#使用测试集,这个看心情download=True#下载
)

将数据放进加载器,作用就是包装一下
以下句子来自:
https://www.cnblogs.com/demo-deng/p/10623334.html
数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。

#数据加载器, DataLoader就是用来包装所使用的数据,每次抛出一批数据
test_data_load=torch.utils.data.DataLoader(dataset=test_data,shuffle=True,#每次打乱顺序batch_size=100#批大小,这里根据数据的样本数量而定,最好是能整除
)

向量转图片。
其中view()函数,我理解的还不够透彻。。。

#将1*784向量,转换成28*28图片
def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return out

生成器
采用线性网络,ReLU()作为激活函数,最后一程=层使用Tanh()。至于为什么这么写,只是说实验结果比较好,我没有仔细研究这个。

#生成器
class Generater(torch.nn.Module):def __init__(self):super(Generater, self).__init__();self.G_lay1=torch.nn.Sequential(torch.nn.Linear(100,128),torch.nn.ReLU(),torch.nn.Linear(128,256),torch.nn.ReLU(),torch.nn.Linear(256,784),torch.nn.Tanh())def forward(self, x):return self.G_lay1(x)

判别器
使用LeakyReLU作为激活函数,最后一层Sigmoid们可以理解为二分类器。另外判别器和生成器最好是对称的

#判别器
class Discriminator(torch.nn.Module):def __init__(self):super(Discriminator, self).__init__()self.D_lay2=torch.nn.Sequential(torch.nn.Linear(784,256),torch.nn.LeakyReLU(0.2),torch.nn.Linear(256,128),torch.nn.LeakyReLU(0.2),torch.nn.Linear(128,1),torch.nn.Sigmoid())def forward(self, x):return self.D_lay2(x)

训练前的准备

#实例化
g_net=Generater().cuda()
d_net=Discriminator().cuda()#损失函数,优化器
loss_fun=torch.nn.BCELoss()
g_optimizer=torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optimizer=torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5,0.999))
epoch_n=20

这里采用beta1为0.5,在多数的时候一般会使用0.9,这个还是要看具体情况。
我这里只train了20次,电脑太渣。

训练

for epoch in range(epoch_n):for i,(img,_) in enumerate(test_data_load):img_num=img.size(0)#训练D——————————————————————————————————————————————————#真图——真标签img=img.view(img_num,-1)real_img=Variable(img).cuda()real_output=d_net(real_img)real_lab= Variable(torch.ones(img_num)).cuda()real_loss=loss_fun(real_output,real_lab)#假图——假标签noise=Variable(torch.randn(img_num, 100)).cuda()fake_img=g_net(noise)fake_lab=Variable(torch.zeros(img_num)).cuda()fake_output=d_net(fake_img)fake_loss=loss_fun(fake_output,fake_lab)d_loss=real_loss+fake_lossd_optimizer.zero_grad()d_loss.backward()d_optimizer.step()#训练G————————————————————————————————————————————#假图——真标签g_noise = Variable(torch.randn(img_num, 100)).cuda()g_img = g_net(noise)g_lab=Variable(torch.ones(img_num)).cuda()g_output=d_net(g_img)g_loss=loss_fun(g_output,g_lab)g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''real_loss: {:.6f}, fake_loss: {:.6f}'.format(epoch, epoch_n, d_loss.data.item(), g_loss.data.item(),real_loss.data.mean(), fake_loss.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './img/real_images.png')#保存生成的图片fake_images = to_img(fake_img.cpu().data)save_image(fake_images,'./img/fake_images-{}.png'.format(epoch + 1),nrow=10)#保存模型
torch.save(g_net.state_dict(), './generator.pth')
torch.save(d_net.state_dict(), './discriminator.pth')

其中save_image中nrow,表示将一个批次的100张图,按照每行10个排列。

整体代码:

import torch
import torchvision
import torch.utils.data
import torch.nn
import torch.autograd.variable
from torch.autograd import Variable
from torchvision.utils import save_image#图像读入与处理
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),])#数据集
test_data=torchvision.datasets.MNIST(root='./data/',#路径transform=transform,#数据处理train=False,#使用测试集,这个看心情download=True#下载
)#数据加载器, DataLoader就是用来包装所使用的数据,每次抛出一批数据
test_data_load=torch.utils.data.DataLoader(dataset=test_data,shuffle=True,#每次打乱顺序batch_size=100#批大小,这里根据数据的样本数量而定,最好是能整除
)#将1*784向量,转换成28*28图片
def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return out#生成器
class Generater(torch.nn.Module):def __init__(self):super(Generater, self).__init__();self.G_lay1=torch.nn.Sequential(torch.nn.Linear(100,128),torch.nn.ReLU(),torch.nn.Linear(128,256),torch.nn.ReLU(),torch.nn.Linear(256,784),torch.nn.Tanh())def forward(self, x):return self.G_lay1(x)#判别器
class Discriminator(torch.nn.Module):def __init__(self):super(Discriminator, self).__init__()self.D_lay2=torch.nn.Sequential(torch.nn.Linear(784,256),torch.nn.LeakyReLU(0.2),torch.nn.Linear(256,128),torch.nn.LeakyReLU(0.2),torch.nn.Linear(128,1),torch.nn.Sigmoid())def forward(self, x):return self.D_lay2(x)#实例化
g_net=Generater().cuda()
d_net=Discriminator().cuda()#损失函数,优化器
loss_fun=torch.nn.BCELoss()
g_optimizer=torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5,0.999))
d_optimizer=torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5,0.999))
epoch_n=20for epoch in range(epoch_n):for i,(img,_) in enumerate(test_data_load):img_num=img.size(0)#训练D——————————————————————————————————————————————————#真图——真标签img=img.view(img_num,-1)real_img=Variable(img).cuda()real_output=d_net(real_img)real_lab= Variable(torch.ones(img_num)).cuda()real_loss=loss_fun(real_output,real_lab)#假图——假标签noise=Variable(torch.randn(img_num, 100)).cuda()fake_img=g_net(noise)fake_lab=Variable(torch.zeros(img_num)).cuda()fake_output=d_net(fake_img)fake_loss=loss_fun(fake_output,fake_lab)d_loss=real_loss+fake_lossd_optimizer.zero_grad()d_loss.backward()d_optimizer.step()#训练G————————————————————————————————————————————#假图——真标签g_noise = Variable(torch.randn(img_num, 100)).cuda()g_img = g_net(noise)g_lab=Variable(torch.ones(img_num)).cuda()g_output=d_net(g_img)g_loss=loss_fun(g_output,g_lab)g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''real_loss: {:.6f}, fake_loss: {:.6f}'.format(epoch, epoch_n, d_loss.data.item(), g_loss.data.item(),real_loss.data.mean(), fake_loss.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './img/real_images.png')#保存生成的图片fake_images = to_img(fake_img.cpu().data)save_image(fake_images,'./img/fake_images-{}.png'.format(epoch + 1),nrow=10)torch.save(g_net.state_dict(), './generator.pth')
torch.save(d_net.state_dict(), './discriminator.pth')

以上代码,并非全部原创,只是站在巨人肩膀上,如果侵权,请评论或私信联系 ,如果有错误,请评论或私信,谢谢。

Pytorch学习——GAN——MINST相关推荐

  1. Pytorch 学习 (一)Minst手写数字识别(含特定函数解析)

    目录 本人目前在跟随csdn博主 "K同学啊"进行365天深度学习训练营进行学习,这是打卡内容 也作为本人学习的记录. 一.准备部分 三.训练模型 四.正式训练 五.输出 MNIS ...

  2. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

  3. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  4. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

    文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...

  5. 新手必备 | 史上最全的PyTorch学习资源汇总

    目录: PyTorch学习教程.手册 PyTorch视频教程 PyTorch项目资源      - NLP&PyTorch实战      - CV&PyTorch实战 PyTorch论 ...

  6. pytorch学习资源汇总

    pytorch学习资源汇总 https://pytorchchina.com/2019/05/07/awesome-pytorch-chinese/ PyTorch学习教程.手册 PyTorch英文版 ...

  7. PyTorch学习教程、手册

    文章目录 PyTorch学习教程.手册 PyTorch视频教程 NLP&PyTorch实战 CV&PyTorch实战 PyTorch论文推荐 PyTorch书籍推荐 PyTorch学习 ...

  8. Pytorch 使用GAN实现二次元人物头像生成 保姆级教程(数据集+实现代码+数学原理)

    Pytorch 使用DCGAN实现二次元人物头像生成(实现代码+公式推导) GAN介绍   算法主体   推导证明(之后将补全完整过程)   随机梯度下降训练D,G   DCGAN介绍及相关原理 Py ...

  9. Pytorch实现GAN 生成动漫头像

    什么是GAN? ​ 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习]最具前景的方法之一.GAN的核心思想来 ...

最新文章

  1. python 作物识别_Python-OpenCV —— 物体识别(TrainCascadeClassification)
  2. jvm执行引擎全解,java解释器即时编译器,全都讲明白
  3. 【Window / 浏览器】 常用 快捷键 整理
  4. linux系统中的目录讲解
  5. mysql avg 报错_MySQL报错汇总
  6. Lua的upvalue和闭包
  7. Apache Cassandra和低延迟应用程序
  8. Mysql ERROR 1418 (HY000): This function has none of DETERMINISTIC, NO SQL, or READS SQL DATA
  9. 参加kaggle比赛
  10. h5跳转小程序_微信小程序吞掉H5?
  11. 搭建MyBatis框架
  12. 分区字段不在SQL过滤中,悲剧
  13. Linux基础:linux网络接口
  14. 执行sc query mysql,sc delete mysql没有反应
  15. 3dmax shift用来复制对象
  16. 构建路径_深度学习的幸福课堂构建从评价细则中找“路径”——基于深度学习的幸福课堂构建实践研究...
  17. 微信小程序开发之微信小程序交互
  18. 专升本english
  19. C语言象棋马的遍历程序,马走日遍历
  20. 国庆头像生成器小程序源码

热门文章

  1. 微型计算机及win7,不一样的微型电脑,加上win7系统,完美无击,还可以放口袋...
  2. Python实现GeoHash算法
  3. python编译安装详解_linux 编译安装python3.6的教程详解
  4. Unity 工具 之 VText 简单快速实现 文字 3D 效果,VText 的导入设置和简单使用(可支持中文字体)
  5. 操作系统 第4章 习题整理
  6. Oracle创建同义词及dblink
  7. 二进制与十进制的转换教案
  8. 如何使用webshell方式登录腾讯云Linux轻量应用服务器实例?
  9. 三相同步电动机的平衡方程式
  10. 剑侠情缘(网络版)---开发回顾 ------赵青