利用DCGAN网络生成动漫人物头像(pytorch实现)

最近在学习生成式对抗网络(GAN),非常喜欢知乎上看到的一个生成动漫人物头像的例子。但可惜的是,他是利用Tensorflow中已经有人造好的轮子:carpedm20/DCGAN-tensorflow,直接使用这个代码实现的。最近正好学完cs231N的课程,就用它来练练手吧。

一、准备数据

网上有很多GAN原理的介绍,此处不再多说,直接上代码!
首先是导入文件所需要的模块,因为本人萌新一个,所以会有些多于的包,测试使用的。

#加载所需要的模块和库,设定展示图片函数以及其它对图像预处理函数
import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import d2lzh_pytorch as d2l
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torchvision.datasets import ImageFolder
import os
from torchvision import transforms
import torchvision

确定使用GPU还是CPU

dtype = torch.cuda.FloatTensor
#cpu用这个
#dtype = torch.FloatTensor

导入图像文件
loader_train.iter().next()中有两个tensor,第一个为图像矩阵,第二个为标签
[0]表读取tensor中的第一个tensor(即图像矩阵),[1]为标签
numpy()将tensor转numpy
squeeze()从数组的形状中删除单维度条目,即把shape中为1的维度去掉,此处加不加均可

plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
#一些参数的设置
NOISE_DIM = 96
batch_size = 128
#加载图像
data_dir = 'D:/data/faces/'
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([transforms.Resize(size=28),transforms.ToTensor(),])
loader_train = DataLoader(ImageFolder(os.path.join(data_dir,), transform=train_augs),batch_size=batch_size)
der_train.__iter__().next()[0].numpy().squeeze()

图像显示函数

def show_images(images):images1 = np.reshape(images, [images.shape[0], -1])images = np.reshape(images, [images.shape[0], 3,-1])  # images reshape to (batch_size, D)sqrtn = int(np.ceil(np.sqrt(images.shape[0])))#batch_sizedim = int(images.shape[1])#3sqrtimg = int(np.ceil(np.sqrt(images.shape[2])))#96fig = plt.figure(figsize=(sqrtn, sqrtn))#Figure(864x864)#figsize=(4,3)为图像英寸宽4高3英寸gs = gridspec.GridSpec(sqrtn, sqrtn)#gridspec.GridSpec()创建区域,参数5,5的意思就是每行五个,每列五个gs.update(wspace=0.05, hspace=0.05)for i, img in enumerate(images1):#enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels( [])ax.set_aspect('equal')plt.imshow(img.reshape([dim,sqrtimg,sqrtimg]).transpose(1,2,0))#换通道,对于torch中则是permutereturn
show_images(imgs)
plt.show()

一次取128个图像,源图像大小是9696,为了减轻运行负担,我们输入2828的图像。
现在我们来看看这些图片都长啥样

二、定义随机噪声

我们生成一些随机噪声,把它扔给生成器

#Random Noise
def sample_noise(batch_size, dim):temp = torch.rand(batch_size, dim) + torch.rand(batch_size, dim) * (-1)return temp

定义一些展平、初始化函数,之后会用到


class Flatten(nn.Module):def forward(self, x):N, C, H, W = x.size()  # read in N, C, H, Wreturn x.view(N, -1)  # "flatten" the C * H * W values into a single vector per imageclass Unflatten(nn.Module):"""An Unflatten module receives an input of shape (N, C*H*W) and reshapes itto produce an output of shape (N, C, H, W)."""def __init__(self, N=-1, C=128, H=7, W=7):super(Unflatten, self).__init__()self.N = Nself.C = Cself.H = Hself.W = Wdef forward(self, x):return x.view(self.N, self.C, self.H, self.W)def initialize_weights(m):if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):init.xavier_uniform_(m.weight.data)

好,接下来开始进行DCGAN的核心部分啦

三、定义生成器、辨识器的损失函数

生成器和辨识器的loss值是训练反映效果的重要指标

Bce_loss = nn.BCEWithLogitsLoss()
def discriminator_loss(logits_real,logits_fake):#batch sizeN = logits_real.size()#目标label,全部设为1表示判别器需要做到的是将正确的全识别正确,错误全识别为错误true_labels = Variable(torch.ones(N)).type(dtype)real_image_loss = Bce_loss(logits_real,true_labels)#识别正确的为正确fake_image_loss = Bce_loss(logits_fake,1-true_labels)#识别错误的为错误loss = real_image_loss + fake_image_lossreturn lossdef generator_loss(logits_fake):#batch sizeN = logits_fake.size()#生成器的作用是将所有假向真靠拢true_labels = Variable(torch.ones(N)).type(dtype)#计算生成器的损失loss=Bce_loss(logits_fake,true_labels)return lossdef get_optimizer(model):#定义使用的优化算法optimizer = Noneoptimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999))return optimizer

