版权申明:本文章为本人原创内容,转载请注明出处,谢谢合作!

实验环境:

1.Pytorch 0.4.0
2.torchvision 0.2.1
3.Python 3.6
4.Win10+Pycharm
本项目是基于DCGAN的,代码是在《深度学习框架PyTorch:入门与实践》第七章的配套代码上做过大量修改过的。项目所用数据集获取:点击获取 提取码:g5qa,感谢知乎用户何之源爬取的数据。 请将下载的压缩包里的图片完整解压至data/face/目录下。整个项目的代码结构如下图:

其中data/face里是存放训练图片的,imgs/存放的是最终的训练结果,model.py是DCGAN的结构,train.py是主要的训练文件。

首先是,model.py:
import torch.nn as nn
# 定义生成器网络G
class NetG(nn.Module):def __init__(self, ngf, nz):super(NetG, self).__init__()# layer1输入的是一个100x1x1的随机噪声, 输出尺寸(ngf*8)x4x4self.layer1 = nn.Sequential(nn.ConvTranspose2d(nz, ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(inplace=True))# layer2输出尺寸(ngf*4)x8x8self.layer2 = nn.Sequential(nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(inplace=True))# layer3输出尺寸(ngf*2)x16x16self.layer3 = nn.Sequential(nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(inplace=True))# layer4输出尺寸(ngf)x32x32self.layer4 = nn.Sequential(nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(inplace=True))# layer5输出尺寸 3x96x96self.layer5 = nn.Sequential(nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),nn.Tanh())# 定义NetG的前向传播def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)return out# 定义鉴别器网络D
class NetD(nn.Module):def __init__(self, ndf):super(NetD, self).__init__()# layer1 输入 3 x 96 x 96, 输出 (ndf) x 32 x 32self.layer1 = nn.Sequential(nn.Conv2d(3, ndf, kernel_size=5, stride=3, padding=1, bias=False),nn.BatchNorm2d(ndf),nn.LeakyReLU(0.2, inplace=True))# layer2 输出 (ndf*2) x 16 x 16self.layer2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True))# layer3 输出 (ndf*4) x 8 x 8self.layer3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True))# layer4 输出 (ndf*8) x 4 x 4self.layer4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True))# layer5 输出一个数(概率)self.layer5 = nn.Sequential(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())# 定义NetD的前向传播def forward(self,x):out = self.layer1(x)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)return out
然后是,train.py:
import argparse
import torch
import torchvision
import torchvision.utils as vutils
import torch.nn as nn
from random import randint
from model import NetD, NetGparser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=64)
parser.add_argument('--imageSize', type=int, default=96)
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--epoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--data_path', default='data/', help='folder to train data')
parser.add_argument('--outf', default='imgs/', help='folder to output images and model checkpoints')
opt = parser.parse_args()
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#图像读入与预处理
transforms = torchvision.transforms.Compose([torchvision.transforms.Scale(opt.imageSize),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])dataset = torchvision.datasets.ImageFolder(opt.data_path, transform=transforms)dataloader = torch.utils.data.DataLoader(dataset=dataset,batch_size=opt.batchSize,shuffle=True,drop_last=True,
)netG = NetG(opt.ngf, opt.nz).to(device)
netD = NetD(opt.ndf).to(device)criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0for epoch in range(1, opt.epoch + 1):for i, (imgs,_) in enumerate(dataloader):# 固定生成器G,训练鉴别器DoptimizerD.zero_grad()## 让D尽可能的把真图片判别为1imgs=imgs.to(device)output = netD(imgs)label.data.fill_(real_label)label=label.to(device)errD_real = criterion(output, label)errD_real.backward()## 让D尽可能把假图片判别为0label.data.fill_(fake_label)noise = torch.randn(opt.batchSize, opt.nz, 1, 1)noise=noise.to(device)fake = netG(noise)  # 生成假图output = netD(fake.detach()) #避免梯度传到G,因为G不用更新errD_fake = criterion(output, label)errD_fake.backward()errD = errD_fake + errD_realoptimizerD.step()# 固定鉴别器D,训练生成器GoptimizerG.zero_grad()# 让D尽可能把G生成的假图判别为1label.data.fill_(real_label)label = label.to(device)output = netD(fake)errG = criterion(output, label)errG.backward()optimizerG.step()print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f'% (epoch, opt.epoch, i, len(dataloader), errD.item(), errG.item()))vutils.save_image(fake.data,'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),normalize=True)torch.save(netG.state_dict(), '%s/netG_%03d.pth' % (opt.outf, epoch))torch.save(netD.state_dict(), '%s/netD_%03d.pth' % (opt.outf, epoch))

实验结果:

跑完第1个epoch的结果:

跑完第25个epoch的结果:

