1、摘要

本文主要讲解:GAN变种ACGAN利用手写数字识别mnist数据集进行训练,最终生成手写数字图片
主要思路:

  1. Initialize generator and discriminator
  2. Initialize weights
  3. Configure data loader
  4. Optimizers Adam
  5. Train Generator
  6. Train Discriminator
  7. Saves a grid of generated digits ranging from 0 to 9

2、数据介绍

mnist手写数字识别数据集

MNIST数据集由Yann LeCun搜集,是一个大型的手写体数字数据库,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。MNIST数字文字识别数据集数据量不会太多,而且是单色的图像,较简单,适合深度学习初学者练习建立模型、训练、预测。MNIST数据库中的图像集是NIST(National Institute of Standards and Technology)的两个数据库的组合:专用数据库1和特殊数据库3。数据集是有250人手写数字组成,一半是高中生,一半是美国人口普查局。

3、相关技术

本文主要使用pytorch实现ACGAN
pytorch在GitHub上的星星多余tensorflow了,tensorflow升级到2.0版导致以前的很多优秀库不兼容,这是硬伤

ACGAN是在CGAN基础上的进一步拓展,采用辅助分类器(Auxiliary Classifier)使得GAN获取的图像分类的功能。

CGAN通过结合标签信息来提高生成数据的质量,SGAN通过重建标签信息来提高生成数据的质量,那么我们可不可以两者都用,答案是显然的,因为ACGAN就是这样干的。更加详细的内容可以参见论文:Conditional Image Synthesis with Auxiliary Classifier GANs
————————————————
版权声明:本文为CSDN博主「时光碎了天」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:ACGAN 简介与代码实战

4、完整代码和步骤

ACGAN代码在4400个回合生成的手写数字如下:

主运行程序入口

import argparse
import osimport numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_imageos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)cuda = True if torch.cuda.is_available() else Falsedef weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)self.init_size = opt.img_size // 4  # Initial size before upsamplingself.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):gen_input = torch.mul(self.label_emb(labels), noise)out = self.l1(gen_input)out = out.view(out.shape[0], 128, self.init_size, self.init_size)img = self.conv_blocks(out)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, bn=True):"""Returns layers of each discriminator block"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return blockself.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# The height and width of downsampled imageds_size = opt.img_size // 2 ** 4# Output layersself.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensordef sample_image(n_row, batches_done):"""Saves a grid of generated digits ranging from 0 to n_classes"""# Sample noisez = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# Get labels ranging from 0 to n_classes for n rowslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# Adversarial ground truthsvalid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise and labels as generator inputz = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))# Generate a batch of imagesgen_imgs = generator(z, gen_labels)# Loss measures generator's ability to fool the discriminatorvalidity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Loss for real imagesreal_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# Loss for fake imagesfake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# Total discriminator lossd_loss = (d_real_loss + d_fake_loss) / 2# Calculate discriminator accuracypred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)

5、学习链接

既能生成图像又能进行分类的ACGAN

