github:https://github.com/SPECTRELWF/pytorch-GAN-study

网络结构

最近在疯狂补深度学习一些基本架构的基础,看了一下大佬的GAN的原始论文,说实话一头雾水,不是能看的很懂。推荐B站李宏毅老师的机器学习2021的课程,听完以后明白多了。原始论文中就说了一个generator和一个discriminator的结构,并没有细节的说具体是怎么去定义的,对新手不太友好,参考了Github的Pytorch-Gan-master仓库的代码,做了一下照搬吧,照着敲一边代码就明白了GAN的思想了。网上找了一张稍微好点的网络结构图:

因为生成对抗网络需要去度量两个分布之间的距离,原始的GAN并没有一个很好的度量,具体细节可以看李宏毅老师的课。导致GAN的训练会比较困难,而且整个LOSS是基本无变化的,但从肉眼上还是能清楚的看到生成的结果在变好。

数据集介绍

使用的是网络上的二次元人脸数据集,数据集链接:
网盘链接:https://pan.baidu.com/s/1MFulwMQJ78U2MCqRUYjkMg
提取码:58v6
其中包含16412张二次元人脸图像,每张图片的分辨率为96*96,


只需要在我上一篇文章MNIST手写数字生成的基础上修改一下dataload就行,完整代码可以去github下载

generator

# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(* block(opt.latent_dim,128,normalize=False),* block(128,256),* block(256,512),* block(512,1024),nn.Linear(1024,int(np.prod(image_shape))),nn.Tanh())def forward(self,z):img = self.model(z)img = img.view(img.size(0),*image_shape)return img

discriminator

class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(image_shape)),512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256,1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0),-1)validity = self.model(img_flat)return validity

完整代码:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/14 下午3:05import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from dataloader.face_loader import face_loader
import torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs('new_images', exist_ok=True)parser = argparse.ArgumentParser()  # 添加参数parser.add_argument("--n_epochs", type=int, default=1000, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=256, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=96, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval betwen image samples")opt = parser.parse_args()
print(opt)image_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else False# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(* block(opt.latent_dim,128,normalize=False),* block(128,256),* block(256,512),* block(512,1024),nn.Linear(1024,int(np.prod(image_shape))),nn.Tanh())def forward(self,z):img = self.model(z)img = img.view(img.size(0),*image_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(image_shape)),512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256,1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0),-1)validity = self.model(img_flat)return validity# lossadversarial_loss = torch.nn.BCELoss()#初始化G和D
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()transforms = transforms.Compose([transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
])
# loaddata
train_data = face_loader(transforms)
dataloader = torch.utils.data.DataLoader(train_data,batch_size = opt.batch_size,shuffle=True,
)optimizer_G = torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor#train
for epoch in range(opt.n_epochs):for i ,imgs in enumerate(dataloader):valid = Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad = False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)real_imgs = Variable(imgs.type(Tensor))optimizer_G.zero_grad()z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim))))gen_imgs = generator(z)g_loss = adversarial_loss(discriminator(gen_imgs),valid)g_loss.backward()optimizer_G.step()#train Discriminatoroptimizer_D.zero_grad()real_loss = adversarial_loss(discriminator(real_imgs),valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),fake)d_loss = (real_loss+fake_loss)/2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:256], "new_images/%d.png" % batches_done, nrow=16, normalize=True)torch.save(generator.state_dict(),"G.pth")
torch.save(discriminator.state_dict(),"D.pth")

结果

GAN网络的训练是比较困难的,我设置批大小为1024,训练了两百个epoch,给出一些结果。
第0次迭代:

第1000次迭代:

这个时候其实就有了人脸的一些基本轮廓了,但细节不好,细节处理是GAN的缺点之一。

第10000次迭代:

第20000次迭代:

第30000次迭代:

第40000次迭代:

从第20000次迭代之后从肉眼上看上去就没什么进步了。

