介绍

本文是李宏毅GAN课程课后作业HW3_1(二次元头像生成,Keras实现)的Pytorch版本。写这篇的原因是一方面刚开始接触GAN,二是个人比较习惯用Pytorch,所以将keras改成Pytorch版本。

实现所需要的资源:

链接:https://pan.baidu.com/s/1cLmFNQpJe1DOI96IVuvVyQ
提取码:nha2

本文一个改动就是将kernel=4变成了3,因为kernel一般都是奇数。其他和原网络基本相同。

下面是主要部分的代码,包括网络模块和训练/验证/测试两个模块。
完整的代码见 https://github.com/AsajuHuishi/Generate_a_quadratic_image_with_GAN

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn as nn
from torch.autograd import Variableimport matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import time
import visdom

1.网络模块

生成器
##定义卷积核
def default_conv(in_channels,out_channels,kernel_size,bias=True):return nn.Conv2d(in_channels,out_channels,kernel_size,padding=kernel_size//2,   #保持尺寸bias=bias)
##定义ReLU
def default_relu():return nn.ReLU(inplace=True)
## reshape
def get_feature(x):return x.reshape(x.size()[0],128,16,16)class Generator(nn.Module):def __init__(self,input_dim=100,conv=default_conv,relu=default_relu,reshape=get_feature):super(Generator,self).__init__()head = [nn.Linear(input_dim,128*16*16),relu()]self.reshape = reshape                               #16x16body = [nn.Upsample(scale_factor=2,mode='nearest'),  #32x32conv(128,128,3),relu(),nn.Upsample(scale_factor=2,mode='nearest'),  #64x64conv(128,64,3),relu(),conv(64,3,3),nn.Tanh()]self.head = nn.Sequential(*head)self.body = nn.Sequential(*body)def forward(self,x):#x:(batchsize,input_dim)x = self.head(x)x = self.reshape(x)x = self.body(x)return x        #(batchsize,3,64,64)def name(self):return 'Generator'
判别器
class Discriminator(nn.Module):def __init__(self,conv=default_conv,relu=default_relu):super(Discriminator,self).__init__()main = [conv(3,32,3),relu(),conv(32,64,3),relu(),conv(64,128,3),relu(),conv(128,256,3),relu()]self.main = nn.Sequential(*main)self.fc = nn.Linear(256*64*64,1)self.sigmoid = nn.Sigmoid()def forward(self,x):#x:(batchsize,3,64,64)x = self.main(x)#(b,256,64,64)x = x.view(x.size()[0],-1)#(b,256*64*64)x = self.fc(x) #(b,1)x = self.sigmoid(x)return xdef name(self):return 'Discriminator'

2.训练/验证/测试模块

相关参数、模型初始化
class GAN(nn.Module):def __init__(self,args):super(GAN,self).__init__()self.img_size = 64self.channels = 3    self.latent_dim = args.latent_dimself.num_epoch = args.num_epochself.batch_size = args.batch_sizeself.cuda = args.cudaself.interval = 20 #每相邻20个epoch验证一次self.continue_training = args.continue_training #是否是继续训练        ## 生成器初始化self.generator = Generator(self.latent_dim)## 判别器初始化self.discriminator = Discriminator()self.testmodelpath = args.testmodelpathself.datapath = args.datapathif self.cuda:self.generator.cuda()self.discriminator.cuda()self.continue_training_isrequired()
训练+dataloader数据集
    def trainer(self):## 读入图片数据,分batchprint('===> Data preparing...')import torchvision.transforms as transformsfrom torch.utils.data import DataLoaderfrom torchvision.datasets import ImageFoldertransform = transforms.ToTensor()  ##dataloader输出是tensor,不加这个会报错dataset = ImageFolder(self.datapath,transform=transform)dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)       ##drop_last: dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃num_batch = len(dataloader) #batch的数量为len(dataloader)=总图片数/batchsizeprint('num_batch:',num_batch)#dataloader: (batchsize,3,64,64) 分布0-1## 判别值target_real = Variable(torch.ones(self.batch_size,1))target_false = Variable(torch.zeros(self.batch_size,1))one_const = Variable(torch.ones(self.batch_size,1))if self.cuda:target_real = target_real.cuda()target_false = target_false.cuda()one_const = one_const.cuda()## 优化器optim_generator = optim.Adam(self.generator.parameters(),lr=0.0002,betas=(0.5,0.999))optim_discriminator = optim.Adam(self.discriminator.parameters(),lr=0.0002,betas=(0.5,0.999))## 误差函数
#        content_criterion = nn.MSELoss()adversarial_criterion = nn.BCELoss()## 训 练 开 始for epoch in range(self.start_epoch,self.num_epoch): ##epoch##用于观察一个epoch不同batch的平均lossmean_dis_loss = 0.0mean_gen_con_loss = 0.0mean_gen_adv_loss = 0.0mean_gen_total_loss = 0.0for i,data in enumerate(dataloader):  ##循环次数:batch的数量为len(dataloader)=总图片数//batchsizeif epoch<3 and i%10==0:print('epoch%d: %d/%d'%(epoch,i,len(dataloader)))##1.1生成noisegen_input = np.random.normal(0,1,(self.batch_size,self.latent_dim)).astype(np.float32)gen_input = torch.from_numpy(gen_input)gen_input = torch.autograd.Variable(gen_input,requires_grad=True)if self.cuda:gen_input = gen_input.cuda()                    fake = self.generator(gen_input) ##生成器生成的(batchsize,3,64,64)real, _ = data  #data:list[tensor,tensor]取第零个 real:(batchsize,3,64,64)if self.cuda:real = real.cuda()fake = fake.cuda()## 1.固定G,训练判别器D                self.discriminator.zero_grad()dis_loss1 = adversarial_criterion(self.discriminator(real),target_real)dis_loss2 = adversarial_criterion(self.discriminator(fake.detach()),target_false)##注意经过G的网络再进入D网络之前要detach()之后再进入dis_loss = 0.5*(dis_loss1+dis_loss2)
#                print('epoch:%d--%d,判别器loss:%.6f'%(epoch,i,dis_loss))dis_loss.backward()optim_discriminator.step()mean_dis_loss+=dis_loss## 2.固定D,训练生成器Gself.generator.zero_grad()##生成noisegen_input = np.random.normal(0,1,(self.batch_size,self.latent_dim)).astype(np.float32)gen_input = torch.from_numpy(gen_input)gen_input = torch.autograd.Variable(gen_input,requires_grad=True)if self.cuda:gen_input = gen_input.cuda()    fake = self.generator(gen_input) ##生成器生成的(batchsize,3,64,64)  gen_con_loss = 0gen_adv_loss = adversarial_criterion(self.discriminator(fake),one_const)##固定D更新Ggen_total_loss = gen_con_loss + gen_adv_loss
#                print('epoch:%d--%d,生成器loss:%.6f'%(epoch,i,gen_total_loss))gen_total_loss.backward()optim_generator.step()mean_gen_con_loss+=gen_con_lossmean_gen_adv_loss+=gen_adv_lossmean_gen_total_loss+=gen_total_loss## 一个epoch输出一次print('epoch:%d/%d'%(epoch, self.num_epoch))print('Discriminator_Loss: %.4f'%(mean_dis_loss/num_batch))print('Generator_total_Loss:%.4f'%(mean_gen_total_loss/num_batch))## 保存模型state_dis = {'dis_model': self.discriminator.state_dict(), 'epoch': epoch}state_gen = {'gen_model': self.generator.state_dict(), 'epoch': epoch}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint') torch.save(state_dis, 'checkpoint/'+self.discriminator.name()+'__'+str(epoch+1)) #each epochtorch.save(state_gen, 'checkpoint/'+self.generator.name()+'__'+str(epoch+1))     #each epochtorch.save(state_dis, 'checkpoint/'+self.discriminator.name())    #final  torch.save(state_gen, 'checkpoint/'+self.generator.name())        #final  ## 验证模型if epoch<45 or epoch%self.interval==0:self.validater(epoch)print('--'.center(12,'-'))
验证
    def validater(self,epoch):vis = visdom.Visdom(env='generate_girl_epoch%d'%(epoch))r,c = 3,3gen_input_val = np.random.normal(0,1,(r*c,self.latent_dim)).astype(np.float32)gen_input_val = torch.from_numpy(gen_input_val)gen_input_val = torch.autograd.Variable(gen_input_val)if self.cuda:gen_input_val = gen_input_val.cuda()   output_val = self.generator(gen_input_val)     #(r*c,3,64,64)output_val = output_val.cpu()output_val = output_val.data.numpy()      #(r*c,3,64,64)        img = np.transpose(output_val,(0,2,3,1))  #(r*c,64,64,3) fig, axs = plt.subplots(r,c)cnt = 0for i in range(r):for j in range(c):vis.image(output_val[cnt],opts={'title':'epoch%d_cnt%d'%(epoch,cnt)}) axs[i, j].imshow(img[cnt, :, :, :])axs[i, j].axis('off')cnt += 1   if not os.path.isdir('images'):os.mkdir('images') fig.savefig('images/val_%d.png'%(epoch+1)) ##保存验证结果plt.close()
测试
    def tester(self,gen_input_test): #输入(N,latent_dim)assert gen_input_test.shape[1]==self.latent_dim, \'dimension 1''s size expect %d,but input %d'%(self.latent_dim,gen_input_test.shape[1])gen_input_test = gen_input_test.astype(np.float32)gen_input_test = torch.from_numpy(gen_input_test)gen_input_test = torch.autograd.Variable(gen_input_test)if self.cuda:gen_input_test = gen_input_test.cuda()   ## 下载验证结果if os.path.isdir('checkpoint'):try:checkpoint_gen = torch.load(self.testmodelpath)self.generator.load_state_dict(checkpoint_gen['gen_model'])except FileNotFoundError:print('Can\'t found dict')output_test = self.generator(gen_input_test)          output_test = output_test.cpu()output_test = output_test.data.numpy()      #(N,3,64,64)img = np.transpose(output_test,(0,2,3,1))  #(N,64,64,3) if not os.path.isdir('images'):os.mkdir('images')         N = img.shape[0] #图像个数for i in range(N):plt.imshow(img[i, :, :, :])plt.axis('off')plt.savefig('images/test_%d.png'%(i+1)) ##保存结果plt.close()

结果和原keras相比没什么区别,毕竟网络都差不多,也不需要过高期望,而且网络本身比较小,生成一个好看的人脸,对是五官是否协调有很大的要求,是很有挑战的事情。

输入:

np.random.normal(0,1,(1,self.latent_dim)).astype(np.float32)

部分结果:

GAN二次元头像生成Pytorch实现(附完整代码)相关推荐

  1. 基于pytorch搭建多特征CNN-LSTM时间序列预测代码详细解读(附完整代码)

    系列文章目录 lstm系列文章目录 1.基于pytorch搭建多特征LSTM时间序列预测代码详细解读(附完整代码) 2.基于pytorch搭建多特征CNN-LSTM时间序列预测代码详细解读(附完整代码 ...

  2. python代码手机壁纸_Python制作微信好友背景墙教程(附完整代码)

    引言 前段时间,微信朋友圈开始出现了一种晒照片新形式,微信好友墙,即在一张大图片中展示出自己的所有微信好友的头像. 效果如下图,出于隐私考虑,这里作了模糊处理. 是不是很炫,而且这还是独一无二的,毕竟 ...

  3. SpringSecurity的安全认证的详解说明(附完整代码)

    SpringSecurity登录认证和请求过滤器以及安全配置详解说明 环境 系统环境:win10 Maven环境:apache-maven-3.8.6 JDK版本:1.8 SpringBoot版本:2 ...

  4. 单选按钮_PerlTk教程之按钮Button、复选按钮Checkbutton、单选按钮Radiobutton(附完整代码)...

    <Perl-Tk教程之按钮Button.复选按钮Checkbutton.单选按钮Radiobutton>Perl-Tk中有三种不同形式的按钮组件可供选择,它们分别是按钮(Button), ...

  5. c++ 三次多项式拟合_线性回归进阶版,多项式线性回归讲解与实现(附完整代码)...

    每天给小编五分钟,小编用自己的代码,带你轻松学习深度学习!本文将会带你做完一个深度学习进阶版的线性回归---多项式线性回归,带你进一步掌握线性回归这一深度学习经典模型,然后在此基础上,小编将在下篇文章 ...

  6. java抽奖_JAVA实现用户抽奖功能(附完整代码)

    需求分析 1)实现三个基本功能:登录.注册.抽奖. 2)登录:用户输入账号密码进行登录,输入账号后会匹配已注册的用户,若输入用户不存在则退出,密码有三次输入机会,登录成功后主界面会显示已登录用户的账号 ...

  7. 想要快速爬取整站图片?速进(附完整代码)

      大家好,我是不温卜火,是一名计算机学院大数据专业大三的学生,昵称来源于成语-不温不火,本意是希望自己性情温和.作为一名互联网行业的小白,博主写博客一方面是为了记录自己的学习过程,另一方面是总结自己 ...

  8. java登录注册抽奖完整代码_JAVA实现用户抽奖功能(附完整代码)

    需求分析 1)实现三个基本功能:登录.注册.抽奖. 2)登录:用户输入账号密码进行登录,输入账号后会匹配已注册的用户,若输入用户不存在则退出,密码有三次输入机会,登录成功后主界面会显示已登录用户的账号 ...

  9. 基于MATLAB的三维数据插值拟合与三次样条拟合算法(附完整代码)

    目录 一. 三维插值 例题1 二. 高维度插值拟合 格式一 格式二 格式三 格式四 格式五 例题2 三. 单变量三次样条插值 例题3 例题4 四. 多变量三次样条插值 例题6 一. 三维插值 首先三维 ...

