Jittor实现Conditional GAN
Generative Adversarial Nets(GAN)提出了一种新的方法来训练生成模型。然而,GAN对于要生成的图片缺少控制。Conditional GAN(CGAN)通过添加显式的条件或标签,来控制生成的图像。本文讲解了CGAN的网络结构、损失函数设计、使用CGAN生成一串数字、从头训练CGAN、以及在mnist手写数字数据集上的训练结果。
CGAN网络架构
通过在生成器generator和判别器discriminator中添加相同的额外信息y,GAN就可以扩展为一个conditional模型。y可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。可以通过将y作为额外输入层,添加到生成器和判别器来完成条件控制。
在生成器generator中,除了y之外,还额外输入随机一维噪声z,为结果生成提供更多灵活性。

损失函数
GAN的损失函数
在解释CGAN的损失函数之前,首先介绍GAN的损失函数。下面是GAN的损失函数设计。

对于判别器D,要训练最大化这个loss。如果D的输入是来自真实样本的数据x,则D的输出D(x)要尽可能地大,log(D(x))也会尽可能大。如果D的输入是来自G生成的假图片G(z),则D的输出D(G(z))应尽可能地小,从而log(1-D(G(z))会尽可能地大。这样可以达到max D的目的。
对于生成器G,要训练最小化这个loss。对于G生成的假图片G(z),希望尽可能地骗过D,让它觉得生成的图片就是真的图片,这样就达到了G“以假乱真”的目的。那么D的输出D(G(z))应尽可能地大,从而log(1-D(G(z))会尽可能地小。这样可以达到min G的目的。
D和G以这样的方式联合训练,最终达到G的生成能力越来越强,D的判别能力越来越强的目的。
CGAN的损失函数
下面是CGAN的损失函数设计。

很明显,CGAN的loss跟GAN的loss的区别就是多了条件限定y。D(x/y)代表在条件y下,x为真的概率。D(G(z/y))表示在条件y下,G生成的图片被D判别为真的概率。
Jittor代码数字生成
首先,导入需要的包,并且设置好所需的超参数:
import jittor as jt
from jittor import nn
import numpy as np
import pylab as pl

%matplotlib inline

隐空间向量长度

latent_dim = 100

类别数量

n_classes = 10

图片大小

img_size = 32

图片通道数量

channels = 1

图片张量的形状

img_shape = (channels, img_size, img_size)
第一步,定义生成器G。该生成器输入两个一维向量y和noise,生成一张图片。
class Generator(nn.Module):
def init(self):
super(Generator, self).init()
self.label_emb = nn.Embedding(n_classes, n_classes)

    def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2))return layersself.model = nn.Sequential(*block((latent_dim + n_classes), 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())def execute(self, noise, labels):gen_input = jt.contrib.concat((self.label_emb(labels), noise), dim=1)img = self.model(gen_input)img = img.view((img.shape[0], *img_shape))return img

第二步,定义判别器D。D输入一张图片和对应的y,输出是真图片的概率。
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear((n_classes + int(np.prod(img_shape))), 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(512, 1))

def execute(self, img, labels):d_in = jt.contrib.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1)validity = self.model(d_in)return validity

第三步,使用CGAN生成一串数字。
代码如下。可以使用训练好的模型来生成图片,也可以使用提供的预训练参数: 模型预训练参数下载:https://cloud.tsinghua.edu.cn/d/fbe30ae0967942f6991c/。

下载提供的预训练参数

!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl
!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl
生成自定义的数字:

定义模型

generator = Generator()
discriminator = Discriminator()
generator.eval()
discriminator.eval()

加载参数

generator.load(’./generator_last.pkl’)
discriminator.load(’./discriminator_last.pkl’)

定义一串数字

number = “201962517”
n_row = len(number)
z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad()
labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad()
gen_imgs = generator(z,labels)

pl.imshow(gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)))
生成结果如下,测试的完整代码在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/test.py。

从头训练Condition GAN
从头训练 Condition GAN 的完整代码在https://github.com/Jittor/gan-jittor/blob/master/models/cgan/cgan.py,下载下来看看!
!wget https://raw.githubusercontent.com/Jittor/gan-jittor/master/models/cgan/cgan.py
!python3.7 ./cgan.py --help

选择合适的batch size,运行试试

运行命令: !python3.7 ./cgan.py --batch_size 8

下载下来的代码里面定义损失函数、数据集、优化器。损失函数采用MSELoss、数据集采用MNIST、优化器采用Adam 如下(此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py):

此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py

Define Loss

adversarial_loss = nn.MSELoss()

Define Model

generator = Generator()
discriminator = Discriminator()

Define Dataloader

from jittor.dataset.mnist import MNIST
import jittor.transform as transform
transform = transform.Compose([
transform.Resize(opt.img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5]),
])
dataloader = MNIST(train=True, transform=transform).set_attrs(batch_size=opt.batch_size, shuffle=True)

optimizer_G = nn.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = nn.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
模型训练的代码如下(此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py):

此段代码仅仅用于解释意图,不能运行,需要运行请运行完整文件cgan.py

valid表示真,fake表示假

valid = jt.ones([batch_size, 1]).float32().stop_grad()
fake = jt.zeros([batch_size, 1]).float32().stop_grad()

真实图像和对应的标签

real_imgs = jt.array(imgs)
labels = jt.array(labels)

#########################################################

训练生成器G

- 希望生成的图片尽可能地让D觉得是valid

#########################################################

随机向量z和随机生成的标签

z = jt.array(np.random.normal(0, 1, (batch_size, opt.latent_dim))).float32()
gen_labels = jt.array(np.random.randint(0, opt.n_classes, batch_size)).float32()

