简述

其实是根据我之前写的两个代码改的。(之前已经有过非常详细的解释了,可以去看看)

  • 【GANs入门】pytorch-GANs任务迁移-单个目标(数字的生成)
  • 【Gans入门】Pytorch实现Gans代码详解【70+代码】

同时,在结合了我之前写的DCGANs的时候,实现的一份代码

  • (深度卷积生成对抗神经网络)DCGANs论文阅读与实现pytorch

MNIST上选特定的数值,是根据下面的这篇文章得到的。

  • MNIST选取特定数值的训练集

之前的代码上都有非常详细的解释。这里只是基于上面的一点点改进而已。就不给出特别详细的解释。但是代码中任然保留有注释部分。

图形演变过程

代码

import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt
import os
import shutil
import imageio
PNGFILE = './png/'
if not os.path.exists(PNGFILE):os.mkdir(PNGFILE)
else:shutil.rmtree(PNGFILE)os.mkdir(PNGFILE)# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001  # learning rate for generator
LR_D = 0.0001  # learning rate for discriminator
N_IDEAS = 100  # think of this as number of ideas for generating an art work (Generator)
target_num = 0  # target Number
EPOCH = 10  # 训练整批数据多少次
DOWNLOAD_MNIST = False  # 已经下载好的话,会自动跳过的
ART_COMPONENTS = 28 * 28# Mnist 手写数字class myMNIST(torchvision.datasets.MNIST):def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None):super(myMNIST, self).__init__(root,train=train,transform=transform,target_transform=target_transform,download=download)if targetNum != None:self.train_data = self.train_data[self.train_labels == targetNum]self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]self.train_labels = self.train_labels[self.train_labels == targetNum][:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]def __len__(self):if self.train:return self.train_data.shape[0]else:return 10000train_data = myMNIST(root='./mnist/',  # 保存或者提取位置train=True,  # this is training datatransform=torchvision.transforms.ToTensor(),  # 转换 PIL.Image or numpy.ndarray 成# torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间download=DOWNLOAD_MNIST,  # 没下载就下载, 下载了就不用再下了targetNum=target_num
)
print(len(train_data))
# print(train_data.shape)# 训练集丢BATCH_SIZE个, 图片大小为28*28
train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True  # 是否打乱顺序
)G = nn.Sequential(  # Generatornn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideasnn.ReLU(),
)D = nn.Sequential(  # Discriminatornn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),  # tell the probability that the art work is made by artist
)# loss & optimizer
optimD = torch.optim.Adam(D.parameters(), lr=LR_D)
optimG = torch.optim.Adam(G.parameters(), lr=LR_G)label_Real = torch.FloatTensor(BATCH_SIZE).data.fill_(1)
label_Fake = torch.FloatTensor(BATCH_SIZE).data.fill_(0)filePath = []for epoch in range(EPOCH):for step, (images, imagesLabel) in enumerate(train_loader):G_ideas = torch.randn((BATCH_SIZE, N_IDEAS))G_paintings = G(G_ideas)images = images.reshape(BATCH_SIZE, -1)prob_artist0 = D(images)  # D try to increase this probprob_artist1 = D(G_paintings)D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))optimD.zero_grad()D_loss.backward(retain_graph=True)optimD.step()optimG.zero_grad()G_loss.backward(retain_graph=True)optimG.step()if step % 20 == 0:plt.cla()picture = torch.squeeze(G_paintings[0]).detach().numpy().reshape((28, 28))plt.imshow(picture, cmap=plt.cm.gray_r)plt.savefig(PNGFILE + '%d-%d.png' % (epoch, step))filePath.append(PNGFILE + '%d-%d.png' % (epoch, step))generated_images = []
for png_path in filePath:generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan-mnist.gif', generated_images, 'GIF', duration=0.1)

基于MNIST的GANs实现【Pytorch】相关推荐

  1. 基于MNIST实现GAN(pytorch)

    基于MNIST实现生成对抗网络(pytorch逐行实现) 本文是pytorch逐行实现GAN网络,作为一个基础GAN框架来学习,以后编写复杂的GAN的衍生网络框架都是同样的思想 import nump ...

  2. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

  3. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  4. 基于Anaconda安装GPU版PyTorch深度学习开发环境

    基于Anaconda安装GPU版PyTorch深度学习开发环境 1 安装Anaconda 2 安装GPU计算驱动 2.1 检查是否有合适的GPU 2.2 下载CUDA和cuDNN 2.3 安装CUDA ...

  5. GAN生成对抗网络基本概念及基于mnist数据集的代码实现

    本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...

  6. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  7. 基于MNIST手写体数字识别--含可直接使用代码【Python+Tensorflow+CNN+Keras】

    基于MNIST手写体数字识别--[Python+Tensorflow+CNN+Keras] 1.任务 2.数据集分析 2.1 数据集总体分析 2.2 单个图片样本可视化 3. 数据处理 4. 搭建神经 ...

  8. 神经网络--基于mnist数据集取得最高的识别准确率

    前言: Hello大家好,我是Dream. 今天来学习一下如何基于mnist数据集取得最高的识别准确率,本文是从零开始的,如有需要可自行跳至所需内容~ 本文目录: 1.调用库函数 2.调用数据集 3. ...

  9. 【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

    「@Author:Runsen」 GAN 是使用两个神经网络模型训练的生成模型.一种模型称为生成网络模型,它学习生成新的似是而非的样本.另一个模型被称为判别网络,它学习区分生成的例子和真实的例子. 生 ...

最新文章

  1. 输入两个整数a和b,计算a+b的和
  2. startActivityForResult()
  3. 洛谷1527(bzoj2738)矩阵乘法——二维树状数组+整体二分
  4. arcalet云服务平台支持Unity3D开发实时多人联机游戏
  5. Java中对象的复制
  6. 使用java多线程分批处理数据工具类
  7. Java导入导出Excel工具类ExcelUtil
  8. windows7 shift+右键 “在此处打开命令窗口”
  9. krita绘图_如何使用Krita制作动画视频
  10. 万圣节头像挂件微信小程序前端
  11. html5 main form 结合,web组件之表单(HTML5)
  12. GOTC 大会预告 | Apache Pulsar PMC 成员翟佳:Apache Pulsar 架构设计与原理
  13. 该怎么回答面试官问“你有什么优缺点?”
  14. java + jfreechart + itextpdf创建折线图饼图并导出为pdf
  15. 搭建通过openOCD下载mini2440程序的调试平台
  16. 2021年后国内互联网发展趋势预测
  17. GD32F4xx uIP协议栈移植记录
  18. 安装KALI里面的翻译工具
  19. jquery的ajax()方法与生命周期
  20. 口布杯花的60种叠法_杯花折叠方法

热门文章

  1. 14条建议,使你的IT职业生涯更上一层楼
  2. 使用Silverlight for Embedded开发绚丽的界面(4)
  3. 如果C++程序要调用已经被编译后的C函数,该怎么办?
  4. WINCE6.0+S3C2443自动重启的实现
  5. 论文笔记之:Multiple Feature Fusion via Weighted Entropy for Visual Tracking
  6. java.两个例子充分阐述多态的可拓展性
  7. Bootstrap 手风琴搭配导航条实现常用菜单栏
  8. sql server left join 重复数据原因图
  9. 解决不是有效的win32应用程序
  10. 【STM32 .Net MF开发板学习-14】红外遥控器编码识别