GAN

  • 网络结构
  • GAN 公式的理解
  • 简单线性 GAN 代码如下
  • 卷积 GAN 代码如下
  • Ref

网络结构

GAN 公式的理解

minGmaxDV(D,G)=Ex∼Pdata(x)[logD(x)]+Ez∼Pz(z)[log(1−D(G(z)))]min_Gmax_D V(D,G) = E_{x\sim P_{data}(x)} [logD(x)] + E_{z\sim P_{z}(z)}[log(1-D(G(z)))]minG​maxD​V(D,G)=Ex∼Pdata​(x)​[logD(x)]+Ez∼Pz​(z)​[log(1−D(G(z)))]

理解 GAN 公式是进一步理解 GAN 的必经过程,所以下面就来简单讲讲该公式。一开始我们要定义出判别器和生成器,这里将 DDD 定义为判别器,将 GGG 定义成生成器。接着要做的就是训练判别器,让它可以识别真实数据,也就有了 GAN 公式的前半部分。

Ex∼Pdata(x)[logD(x)]E_{x\sim P_{data}(x)}[logD(x)]Ex∼Pdata​(x)​[logD(x)]

其中,Ex∼Pdata(x)E_{x\sim P_{data}(x)}Ex∼Pdata​(x)​ 表示期望 xxx 从 PdataP_{data}Pdata​ 分布中获取;xxx 表示真实数据, PdataP_{data}Pdata​ 表示真实数据的分布。

前半部分的意思就是:判别器判别出真实数据的概率,判别器的目的就是要最大化这一项,简单来说,就是对于服从 PdataP_{data}Pdata​ 分布的 xxx,判别器可以准确得出 D(x)≈1D(x)\approx 1D(x)≈1。

接着看 GAN 公式略微复杂的后半部分。

Ez∼Pz(z)[log(1−D(G(z)))]E_{z\sim P_z(z)} [log(1-D(G(z)))]Ez∼Pz​(z)​[log(1−D(G(z)))]

其中,Ez∼Pz(z)E_{z\sim P_z(z)}Ez∼Pz​(z)​ 表示期望 zzz 是从 Pz(z)P_z(z)Pz​(z) 分布中获取;zzz 表示生成数据;Pz(z)P_z(z)Pz​(z) 表示生成数据的分布。

对于判别器 DDD 而言,如果向其输入的是生成数据,即 D(G(z))D(G(z))D(G(z)),判别器的目标就是最小化 D(G(z))D(G(z))D(G(z)),即判别器希望 D(G(z))≈0D(G(z))\approx 0D(G(z))≈0,也就是判别器希望 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z))) 最大化。

但对生成器来说,它的目标却与判别器相反,生成器希望自己生成的数据被判别器打上高分,即希望 D(G(z))≈1D(G(z))\approx 1D(G(z))≈1,也就是最小化 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z)))。生成器只能影响 GAN 公式的后半部分,对前半部分没有影响。

现在可以理解公式 V(D,G)=Ex∼Pdata(x)[logD(x)]+Ez∼Pz(z)[log(1−D(G(z)))]V(D,G) = E_{x\sim P_{data}(x)}[logD(x)] + E_{z\sim P_z(z)}[log(1-D(G(z)))]V(D,G)=Ex∼Pdata​(x)​[logD(x)]+Ez∼Pz​(z)​[log(1−D(G(z)))],但为什么 GAN 公式中还有 minGmaxDmin_Gmax_DminG​maxD​ 呢?

要理解 minGmaxDmin_Gmax_DminG​maxD​,就要先回忆一下 GAN 的训练流程。一开始,固定生成器 GGG 的参数专门去训练判别器 DDD。GAN 公式表达的意思也一样,先针对判别器 DDD 去训练,也就是最大化 D(x)D(x)D(x) 和 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z))) 的值,从而达到最大化 V(D,G)V(D,G)V(D,G) 的目的,表达如下:

DG⋆=argmaxDV(D,G)D_G^\star = argmax_D V(D,G)DG⋆​=argmaxD​V(D,G)

当训练完判别器 DDD 后,就会固定判别器 DDD 的参数去训练生成器 GGG,因为此时判别器已经经过一次训练了,所以生成器 GGG 的目标就变成:当 D=DG⋆D=D_G^\starD=DG⋆​ 时,最小化 log(1−D(G(z)))log(1-D(G(z)))log(1−D(G(z))) 的值,从而达到最小化 V(D,G)V(D,G)V(D,G)的目的。表达如下:

G⋆=argminGV(G,DG⋆)G^\star = argmin_G V(G,D_G^\star)G⋆=argminG​V(G,DG⋆​)

