github:https://github.com/SPECTRELWF/pytorch-GAN-study
个人主页:liuweifeng.top:8090

网络结构

最近在疯狂补深度学习一些基本架构的基础,看了一下大佬的GAN的原始论文,说实话一头雾水,不是能看的很懂。推荐B站李宏毅老师的机器学习2021的课程,听完以后明白多了。原始论文中就说了一个generator和一个discriminator的结构,并没有细节的说具体是怎么去定义的,对新手不太友好,参考了Github的Pytorch-Gan-master仓库的代码,做了一下照搬吧,照着敲一边代码就明白了GAN的思想了。网上找了一张稍微好点的网络结构图:

因为生成对抗网络需要去度量两个分布之间的距离,原始的GAN并没有一个很好的度量,具体细节可以看李宏毅老师的课。导致GAN的训练会比较困难,而且整个LOSS是基本无变化的,但从肉眼上还是能清楚的看到生成的结果在变好。

数据集介绍

使用的是经典的MNIST数据集,后期会拿一些人脸数据集来做实验。

generator

# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(* block(opt.latent_dim,128,normalize=False),* block(128,256),* block(256,512),* block(512,1024),nn.Linear(1024,int(np.prod(image_shape))),nn.Tanh())def forward(self,z):img = self.model(z)img = img.view(img.size(0),*image_shape)return img

discriminator

class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(image_shape)),512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256,1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0),-1)validity = self.model(img_flat)return validity

完整代码:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/14 下午3:05import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs('new_images', exist_ok=True)parser = argparse.ArgumentParser()  # 添加参数parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=1024, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")opt = parser.parse_args()
print(opt)image_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else False# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(* block(opt.latent_dim,128,normalize=False),* block(128,256),* block(256,512),* block(512,1024),nn.Linear(1024,int(np.prod(image_shape))),nn.Tanh())def forward(self,z):img = self.model(z)img = img.view(img.size(0),*image_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(image_shape)),512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256,1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0),-1)validity = self.model(img_flat)return validity# lossadversarial_loss = torch.nn.BCELoss()#初始化G和D
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# loaddata
os.makedirs("data/mnist",exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("data/mnist",train = True,download=True,transform = transforms.Compose([transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5]),])),batch_size=opt.batch_size,shuffle = True
)optimizer_G = torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor#train
for epoch in range(opt.n_epochs):for i ,(imgs,_) in enumerate(dataloader):valid = Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad = False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)real_imgs = Variable(imgs.type(Tensor))optimizer_G.zero_grad()z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim))))gen_imgs = generator(z)g_loss = adversarial_loss(discriminator(gen_imgs),valid)g_loss.backward()optimizer_G.step()#train Discriminatoroptimizer_D.zero_grad()real_loss = adversarial_loss(discriminator(real_imgs),valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),fake)d_loss = (real_loss+fake_loss)/2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:1024], "new_images/%d.png" % batches_done, nrow=32, normalize=True)torch.save(generator.state_dict(),"G.pth")
torch.save(discriminator.state_dict(),"D.pth")

结果

GAN网络的训练是比较困难的,我设置批大小为1024,训练了两百个epoch,给出一些结果。
第0次迭代:

基本上就是纯纯噪声了,初始数据采样来源于标准正态分布。

第400次迭代:

第10000次迭代:

第186800次迭代:

此时就已经基本有了数字的样子了

Pytorch:GAN生成对抗网络实现MNIST手写数字的生成相关推荐

  1. 生成对抗网络(GAN)——MNIST手写数字生成

    前言 正文 一.什么是GAN 二.GAN的应用 三.GAN的网络模型 对抗生成手写数字 一.引入必要的库 一.引入必要的库 二.进行准备工作 三.定义生成器和判别器模型 四.设置损失函数和优化器,以及 ...

  2. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

  3. GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字

    有关条件GAN(cgan)的相关原理,可以参考: GAN系列之CGAN原理简介以及pytorch项目代码实现 其他类型的GAN原理介绍以及应用,可以查看我的GANs专栏 一.数据集介绍,加载数据 依旧 ...

  4. 对抗生成网络学习(十三)——conditionalGAN生成自己想要的手写数字(tensorflow实现)

    一.背景 其实我原本是不打算做这个模型,因为conditionalGAN能做的,infoGAN也能做,infoGAN我在之前的文章中写到了:对抗神经网络学习(五)--infoGAN生成宽窄不一,高低各 ...

  5. 用PyTorch实现MNIST手写数字识别(非常详细)

    ​​​​​Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...

  6. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  7. PyTorch基础入门五:PyTorch搭建多层全连接神经网络实现MNIST手写数字识别分类

    )全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC. FC的准则很简单:神经网络中除输入层之外的每个节点都和上一层的所有节点有连 ...

  8. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!

    今天我们将使用 Pytorch 来实现 LeNet-5 模型,并用它来解决 MNIST数据集的识别. 正文开始! 一.使用 LeNet-5 网络结构创建 MNIST 手写数字识别分类器 MNIST是一 ...

  9. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

最新文章

  1. java怎么导入别人的代码_怎么用eclipse将图标导入到java代码中
  2. matlab 级联cic,Matlab中CIC滤波器的应用
  3. mongodb 安装、开启服务 和 php添加mongodb扩展
  4. js 每隔四位加一个空格
  5. Linux内核和应用层程序通信get/setsockopt示例
  6. PHP中call user func()和call_user_func_array()调用自定义函数小结
  7. 欧拉路径(Euler_Path)和欧拉回路(Euler_Loop)
  8. Repeater的嵌套结合用户控件的使用
  9. 漫画:凌晨2点,老板在工作群@了我...
  10. JS获取对象的第一个值
  11. 使用Java语言借助Quartz jar包实现定时器的方法
  12. 何为监督学习、无监督学习、强化学习、弱监督学习、半监督学习、多示例学习?
  13. win10 无法安装Hp1020和HP1106打印机问题
  14. 公网远程开机(唤醒家庭PC)
  15. matlab模拟投硬币实验,利用几何画板模拟抛硬币实验
  16. 图片:“给你五十行代码把我变成字符画!” 程序:“太多了,一半都用不完!”
  17. matlab解决力学问题程序,力学专业程序实践:用MATLAB解决力学问题的方法与实例...
  18. 太阳诱电 | 汽车用金属功率电感器MCOIL™ LCEN 系列实现商品化
  19. 模拟退火算法+大规模邻域算法求解大规模固定节点的路径规划问题matlab代码
  20. 钢铁集团的混合云灾备

热门文章

  1. 微信小程序实现文件下载 以及微信小程序保存Excel
  2. 大讲台大数据特训学习笔记
  3. 零样本性能超越GPT-3!谷歌提出1370亿参数自回归语言模型
  4. 从6篇经典论文看问题生成及其相关技术
  5. QT接收Linux内核,嵌入式linux上QT标准键盘输入的实现
  6. python读取word指定内容_python解析html提取数据,并生成word文档实例解析
  7. Docker操作容器2
  8. MyBatis——insert错误[Could not set property ‘id‘ of ‘class‘ with value ‘xxx‘]解决方案
  9. 《计算机组成原理》课程设计报告——TEC-2实验系统——微程序设计
  10. VS Code——Live Server的简介、安装与使用