GAN的原理

GAN是一种典型的生成网络模型,它类似于编解码结构,通过训练,他能够生成不同于训练集的各种图片。

首先先训练判别器,把真图通过判别器的输出和真标签作损失,把假图通过判别器的输出和假标签作损失,让它具备判别真图和假图的能力。然后再训练生成器,把生成器生成的假图通过判别器的输出和真标签作损失。经过反复的训练,让判别器难以分辨生成图的真假,也就是让它判别为真或为假的概率各为0.5

数据集下载

网上下载的动漫头像数据集有很多不清晰的奇异样本,对此我做了清洗,剩下的都是符合标准的,可直接下载
百度网盘:https://pan.baidu.com/s/1–zFrJdg1gtW2wJ6wtWQsQ
密码:bu55

网络结构

生成网络

相当于一个编码器

class NetD(nn.Module):# 构建一个判别器,相当与一个二分类问题, 生成一个值def __init__(self):super(NetD, self).__init__()ndf = opt.ndfself.main = nn.Sequential(# 输入96*96*3nn.Conv2d(3, ndf, 5, 3, 1, bias=False),nn.LeakyReLU(negative_slope=0.2, inplace=True),# 输入32*32*ndfnn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, True),# 输入16*16*ndf*2nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, True),# 输入为8*8*ndf*4nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, True),# 输入为4*4*ndf*8nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True),nn.Sigmoid()  # 分类问题)def forward(self, x):return self.main(x).view(-1)

生成器

相当于一个解码器

class NetG(nn.Module):# 定义一个生成模型,通过输入噪声来产生一张图片def __init__(self):super(NetG, self).__init__()ngf = opt.ngfself.main = nn.Sequential(# 假定输入为一张1*1*opt.nz维的数据(opt.nz维的向量)nn.ConvTranspose2d(opt.nz , ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(inplace=True),# 输入一个4*4*ngf*8nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 输入一个8*8*ngf*4nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=True),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 输入一个16*16*ngf*2nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(inplace=True),# 输入一个32*32*ngfnn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),nn.Tanh()# 输出一张96*96*3)def forward(self, x):return self.main(x)

GAN网络结构设计要点

1、在D网络中用stride卷积(stride>1)代替pooling层,在G网络中用conv2d_transpose代替上采样层
2、在G和D网络中直接将BN应用到所有层会导致样本震荡和模型不稳定,通过在G网络输出层和D网络输入层不采用BN层可以有效防止这种现象
3、不使用全连接层作为输出
4、G网络中除了输出层用tanh激活,其他层都是用ReLu激活
5、D网络中都使用LeakyReLu激活

网络模型训练

训练细节

1、预处理环节,将图像scale到tanh的[-1,1]
2、所有的参数初始化由(0,0.02)的正态分布中随机得到
3、LeakyReLu的斜率是0.2(默认)
4、优化器Adam的learning rate=0.0002,momentum参数betas的beta1从0.9降为0.5,beta2默认,防止震荡和不稳定
5、可以G网络训练1次,然后D网络训练1次,如此反复;也可以G网络先训练几次后,D网络再训练1次,如此反复。前者效果出得较快,后者较慢。
训练代码