通过上面分成两步的分析,我们可以理解 minGmaxDmin_Gmax_DminG​maxD​ 的含义,简单来说,就是先从判别器 DDD 的角度最大化 V(D,G)V(D,G)V(D,G),再从生成器 GGG 的角度最小化 V(D,G)V(D,G)V(D,G)。

上边公式讲解中,大量使用对数,对数函数在它的定义域内是单调增函数,数据取对数后,并不会改变数据间的相对关系,这里使用对数是为了让计算更加方便。

Ref:《深入浅出GAN生成对抗网络》-廖茂文

简单线性 GAN 代码如下

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import osif not os.path.exists('./img'):os.mkdir('./img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 128
num_epoch = 100
z_dimension = 100# Image processing
img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5), std=(0.5))
])
# MNIST dataset
mnist = datasets.MNIST(root='./data/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)# Discriminator
class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.dis = nn.Sequential(nn.Linear(784, 256),nn.LeakyReLU(0.2),nn.Linear(256, 256),nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid())def forward(self, x):x = self.dis(x)return x# Generator
class generator(nn.Module):def __init__(self):super(generator, self).__init__()self.gen = nn.Sequential(nn.Linear(100, 256),nn.ReLU(True),nn.Linear(256, 256), nn.ReLU(True), nn.Linear(256, 784), nn.Tanh())def forward(self, x):x = self.gen(x)return xD = discriminator()
G = generator()
if torch.cuda.is_available():D = D.cuda()G = G.cuda()
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)# Start training
for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorimg = img.view(num_img, -1)real_img = Variable(img).cuda()real_label = Variable(torch.ones(num_img, 1)).cuda()fake_label = Variable(torch.zeros(num_img, 1)).cuda()# compute loss of real_imgreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)real_scores = real_out  # closer to 1 means better# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)fake_out = D(fake_img)d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_out  # closer to 0 means better# bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# ===============train generator# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)output = D(fake_img)g_loss = criterion(output, real_label)# bp and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './img/fake_images-{}.png'.format(epoch + 1))torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

卷积 GAN 代码如下

__author__ = 'ShelockLiao'import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import osif not os.path.exists('./dc_img'):os.mkdir('./dc_img')def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)out = out.view(-1, 1, 28, 28)return outbatch_size = 128
num_epoch = 100
z_dimension = 100  # noise dimensionimg_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5), (0.5))
])mnist = datasets.MNIST('./data', transform=img_transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True,num_workers=4)class discriminator(nn.Module):def __init__(self):super(discriminator, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 5, padding=2),  # batch, 32, 28, 28nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2),  # batch, 32, 14, 14)self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, padding=2),  # batch, 64, 14, 14nn.LeakyReLU(0.2, True),nn.AvgPool2d(2, stride=2)  # batch, 64, 7, 7)self.fc = nn.Sequential(nn.Linear(64*7*7, 1024),nn.LeakyReLU(0.2, True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):'''x: batch, width, height, channel=1'''x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)x = self.fc(x)return xclass generator(nn.Module):def __init__(self, input_size, num_feature):super(generator, self).__init__()self.fc = nn.Linear(input_size, num_feature)  # batch, 3136=1x56x56self.br = nn.Sequential(nn.BatchNorm2d(1),nn.ReLU(True))self.downsample1 = nn.Sequential(nn.Conv2d(1, 50, 3, stride=1, padding=1),  # batch, 50, 56, 56nn.BatchNorm2d(50),nn.ReLU(True))self.downsample2 = nn.Sequential(nn.Conv2d(50, 25, 3, stride=1, padding=1),  # batch, 25, 56, 56nn.BatchNorm2d(25),nn.ReLU(True))self.downsample3 = nn.Sequential(nn.Conv2d(25, 1, 2, stride=2),  # batch, 1, 28, 28nn.Tanh())def forward(self, x):x = self.fc(x)x = x.view(x.size(0), 1, 56, 56)x = self.br(x)x = self.downsample1(x)x = self.downsample2(x)x = self.downsample3(x)return xD = discriminator().cuda()  # discriminator model
G = generator(z_dimension, 3136).cuda()  # generator modelcriterion = nn.BCELoss()  # binary cross entropyd_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)# train
for epoch in range(num_epoch):for i, (img, _) in enumerate(dataloader):num_img = img.size(0)# =================train discriminatorreal_img = Variable(img).cuda()real_label = Variable(torch.ones(num_img, 1)).cuda()fake_label = Variable(torch.zeros(num_img, 1)).cuda()# compute loss of real_imgreal_out = D(real_img)d_loss_real = criterion(real_out, real_label)real_scores = real_out  # closer to 1 means better# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)fake_out = D(fake_img)d_loss_fake = criterion(fake_out, fake_label)fake_scores = fake_out  # closer to 0 means better# bp and optimized_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()# ===============train generator# compute loss of fake_imgz = Variable(torch.randn(num_img, z_dimension)).cuda()fake_img = G(z)output = D(fake_img)g_loss = criterion(output, real_label)# bp and optimizeg_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (i+1) % 100 == 0:print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} ''D real: {:.6f}, D fake: {:.6f}'.format(epoch, num_epoch, d_loss.item(), g_loss.item(),real_scores.data.mean(), fake_scores.data.mean()))if epoch == 0:real_images = to_img(real_img.cpu().data)save_image(real_images, './dc_img/real_images.png')fake_images = to_img(fake_img.cpu().data)save_image(fake_images, './dc_img/fake_images-{}.png'.format(epoch+1))torch.save(G.state_dict(), './generatorConv.pth')
torch.save(D.state_dict(), './discriminatorConv.pth')

