文章目录

  • 1.实现效果
  • 2.环境配置
    • 2.1Python
    • 2.2Pytorch、CUDA
    • 2.3Python IDE
  • 3.具体实现
    • 3.1数据预处理(data.py)
      • (1)导入包
      • (2)定义数据类
    • 3.2模型Generator,Discriminator,权重初始化(model.py)
      • (1)导入包
      • (2)Generator
      • (3)Discriminator
      • (4)权重初始化
    • 3.3网络训练(net.py)
      • (1)导入包
      • (2)创建类
    • 3.4 主函数(main.py)
      • (1)导入文件
      • (2)定义超参数
      • (3)实例化
      • (4)进行训练
  • 4.训练过程
    • 4.1 Generator和Discriminator的Loss损失曲线图
    • 4.2 D(x)和D(G(z))曲线图
    • 4.3最终生成结果图
  • 5.完整代码
  • 6.引用参考
  • 7.问题反馈

1.实现效果

使用DCGAN训练faces数据集,最终实现生成二次元动漫头像。
最后虽然生成了动漫头像,但是一些细节还是和真实的图像差别较大,比如说眼睛大小,眼睛颜色等。
之后我会将MINIST数据集、Oxford17数据集、以及faces数据集在训练过程中不同轮次的输出结果做一个总结。
生成二次元动漫头像的程序依然是沿用data.py、model.py、net.py、main.py但具体的编程的细节呢有所改变。
之前MINIST以及Oxford17数据集的程序
这里:
【Pytorch】DCGAN实战(一):基于MINIST数据集的手写数字生成
【Pytorch】DCGAN实战(二):基于Oxord17的鲜花图像生成

2.环境配置

2.1Python

Python版本为3.7

2.2Pytorch、CUDA

在这里不详细介绍了,网上有很多的安装教程,小伙伴们自行查找吧!

2.3Python IDE

Pycharm

3.具体实现

整体分为4个文件:data.py、model.py、net.py、main.py

3.1数据预处理(data.py)

(1)导入包

from torch.utils.data import DataLoader
from torchvision import utils, datasets, transforms

(2)定义数据类

class ReadData():def __init__(self,data_path,image_size=64):self.root=data_pathself.image_size=image_sizeself.dataset=self.getdataset()def getdataset(self):#3.datasetdataset = datasets.ImageFolder(root=self.root,transform=transforms.Compose([transforms.Resize(self.image_size),transforms.CenterCrop(self.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))print(f'Total Size of Dataset: {len(dataset)}')return datasetdef getdataloader(self,batch_size=128):dataloader = DataLoader(self.dataset,batch_size=batch_size,shuffle=True,num_workers=0)return dataloader

3.2模型Generator,Discriminator,权重初始化(model.py)

(1)导入包

import torch.nn as nn

(2)Generator