最新文章

  1. Python:XPath与lxml类库
  2. css布局中的居中问题
  3. 北京智源人工智能研究院2020年博士后招收简章
  4. 《LeetCode力扣练习》剑指 Offer 24. 反转链表 Java
  5. 在线IDE之关键字另色显示
  6. AliOS Things 硬件抽象层(HAL)对接系列2 — SPI driver porting
  7. Leetcode-121. 买卖股票的最佳时机
  8. mysql killed进程不结束_优秀的数据库产品——MySQL 云数据库服务
  9. sql中datetime日期类型字段比较(mysqloracle)
  10. Buck-Boost变换
  11. 什么是 Caché?
  12. LeetCode 338. 比特位计数(动态规划)
  13. 大数据分析平台搭建方式有哪些
  14. 希望是一个全新的开始
  15. 《人生若只如初见——古典诗词的美丽与哀愁》--安意如
  16. Mac Android Studio连接MuMu模拟器
  17. Pytorch框架学习记录10——线性层
  18. 云班课计算机基础测试题,云班课在高职计算机基础微课教学中应用探究.doc
  19. html中多个空格怎么打?
  20. 【2309. 兼具大小写的最好英文字母】

热门文章

  1. 结构化查询语句简称mysql_整理MySql常用查询语句
  2. go第三方库文档 日志构建zap
  3. 解剖SQLSERVER 第十七篇 使用 OrcaMDF Corruptor 故意损坏数据库(译)
  4. 基于JAVA的游戏补丁共享网站实现
  5. java企查查爬_爬取企查查热搜
  6. 面向民航的航空数据链协议解析应用研究
  7. python中倒背如流_倒背如流中倒背是什么意思古代有种背书方法是倒背,倒
  8. 【读书笔记】《JS函数式编程指南》(一)
  9. JAVA RESTful WebService实战笔记(二)
  10. 自己做量化交易软件(9通通量化框架的雏形建立