有关条件GAN(cgan)的相关原理,可以参考:

GAN系列之CGAN原理简介以及pytorch项目代码实现

其他类型的GAN原理介绍以及应用,可以查看我的GANs专栏

一、数据集介绍,加载数据

依旧使用到的是我们的老朋友-----MNIST手写数字数据集,  本文不再详细做介绍

相关数据集介绍可以参考:深度学习入门--MNIST数据集及创建自己的手写数字数据集

传统GAN生成手写数字参考:入门GAN实战---生成MNIST手写数据集代码实现pytorch

DCGAN生成手写数字参考:Pytorch 使用DCGAN生成MNIST手写数字 入门级教程

# 独热编码
def one_hot(x, class_count=10):return torch.eye(class_count)[x, :] transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])# 这个数据集其实包含两部分,第一部分是数据,第二部分是标签 print(dataset[0])
#这个第二部分就是我们所需要的condition,这个condition是数值类型,1就是1,2就是2。
#作为输入的condition并不是很合适,一种处理方法就是作为一种向量输入,就是独热编码化。
#比如说现在有10个类别,10个类别将被独热编码为长度为10的tensor,使用这个tensor作为我们的condition是比较合适的dataset = torchvision.datasets.MNIST('data',train=True,transform=transform,target_transform=one_hot,download=False)
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)

这里有个小技巧是作者用到独热编码化

One-Hot编码,又称为一位有效编码,主要是采用位状态寄存器来对个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。独热编码 是利用0和1表示一些参数,使用N位状态寄存器来对N个状态进行编码。

例如:参考数字手写体识别中:如数字字体识别0~9中,6的独热编码为

0000001000

二、定义生成器

# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(10, 128 * 7 * 7)self.bn1 = nn.BatchNorm1d(128 * 7 * 7)self.linear2 = nn.Linear(100, 128 * 7 * 7)self.bn2 = nn.BatchNorm1d(128 * 7 * 7)self.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3, 3),padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4, 4),stride=2,padding=1)self.bn4 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 1,kernel_size=(4, 4),stride=2,padding=1)def forward(self, x1, x2): # X1代表label X2代表imagex1 = F.relu(self.linear1(x1))x1 = self.bn1(x1)x1 = x1.view(-1, 128, 7, 7)x2 = F.relu(self.linear2(x2))x2 = self.bn2(x2)x2 = x2.view(-1, 128, 7, 7)x = torch.cat([x1, x2], axis=1)x = F.relu(self.deconv1(x))x = self.bn3(x)x = F.relu(self.deconv2(x))x = self.bn4(x)x = torch.tanh(self.deconv3(x))return x

三、定义判别器

# 定义判别器
# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.linear = nn.Linear(10, 1*28*28)self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值def forward(self, x1, x2):x1 =F.leaky_relu(self.linear(x1))x1 = x1.view(-1, 1, 28, 28)x = torch.cat([x1, x2], axis=1)x = F.dropout2d(F.leaky_relu(self.conv1(x)))x = F.dropout2d(F.leaky_relu(self.conv2(x)))x = self.bn(x)x = x.view(-1, 128*6*6)x = torch.sigmoid(self.fc(x))return x

四、完整代码展示

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image# 独热编码
# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):return torch.eye(class_count)[x, :]  # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data',train=True,transform=transform,target_transform=one_hot,download=False)
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(10, 128 * 7 * 7)self.bn1 = nn.BatchNorm1d(128 * 7 * 7)self.linear2 = nn.Linear(100, 128 * 7 * 7)self.bn2 = nn.BatchNorm1d(128 * 7 * 7)self.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3, 3),padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4, 4),stride=2,padding=1)self.bn4 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 1,kernel_size=(4, 4),stride=2,padding=1)def forward(self, x1, x2):x1 = F.relu(self.linear1(x1))x1 = self.bn1(x1)x1 = x1.view(-1, 128, 7, 7)x2 = F.relu(self.linear2(x2))x2 = self.bn2(x2)x2 = x2.view(-1, 128, 7, 7)x = torch.cat([x1, x2], axis=1)x = F.relu(self.deconv1(x))x = self.bn3(x)x = F.relu(self.deconv2(x))x = self.bn4(x)x = torch.tanh(self.deconv3(x))return x# 定义判别器
# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.linear = nn.Linear(10, 1*28*28)self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值def forward(self, x1, x2):x1 =F.leaky_relu(self.linear(x1))x1 = x1.view(-1, 1, 28, 28)x = torch.cat([x1, x2], axis=1)x = F.dropout2d(F.leaky_relu(self.conv1(x)))x = F.dropout2d(F.leaky_relu(self.conv2(x)))x = self.bn(x)x = x.view(-1, 128*6*6)x = torch.sigmoid(self.fc(x))return x# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)# 损失计算函数
loss_function = torch.nn.BCELoss()# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i + 1)plt.imshow((predictions[i] + 1) / 2, cmap='gray')plt.axis("off")plt.show()
noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)
print(label_seed)
# print(label_seed_onehot)# 开始训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(150):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader.dataset)# 对全部的数据集做一次迭代for step, (img, label) in enumerate(dataloader):img = img.to(device)label = label.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device)d_optim.zero_grad()real_output = dis(label, img)d_real_loss = loss_function(real_output,torch.ones_like(real_output, device=device))d_real_loss.backward() #求解梯度# 得到判别器在生成图像上的损失gen_img = gen(label,random_noise)fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果d_fake_loss = loss_function(fake_output,torch.zeros_like(fake_output, device=device))d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optim.step()  # 优化# 得到生成器的损失g_optim.zero_grad()fake_output = dis(label, gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output, device=device))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_loss.item()g_epoch_loss += g_loss.item()with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)if epoch % 10 == 0:print('Epoch:', epoch)generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)

