欢迎大家来到咱们的深度学习CV项目实战专栏,GAN是当下非常热门的技术,本次我们给大家介绍如何来训练自己的第1个生成对抗网络项目。

作者&编辑 | 言有三

本项目结果展示

本文篇幅:4000字

背景要求:会使用Python和Pytorch深度学习开源框架

附带资料:开源代码一份,支持Pytorch

数据一份:文末有获取方法

1 项目背景

GAN无疑是这几年深度学习领域里最酷的技术,不管是理论的研究,还是GAN在图像生成,图像翻译,语音图像等基础领域的应用,都非常的丰富。我们公众号从很早以前开始就输出过非常多的GAN相关资源,包括免费与付费的视频课,知识星球中的GAN模型原理解读专题,公众号的GAN付费专栏,大家可以阅读下面的文章了解详情。

【总结】从视频到图文,代码实战,有三AI-GAN学习资料汇总!

2 项目解读

为了让大家能够快速上手,本次我们给大家介绍一个非常适合新手入门的项目,使用DCGAN来进行图片生成,项目效果如文章开头的图片。

2.1 数据获取

本次我们完成一个人脸表情图像生成的任务,使用的数据集也是多次在咱们的项目中出现过的数据集,我们选择了其中一类表情的数据,如下:

数据的读取非常简单,直接使用torchvision的ImageFolder接口即可,与我们以前介绍过的图像分类任务相同,核心代码如下,不再赘述。

## 读取数据

dataroot = "mouth/"

dataset = datasets.ImageFolder(root=dataroot,

transform=transforms.Compose([

transforms.Resize(image_size),

transforms.CenterCrop(image_size),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,

shuffle=True, num_workers=workers)

不熟悉的朋友可以阅读:【CV实战】年轻人的第一个深度学习CV项目应该是什么样的?(支持13大深度学习开源框架)

2.2 判别器定义

接下来我们再看判别器的定义,它就是一个图像分类模型,与原始DCGAN论文中的参数配置略有差异。

class Discriminator(nn.Module):

def __init__(self, ndf=64, nc=3):

super(Discriminator, self).__init__()

self.ndf = ndf

self.nc = nc

self.main = nn.Sequential(

# 输入图片大小 (nc) x 64 x 64,输出 (ndf) x 32 x 32

nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),

nn.LeakyReLU(0.2, inplace=True),

# 输入(ndf) x 32 x 32,输出(ndf*2) x 16 x 16

nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 2),

nn.LeakyReLU(0.2, inplace=True),

# 输入(ndf*2) x 16 x 16,输出 (ndf*4) x 8 x 8

nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 4),

nn.LeakyReLU(0.2, inplace=True),

# 输入(ndf*4) x 8 x 8,输出(ndf*8) x 4 x 4

nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),

nn.BatchNorm2d(ndf * 8),

nn.LeakyReLU(0.2, inplace=True),

# 输入(ndf*8) x 4 x 4,输出1 x 1 x 1

nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),

nn.Sigmoid()

)

def forward(self, input):

return self.main(input)

以上代码定义了一个包含5层卷积,其中前面4个卷积层的卷积核大小为4×4,宽和高的步长等于2,使用了padding技术,padding大小为1,每经过一次卷积,图像长宽都降低为原来的1/2。每一个卷积层后都跟随一个batch normalization层和lrelu层。

输出层也是一个卷积层,其特征图大小空间尺寸为4×4,卷积核大小也是4×4,所以输出空间层维度为1,使用sigmoid激活函数,输出就是一个0到1之间的概率值。

2.3 生成器定义

接下来我们再看生成器的定义,它输入一维的噪声向量,输出二维的图像。

## 定义生成器与判别器

class Generator(nn.Module):

def __init__(self, nz=100, ngf=64, nc=3):

super(Generator, self).__init__()

self.ngf = ngf

self.nz = nz

self.nc = nc

self.main = nn.Sequential(

# 输入噪声向量Z,(ngf*8) x 4 x 4特征图

nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),

nn.BatchNorm2d(ngf * 8),

nn.ReLU(True),

# 输入(ngf*8) x 4 x 4特征图,输出(ngf*8) x 4 x 4

nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),

nn.BatchNorm2d(ngf * 4),

nn.ReLU(True),

# 输入(ngf*4) x 8 x 8,输出(ngf*2) x 16 x 16

nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),

nn.BatchNorm2d(ngf * 2),

nn.ReLU(True),

# 输入(ngf*2) x 16 x 16,输出(ngf) x 32 x 32

nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),

nn.BatchNorm2d(ngf),

nn.ReLU(True),

# 输入(ngf) x 32 x 32,输出(nc) x 64 x 64

nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),

nn.Tanh()

)

def forward(self, input):

return self.main(input)

可以看出,总共包含5个上采样层。其中前4个上采样层后接有BN层,和ReLU层。第一个上采样层将输入的一维噪声向量经过一个上采样层生成4×4大小的图,然后经过后面的4个上采样层得到输出。前面4层的激活函数为ReLU,最后一层的激活函数为Tanh。

2.4 优化目标与方法定义

损失函数使用了BCE交叉熵损失,真样本和假样本标签分别为1和0。

# 损失函数

criterion = nn.BCELoss()

# 真假标签

real_label = 1.0

fake_label = 0.0

判别器和生成器都采用了Adam方法作为优化器,且使用了同样的配置,定义如下:

lr = 0.0003

beta1 = 0.5

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

3 模型训练

接下来进行模型训练,添加可视化部分,缓存中间结果,核心的迭代代码如下:

for epoch in range(num_epochs):

lossG = 0.0

lossD = 0.0

for i, data in enumerate(dataloader, 0):

############################

# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))

###########################

## 训练真实图片

netD.zero_grad()

real_data = data[0].to(device)

b_size = real_data.size(0)

label = torch.full((b_size,), real_label, device=device)

output = netD(real_data).view(-1)

# 计算真实图片损失,梯度反向传播

errD_real = criterion(output, label)

errD_real.backward()

D_x = output.mean().item()

## 训练生成图片

# 产生latent vectors

noise = torch.randn(b_size, nz, 1, 1, device=device)

# 使用G生成图片

fake = netG(noise)

label.fill_(fake_label)

output = netD(fake.detach()).view(-1)

# 计算生成图片损失,梯度反向传播

errD_fake = criterion(output, label)

errD_fake.backward()

D_G_z1 = output.mean().item()

# 累加误差,参数更新

errD = errD_real + errD_fake

optimizerD.step()

############################

# (2) Update G network: maximize log(D(G(z)))

###########################

netG.zero_grad()

label.fill_(real_label)  # 给生成图赋标签

# 对生成图再进行一次判别

output = netD(fake).view(-1)

# 计算生成图片损失,梯度反向传播

errG = criterion(output, label)

errG.backward()

D_G_z2 = output.mean().item()

optimizerG.step()

# Output training stats

if i % 50 == 0:

print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'