# opt参数
ngf=96
ndf=96
nz=256
img_size=96
batch_size=100
num_workers=4
netg_path=r"网络参数/netg_5.pt"
netd_path=r"网络参数/netd_5.pt"
lr1=0.0002
lr2=0.0002
beta1=0.5
epochs=200
d_every=1
g_every=5
save_every=20
from torchvision.utils import save_image
import Nets
import torch
from torch.utils.data import DataLoader
import opt
import torch.nn as nn
import datasetif __name__=="__main__":# 1. 加载数据dataset = dataset.Dataset()dataloader = DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers,drop_last=True)# 2.初始化网络netg, netd = Nets.NetG(), Nets.NetD()# 3. 设定优化器参数optimize_g = torch.optim.Adam(netg.parameters(), lr=opt.lr1, betas=(opt.beta1,0.999))optimize_d = torch.optim.Adam(netd.parameters(), lr=opt.lr2, betas=(opt.beta1,0.999))loss_func = nn.BCELoss()# 4. 定义标签, 并且开始注入生成器的输入noisetrue_labels = torch.ones(opt.batch_size)fake_labels = torch.zeros(opt.batch_size)noises = torch.randn(opt.batch_size, opt.nz, 1, 1)#  6.训练网络netg.train()netd.train()for epoch in range(opt.epochs):for i, img in enumerate(dataloader):real_img = img# 训练判别器if i % opt.d_every == 0:optimize_d.zero_grad()# 真图real_out = netd(real_img)error_d_real = loss_func(real_out, true_labels)error_d_real.backward()# 随机生成的假图noises = noises.detach()fake_image = netg(noises).detach()fake_out = netd(fake_image)error_d_fake = loss_func(fake_out, fake_labels)error_d_fake.backward()optimize_d.step()# 计算losserror_d = error_d_fake + error_d_realprint("第{0}轮: 判别网络   损失:{1}  对真图评分:{2}  对生成图评分:{3}".format(epoch+1,error_d.item(),real_out.data.mean(),fake_out.data.mean()))# 训练生成器if i % opt.g_every == 0 and i>0:optimize_g.zero_grad()noises.data.copy_(torch.randn(opt.batch_size, opt.nz, 1, 1))fake_img = netg(noises)output = netd(fake_img)error_g = loss_func(output, true_labels)error_g.backward()optimize_g.step()print("       生成网络   损失:{0}".format(error_g.item()))#  7.保存模型和图片if i % opt.save_every == 0 and i>0:fix_noises = torch.randn(opt.batch_size, opt.nz, 1, 1)fix_fake_image = netg(fix_noises)# save_image(real_img.data*0.5+0.5, "./img/{0}-{1}-real_img.jpg".format(epoch, i), nrow=10)save_image(fix_fake_image.data*0.5+0.5, "./image/{0}-{1}-fake_img.jpg".format(epoch, i), nrow=10)torch.save(netd.state_dict(), opt.netd_path)torch.save(netg.state_dict(), opt.netg_path)

效果展示

生成网络随机生成的头像