五、结果展示

随机生成的条件

根据这个条件生成的手写数字

GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字相关推荐

  1. Pytorch入门——MNIST手写数字识别代码

    MNIST手写数字识别教程 本文仅仅放出该教程的代码 具体教程请看 Pytorch入门--手把手教你MNIST手写数字识别 import torch import torchvision from t ...

  2. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

  3. MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测

    Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...

  4. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 (zz)

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 我想写一系列深度学习的简单实战教程,用mxnet做实现平台的实例代码简单讲解深度学习常用的一些技术方向和实战样例.这 ...

  5. 用PyTorch实现MNIST手写数字识别(非常详细)

    ​​​​​Keras版本: Keras入门级MNIST手写数字识别超级详细教程 2022/4/17 更新修复下代码.完善优化下文章结构,文末提供一个完整版代码. 可以在这里下载源码文件(免积分): 用 ...

  6. FlyAi实战之MNIST手写数字识别练习赛(准确率99.55%)

    欢迎关注WX公众号:[程序员管小亮] 文章目录 欢迎关注WX公众号:[程序员管小亮] 一.介绍 二.代码实现 1_数据加载 2_归一化 3_定义网络结构 4_设置优化器和退火函数 5_数据增强 6_拟 ...

  7. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  8. PyTorch基础入门五:PyTorch搭建多层全连接神经网络实现MNIST手写数字识别分类

    )全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC. FC的准则很简单:神经网络中除输入层之外的每个节点都和上一层的所有节点有连 ...

  9. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

最新文章

  1. Java继承Exception自定义异常类教程以及Javaweb中用Filter拦截并处理异常
  2. handler消息机制入门
  3. 阿里云服务器(BT面板)Vue+Node(Egg)部署流程
  4. 连发Science和Nature, 王二涛研究员:推倒教科书里的“围墙”
  5. Java面试之线程池详细
  6. java中的VO、PO、BO、DAO、POJO
  7. 423.从英文中重建数字
  8. 【全套完结】数字信号处理----全套Matlab实验报告【建议保存】
  9. pcie16x能插1x的卡嘛?_存储先锋,雷克沙633x TF卡评测
  10. Excel表VLOOKUP多个条件匹配数据
  11. sqlite报错database is locked
  12. VS2008 简体中文正式版序列号(到期解决办法)
  13. 融云集成一个聊天室页面(vue版本)
  14. 概率论在实际生活的例子_生活中有趣的概率论例子
  15. PDF文件太大了,如何免费压缩变小?
  16. 天猫2月咖啡行业数据分析(咖啡品牌销量排行)
  17. HTML5 基础练习题总结(一)
  18. 什么是工业AGV导航读码器?用在什么地方?
  19. 多元线性回归分析理论详解及SPSS结果分析
  20. css弹性盒模型详解----flex-direction

热门文章

  1. time()等时间函数的使用
  2. Navicat查询结果中复制字段名和值
  3. 买三种文具编程C语言,学生党公认“最没用”的三种文具,学生:中看不中用,谁买谁吃亏...
  4. 分享本周所学——Transformer模型详解
  5. 人脸检测算法对比分析
  6. ERDAS IMAGINE 2014 32位 破解安装
  7. 百趣生物技术介绍 | iTRAQ/TMT标记定量蛋白质组研究
  8. java二叉树的深度_Java实现二叉树的深度计算
  9. centos7 转换为lvm_从CentOS7默认安装的/home中转移空间到根目录/ - LVM操作简明教程...
  10. GitBook插件整理