GAN变种ACGAN利用手写数字识别mnist生成手写数字相关推荐

  1. 手写数字识别案例、手写数字图片处理

    python_手写数字识别案例.手写数字图片处理 1.手写数字识别案例 步骤: 收集数据 带有标签的训练数据集来源于trainingDigits文件夹里面所有的文件,接近2000个文件,每个文件中有3 ...

  2. Pytorch实战1:LeNet手写数字识别 (MNIST数据集)

    版权说明:此文章为本人原创内容,转载请注明出处,谢谢合作! Pytorch实战1:LeNet手写数字识别 (MNIST数据集) 实验环境: Pytorch 0.4.0 torchvision 0.2. ...

  3. android 手写字体识别,一种基于Android系统的手写数学公式识别及生成MathML的方法...

    专利名称:一种基于Android系统的手写数学公式识别及生成MathML的方法 技术领域: 本发明属于模式识别技术领域,涉及数学公式中字符间的空间结构分析,具体涉及一种基于Android系统的手写数学 ...

  4. PyTorch手写字体识别MNIST

    手写字体识别MNIST 1.准备工作 可以看这个老师的视频进行学习,讲解的非常仔细:视频学习 2.项目代码 2.1 导入模块 # 1.加载相关库 import torch import torch.n ...

  5. matlab朴素贝叶斯手写数字识别_从“手写数字识别”学习分类任务

    机器学习问题可以分为回归问题和分类问题,回归问题已经在线性回归讲过,本文学习分类问题.分类问题跟回归问题有明显的区别,回归问题是连续的数值,而分类问题是离散的类别,比如将性别分为[男,女],将图片分为 ...

  6. Keras搭建CNN(手写数字识别Mnist)

    MNIST数据集是手写数字识别通用的数据集,其中的数据是以二进制的形式保存的,每个数字是由28*28的矩阵表示的. 我们使用卷积神经网络对这些手写数字进行识别,步骤大致为: 导入库和模块 我们导入Se ...

  7. 使用Pytorch实现手写数字识别(Mnist数据集)

    目标 知道如何使用Pytorch完成神经网络的构建 知道Pytorch中激活函数的使用方法 知道Pytorch中torchvision.transforms中常见图形处理函数的使用 知道如何训练模型和 ...

  8. Python 手写数字识别 MNIST数据集下载失败

    目录 一.MNIST数据集下载失败 1 失败的解决办法(经验教训): 2 亲测有效的解决方法: 一.MNIST数据集下载失败 场景复现:想要pytorch+MINIST数据集来实现手写数字识别,首先就 ...

  9. 手写字体识别 --MNIST数据集

    Matlab 手写字体识别 忙过这段时间后,对于上次读取的Matlab内部数据实现的识别,我回味了一番,觉得那个实在太小.所以打算把数据换成[MNIST数据集][1]. 基础思想还是相同的,使用Tre ...

最新文章

  1. 老李分享:HTTP协议之请求和响应
  2. 转:设置session过期时间
  3. 一文弄懂SSD目标检测算法
  4. C#WindowsForm之创建窗体
  5. ERROR 2002 (HY000): Can't connect to local MySQL server through socket '/var/lib/mysql/mysq
  6. 线性代数可以速成吗_怎么在一个晚上搞定线性代数?
  7. Java学习笔记:Javaweb的服务器介绍
  8. xmake 新站发布:xmake.io
  9. [JavaScript]项目优化总结
  10. java8安装_安装jenkins
  11. Java类的域初始化_Java域的初始化
  12. 你想要的宏基因组-微生物组知识全在这(2020.03)
  13. 中标麒麟(NeoKylin7)下达梦数据库(DM8)的安装部署
  14. 缓冲流、转换流、序列化流、装饰设计模式、comms-io工具包
  15. 穿山甲android对接错误码40029,头条 穿山甲广告 错误码列表
  16. 记录一次GeoTIFF文件二进制源码阅读
  17. 微信小程序开发工具win10下编译非常慢解决方法
  18. 弹性布局flex(兼容不同浏览器)
  19. 程序员职场规划:你的命运不是一头骡子
  20. 页面无限刷新,JS修改当前页面地址,是界面不再刷新

热门文章

  1. 双十一入手什么最好,经常失眠必入的助眠好物
  2. 互联网小白在网络上的成长。
  3. 大数据 hadoop
  4. 小度Wifi,360随身Wifi2,小米Wifi树莓派驱动下载
  5. 行业应用|工业AI视觉系统,助力物流行业智慧分拣加速升级
  6. 社区运营的道、法、术、器
  7. MacOS中afconvert的使用(音频格式转换)
  8. 子网掩码换算和计算网络号
  9. 世界上最常用的几种语言
  10. XHTML基础题及答案20道——必刷前端题目(背)