class Generator(nn.Module):def __init__(self, nz,ngf,nc):super(Generator, self).__init__()self.nz = nzself.ngf = ngfself.nc=ncself.main = nn.Sequential(# input is Z, going into a convolutionnn.ConvTranspose2d(self.nz, self.ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(self.ngf * 8),nn.ReLU(True),# state size. (ngf*8) x 4 x 4nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ngf * 4),nn.ReLU(True),# state size. (ngf*4) x 8 x 8nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ngf * 2),nn.ReLU(True),# state size. (ngf*2) x 16 x 16nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ngf),nn.ReLU(True),# state size. (ngf) x 32 x 32nn.ConvTranspose2d(self.ngf, self.nc, 4, 2, 1, bias=False),nn.Tanh()# state size. (nc) x 64 x 64)def forward(self, input):return self.main(input)

(3)Discriminator

class Discriminator(nn.Module):def __init__(self, ndf,nc):super(Discriminator, self).__init__()self.ndf=ndfself.nc=ncself.main = nn.Sequential(# input is (nc) x 64 x 64nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),# state size. (1) x 1 x 1nn.Sigmoid())def forward(self, input):return self.main(input)

(4)权重初始化

def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)

3.3网络训练(net.py)

(1)导入包

import torch
import torch.nn as nn
from torchvision import utils, datasets, transforms
import time
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import os

(2)创建类

class DCGAN():def __init__(self,lr,beta1,nz, batch_size,num_showimage,device, model_save_path,figure_save_path,generator, discriminator, data_loader,):self.real_label=1self.fake_label=0self.nz=nzself.batch_size=batch_sizeself.num_showimage=num_showimageself.device = deviceself.model_save_path=model_save_pathself.figure_save_path=figure_save_pathself.G = generator.to(device)self.D = discriminator.to(device)self.opt_G=torch.optim.Adam(self.G.parameters(), lr=lr, betas=(beta1, 0.999))self.opt_D = torch.optim.Adam(self.D.parameters(), lr=lr, betas=(beta1, 0.999))self.criterion = nn.BCELoss().to(device)self.dataloader=data_loaderself.fixed_noise = torch.randn(self.num_showimage, nz, 1, 1, device=device)self.img_list = []self.G_loss_list = []self.D_loss_list = []self.D_x_list = []self.D_z_list = []def train(self,num_epochs):loss_tep = 10G_loss=0D_loss=0print("Starting Training Loop...")# For each epochfor epoch in range(num_epochs):#**********计时*********************beg_time = time.time()# For each batch in the dataloaderfor i, data in enumerate(self.dataloader):############################# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))###########################x = data[0].to(self.device)b_size = x.size(0)lbx = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device)D_x = self.D(x).view(-1)LossD_x = self.criterion(D_x, lbx)D_x_item = D_x.mean().item()# print("log(D(x))")z = torch.randn(b_size, self.nz, 1, 1, device=self.device)gz = self.G(z)lbz1 = torch.full((b_size,), self.fake_label, dtype=torch.float, device=self.device)D_gz1 = self.D(gz.detach()).view(-1)LossD_gz1 = self.criterion(D_gz1, lbz1)D_gz1_item = D_gz1.mean().item()# print("log(1 - D(G(z)))")LossD = LossD_x + LossD_gz1# print("log(D(x)) + log(1 - D(G(z)))")self.opt_D.zero_grad()LossD.backward()self.opt_D.step()# print("update LossD")D_loss+=LossD############################# (2) Update G network: maximize log(D(G(z)))###########################lbz2 = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device) # fake labels are real for generator costD_gz2 = self.D(gz).view(-1)D_gz2_item = D_gz2.mean().item()LossG = self.criterion(D_gz2, lbz2)# print("log(D(G(z)))")self.opt_G.zero_grad()LossG.backward()self.opt_G.step()# print("update LossG")G_loss+=LossGend_time = time.time()# **********计时*********************run_time = round(end_time - beg_time)# print('lalala')print(f'Epoch: [{epoch + 1:0>{len(str(num_epochs))}}/{num_epochs}]',f'Step: [{i + 1:0>{len(str(len(self.dataloader)))}}/{len(self.dataloader)}]',f'Loss-D: {LossD.item():.4f}',f'Loss-G: {LossG.item():.4f}',f'D(x): {D_x_item:.4f}',f'D(G(z)): [{D_gz1_item:.4f}/{D_gz2_item:.4f}]',f'Time: {run_time}s',end='\r\n')# print("lalalal2")# Save Losses for plotting laterself.G_loss_list.append(LossG.item())self.D_loss_list.append(LossD.item())# Save D(X) and D(G(z)) for plotting laterself.D_x_list.append(D_x_item)self.D_z_list.append(D_gz2_item)# # Save the Best Model# if LossG < loss_tep:#     torch.save(self.G.state_dict(), 'model.pt')#     loss_tep = LossGif not os.path.exists(self.model_save_path):os.makedirs(self.model_save_path)torch.save(self.D.state_dict(), self.model_save_path + 'disc_{}.pth'.format(epoch))torch.save(self.G.state_dict(), self.model_save_path + 'gen_{}.pth'.format(epoch))# Check how the generator is doing by saving G's output on fixed_noisewith torch.no_grad():fake = self.G(self.fixed_noise).detach().cpu()self.img_list.append(utils.make_grid(fake * 0.5 + 0.5, nrow=10))print()if not os.path.exists(self.figure_save_path):os.makedirs(self.figure_save_path)plt.figure(1,figsize=(8, 4))plt.title("Generator and Discriminator Loss During Training")plt.plot(self.G_loss_list[::10], label="G")plt.plot(self.D_loss_list[::10], label="D")plt.xlabel("iterations")plt.ylabel("Loss")plt.axhline(y=0, label="0", c="g")  # asymptoteplt.legend()plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'loss.jpg', bbox_inches='tight')plt.figure(2,figsize=(8, 4))plt.title("D(x) and D(G(z)) During Training")plt.plot(self.D_x_list[::10], label="D(x)")plt.plot(self.D_z_list[::10], label="D(G(z))")plt.xlabel("iterations")plt.ylabel("Probability")plt.axhline(y=0.5, label="0.5", c="g")  # asymptoteplt.legend()plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'D(x)D(G(z)).jpg', bbox_inches='tight')fig = plt.figure(3,figsize=(5, 5))plt.axis("off")ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in self.img_list]ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)HTML(ani.to_jshtml())# ani.to_html5_video()ani.save(self.figure_save_path + str(num_epochs) + 'epochs_' + 'generation.gif')plt.figure(4,figsize=(8, 4))# Plot the real imagesplt.subplot(1, 2, 1)plt.axis("off")plt.title("Real Images")real = next(iter(self.dataloader))  # real[0]image,real[1]labelplt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))# Load the Best Generative Model# self.G.load_state_dict(#     torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))self.G.eval()# Generate the Fake Imageswith torch.no_grad():fake = self.G(self.fixed_noise).cpu()# Plot the fake imagesplt.subplot(1, 2, 2)plt.axis("off")plt.title("Fake Images")fake = utils.make_grid(fake[:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0)plt.imshow(fake)# Save the comparation resultplt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'result.jpg', bbox_inches='tight')plt.show()def test(self,epoch):# Size of the Figureplt.figure(figsize=(8, 4))# Plot the real imagesplt.subplot(1, 2, 1)plt.axis("off")plt.title("Real Images")real = next(iter(self.dataloader))#real[0]image,real[1]labelplt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))# Load the Best Generative Modelself.G.load_state_dict(torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))self.G.eval()# Generate the Fake Imageswith torch.no_grad():fake = self.G(self.fixed_noise.to(self.device))# Plot the fake imagesplt.subplot(1, 2, 2)plt.axis("off")plt.title("Fake Images")fake = utils.make_grid(fake * 0.5 + 0.5, nrow=10)plt.imshow(fake.permute(1, 2, 0))# Save the comparation resultplt.savefig(self.figure_save_path+'result.jpg', bbox_inches='tight')plt.show()

3.4 主函数(main.py)

(1)导入文件

from data import ReadData
from model import Discriminator, Generator, weights_init
from net import DCGAN
import torch

(2)定义超参数

ngpu=1
ngf=64
ndf=64
nc=3
nz=100
lr=0.003
beta1=0.5
batch_size=100
num_showimage=100data_path="./oxford17_class"
model_save_path="./models/"
figure_save_path="./figures/"device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

(3)实例化

dataset=ReadData(data_path)
dataloader=dataset.getdataloader(batch_size=batch_size)G = Generator(nz,ngf,nc).apply(weights_init)
print(G)
D = Discriminator(ndf,nc).apply(weights_init)
print(D)dcgan=DCGAN( lr,beta1,nz,batch_size,num_showimage,device, model_save_path,figure_save_path,G, D, dataloader)

(4)进行训练

dcgan.train(num_epochs=20)

4.训练过程

4.1 Generator和Discriminator的Loss损失曲线图

训练过程中Generator和Discriminator的Loss曲线图(以200个epoch为例):

4.2 D(x)和D(G(z))曲线图

训练过程中Discriminator输出(以200个epoch为例):

4.3最终生成结果图

训练结束后生成图片(以5个epoch为例):

5.完整代码

链接:https://pan.baidu.com/s/15J6sZL3rCPLm2jZFEuyzNw
提取码:DGAN

6.引用参考

https://blog.csdn.net/qq_42951560/article/details/112199229
https://blog.csdn.net/qq_42951560/article/details/110308336

7.问题反馈

如果运行有问题,欢迎给我私信留言!

【Pytorch】DCGAN实战(三):二次元动漫头像生成相关推荐

  1. 基于DCGAN的动漫头像生成

    基于DCGAN的动漫头像生成 数据 数据集:动漫图库爬虫获得,经过数据清洗,裁剪得到动漫头像.分辨率为3 * 96 * 96,共5万多张动漫头像的图片,从知乎用户何之源处下载. 生成器:输入为随机噪声 ...

  2. 基于DCGAN动漫头像生成的意义用论文方式表达

    针对DCGAN技术用于动漫头像生成的意义,本文试图从技术层面探索和讨论该技术的可用性.本文将探索DCGAN技术如何利用卷积神经网络来模拟动漫头像,以及如何使用自动编码器和生成器来创建动漫头像.本文将比 ...

  3. 基于DCGAN动漫头像生成的课题意义200字

    DCGAN(深度卷积生成对抗网络)使用生成式对抗网络(GAN)技术来生成动漫头像,这是一个有趣而有趣的研究课题.DCGAN可以有效地利用深度学习技术来自动生成许多有趣的动漫头像.它可以使用较少的数据集 ...

  4. 基于GAN的动漫头像生成系统(源码&教程)

    1.研究背景 我们都喜欢动漫角色,并试图创造我们的定制角色.然而,要掌握绘画技巧需要巨大的努力,之后我们首先有能力设计自己的角色.为了弥补这一差距,动画角色的自动生成提供了一个机会,在没有专业技能的情 ...

  5. 动漫头像生成的课题背景200字不要和上面的重复

    近年来,使用动漫头像的应用越来越多,从社交媒体.单机游戏.网络游戏等等,动漫头像的应用涉及到许多方面.随着人工智能技术的发展,让计算机自动生成动漫头像的研究越来越受到重视.该课题的研究目的是使用深度学 ...

  6. (tensorflow学习) DCGAN 动漫头像生成

    训练过程 数据收集 训练DCGAN生成动漫头像,首先需要大量的训练数据 可以用python写个爬虫去https://konachan.net/爬动漫图片, 然后用https://github.com/ ...

  7. 使用 gan-qp 做动漫头像生成实验

    代码托管在github https://github.com/One-sixth/gan-qp-mod-pytorch 我使用 gan-qp 在我的动漫头像数据集(3.7w张图像)上训练的,不过效果不 ...

  8. 生成二次元动漫头像DCGAN代码实现

    写在前面 此份代码可以在pycharm上运行,前提是已经安装tensorflow2.0gpu版本 import tensorflow as tf import matplotlib.pyplot as ...

  9. 基于DCGAN的动漫头像生成神经网络实现

    一.前言 1.什么是DCGAN? 2.DCGAN的TensorFlow实现 3.什么是转置卷积? 4.转置卷积的Tensorflow实现 5.Batch Normalization解读 本文假设读者已 ...

最新文章

  1. Neutron — Hierarchical Port Binding(层次化端口绑定)
  2. 如何自动播放光盘、解决win7电脑不能播放光盘
  3. SAP Cloud for Customer My settings按钮被disable的原因分析
  4. ajax 局部页面替换innerhtml,ajax jquery 页面局部刷新的不同实现代码
  5. yaf(5) smarty
  6. 后端基础概念:各种OCV一网打尽(上篇)
  7. MyBatis之使用XML配置SQL映射(二)CRUD映射配置
  8. 小米线刷包需要解压么_小米10刷机教程,线刷升级更新官方系统包
  9. 爱心打印函数(基于EasyX图形库)
  10. 怎么批量删除 Word、Excel 以及文本文档中的空白行?
  11. RNN denoise
  12. vh和vw是什么单位?
  13. 跨时钟域处理所用到的同步器
  14. php面试题和答案整理
  15. 硬核,这年头机器人都开始自学“倒车入库”了
  16. python如何画3个相切的圆_使用python绘制4个相切的圆形
  17. 神经网络的基本结构介绍
  18. 虚拟内存,物理内存,页面文件,还有任务管理器
  19. 艾司博讯:拼多多访客怎么进来在哪里看到
  20. 每天学点英语语法-重头开始8

热门文章

  1. vscode使用lldb调试C++程序
  2. mtk6582平台GT9157触摸屏驱动移植
  3. Q2表现喜忧参半,“在线音频第一股”荔枝还甜吗?
  4. 《互联网理财一册通》一一12.1 网银银行移动客户端
  5. App出海:如何制定社媒营销策略?
  6. 敏捷项目管理流程-敏捷Scrum项目管理汇总(思维导图)
  7. 二级C语言操作例题(十)
  8. 自考会计原理和实务能带什么计算机,高等教育自学考试“会计原理与实务”命题说明...
  9. python安装dill_Python dill包_程序模块 - PyPI - Python中文网
  10. UNO 游戏实现心得 (version 1)