基于GAN的动漫头像生成相关推荐

  1. 基于GAN的动漫头像生成系统(源码&教程)

    1.研究背景 我们都喜欢动漫角色,并试图创造我们的定制角色.然而,要掌握绘画技巧需要巨大的努力,之后我们首先有能力设计自己的角色.为了弥补这一差距,动画角色的自动生成提供了一个机会,在没有专业技能的情 ...

  2. 基于DCGAN的动漫头像生成

    基于DCGAN的动漫头像生成 数据 数据集:动漫图库爬虫获得,经过数据清洗,裁剪得到动漫头像.分辨率为3 * 96 * 96,共5万多张动漫头像的图片,从知乎用户何之源处下载. 生成器:输入为随机噪声 ...

  3. 基于DCGAN的动漫头像生成神经网络实现

    一.前言 1.什么是DCGAN? 2.DCGAN的TensorFlow实现 3.什么是转置卷积? 4.转置卷积的Tensorflow实现 5.Batch Normalization解读 本文假设读者已 ...

  4. 基于DCGAN动漫头像生成的课题意义200字

    DCGAN(深度卷积生成对抗网络)使用生成式对抗网络(GAN)技术来生成动漫头像,这是一个有趣而有趣的研究课题.DCGAN可以有效地利用深度学习技术来自动生成许多有趣的动漫头像.它可以使用较少的数据集 ...

  5. 基于DCGAN动漫头像生成的意义用论文方式表达

    针对DCGAN技术用于动漫头像生成的意义,本文试图从技术层面探索和讨论该技术的可用性.本文将探索DCGAN技术如何利用卷积神经网络来模拟动漫头像,以及如何使用自动编码器和生成器来创建动漫头像.本文将比 ...

  6. Pytorch 使用GAN实现二次元人物头像生成 保姆级教程(数据集+实现代码+数学原理)

    Pytorch 使用DCGAN实现二次元人物头像生成(实现代码+公式推导) GAN介绍   算法主体   推导证明(之后将补全完整过程)   随机梯度下降训练D,G   DCGAN介绍及相关原理 Py ...

  7. 【生成对抗网络】基于DCGAN的二次元人物头像生成(TensorFlow2)

    文章目录 1 导包 2 数据准备 3 定义生成器 4 定义判别器 5 定义损失函数和优化器 6 定义训练批次函数 7 定义可视化训练结果函数 8 定义训练主函数 9 训练 10 结果 11 使用生成器 ...

  8. 动漫头像生成的课题背景200字不要和上面的重复

    近年来,使用动漫头像的应用越来越多,从社交媒体.单机游戏.网络游戏等等,动漫头像的应用涉及到许多方面.随着人工智能技术的发展,让计算机自动生成动漫头像的研究越来越受到重视.该课题的研究目的是使用深度学 ...

  9. 深度学习之基于GAN实现手写数字生成

    在弄毕设的时候,室友的毕设是基于DCGAN实现音乐的自动生成.那是第一次接触对抗神经网络,当时听室友的描述就是两个CNN,一个生成一个监测,在互相博弈. 最近我关注的一个大神在弄有关于GAN的东西,所 ...

  10. Python基于Tensorflow实现DCGAN-动漫头像生成

    目录 前言 DCGAN简介 python代码 1. 导入python包.定义全局变量 2. 读取数据 3. 搭建生成器generator 4. 搭建判别器discriminator 5.  搭建GAN ...

最新文章

  1. 【Kaggle-MNIST之路】自定义程序结构(七)
  2. 文巾解题35. 搜索插入位置
  3. c语言复杂的程序代码,C语言中复杂结构的序列化
  4. 分享在MVC3.0中使用jQuery DataTable 插件
  5. 基于VS Code创建Java command-line app
  6. ssh 看apache_使用Apache KeyedObjectPool的ssh连接池
  7. 【017】◀▶ C#学习(九) - ADO.NET
  8. 算法笔记_100:蓝桥杯练习 算法提高 三个整数的排序(Java)
  9. php 正则匹配所有路径,与文件路径匹配的PHP正则表达式
  10. ofstream清空文件内容_回收站被删除的文件怎么恢复 回收站清空了怎么恢复
  11. libiec61850探究【1】-第一个MMS通讯实例
  12. 面向深度学习的多模态融合技术研究综述
  13. 高等数学张宇18讲 第二讲 极限与连续
  14. UVA 10098 Generating Fast
  15. defined 用法
  16. 电磁原理---电磁炉
  17. Taro(React)实现具有滚动效果的倒数计时器
  18. MultiSim电路仿真之受控源的使用
  19. Flink DataStream中join
  20. shell脚本——批量创建用户

热门文章

  1. 开学作业——如何做好课堂笔记
  2. 汽车零部件行业追溯系统的应用
  3. 梦三国测试服显示连接服务器失败,我的登陆进去以后说与服务器失去连接怎么回事...
  4. 经典上海弄堂线路攻略
  5. Android蓝牙通讯(服务端、客户端)
  6. java实现token 过期,java – SQS ExpiredToken:请求中包含的安全令牌是过期状态码:403...
  7. PVE系统更换大硬盘的扩容方法
  8. fedora linux搜狗输入法,GitHub - Hello-Linux/fedora-Sougou-Pinyin: fedora 搜狗拼音,安装超简单,各种精美皮肤!...
  9. MTK6577 Android源代码目录
  10. 自动化学科前沿讲座分享,作业,自动化与人工智能