Pytorch:GAN生成对抗网络实现二次元人脸的生成相关推荐

  1. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

  2. 生成对抗网络简介,深度卷积生成对抗网络(DCGAN)简介

    本博客是个人学习的笔记,讲述的是生成对抗网络(generate adversarial network ) 的一种架构:深度生成对抗网络 的简单介绍,下一节将使用 tensorflow 搭建 DCGA ...

  3. 基于PyTorch的生成对抗网络入门(3)——利用PyTorch搭建生成对抗网络(GAN)生成彩色图像超详解

    目录 一.案例描述 二.代码详解 2.1 获取数据 2.2 数据集类 2.3 构建判别器 2.3.1 构造函数 2.3.2 测试判别器 2.4 构建生成器 2.4.1 构造函数 2.4.2 测试生成器 ...

  4. 学习笔记|生成对抗网络(Generative Adversarial Networks,GAN)——让机器学习具有创造力

    文章目录 1.生成对抗网络概述 1.1 对"生成"的理解 1.2 对"对抗"的理解 2. 生成对抗网络的理论基础 1.生成对抗网络概述 有时候我们希望网络具有一 ...

  5. 生成对抗网络(GAN)相比传统训练方法有什么优势?(一)

    作者:元峰 链接:https://www.zhihu.com/question/56171002/answer/148593584 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非商业转载 ...

  6. 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例

    1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...

  7. 改善图像处理效果的五大生成对抗网络

    作者 | Martin Isaksson 译者 | Sambodhi 策划 | 刘燕 在图像处理方面,机器学习实践者们正在逐渐转向借助生成对抗网络的力量,本文带你了解其中五种生成对抗网络,可根据自己的 ...

  8. 【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题

    参考资料: <PyTorch深度学习>(人民邮电出版社)第7章 生成网络 PyTorch官方文档 廖星宇著<深度学习入门之Pytorch>第6章 生成对抗网络 其他参考的网络资 ...

  9. 生成对抗网络入门指南(内含资源和代码)

    python进阶教程 机器学习 深度学习 长按二维码关注 进入正文 前言:春节假期结束了,在这个假期中,原本好好的计划终究没能跟上变化,发生了很多意想不到的事情,导致公众号有近十天没能更新,首先给长期 ...

最新文章

  1. 逐飞 RT1064 库 GCC (VSCode) 移植踩坑
  2. Github在windows7环境下使用入门
  3. Python笔记总结(1)
  4. BFE Ingress Controller正式发布!
  5. 斐波那契数列的3种求法及几种素数筛法
  6. 【ArcGIS微课1000例】0003:按属性选择(Select by Attributes)
  7. Oracle_spatial的空间索引
  8. Win10 PC 能打电话了?腾讯追讨前员工 1940 万;淘宝进军 MR 购物 | 极客头条
  9. 学习python:异常处理
  10. ip pv uv及相应统计shell
  11. Vista 自动激活工具(最新 最权威 所有版本 可升级)
  12. Texture ASTC转换ETC
  13. 基于opencv的图片文字识别实战
  14. 安徽师范大学计算机学院在哪个校区,2021年安徽师范大学有几个校区,大一新生在哪个校区...
  15. 游戏运动模糊技术讲解
  16. SQLyog 新建mysql链接时 错误号码 2058
  17. java暗黑破坏神,《暗黑破坏神2》1.10 雇用兵详细介绍
  18. 无法为数据库中的对象分配空间,因为'PRIMARY'文件组已满问题处理方式
  19. wps在线预览接口_WPS文档在线预览接入的一点心得
  20. Python编程:从入门到实践关于pi,百万位圆周率,pi_million_digits.txt,分享给大家

热门文章

  1. 技术部门Leader是不是一定要技术大牛担任?
  2. 脑电情绪识别:脑功能连接网络与局部激活信息结合
  3. 免费网络学术资源获取
  4. 在你的计算机上使用qr码登录,如何在Android 10上使用QR码共享您的Wi-fi凭据 | MOS86...
  5. winform listview 设置选中项 图片_实战PyQt5: 069-MV框架中的项视图拖放功能
  6. bilibili怎么设置弹幕数量_python爬虫:bilibili弹幕爬取+词云生成
  7. 子类重写父类虚函数_C/C++编程笔记:关于C++的虚函数和多态,你真的了解吗?...
  8. linux中atoi函数的实现 值得借鉴,【转】atoi()函数的实现
  9. Shiro+springboot+mybatis(md5+salt+散列)认证与授权-02
  10. vue实现时间选择器,精确到秒