Pytorch:GAN生成对抗网络实现二次元人脸的生成
github:https://github.com/SPECTRELWF/pytorch-GAN-study
网络结构
最近在疯狂补深度学习一些基本架构的基础,看了一下大佬的GAN的原始论文,说实话一头雾水,不是能看的很懂。推荐B站李宏毅老师的机器学习2021的课程,听完以后明白多了。原始论文中就说了一个generator和一个discriminator的结构,并没有细节的说具体是怎么去定义的,对新手不太友好,参考了Github的Pytorch-Gan-master仓库的代码,做了一下照搬吧,照着敲一边代码就明白了GAN的思想了。网上找了一张稍微好点的网络结构图:
因为生成对抗网络需要去度量两个分布之间的距离,原始的GAN并没有一个很好的度量,具体细节可以看李宏毅老师的课。导致GAN的训练会比较困难,而且整个LOSS是基本无变化的,但从肉眼上还是能清楚的看到生成的结果在变好。
数据集介绍
使用的是网络上的二次元人脸数据集,数据集链接:
网盘链接:https://pan.baidu.com/s/1MFulwMQJ78U2MCqRUYjkMg
提取码:58v6
其中包含16412张二次元人脸图像,每张图片的分辨率为96*96,
只需要在我上一篇文章MNIST手写数字生成的基础上修改一下dataload就行,完整代码可以去github下载
generator
# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(* block(opt.latent_dim,128,normalize=False),* block(128,256),* block(256,512),* block(512,1024),nn.Linear(1024,int(np.prod(image_shape))),nn.Tanh())def forward(self,z):img = self.model(z)img = img.view(img.size(0),*image_shape)return img
discriminator
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(image_shape)),512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256,1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0),-1)validity = self.model(img_flat)return validity
完整代码:
# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/11/14 下午3:05import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from dataloader.face_loader import face_loader
import torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs('new_images', exist_ok=True)parser = argparse.ArgumentParser() # 添加参数parser.add_argument("--n_epochs", type=int, default=1000, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=256, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=96, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval betwen image samples")opt = parser.parse_args()
print(opt)image_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else False# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(* block(opt.latent_dim,128,normalize=False),* block(128,256),* block(256,512),* block(512,1024),nn.Linear(1024,int(np.prod(image_shape))),nn.Tanh())def forward(self,z):img = self.model(z)img = img.view(img.size(0),*image_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(image_shape)),512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256,1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0),-1)validity = self.model(img_flat)return validity# lossadversarial_loss = torch.nn.BCELoss()#初始化G和D
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()transforms = transforms.Compose([transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
])
# loaddata
train_data = face_loader(transforms)
dataloader = torch.utils.data.DataLoader(train_data,batch_size = opt.batch_size,shuffle=True,
)optimizer_G = torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor#train
for epoch in range(opt.n_epochs):for i ,imgs in enumerate(dataloader):valid = Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad = False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)real_imgs = Variable(imgs.type(Tensor))optimizer_G.zero_grad()z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim))))gen_imgs = generator(z)g_loss = adversarial_loss(discriminator(gen_imgs),valid)g_loss.backward()optimizer_G.step()#train Discriminatoroptimizer_D.zero_grad()real_loss = adversarial_loss(discriminator(real_imgs),valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),fake)d_loss = (real_loss+fake_loss)/2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:256], "new_images/%d.png" % batches_done, nrow=16, normalize=True)torch.save(generator.state_dict(),"G.pth")
torch.save(discriminator.state_dict(),"D.pth")
结果
GAN网络的训练是比较困难的,我设置批大小为1024,训练了两百个epoch,给出一些结果。
第0次迭代:
第1000次迭代:
这个时候其实就有了人脸的一些基本轮廓了,但细节不好,细节处理是GAN的缺点之一。
第10000次迭代:
第20000次迭代:
第30000次迭代:
第40000次迭代:
从第20000次迭代之后从肉眼上看上去就没什么进步了。
Pytorch:GAN生成对抗网络实现二次元人脸的生成相关推荐
- GAN (生成对抗网络) 手写数字图片生成
GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...
- 生成对抗网络简介,深度卷积生成对抗网络(DCGAN)简介
本博客是个人学习的笔记,讲述的是生成对抗网络(generate adversarial network ) 的一种架构:深度生成对抗网络 的简单介绍,下一节将使用 tensorflow 搭建 DCGA ...
- 基于PyTorch的生成对抗网络入门(3)——利用PyTorch搭建生成对抗网络(GAN)生成彩色图像超详解
目录 一.案例描述 二.代码详解 2.1 获取数据 2.2 数据集类 2.3 构建判别器 2.3.1 构造函数 2.3.2 测试判别器 2.4 构建生成器 2.4.1 构造函数 2.4.2 测试生成器 ...
- 学习笔记|生成对抗网络(Generative Adversarial Networks,GAN)——让机器学习具有创造力
文章目录 1.生成对抗网络概述 1.1 对"生成"的理解 1.2 对"对抗"的理解 2. 生成对抗网络的理论基础 1.生成对抗网络概述 有时候我们希望网络具有一 ...
- 生成对抗网络(GAN)相比传统训练方法有什么优势?(一)
作者:元峰 链接:https://www.zhihu.com/question/56171002/answer/148593584 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非商业转载 ...
- 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例
1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...
- 改善图像处理效果的五大生成对抗网络
作者 | Martin Isaksson 译者 | Sambodhi 策划 | 刘燕 在图像处理方面,机器学习实践者们正在逐渐转向借助生成对抗网络的力量,本文带你了解其中五种生成对抗网络,可根据自己的 ...
- 【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题
参考资料: <PyTorch深度学习>(人民邮电出版社)第7章 生成网络 PyTorch官方文档 廖星宇著<深度学习入门之Pytorch>第6章 生成对抗网络 其他参考的网络资 ...
- 生成对抗网络入门指南(内含资源和代码)
python进阶教程 机器学习 深度学习 长按二维码关注 进入正文 前言:春节假期结束了,在这个假期中,原本好好的计划终究没能跟上变化,发生了很多意想不到的事情,导致公众号有近十天没能更新,首先给长期 ...
最新文章
- 逐飞 RT1064 库 GCC (VSCode) 移植踩坑
- Github在windows7环境下使用入门
- Python笔记总结(1)
- BFE Ingress Controller正式发布!
- 斐波那契数列的3种求法及几种素数筛法
- 【ArcGIS微课1000例】0003:按属性选择(Select by Attributes)
- Oracle_spatial的空间索引
- Win10 PC 能打电话了?腾讯追讨前员工 1940 万;淘宝进军 MR 购物 | 极客头条
- 学习python:异常处理
- ip pv uv及相应统计shell
- Vista 自动激活工具(最新 最权威 所有版本 可升级)
- Texture ASTC转换ETC
- 基于opencv的图片文字识别实战
- 安徽师范大学计算机学院在哪个校区,2021年安徽师范大学有几个校区,大一新生在哪个校区...
- 游戏运动模糊技术讲解
- SQLyog 新建mysql链接时 错误号码 2058
- java暗黑破坏神,《暗黑破坏神2》1.10 雇用兵详细介绍
- 无法为数据库中的对象分配空间,因为'PRIMARY'文件组已满问题处理方式
- wps在线预览接口_WPS文档在线预览接入的一点心得
- Python编程:从入门到实践关于pi,百万位圆周率,pi_million_digits.txt,分享给大家
热门文章
- 技术部门Leader是不是一定要技术大牛担任?
- 脑电情绪识别:脑功能连接网络与局部激活信息结合
- 免费网络学术资源获取
- 在你的计算机上使用qr码登录,如何在Android 10上使用QR码共享您的Wi-fi凭据 | MOS86...
- winform listview 设置选中项 图片_实战PyQt5: 069-MV框架中的项视图拖放功能
- bilibili怎么设置弹幕数量_python爬虫:bilibili弹幕爬取+词云生成
- 子类重写父类虚函数_C/C++编程笔记:关于C++的虚函数和多态,你真的了解吗?...
- linux中atoi函数的实现 值得借鉴,【转】atoi()函数的实现
- Shiro+springboot+mybatis(md5+salt+散列)认证与授权-02
- vue实现时间选择器,精确到秒