阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读。
相关资料链接为:使用PyTorch构建GAN生成对抗
本次训练代码使用了本地GPU计算。

1 CelebADataset类的修改

原则上这一类不需要修改,但为了提升模型运行速度,可以对图片周边适当裁剪,保留五官等重要内容。

# 设置裁剪功能(辅助函数)
def crop_centre(img, new_width, new_height):height, width, _ = img.shapestartx = width//2 - new_width//2starty = height//2 - new_height//2return img[  starty:starty + new_height, startx:startx + new_width, :]

上面这个函数可以用来从图像的中心裁剪。该函数接收三个参数:

  • img:原始图像,需要是 numpy 数组形式
  • new_width:裁剪后图像的新宽度
  • new_height:裁剪后图像的新高度

该函数通过计算原始图像的中心位置,以及所需裁剪图像的起始位置,从而在 numpy 数组上实现裁剪。最后,函数返回裁剪后的图像。
有了这个函数后,可以在类中预置对图像的裁剪功能,需要对类的__getitem__方法和plot_image方法进行优化。

class CelebADataset(Dataset):def __getitem__(self, index):if index >= len(self.dataset):raise IndexError()img = numpy.array(self.dataset[str(index) + '.jpg'])img = crop_centre(img, 128, 128)return torch.cuda.FloatTensor(img).permute(2,0,1).view(1,3,128,128) / 255.0def plot_image(self, index):img = numpy.array(self.dataset[str(index)+'.jpg'])img = crop_centre(img, 128, 128)plt.imshow(img, interpolation='nearest')

2 鉴别器类的修改

鉴别器的网络结构是卷积GAN需要重点修改的地方。此次的卷积GAN设置了3个卷积层和1个全连接层。

class Discriminator(nn.Module):def __init__(self):   self.model = nn.Sequential(nn.Conv2d(3, 256, kernel_size=8, stride=2),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 256, kernel_size=8, stride=2),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 3, kernel_size=8, stride=2),nn.LeakyReLU(0.2),View(3*10*10),nn.Linear(3*10*10, 1),nn.Sigmoid())

经过裁剪的图片的小为128*128;
第一个卷积层使用了256个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出256个特征图,特征图的大小为 128−82+1\frac{128-8}{2}+12128−8​+1 ,即61*61;
第二个卷积层使用了256个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出256个特征图,特征图的大小为 61−82+1\frac{61-8}{2}+1261−8​+1 ,即27*27;
第二个卷积层使用了3个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出3个特征图,特征图的大小为 27−82+1\frac{27-8}{2}+1227−8​+1 ,即10*10;
经过了3层的卷积后,图片的大小已经降到了(3*10*10)。

3 鉴别器测试

修改完鉴别器之后,可以使用真实图像和随即图像,初步判断鉴别器的能力与测试这部分修改后的代码是否存在BUG。

# 鉴别器类建立
D = Discriminator()
D.to(device)# 测试鉴别器
for image_data_tensor in celeba_dataset:# real dataD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# fake dataD.train(generate_random_image((1,3,128,128)), torch.cuda.FloatTensor([0.0]))pass

同样,可以查看损失函数的变化情况并使用测试集进行测试。

for image_data_tensor in celeba_dataset:# real dataD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# fake dataD.train(generate_random_image((1,3,128,128)), torch.cuda.FloatTensor([0.0]))passD.plot_progress()for i in range(4):image_data_tensor = celeba_dataset[random.randint(0,20000)]print( D.forward( image_data_tensor ).item() )passfor i in range(4):print( D.forward( generate_random_image((1,3,128,128))).item() )pass


可以看出,鉴别器对于数据的判断非常有信息。

之后还需对生成器进行同步修改,并使用代码生成图像,这部分内容将放在下篇。
使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成下

此部分的完整代码可在文末留言申请或在博主的资源区自行下载。