% (epoch, num_epochs, i, len(dataloader),

errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

# 存储损失

nbatch = nbatch + 1

lossG = lossG + errG.item()

lossD = lossD + errD.item()

# 对固定的噪声向量,存储生成的结果

if (iters % 20 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):

with torch.no_grad():

fake = netG(fixed_noise).detach().cpu()

img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

i = vutils.make_grid(fake, padding=2, normalize=True)

fig = plt.figure(figsize=(8, 8))

plt.imshow(np.transpose(i, (1, 2, 0)))

plt.axis('off')  # 关闭坐标轴

plt.savefig("out/%d_%d.png" % (epoch, iters))

plt.close(fig)

iters += 1

writer.add_scalar('data/lossG', lossG, epoch)

writer.add_scalar('data/lossD', lossD, epoch)

torch.save(netG.state_dict(),'models/netG.pt')

训练结果曲线如下:

由于标准GAN的损失与模型的生成结果之间的关系不像以前介绍的任务那么直观,我们应该以实际生成的图片结果为准,下面从左到右分别是第0,10,100个epoch的生成结果。

从图结果来看,随着训练的进行,逐渐生成了许多有意义且非常逼真的样本,在10个epoch的时候生成图片都有明显的瑕疵,到100个epoch时已经开始生成一些逼真的样本。不过最终生成的图像仍然有一部分效果很差,这是因为DCGAN本身模型性能所限,后续可以使用更好的模型进行改进。

4 模型测试

上面已经训练好了模型,我们接下来的目标,就是要用它来做推理,真正把模型用起来。

import torch

import torch.nn as nn

import torchvision.utils as vutils

import matplotlib.pyplot as plt

from net import Generator

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

netG = Generator().to(device)

## 载入模型权重

modelpath = sys.argv[1]

savepath = sys.argv[2]

netG.load_state_dict(torch.load(modelpath,map_location=lambda storage,loc: storage))

netG.eval() ## 设置推理模式,使得dropout和batchnorm等网络层在train和val模式间切换

torch.no_grad() ## 停止autograd模块的工作,以起到加速和节省显存

nz = 100

for i in range(0,100):

noise = torch.randn(64, nz, 1, 1, device=device)

fake = netG(noise).detach().cpu()

rows = vutils.make_grid(fake, padding=2, normalize=True)

fig = plt.figure(figsize=(8, 8))

plt.imshow(np.transpose(rows, (1, 2, 0)))

plt.axis('off')  # 关闭坐标轴

plt.savefig(os.path.join(savepath,"%d.png" % (i)))

plt.close(fig)

推理的核心代码就是使用torch.load函数载入生成器模型,然后输入随机的噪声向量,得到生成的结果。

下面展示了一些结果。

从图中我们可以看到,总体的生成结果还是不错的,不过本次的任务还有许多可以提升的空间,包括但不限于:(1) 做更多的数据增强。(2) 改进模型。这些就留给读者去进行实验。

 5 资源获取和拓展学习

本文的完整代码,可以在我们的开源项目中获取,项目地址如下。

https://github.com/longpeng2008/yousan.ai

由于数据集较大,如果想要获得数据集,请到有三AI知识星球中下载,当然你也完全可以将其替换成自己的数据集:

有三AI知识星球链接与介绍如下:

【杂谈】有三AI知识星球指导手册出炉!和公众号相比又有哪些内容?

如果想要学习更多GAN的内容,请阅读下面的文章了解生态的相关资料。

【总结】从视频到图文,代码实战,有三AI-GAN学习资料汇总!

总结

本次我们完成了一个GAN图像生成项目的全部流程,本次任务和前面的CV分类实战与分割实战是一脉相承的,使用了相同的数据集,大家可以一起阅读学习。

转载文章请后台联系

侵权必究

往期文章

  • 【CV夏季划】2021年有三AI-CV夏季划出炉,冲刺秋招,从CV基础到模型优化彻底掌握

  • 【CV秋季划】模型优化很重要,如何循序渐进地学习好?

  • 【CV秋季划】人脸算法那么多,如何循序渐进地学习好?

  • 【CV秋季划】图像质量提升与编辑有哪些研究和应用,如何循序渐进地学习好?

  • 【CV秋季划】生成对抗网络GAN有哪些研究和应用,如何循序渐进地学习好?

【CV实战】年轻人的第一个GAN项目应该是什么样的(Pytorch框架)?相关推荐

  1. 【阿里云课程】如何从零开始完成第一个GAN项目

    大家好,继续更新有三AI与阿里天池联合推出的深度学习系列课程,本次更新内容为实践课,介绍如下: 从零使用GAN进行图片生成 本次课程是阿里天池联合有三AI推出的深度学习系列课程实践课第2期,从零完成G ...

  2. 【Vue 3 实战一】搭建一个新项目并上传至gitee

    提示:专栏内容均为原创,搬运必究 文章目录 一.Vue3的新特性? 二.创建新项目 1. 利用脚手架 2.项目配置选择 (建议与下方一致) 三.运行项目 1. 安装依赖 2. 运行项目 3. 上传代码 ...

  3. [vue] 从0到1自己构架一个vue项目,说说有哪些步骤、哪些重要插件、目录结构你会怎么组织

    [vue] 从0到1自己构架一个vue项目,说说有哪些步骤.哪些重要插件.目录结构你会怎么组织 1 项目类型 前端的项目目前来看主要分为小程序开发,H5页面开发.PC官网.后台管理系统开发.Nativ ...

  4. 【视频课】永久免费!5小时快速掌握Pytorch框架入门及实战

    前言 PyTorch是深度学习的主流框架之一,新手入门相对容易.为了帮助初学者解决PyTorch入门及实践的问题,有三AI推出<深度学习之PyTorch-入门及实战>课程,课程将算法.模型 ...

  5. [PyTorch] 译+注:一个例子,让你明白PyTorch框架

    文章目录 Introduction Motivation Table of Contents A Simple Regression Problem (一个简单的线性回归) Data Generati ...

  6. 【CV实战】年轻人的第一个深度学习图像分割项目应该是什么样的(Pytorch框架)?...

    我们上次给新手们介绍了第一个合适入门的深度学习CV项目,可阅读[CV实战]年轻人的第一个深度学习CV项目应该是什么样的?(支持13大深度学习开源框架),本次我们再给大家介绍一个新的任务,图像分割,包括 ...

  7. 【CV实战】年轻人的第一个深度学习CV项目应该是什么样的?(支持13大深度学习开源框架)...

    计算机视觉发展至今,许多技术已经非常成熟了,在各行各业落地业务非常多,因此不断的有新同学入行.本次我们就来介绍,对于新手来说,如何做一个最合适的项目.本次讲述一个完整的工业级别图像分类项目的标准流程, ...

  8. 数百个CV实战项目与必备7本书5000页中英文CV书籍免费送啦~

    数百个CV实战项目与必备7本5000页书籍,公众号[深度学习冲鸭]的后台回复关键字[CV入坑必备](建议复制~)获得: CV学习书籍汇总 1:<学习OpenCV中文版> 2:<图像处 ...

  9. 数百个CV实战项目与必备7本书5000页中英文CV书籍以及算法工程师必备资料免费送啦~...

    数百个CV实战项目与必备7本5000页书籍,公众号[深度学习冲鸭]的后台回复关键字[CV入坑必备](建议复制~)获得: CV学习书籍汇总 1:<学习OpenCV中文版> 2:<图像处 ...

最新文章

  1. [日记]一个人去散步
  2. jsp页面展示更加商品的分类,控制商品的显示
  3. 《用户故事与敏捷方法》阅读笔记一
  4. LINQ从方法中返回查询
  5. python制作gif动画_实用的Python(2)利用Python制作gif动图
  6. java接口课程_用java定义一个接口,用于查询课程
  7. 图解分布式架构的发展和演进 | 技术干货
  8. C#中将原表复制到新表
  9. 从linux使用sz命令下载大于4g的文件到windows
  10. 苹果mac光标自行移动如何解决?
  11. 违章查询源码 php,PHP教程:php车辆违章查询数据示例
  12. 精确率(查准率)、召回率(查全率)和F1值
  13. cad插件_CAD插件自动标注安装教程
  14. 多模态知识图谱构建和推理技术 王萌 东南大学
  15. matlab抛物柱面画图,抛物柱面 - calculus的日志 - 网易博客
  16. 学习C语言的必备书籍-从入门到精通
  17. R2S铝合金外壳散热测试
  18. 远程公司内网服务器【内网穿透】
  19. 无线蓝牙耳机哪个品牌音质好?性价比高音质好的蓝牙耳机排行榜
  20. SCI论文写作常用词汇短语总结

热门文章

  1. vs2013下oracle proc配置
  2. 算法每日学打卡:java语言基础题目打卡(01-10)
  3. 【struts2+hibernate+spring项目实战】Spring计时器任务 Spring整合JavaMail(邮件发送)(ssh)
  4. UI组件之AdapterView及其子类(三)Spinner控件详解
  5. 【排序算法】— 手写堆排序
  6. Java设计模式——桥模式
  7. php判断表单修改内容,JavaScript判断用户是否对表单进行了修改的方法_javascript技巧...
  8. python怎么启动mne_MNE-Python专辑 | MNE-Python详细安装与使用(更新)
  9. 服务器2012怎么换桌面背景,2012年职称计算机Windows XP:更改桌面背景和颜色
  10. 99%网工都会遇到的10道经典面试问题