随机向量z和随机生成的标签经过生成器G生成的图片,希望判别器能够认为生成的图片和生成的标签是一致的,以此优化生成器G的生成能力。

gen_imgs = generator(z, gen_labels)
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
g_loss.sync()
optimizer_G.step(g_loss)

#########################################################

训练判别器D

- 尽可能识别real_imgs为valid

- 尽可能识别gen_imgs为fake

#########################################################

真实的图片和标签经过判别器的结果,要尽可能接近valid。

validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)

G生成的图片和对应的标签经过判别器的结果,要尽可能接近fake。

validity_fake = discriminator(gen_imgs.stop_grad(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)

d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.sync()
optimizer_D.step(d_loss)
MNIST数据集训练结果
下面展示了Jittor版CGAN在MNIST数据集的训练结果。下面分别是训练0 epoch和90 epoches的结果。

使用Jittor实现Conditional GAN相关推荐

  1. 【李宏毅2020 ML/DL】P75 Generative Adversarial Network | Conditional GAN

    我已经有两年 ML 经历,这系列课主要用来查缺补漏,会记录一些细节的.自己不知道的东西. 这是我的李宏毅老师 GAN 系列第2篇笔记,GAN 系列: 1:Basic Idea 本节内容综述 本节课讲解 ...

  2. 四天搞懂生成对抗网络(二)——风格迁移的“精神始祖”Conditional GAN

    点击左上方蓝字关注我们 [飞桨开发者说]吕坤,唐山广播电视台,算法工程师,喜欢研究GAN等深度学习技术在媒体.教育上的应用. 从"自由挥洒"到"有的放矢" 1. ...

  3. GAN小结(BEGAN EBGAN WGAN CycleGAN conditional GAN DCGAN PGGAN VAEGAN)

    断断续续看了生成对抗网络一些日子,下面把我比较感兴趣也算是我认为效果比较好的GAN进行简单梳理,其中会参考众多前辈的文章,主要包括 1.EBGAN 原文 https://arxiv.org/pdf/1 ...

  4. 李弘毅老师GAN笔记(二),Conditional GAN

    本文记录的是李弘毅老师的GAN的第二节课,bilibili视频链接为:https://www.bilibili.com/video/av24011528/?p=2,相关的课程课件链接为:http:// ...

  5. 【飞桨PaddlePaddle】四天搞懂生成对抗网络(二)——风格迁移的“精神始祖”Conditional GAN

    从"自由挥洒"到"有的放矢" 1.给GAN加个"按钮" 上一篇<四天搞懂生成对抗网络(一)--通俗理解经典GAN>中,我们实现了 ...

  6. CycleGAN和Conditional GAN(cGAN)

    当谈到CycleGAN和Conditional GAN(cGAN)时,我们涉及到生成对抗网络(GAN)的不同变体.让我逐步介绍它们的原理和应用. **CycleGAN**: CycleGAN是一种无监 ...

  7. Conditional GAN

    1. 概述 GAN的出现为数据生成提供了一种新的思路,以图像生成为例,在GAN框架中,不再需要其他复杂的操作,如Markov Chain等,只需要对定义好的网络进行backpropagation即可完 ...

  8. 【轩说AI】生成模型(2)—— GAN生成对抗网络 + WGAN + Conditional GAN + Cycle GAN

    文章目录 GAN生成对抗网络(Generative Adversarial Network) 神经网络的本质就是一个函数,一个用于拟合的函数 生成模型面临的前所未有的问题 GAN解决这一问题的思想 O ...

  9. 2020李宏毅机器学习笔记-Generative Adversarial Network - Conditional GAN

    目录 摘要 1. Text-to-Image 1.1 Traditional supervised approach 1.2 Conditional GAN 1.3 Conditional GAN - ...

最新文章

  1. 用mysqldump远程备份mysql数据库
  2. matlab把图片转为base64
  3. 美国国防部CIO:企业IT可考虑用开源做数据分析
  4. Android中Touch事件分析--解决HorizontalScrollView滑动和按钮事件触发问题
  5. 7-37 图形卡片排序游戏 (40 分)
  6. 里程碑 .Net7再更新,从此彻底碾压Java!
  7. 学棋五年的我输给了昇腾CANN
  8. 479 Largest Palindrome Product 最大回文数乘积
  9. 运维,请警惕脚本灾难!
  10. matlab绘制vti群速度,声波测井中的相速度与群速度讨论.PDF
  11. 1.ESP32文件夹配置,创建自己的工程 Vscode+Idf插件
  12. Mac相关配置(本地host,端口被占用)
  13. 抖音直播伴侣和快手直播伴侣如何实现无人直播还能防封号
  14. Java Fork/Join与协程
  15. 【golang】time: missing Location in call to Date
  16. eclispe如何运行jsp文件
  17. 陈力:传智播客古代 珍宝币 泡泡龙游戏开发第48讲:PHP绘图技术
  18. 基于Dotnetty和Protobuf的Unity客户端开发(一)
  19. 后摩尔定律时代的计算机性能提升之道
  20. 飞机大战游戏---Pygame

热门文章

  1. mac OneNote恢复历史记录
  2. 【Docker】registry部署docker私有镜像仓库
  3. HJ86 求最大连续bit数
  4. C/C++ 编程规范(02)— 标识符命名
  5. 2022-2028年中国氟橡胶预混胶行业市场行情监测及未来前景规划报告
  6. Redis缓存穿透击穿雪崩
  7. 视觉导航的神经拓扑SLAM
  8. 每秒能捕捉万亿帧的相机
  9. 深度学习调参体验(一)
  10. 量子力学在计算机上的应用,量子力学在医学科学中的应用