使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上相关推荐

  1. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

    文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...

  2. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

    文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...

  3. 使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类

    文章目录 1 数据准备 2 数据预览 3 简单神经网络创建 3.1 设计网络结构 3.2 损失函数相关设置 3.3 向网络传递信息 3.4 定义训练函数train 4 函数汇总 1 数据准备 神经网络 ...

  4. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

    文章目录 1 生成对抗网络基本概念 2 生成对抗网络建模 2.1 建立MnistDataset类 2.2 建立鉴别器 2.3 测试鉴别器 2.4 Mnist生成器制作 3 模型的训练 4 模型表现的判 ...

  5. 编译和安装gdb源码详细步骤介绍

    1.gdb源码下载 (1)源码下载网址:https://ftp.gnu.org/gnu/gdb/: (2)本文下面的编译是按照8.2版本的源码进行的,其余版本的源码可能会报错,需要自行解决: 2.编译 ...

  6. MTCNN源码详细解读(1)- PNet/RNet/ONet的网络结构和损失函数

    代码地址 https://github.com/AITTSMD/MTCNN-Tensorflow 这里我就不在进行MTCNN的介绍了.分析的再清楚都不如从源码的实现去分析. Talk is cheap ...

  7. SpringMVC+Maven开发项目源码详细介绍

    代码地址如下: http://www.demodashi.com/demo/11638.html Spring MVC概述 Spring MVC框架是一个开源的Java平台,为开发强大的基于Java的 ...

  8. spark word2vec 源码详细解析

    spark word2vec 源码详细解析 简单介绍spark word2vec skip-gram 层次softmax版本的源码解析 word2vec 的原理 只需要看层次哈弗曼树skip-gram ...

  9. 低CUDA算力显卡用上高版本pytorch(ubuntu18.04源码编译特定版本pytorch(v1.10.0))

    低CUDA算力显卡用上高版本pytorch(ubuntu18.04源码编译特定版本pytorch(v1.10.0)) 一 电脑配置 二 正常情况下源码编译步骤 三 我的编译过程及出现的问题 首先 安装 ...

最新文章

  1. JPA保存数据自动加入创建人,修改人
  2. 我之我见:samba共享
  3. 关于 Java Collections API 您不知道的 5 件事--转
  4. dell服务器从硬盘引导,就是折腾 篇三:戴尔H710 mini(D1版本)阵列卡刷直通模式 附硬盘引导和还原IR模式办法...
  5. 读债务危机0806:2007到2011年泡沫蔓延
  6. [HDOJ5289]Assignment(RMQ,二分)
  7. 深度学习的实用层面 —— 1.9 正则化输入
  8. 常用脚本--在线重建或重整实例下所有索引
  9. 嵌入式操作系统内核原理和开发(开篇)
  10. 【luogu1709】小B的询问 - 莫队
  11. busybox(二)编译
  12. web测试软件act,使用ACT对Web程序进行性能容量测试.doc
  13. XML入门的常见问题
  14. C3P0数据库连接池
  15. 最全java面试题及答案(208道)
  16. 计算机械效率的简便公式,机械效率公式
  17. Excel如何将一列数据转为一行
  18. MT4外汇结余净值可用预付款
  19. 没有什么秘密的学习方法
  20. 机器学习之为什么要数据预处理?如何预处理数据?

热门文章

  1. Could not find artifact org.aopalliance:com.springsource.org.aopalliance:pom:1.0.0 in central (https
  2. Python基于周立功盒子的二次开发的准备工作
  3. C#让textbox不能写入
  4. 基于51单片机的指纹密码锁设计
  5. Wget下载整个网站(包含图片/JS/CSS)
  6. (五)SQLite数据存储与读取
  7. 快速入门Maxwell基本操作流程(3D部分)
  8. 使用python的scapy库,提供一个发送nbns询问包的一个示例代码
  9. 用Python+Mysql+MDUI实现的数据库增删改查列表操作及单,多文件上传实例
  10. 大IPD之——学习华为让业务主管成为人力资源管理的第一责任人(十六)