论文:Generative Adversarial Networks
作者:Ian J. Goodfellow
年份:2014年

从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简单记录一下,有时间会补充。
更多关于GAN的可以看我另一篇:https://blog.csdn.net/demo_jie/article/details/106724016

直接讲代码实现部分,这个代码是用pytorch训练GAN,基于MNIST数据集
真实图片:

代码:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import osif not os.path.exists('img'):os.mkdir('img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)   #输出限制在0,1范围内out = out.view(-1, 1, 28, 28)return out# 初始化参数
batch_size = 128
num_epoch = 10
z_dimension = 100
# 对图片进行一些前期处理操作
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# img_transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ]
# 数据集下载
mnist = datasets.MNIST(root='E:/low-light/deep learning/GAN/data/', train=True, transform=img_transform, download=True)
# 数据集加载
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)# 判别网络
class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2), nn.Linear(256, 1),nn.Sigmoid())  # sigmoid激活函数得到一个0到1之间的概率进行二分类def forward(self, x):x = self.dis(x)return x# 生成器
class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256),nn.ReLU(True),nn.Linear(256, 784),nn.Tanh())  # Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。def forward(self, x):x = self.gen(x)return xD = discriminator()
G = generator()
if torch.cuda.is_available():D = D.cuda()G = G.cuda()
# 判别器的训练由两部分组成,第一部分是真的图像判别为真,第二部分是假的图片判别为假,在这两个过程中,生成器的参数不参与更新。
# 二进制交叉熵损失和优化器
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
# 开始训练
for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# ========================================================================训练判别器img = img.view(num_img, -1)  # # 将图片展开乘28x28=784# real_img = Variable(img).cuda()# real_label = Variable(torch.ones(num_img)).cuda()# fake_label = Variable(torch.zeros(num_img)).cuda()real_img = Variable(img)real_label = Variable(torch.ones(num_img))  # 定义真实label为1fake_label = Variable(torch.zeros(num_img))  # 定义假label为1# 计算 real_img 的损失real_out = D(real_img)  # 将真实的图片放入判别器中d_loss_real = criterion(real_out, real_label)  # 得到真实图片的lossreal_scores = real_out  # 越接近一越好# 计算 fake_img的损失# z = Variable(torch.randn(num_img, z_dimension)).cuda()z = Variable(torch.randn(num_img, z_dimension))  # 随机生成一些噪声fake_img = G(z)  # 放入生成网络生成一张假的图片fake_out = D(fake_img)  ## 判别器判断假的图片d_loss_fake = criterion(fake_out, fake_label)  ## 得到假的图片的lossfake_scores = fake_out  # 越接近0越好# 反向传播和优化d_loss = d_loss_real + d_loss_fake  # 将真假图片的loss加起来d_optimizer.zero_grad()  # 每次梯度归零d_loss.backward()  # 反向传播d_optimizer.step()  # 更新参数# =====================================================================训练生成器# 计算fake_img损失# z = Variable(torch.randn(num_img, z_dimension)).cuda()z = Variable(torch.randn(num_img, z_dimension))  # 得到随机噪声fake_img = G(z)  # 生成假的图片output = D(fake_img)  # 经过判别器得到结果g_loss = criterion(output, real_label)  ##得到假的图片与真实图片label的loss# 反向传播和优化g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f},D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, 'real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, 'fake_images-{}.png'.format(epoch + 1))
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')

运行结果:

这次一共跑了10次,以下是生成的噪声图片,分别是跑了1,3,5,7,9,10次的图片(训练次数太少了,所以效果不明显,可以自己设置训练次数)



生成的真实图片:

pytorch训练GAN的代码(基于MNIST数据集)相关推荐

  1. 【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

    「@Author:Runsen」 GAN 是使用两个神经网络模型训练的生成模型.一种模型称为生成网络模型,它学习生成新的似是而非的样本.另一个模型被称为判别网络,它学习区分生成的例子和真实的例子. 生 ...

  2. GAN生成对抗网络基本概念及基于mnist数据集的代码实现

    本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...

  3. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  4. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  5. 神经网络--基于mnist数据集取得最高的识别准确率

    前言: Hello大家好,我是Dream. 今天来学习一下如何基于mnist数据集取得最高的识别准确率,本文是从零开始的,如有需要可自行跳至所需内容~ 本文目录: 1.调用库函数 2.调用数据集 3. ...

  6. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  7. Fate集群 | 基于MNIST数据集的模型训练+模型预测 详细过程

    文章目录 一.获取数据集并简单处理 1.分割数据集 2.拷贝数据集 二.模型训练 1.上传数据 1)host方 2)guest方 2.构建模型 3.修改配置文件 1)DSL简介 2)DSL配置文件 3 ...

  8. 基于 MNIST 数据集的 Pytorch 卷积自动编码器

    自编码器 自动编码器是一种无监督的深度学习算法,它学习输入数据的编码表示,然后重新构造与输出相同的输入.它由编码器和解码器两个网络组成.编码器将高维输入压缩成低维潜在代码(也称为潜在代码或编码空间) ...

  9. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型...

    「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...

最新文章

  1. [FaceBook]测试、发布和分享小游戏
  2. Linux CentOS 6+复制本地前端文件压缩包解压到服务器端指定目录
  3. Spring源码分析【9】-SpringSecurity密码Remove原理
  4. 如何使dropship第三方销售是基于发货数量,而不是基于LIV发票校验的数量
  5. 百度之星12月30号题目之维基解密
  6. Hdu5015 233 Matrix矩阵
  7. 2.1 为什么要进行实例探究-深度学习第四课《卷积神经网络》-Stanford吴恩达教授
  8. E - Water Distribution
  9. linux系统创建windows启动盘
  10. 外卖和快递行业数据_下周一起,整治全面启动!锁定全市外卖、快递行业!
  11. eclipse 使用指南
  12. lucene中write.lock索引锁机制的原理
  13. webservice使用EF生成的model序列化问题
  14. python定期自动运行_令人惊叹的8个Python新手工具
  15. java html文件转换pdf文件_Java实现HTML转换为PDF的常见方法
  16. 一些常用单位之间的换算
  17. java word jar包_java操作word书签生成word模板不用jar包
  18. 3.1 埃拉托色尼筛选
  19. 禁止所有搜索引擎蜘蛛的爬行收录
  20. linux的nfs配置文件的编写信息(学习day1)

热门文章

  1. Redis 进阶笔记
  2. android3d画廊自动切换,Android实例(一)—— 3D画廊
  3. 从董明珠雷军世纪之赌中看到什么样的格力和小米?
  4. 奇幻之旅,全世界畅游
  5. 在今日头条有粉丝17万月入万把块,究竟今日头条要怎么挣钱?
  6. 开源解决方案搭建统一日志平台
  7. 自然语言处理(三):传统RNN(NvsN,Nvs1,1vsN,NvsM)pytorch代码解析
  8. Java语言基础(常见对象3(数组排序(冒泡排序、选择排序、直接插入排序、快速排序、归并排序)、Arrays、基本数据类型包装类、Integer类、自动拆箱、自动装箱))
  9. 如何通过拍照识别植物?试试这几个软件
  10. 方法重写的 两同 两小 一大