Pytorch实战3:DCGAN深度卷积对抗生成网络生成动漫头像相关推荐

  1. DL之DCGAN:基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成

    DL之DCGAN:基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成 目录 基于keras框架利用深度卷积对抗网络DCGAN算法对MNIST数据集实现图像生成 设计思路 ...

  2. DCGAN——深度卷积生成对抗网络

    译文 | 让深度卷积网络对抗:DCGAN--深度卷积生成对抗网络 原文: https://arxiv.org/pdf/1511.06434.pdf -- 前言:如何把CNN与GAN结合?DCGAN是这 ...

  3. 好像还挺好玩的GAN2——Keras搭建DCGAN利用深度卷积神经网络实现图片生成

    好像还挺好玩的GAN2--Keras搭建DCGAN利用深度卷积神经网络实现图片生成 注意事项 学习前言 什么是DCGAN 神经网络构建 1.Generator 2.Discriminator 训练思路 ...

  4. 深度卷积对抗神经网络 基础 第六部分 缺点和偏见 GANs Disadvantages and Bias

    深度卷积对抗神经网络 基础 第六部分 缺点和偏见 GANs Disadvantages and Bias GANs 综合评估 生成对抗网络(英语:Generative Adversarial Netw ...

  5. 深度卷积对抗神经网络 基础 第七部分 StyleGAN

    深度卷积对抗神经网络 基础 第七部分 StyleGAN 深度卷积神经网络模型已经应用在非常多的领域,但是其总包含了很多潜在的问题,比如说训练速度过慢,生成器与判别器的进化程度不平衡等等.那么,随着各种 ...

  6. 深度卷积对抗生成网络DCGAN论文解读

    论文名称:Unsupervised representation learning with deep convolutional generative adversarial networks 论文 ...

  7. Python 头像动漫化,快来生成女朋友的动漫头像

    很多时候我们都会为头像发愁,像我这种万年不换头像的咸鱼从来没有这种烦恼.但是吧,有个个性化的头像还是非常有趣的,例如用自己的漫画来做头像,那么用Python如何实现呢?我打算把这个小技巧给大家分享. ...

  8. pytorch 图像分割的交并比_PyTorch专栏(二十二): 深度卷积对抗生成网络

    作者 | News 编辑 | 奇予纪 出品 | 磐创AI团队出品 简介 本教程通过一个例子来对 DCGANs 进行介绍.我们将会训练一个生成对抗网络(GAN)用于在展示了许多真正的名人的图片后产生新的 ...

  9. 理解与学习深度卷积生成对抗网络

    一.GAN 引言:生成对抗网络GAN,是当今的一大热门研究方向.在2014年,被Goodfellow大神提出来,当时的G神还是蒙特利尔大学的博士生.据有关媒体统计:CVPR2018的论文里,有三分之一 ...

最新文章

  1. 【c语言】符号常量的使用
  2. python 按照要求对字符串进行处理
  3. isp 图像算法(四)之white balance gain control 就是对 r,gr,gb,b 进行加权
  4. 如何在SAP云平台上使用MongoDB服务 1
  5. 二叉树的前序遍历Python解法
  6. BZOJ 2115 Wc2011 Xor DFS+高斯消元
  7. Ensure that config phoenix.schema.isNamespaceMappingEnabled is consistent on client and server
  8. Java LinkedList getFirst()方法与示例
  9. datax 导入数据中文乱码_DataX在有赞大数据平台的实践
  10. 售价150万的“AI老婆”,上市仅1小时就被抢光
  11. php sqlsrv 分页,sqlsrv php分页
  12. 小米扫地机器人原地不动_小米扫地机器人的4个不可思议
  13. HibernateProxy. Forgot to register a type adapter?
  14. Windows下安装神通数据库
  15. Java学习总结篇一初识jvav
  16. 互联网时代个人信息安全的重要性
  17. 微信小程序背景图片background无法在手机端显示问题解决方案
  18. 【kali-漏洞利用】(3.4)免杀Payload 生成工具(下):Veil后门使用、监听失败原因
  19. 正则表达式匹配字符串(scala)
  20. c语言中dna图案打印题,C语言打印DNA螺旋

热门文章

  1. Lambda表达式和Stream类的使用
  2. UBports安装Arduino记录
  3. filebeat重复采集数据问题排查
  4. python主机配置_python 之根据自己的需求配置hostname
  5. 星际无限高级合伙人培训大会在深圳南山举行
  6. C/C++语言100题练习计划 88——猜数游戏(二分查找实现)
  7. net logon服务无法启动
  8. 大型、超大型数据中心园区设计如何审查与优化
  9. 使用RNN模型构建人名分类器
  10. Android学习之路——转自stormzhang