CGAN-条件GAN

原始GAN的缺点:

生成的图像是随机的,不可预测的,无法控制网络的输出特定的图片,生成目标不明确,可控性不强。

针对原始GAN不能生成具有特定属性的图片的问题

CGAN核心在于将属性信息融入生成器G和判别器D中,属性y可以使任何标签信息,例如图像的类别、人脸图像的面部表情等
CGAN中心思想是希望可以控制GAN生成的图片,而不是单出的随机生成图片。具体来说,Conditional GAN在生成器和判别器的输入中增加了额外的条件信息,生成器生成的图片只有足够真实且与条件相符,才能够通过判别器

模型部分

在判别器和生成器中都添加了额外信息y, y可以是类别标签或者其他类型的数据,可以将y作为一个额外的输入层丢入判别器和生成器
生成器中,z和y连在一起隐含表示,带约束条件这个简单直接的改进被证明非常有效。
论文作者在cGAN对于图像自动标注的多模态学习上的应用,在MIRFlickr25000数据集上,以图像特征为条件变量,生成该图像的tag的词向量。

缺陷:

cGAN生成的图像边缘模糊,分辨率低,但是为后面的pix2pixGAN和Cycle-GAN开拓了道路
cGAN将无监督学习转变为有监督学习

模型实现

由于引入类别y,将数据集中target字段值变为独热标签形式。(注意返回向量格式)
并在创建数据集时,写入target_transfrom字段中
注意:由于有y的存在,所以数据集加载过程中都需要完整的批次,加入drop_last字段值

def onehot(x, class_count=10):return torch.eye(class_count)[x, :]
real_dataset = datasets.MNIST("mnist", train=True, target_transform=onehot,transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]))
real_dataload = DataLoader(real_dataset,batch_size=64, shuffle=True, drop_last=True)

生成器

由于添加了类别标签,数据集只有10个类别,类别标签y大小也为[10]
噪声100长度向量通过线性扩大为128 * 7 * 7,y同理扩大为128 * 7 * 7,易于合并成[256, 7, 7]大小向量,之后和DCGAN模型同理。