四、定义DC生成器和DC辨识器

生成器生成假图片,辨识器分辨出哪些是假图片、哪些是训练图片。当辨识器无法辨别出生成器生成的图片和训练图片时,则达到了我们预期的效果。

def build_dc_classifier():"""Build and return a PyTorch model for the DCGAN discriminator implementingthe architecture above."""return nn.Sequential(Unflatten(batch_size, 3, 28, 28),nn.Conv2d(3, 32,kernel_size=5, stride=1),nn.LeakyReLU(negative_slope=0.01),nn.MaxPool2d(2, stride=2),nn.Conv2d(32, 64,kernel_size=5, stride=1),nn.LeakyReLU(negative_slope=0.01),nn.MaxPool2d(kernel_size=2, stride=2),Flatten(),nn.Linear(4*4*64, 4*4*64),nn.LeakyReLU(negative_slope=0.01),nn.Linear(4*4*64,1))
def build_dc_generator(noise_dim=NOISE_DIM):"""Build and return a PyTorch model implementing the DCGAN generator usingthe architecture described above."""return nn.Sequential(nn.Linear(noise_dim, 1024),nn.ReLU(),nn.BatchNorm1d(1024),nn.Linear(1024, 7*7*128),nn.BatchNorm1d(7*7*128),Unflatten(batch_size, 128, 7, 7),nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),nn.ReLU(inplace=True),nn.BatchNorm2d(num_features=128),nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.BatchNorm2d(num_features=64),nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1),nn.Tanh(),Flatten())

定义好生成器和辨识器后,我们开始为其准备训练函数

def run_a_gan(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250,batch_size=128, noise_size=96, num_epochs=10):iter_count = 0for epoch in range(num_epochs):for x, _ in loader_train:if len(x) != batch_size:continueD_solver.zero_grad()real_data = Variable(x).type(dtype)logits_real = D(2 * (real_data - 0.5)).type(dtype)g_fake_seed = Variable(sample_noise(batch_size, noise_size)).type(dtype)fake_images = G(g_fake_seed).detach()logits_fake = D(fake_images.view(batch_size, 3, 28, 28))d_total_error = discriminator_loss(logits_real, logits_fake)d_total_error.backward()D_solver.step()G_solver.zero_grad()g_fake_seed = Variable(sample_noise(batch_size, noise_size)).type(dtype)fake_images = G(g_fake_seed)gen_logits_fake = D(fake_images.view(batch_size, 3, 28, 28))g_error = generator_loss(gen_logits_fake)g_error.backward()G_solver.step()if (iter_count % show_every == 0):#print(iter_count, d_total_error.data, g_error.data)#imgs_numpy = fake_images.data.cpu().numpy()#show_images(imgs_numpy[0:16])print("iter_count",iter_count,"g_loss", g_error.data, "d_loss", d_total_error.data)#plt.show()#print()iter_count += 1imgs_numpy = fake_images.data.cpu().numpy()show_images(imgs_numpy[0:16])plt.show()print("iter_count", iter_count, "g_loss", g_error.data, "d_loss", d_total_error.data)print()

当以上准备完毕后,开始训练吧

D_DC = build_dc_classifier().type(dtype)
D_DC.apply(initialize_weights)
G_DC = build_dc_generator().type(dtype)
G_DC.apply(initialize_weights)D_DC_solver = get_optimizer(D_DC)
G_DC_solver = get_optimizer(G_DC)run_a_gan(D_DC, G_DC, D_DC_solver, G_DC_solver, discriminator_loss, generator_loss, num_epochs=10)

五、生成结果

迭代2000左右生成的图像如下

咳咳,虽然有些图片比较崩坏,但是能看到一些动漫头像的轮廓了,部分图片已经有着比较清晰的五官了。

由于这是笔者初学GAN的第一站,无论是判别器和生成器的架构设置,还是训练程度的把握都没有明确的认知,训练过程也没有一个确切的参数衡量图像生成的好坏,导致最后生成的图像仅仅只有一个模糊的轮廓。

想要生成清晰的人物头像可以参考https://blog.csdn.net/york1996/article/details/82776704一文。

