1.简介

这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码。数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G。生成器与判别器模型均采用简单的卷积结构,代码参考了pytorch官网。

建议对pytorch和神经网络原理还不熟悉的同学,可以先看下之前的文章了解下基础:pytorch基础教学简单实例(附代码)_Lizhi_Tech的博客-CSDN博客_pytorch实例


2.GAN原理

简而言之,生成对抗网络可以归纳为以下几个步骤:

  1. 随机噪声输入进生成器,生成虚假图片。

  1. 将带标签的虚假图片和真实图片输入进判别器进行更新,最大化 log(D(x)) + log(1 - D(G(z)))。

  1. 根据判别器的输出结果更新生成器,最大化 log(D(G(z)))。


3.代码

from __future__ import print_function
import random
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt# 设置随机算子
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)# 数据集位置
dataroot = "data/celeba"# dataloader的核数
workers = 2# Batch大小
batch_size = 128# 图像缩放大小
image_size = 64# 图像通道数
nc = 3# 隐向量维度
nz = 100# 生成器特征维度
ngf = 64# 判别器特征维度
ndf = 64# 训练轮数
num_epochs = 5# 学习率
lr = 0.0002# Adam优化器的beta系数
beta1 = 0.5# gpu个数
ngpu = 1# 加载数据集
dataset = dset.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
# 创建dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=True, num_workers=workers)# 使用cpu还是gpu
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")# 初始化权重
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)# 生成器
class Generator(nn.Module):def __init__(self, ngpu):super(Generator, self).__init__()self.ngpu = ngpuself.main = nn.Sequential(# input is Z, going into a convolutionnn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# state size. (ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# state size. (ngf*4) x 8 x 8nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# state size. (ngf*2) x 16 x 16nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# state size. (ngf) x 32 x 32nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# state size. (nc) x 64 x 64)def forward(self, input):return self.main(input)# 实例化生成器并初始化权重
netG = Generator(ngpu).to(device)
netG.apply(weights_init)# 判别器
class Discriminator(nn.Module):def __init__(self, ngpu):super(Discriminator, self).__init__()self.ngpu = ngpuself.main = nn.Sequential(# input is (nc) x 64 x 64nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):return self.main(input)# 实例化判别器并初始化权重
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)# 损失函数
criterion = nn.BCELoss()# 随机输入噪声
fixed_noise = torch.randn(64, nz, 1, 1, device=device)# 真实标签与虚假标签
real_label = 1.
fake_label = 0.# 创建优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))# 开始训练
img_list = []
G_losses = []
D_losses = []
iters = 0print("Starting Training Loop...")
for epoch in range(num_epochs):for i, data in enumerate(dataloader, 0):############################# (1) 更新D: 最大化 log(D(x)) + log(1 - D(G(z)))############################ 使用真实标签的batch训练netD.zero_grad()real_cpu = data[0].to(device)b_size = real_cpu.size(0)label = torch.full((b_size,), real_label, dtype=torch.float, device=device)output = netD(real_cpu).view(-1)errD_real = criterion(output, label)errD_real.backward()D_x = output.mean().item()# 使用虚假标签的batch训练noise = torch.randn(b_size, nz, 1, 1, device=device)fake = netG(noise)label.fill_(fake_label)output = netD(fake.detach()).view(-1)errD_fake = criterion(output, label)errD_fake.backward()D_G_z1 = output.mean().item()errD = errD_real + errD_fake# 更新DoptimizerD.step()############################# (2) 更新G: 最大化 log(D(G(z)))###########################netG.zero_grad()label.fill_(real_label)output = netD(fake).view(-1)errG = criterion(output, label)errG.backward()D_G_z2 = output.mean().item()# 更新GoptimizerG.step()# 输出训练状态if i % 50 == 0:print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'% (epoch, num_epochs, i, len(dataloader),errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))# 保存每轮lossG_losses.append(errG.item())D_losses.append(errD.item())# 记录生成的结果if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):with torch.no_grad():fake = netG(fixed_noise).detach().cpu()img_list.append(vutils.make_grid(fake, padding=2, normalize=True))iters += 1# loss曲线
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()# 生成效果图
real_batch = next(iter(dataloader))# 真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))# 生成的虚假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

4.结果

真实图像与生成图像

loss曲线


完整的代码与数据集可以在我的github上找到:lizhiTech/pytorch_GAN_simple_example: 介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码 (github.com)

或者csdn下载:

我们也提供包括深度学习、计算机视觉、机器学习等其他方向的其他代码及辅导服务,有需求可以通过csdn私聊或github上的联系方式联系我们。

pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)相关推荐

  1. [Python图像识别] 四十九.图像生成之什么是生成对抗网络GAN?基础原理和代码普及

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  2. [Python人工智能] 二十九.什么是生成对抗网络GAN?基础原理和代码普及(1)

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CN ...

  3. 生成对抗网络GAN损失函数loss的简单理解

    原始的公式长这样: min⁡Gmax⁡DV(D,G)=Ex∼pdata (x)[log⁡D(x)]+Ez∼pz(z)[log⁡(1−D(G(z)))]\min _{G} \max _{D} V(D, ...

  4. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  5. 万字详解什么是生成对抗网络GAN

    摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...

  6. [Pytorch系列-72]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型训练CycleGAN模型

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  7. [Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleG ...

  8. [Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  9. 【CV秋季划】生成对抗网络GAN有哪些研究和应用,如何循序渐进地学习好(2022年言有三一对一辅导)?...

    GAN自从被提出来后,技术发展就非常迅猛,已经被落地于众多的方向,其应用涉及图像与视频生成,数据仿真与增强,各种各样的图像风格化任务,人脸与人体图像编辑,图像质量提升. 那我们究竟如何去长期学好相关的 ...

最新文章

  1. 绘制你的世界:探索构图和真实的深度感
  2. 一步步实现 Redis 搜索引擎
  3. javaee_SSH
  4. php用json交换二维数组,PHP和Javascript的JSON交互(处理一个二维数组)
  5. 给tomcat 配置https
  6. python对csv数据提取某列的某些行_python pandas获取csv指定行 列的操作方法
  7. 在flex中显示gif
  8. Android Thread 官方说明
  9. 多线程之终止线程的四种方法
  10. linux永久禁止进程,SELinux如何永久禁用 SELinux如何永久禁用
  11. 学计算机的学后感,计算机学习心得体会(通用10篇)
  12. Windows ZIP Archive安装和卸载MySQL 8.0
  13. 小程序即时配送配置指南
  14. 京东大数据平台进化之路
  15. 截屏与截长图功能的实现
  16. 书摘—拆掉思维里的墙
  17. python量化实战 顾比倒数线_顾比倒数线 主图源码
  18. #汇编语言字符串的输出(dosbox运行时输出乱码问题解决)
  19. CTF密码学(Crypto)一些在线解密网站
  20. C#打字游戏案例(纯代码实现),新手入门必备!

热门文章

  1. 万圣节头像框生成工具微信小程序源码
  2. 金蝶中间件部署报栈溢出_全网最全、最新消息中间件面试题(2020最新版)
  3. 实用的 jquery 弹出窗口 插件winbox
  4. CAD对象分解命令怎么用?CAD对象分解命令使用技巧
  5. 爬取矿大教务系统成绩
  6. h5免签聚合支付系统yy支付-y币+yy陪玩-系统源码
  7. 建设一个网站需要什么
  8. 参加百度轻应用编程马拉松总结
  9. 2020年计算机应用基础本试卷号1200,计算机应用基础试卷2020年.pdf
  10. 数码相框_通过随机选择的媒体文件轻松加载数码相框和MP3播放器