class Generate(nn.Module):def __init__(self) -> None:super().__init__()self.genModel1 = nn.Sequential(nn.Linear(100,  128 * 7 * 7 ),nn.ReLU(),nn.BatchNorm1d(128 * 7 * 7),)self.genModel2 = nn.Sequential(nn.Linear(10,  128 * 7 * 7 ),nn.ReLU(),nn.BatchNorm1d(128 * 7 * 7),)self.genModel3 = nn.Sequential(# [128, 28, 28]nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.BatchNorm2d(128),# [64, 14, 14] nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),nn.ReLU(),nn.BatchNorm2d(64),# [1, 28, 28]nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, z, tar):z = self.genModel1(z)z = torch.reshape(z, (-1, 128, 7, 7))tar = self.genModel2(tar)tar = torch.reshape(tar, (-1, 128, 7, 7))input = torch.cat([z, tar], 1)input = self.genModel3(input)return input

判别器

判别器同样需要输入标签y,处理思想一样,将y和输入img放缩成相同大小向量,这里通过线性层将y扩大为28 * 28长度,reshape成28 * 28大小图片,将img和y合并为[2, 28, 28]放入判别器,之后和DCGAN处理相同

class Discriminator(nn.Module):def __init__(self) -> None:super().__init__()self.linear = nn.Linear(10, 1 * 28 * 28)self.dismodel1 = nn.Sequential(nn.Conv2d(2, 64, kernel_size=3, stride=2),nn.LeakyReLU(),nn.Dropout2d(),nn.Conv2d(64, 128, kernel_size=3, stride=2),nn.LeakyReLU(),nn.Dropout2d(),nn.BatchNorm2d(128),nn.Flatten(),nn.Linear(128 * 6 * 6,1),nn.Sigmoid())def forward(self, img, tar):tar = self.linear(tar)tar = torch.reshape(tar,(-1, 1, 28, 28))input = torch.cat([img, tar], 1)input = self.dismodel1(input)return input

可视化

和DCGAN处理有点区别,需要将标签y作为参数放入

def generate_and_save_images(model, epoch, label_input, noise_input):predictions = np.squeeze(model(noise_input, label_input).detach().cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow((predictions[i] + 1)/2, cmap='gray')plt.axis('off')plt.savefig(f'./train_mnist/image_at_epoch_{format(epoch)}.png')plt.show()

优化函数及测试用例

gen = Generate()
dis = Discriminator()
gen_optim = optim.Adam(gen.parameters(), lr = 1e-4)
dis_optim = optim.Adam(dis.parameters(), lr = 1e-5)
loss = nn.BCELoss()
write = SummaryWriter('cGAN_log')
noise_input = torch.randn(16, 100)
label_input = torch.randint(0, 10,size=(16,))
label_input_onehot = onehot(label_input)

训练数据集,和GAN没啥区别就是添加了y标签进去

for epoch in range(100):train_step = 0for item in real_dataload:dis_loss = 0dis_optim.zero_grad()img,target = itemimg_dis = dis(img, target)img_dis_loss = loss(img_dis, torch.ones_like(img_dis))img_dis_loss.backward()z = torch.randn((64, 100))z_gen = gen(z, target)z_dis = dis(z_gen.detach(), target.detach())z_dis_loss = loss(z_dis, torch.zeros_like(z_dis))z_dis_loss.backward()dis_optim.step()dis_loss = z_dis_loss + img_dis_lossgen_optim.zero_grad()z_out = dis(z_gen, target)z_gen_loss = loss(z_out, torch.ones_like(z_out))z_gen_loss.backward()gen_optim.step()train_step += 1write.add_scalar("分类器", dis_loss, train_step)write.add_scalar("生产器", z_gen_loss, train_step)print(label_input)generate_and_save_images(gen, epoch, label_input_onehot, noise_input)if epoch > 20 and epoch % 20 == 0:torch.save(gen, f"gen_method_{epoch}.pth")torch.save(dis, f"dis_method_{epoch}.pth")
write.close()

训练结果

[3, 3, 5, 1, 3, 7, 6, 1, 9, 7, 5, 3, 9, 0, 6, 5]
epoch1

epoch40
epoch80:

epoch 120
损失变化:

cGAN网络的基本实现(Mnist数字集)相关推荐

  1. 【Deep Learning】MLP识别手写 MNIST数字集

    # -*- coding: utf-8 -*- ''' Created on 2018年4月19日 @author: userhttps://github.com/mnielsen/neural-ne ...

  2. python识别数字程序_详解python实现识别手写MNIST数字集的程序

    我们需要做的第⼀件事情是获取 MNIST 数据.如果你是⼀个 git ⽤⼾,那么你能够通过克隆这本书的代码仓库获得数据,实现我们的⽹络来分类数字 git clone https://github.co ...

  3. 手写数字集MNIST(1)下载

    ML界的Hello World:手写数字集MNIST(1)下载 下载.显示MNIST数据集 import tensorflow as tf import matplotlib.pyplot as pl ...

  4. 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集

    一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...

  5. DL之DCGAN:基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成

    DL之DCGAN:基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成 目录 基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成 设计思路 ...

  6. 每天一道LeetCode-----将数字集转成字母集,计算有多少种转换方式

    Decode Ways 原题链接Decode Ways 每一个数字和一个字母对应,总共有26个字母,对于每一个或每两个数字,都有可能将其转化成字母,计算有多少中转换形式 以1221为例,所有的转换形式 ...

  7. 【对讲机的那点事】你了解TETRA数字集群通信系统组网的模式吗?

    tetra数字集群通信系统,是一种基于数字时分多址(tdma)技术的无线集群移动通信系统.tetra是目前世界上最先进的陆地集群无线通信系统,被公共安全部门.铁路.交通.大型企业等部门广泛的采用,用于 ...

  8. matlab求点介数程序,matlab_bgl 一个很有用的计算网络中每个节点介数的程序,对 分析 Cloud Computing 云 266万源代码下载- www.pudn.com...

    文件名称: matlab_bgl下载  收藏√  [ 5  4  3  2  1 ] 开发工具: Others 文件大小: 2098 KB 上传时间: 2016-10-26 下载次数: 0 提 供 者 ...

  9. A、B、C三类IP地址的最大网络数和每个网络中的最大主机数

    A.B.C三类IP地址的最大网络数和每个网络中的最大主机数,为便于查找,总结如下: IP地址由两部分组成:网络号net-id 与 主机号host-id. 一.A类IP的最大网络数和每个网络中的最大主机 ...

最新文章

  1. 一个用于录制用户输入操作并实时回放的小工具
  2. [笔记]解决m2eclipse给项目添加maven依赖管理时可能不给项目的build path...
  3. PHP 找出数值数组中不重复最大的10个数和最小的10个数
  4. netcore一键部署到linux服务器以服务方式后台运行
  5. 知识管理系统Data Solution研发日记之十二 网页数据抓取Fetch,呈现Render,导出Export...
  6. codeforces 231A-C语言解题报告
  7. 技术者利用wordpress+阿里云服务器+LAMP新搭建的博客网站:www.youngxy.top
  8. win7一直显示正在关机_LG可编程控制器一直显示正在通信维修选凌科公司规模大...
  9. SQLSERVER中RANK OVER(PARTITION BY)的用法
  10. 小伙子自学C++编程简单DIY,即让你拥有一个屏幕画笔,非常实用!
  11. 牛津高阶字典ld2_奶爸1.6G Mdict词库的补充及在Bluedict中使用的心得
  12. 海康sdk远程门禁_海康威视远程监控Android端SDK调用示例
  13. 动词变名词的变化规则_动词accept变成名词-tion?那就错……多名词变化规律如下总结...
  14. 糗事百科成人版段子爬虫实战
  15. 技术人的职场晋级指南:当心“1万小时定律”毁了你!
  16. ReThought (二): 如何照顾团队中的新人
  17. 股价跳水20%,市值缩水1230亿美元?Facebook财报会议告诉你原因
  18. 【requests库】爬取Pixiv日榜图片 并保存到本地
  19. TCP/IP中的TTL
  20. HashMap的链表结构

热门文章

  1. 电化学传感器(1)原理(2)---设计恒电位电路
  2. Codesys禾川Q1配置SV-X3EB实现单轴控制
  3. vue3调用百度地图标注选择位置并获取经纬度
  4. 数学常数e的含义 (转载)
  5. LeetCode 374题
  6. 一文透析腾讯云如何为企业构建「数据全生命周期保护」
  7. Vim Cscope配置与使用
  8. 2021-07-04 IP地址与子网掩码
  9. MacOS M1 安装riscv toolchain
  10. 计算机中如何取消家长控制用户,Win7系统无法更改家长控制选项怎么解决