DCGAN生成动漫人物头像---pytroch版相关推荐

  1. pytorch实现DCGAN生成动漫人物头像

    pytorch实现DCGAN生成动漫人物头像 DCGAN原理 参考这一系列文章 数据集 21551张64*64动漫人物头像 生成效果 训练1个epoch(emm-) 训练10个epoch(起码有颜色了 ...

  2. 使用DCGAN生成动漫人物头像

    使 用 D C G A N 生 成 动 漫 人 物 头 像 使用DCGAN生成动漫人物头像 使用DCGAN生成动漫人物头像 import tensorflow as tf from tensorflo ...

  3. Pytorch 使用DCGAN生成动漫人物头像 入门级实战教程

    有关DCGAN实战的小例子之前已经更新过一篇,感兴趣的朋友可以点击查看 Pytorch 使用DCGAN生成MNIST手写数字 入门级教程 有关DCGAN的相关原理:DCGAN论文解读-----DCGA ...

  4. pytorch学习——DCGAN——生成动漫人物头像

    本文参考官方博客以及李宏毅老师讲解. 另参考https://blog.csdn.net/sunqiande88/article/details/80219842 关于其中转置卷积和卷积的问题,会另外开 ...

  5. python 动漫卡通人物图片大全,『TensorFlow』DCGAN生成动漫人物头像_下

    一.计算图效果以及实际代码实现 计算图效果 实际模型实现 相关介绍移步我的github项目. 二.生成器与判别器设计 生成器 相关参量, 噪声向量z维度:100 标签向量y维度:10(如果有的话) 生 ...

  6. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

  7. 【GAN实战项目:DCGAN in Tensorflow生成动漫人物头像】代码学习

    目录 一.爬虫代码 二.头像截取 三.训练 四.遇到的问题及处理方法 五.生成效果 DCGAN的原理 和GAN是一样的,只是把G和D换成了两个卷积神经网络(CNN).但不是直接换就可以了,DCGAN对 ...

  8. 基础 | 如何通过DCGAN实现动漫人物图像的自动生成?

    点击上方"机器学习与生成对抗网络",关注"星标" 获取有趣.好玩的前沿干货! 文章来源:淘系技术 背景 基于生成对抗网络(GAN)的动漫人物生成近年来兴起的动漫 ...

  9. 深度学习之基于DCGAN实现动漫人物的生成

    注:因为硬件原因,这次的实验并没有生成图片,但是代码应该是没有问题的,可以参考学习一下. 本次基于DCGAN实现动漫人物的生成.最终的效果可以参考大神**K同学啊**的博客.与上篇文章基于DCGAN生 ...

  10. 如何通过DCGAN实现动漫人物图像的自动生成?

    深度学习领域在近几年取得了重大突破,其中大部分研究成果都基于感知技术,计算机通过模仿人类的思维方式,感知物体.识别内容.生成对抗网络的理念由Goodfellow于2014年提出的,它的发展历程只有六年 ...

最新文章

  1. 【Python】实现将testlink上的用例指定格式保存至Excel,用于修改上传
  2. 被迫重构代码,这次我干掉了 if-else
  3. Python装饰器-装饰流程,执行顺序
  4. Windows Phone 7 开发资源汇总
  5. [MySQL]MySQL分区与传统的分库分表(精华)
  6. 经典SQL语句大全(技巧篇)
  7. SSH登录太慢(等很久才提示输入密码)的问题
  8. day5 Java中的方法与重载
  9. ELK 源码详细安装步骤
  10. 摩拜回应裁员传闻:属正常业务调整 部分岗位仍在招聘
  11. logback读取src/test/resource下的配置文件
  12. 帆软邮件STMP配置、发送测试及邮件发送失败日志排查
  13. MS08_067漏洞复现
  14. 小学计算机座位安排表,戳痛父母们的班级座位表,安排孩子怎么坐也是一门学问...
  15. Let's Encrypt申请证书-保姆教程
  16. uni-app在安卓手机监听物理返回键
  17. 给自家人做个招聘广告,前后端和移动工程师看过来
  18. 厦大2021级期末上机考试
  19. 未来链上跨境支付、融资领域龙头 Tribal 的发展与机遇
  20. MySQL面试:为什么用自增列作为主键

热门文章

  1. .NET发送邮箱(验证码)
  2. 码农的自我修养之从需求分析到软件设计
  3. 算法分析与设计实验报告——图的m着色问题
  4. python中断输入_在 Python 中接管键盘中断信号的实现方法
  5. 使用MindSpore进行一阶导数计算
  6. #Geek Founders# 蒋涛的 CES 2016 感受 - Day 4 (总结版)
  7. 【数据处理】 python 极速极简画图——频数(率)分布直方图
  8. Vmware14安装ubuntu18
  9. 支付宝退款流程 php,支付宝退款接口对接流程PHP语言
  10. 例题4-6 师兄帮帮忙(A Typical Homework (a.k.a Shi Xiong Bang Bang Mang),Rujia Liu's Present 5, UVa 12412)