Ref

  1. https://github.com/L1aoXingyu/pytorch-beginner/tree/master/09-Generative%20Adversarial%20network

PyTorch 实现 GAN 生成式对抗网络 含代码相关推荐

  1. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

    文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...

  2. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

    文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...

  3. GANs系列:GAN生成式对抗网络原理以及数学表达式解剖

    一.GAN介绍 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两 ...

  4. DeepLearing:GAN生成式对抗网络

    GAN 生成对抗网络 文章目录 GAN 生成对抗网络 一.介绍 Q&A 二.GAN模型 GAN结构 生成器("假钞厂") 判别器("验钞机") 生成器与 ...

  5. GAN生成式对抗网络

    阅读Antonio Gulli<Deep Learning with Tensorflow 2 and keras>Second Edition 第六章GAN笔记 GAN生成对抗网络是无监 ...

  6. 不要怂,就是GAN (生成式对抗网络) (六):Wasserstein GAN(WGAN) TensorFlow 代码

    先来梳理一下我们之前所写的代码,原始的生成对抗网络,所要优化的目标函数为: 此目标函数可以分为两部分来看: ①固定生成器 G,优化判别器 D, 则上式可以写成如下形式: 可以转化为最小化形式: 我们编 ...

  7. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph...

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  8. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

    文章目录 1 生成对抗网络基本概念 2 生成对抗网络建模 2.1 建立MnistDataset类 2.2 建立鉴别器 2.3 测试鉴别器 2.4 Mnist生成器制作 3 模型的训练 4 模型表现的判 ...

  9. pytorch基于GAN生成对抗网络的数据集扩充

    文章目录 前言 一.GAN基本原理 1.结构图 2.目标函数 二.实现 1.实现流程图 2.实例 2.1采集少量原始数据 2.2GAN模型训练(注意修改图片路径) 2.3用训练好的模型扩充数据集(生成 ...

最新文章

  1. Android——apk反编译
  2. 【分享】网络民工们你们真的懂防火墙吗?
  3. 拓扑排序 - 项目管理
  4. .NET 6 使用 string.Create 提升字符串创建和拼接性能
  5. Oracle对表的基本操作
  6. palapaweb的mysql无法运行_Mysql 服务无法启动 服务没有报告任何错误
  7. cv2不能读取中文路径
  8. Atiflash显卡BIOS、Nvflsh显卡BIOS、显卡BIOS刷新与超频详细说明教程--转载+BYZ修正...
  9. Fortran入门教程(十)——结构体
  10. 推荐几个适合 新手学习 软件逆向 脱壳破解 的网站
  11. 毕业论文查重过关最强最全规律
  12. 7.计算机系统包括,windows7分几个版本_windows7有哪些版本
  13. background的使用方法
  14. java还原三阶魔方_魔方小站四阶魔方教程2 一看就懂的魔方教程(魔方玩法视频教程+还原公式一步一步图解+3D动画)...
  15. DOCTYPE 的作用是什么
  16. 如何在Google Chrome浏览器中清除浏览历史记录
  17. 一文详细理解计算机网络体系结构(考试和面试必备)
  18. 直击汇佳学校|中考后转轨国际学校 重大改变的他们现在如何?
  19. 万卷书 - 如何成为聪明的父母 [Unlocking Parental Intelligence]
  20. 教育邮箱怎么注册申请,教育电子邮箱注册小妙招

热门文章

  1. python第五十一课——__slots
  2. python-pdf添加水印
  3. python 缺失值处理(Imputation)
  4. 数据切分——Mysql分区表的管理与维护
  5. 如何升级xcode 中的cocos2dx 到v2.2.2以上版本
  6. AMD规范:简单而优雅的动态载入JavaScript代码
  7. Bing Maps Geographic Coverage - Bing地图图像覆盖范围
  8. win32文件读写demo
  9. 认识windows消息机制和Spy++工具
  10. PE文件数据结构汇总