使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上
阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读。
相关资料链接为:使用PyTorch构建GAN生成对抗
本次训练代码使用了本地GPU计算。
1 CelebADataset类的修改
原则上这一类不需要修改,但为了提升模型运行速度,可以对图片周边适当裁剪,保留五官等重要内容。
# 设置裁剪功能(辅助函数)
def crop_centre(img, new_width, new_height):height, width, _ = img.shapestartx = width//2 - new_width//2starty = height//2 - new_height//2return img[ starty:starty + new_height, startx:startx + new_width, :]
上面这个函数可以用来从图像的中心裁剪。该函数接收三个参数:
- img:原始图像,需要是 numpy 数组形式
- new_width:裁剪后图像的新宽度
- new_height:裁剪后图像的新高度
该函数通过计算原始图像的中心位置,以及所需裁剪图像的起始位置,从而在 numpy 数组上实现裁剪。最后,函数返回裁剪后的图像。
有了这个函数后,可以在类中预置对图像的裁剪功能,需要对类的__getitem__
方法和plot_image
方法进行优化。
class CelebADataset(Dataset):def __getitem__(self, index):if index >= len(self.dataset):raise IndexError()img = numpy.array(self.dataset[str(index) + '.jpg'])img = crop_centre(img, 128, 128)return torch.cuda.FloatTensor(img).permute(2,0,1).view(1,3,128,128) / 255.0def plot_image(self, index):img = numpy.array(self.dataset[str(index)+'.jpg'])img = crop_centre(img, 128, 128)plt.imshow(img, interpolation='nearest')
2 鉴别器类的修改
鉴别器的网络结构是卷积GAN需要重点修改的地方。此次的卷积GAN设置了3个卷积层和1个全连接层。
class Discriminator(nn.Module):def __init__(self): self.model = nn.Sequential(nn.Conv2d(3, 256, kernel_size=8, stride=2),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 256, kernel_size=8, stride=2),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),nn.Conv2d(256, 3, kernel_size=8, stride=2),nn.LeakyReLU(0.2),View(3*10*10),nn.Linear(3*10*10, 1),nn.Sigmoid())
经过裁剪的图片的小为128*128;
第一个卷积层使用了256个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出256个特征图,特征图的大小为 128−82+1\frac{128-8}{2}+12128−8+1 ,即61*61;
第二个卷积层使用了256个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出256个特征图,特征图的大小为 61−82+1\frac{61-8}{2}+1261−8+1 ,即27*27;
第二个卷积层使用了3个卷积核,每个卷积核大小为8,步长为2。这一卷积层将会输出3个特征图,特征图的大小为 27−82+1\frac{27-8}{2}+1227−8+1 ,即10*10;
经过了3层的卷积后,图片的大小已经降到了(3*10*10)。
3 鉴别器测试
修改完鉴别器之后,可以使用真实图像和随即图像,初步判断鉴别器的能力与测试这部分修改后的代码是否存在BUG。
# 鉴别器类建立
D = Discriminator()
D.to(device)# 测试鉴别器
for image_data_tensor in celeba_dataset:# real dataD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# fake dataD.train(generate_random_image((1,3,128,128)), torch.cuda.FloatTensor([0.0]))pass
同样,可以查看损失函数的变化情况并使用测试集进行测试。
for image_data_tensor in celeba_dataset:# real dataD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# fake dataD.train(generate_random_image((1,3,128,128)), torch.cuda.FloatTensor([0.0]))passD.plot_progress()for i in range(4):image_data_tensor = celeba_dataset[random.randint(0,20000)]print( D.forward( image_data_tensor ).item() )passfor i in range(4):print( D.forward( generate_random_image((1,3,128,128))).item() )pass
可以看出,鉴别器对于数据的判断非常有信息。
之后还需对生成器进行同步修改,并使用代码生成图像,这部分内容将放在下篇。
使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成下
此部分的完整代码可在文末留言申请或在博主的资源区自行下载。
使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上相关推荐
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下
文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上
文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...
- 使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类
文章目录 1 数据准备 2 数据预览 3 简单神经网络创建 3.1 设计网络结构 3.2 损失函数相关设置 3.3 向网络传递信息 3.4 定义训练函数train 4 函数汇总 1 数据准备 神经网络 ...
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别
文章目录 1 生成对抗网络基本概念 2 生成对抗网络建模 2.1 建立MnistDataset类 2.2 建立鉴别器 2.3 测试鉴别器 2.4 Mnist生成器制作 3 模型的训练 4 模型表现的判 ...
- 编译和安装gdb源码详细步骤介绍
1.gdb源码下载 (1)源码下载网址:https://ftp.gnu.org/gnu/gdb/: (2)本文下面的编译是按照8.2版本的源码进行的,其余版本的源码可能会报错,需要自行解决: 2.编译 ...
- MTCNN源码详细解读(1)- PNet/RNet/ONet的网络结构和损失函数
代码地址 https://github.com/AITTSMD/MTCNN-Tensorflow 这里我就不在进行MTCNN的介绍了.分析的再清楚都不如从源码的实现去分析. Talk is cheap ...
- SpringMVC+Maven开发项目源码详细介绍
代码地址如下: http://www.demodashi.com/demo/11638.html Spring MVC概述 Spring MVC框架是一个开源的Java平台,为开发强大的基于Java的 ...
- spark word2vec 源码详细解析
spark word2vec 源码详细解析 简单介绍spark word2vec skip-gram 层次softmax版本的源码解析 word2vec 的原理 只需要看层次哈弗曼树skip-gram ...
- 低CUDA算力显卡用上高版本pytorch(ubuntu18.04源码编译特定版本pytorch(v1.10.0))
低CUDA算力显卡用上高版本pytorch(ubuntu18.04源码编译特定版本pytorch(v1.10.0)) 一 电脑配置 二 正常情况下源码编译步骤 三 我的编译过程及出现的问题 首先 安装 ...
最新文章
- JPA保存数据自动加入创建人,修改人
- 我之我见:samba共享
- 关于 Java Collections API 您不知道的 5 件事--转
- dell服务器从硬盘引导,就是折腾 篇三:戴尔H710 mini(D1版本)阵列卡刷直通模式 附硬盘引导和还原IR模式办法...
- 读债务危机0806:2007到2011年泡沫蔓延
- [HDOJ5289]Assignment(RMQ,二分)
- 深度学习的实用层面 —— 1.9 正则化输入
- 常用脚本--在线重建或重整实例下所有索引
- 嵌入式操作系统内核原理和开发(开篇)
- 【luogu1709】小B的询问 - 莫队
- busybox(二)编译
- web测试软件act,使用ACT对Web程序进行性能容量测试.doc
- XML入门的常见问题
- C3P0数据库连接池
- 最全java面试题及答案(208道)
- 计算机械效率的简便公式,机械效率公式
- Excel如何将一列数据转为一行
- MT4外汇结余净值可用预付款
- 没有什么秘密的学习方法
- 机器学习之为什么要数据预处理?如何预处理数据?
热门文章
- Could not find artifact org.aopalliance:com.springsource.org.aopalliance:pom:1.0.0 in central (https
- Python基于周立功盒子的二次开发的准备工作
- C#让textbox不能写入
- 基于51单片机的指纹密码锁设计
- Wget下载整个网站(包含图片/JS/CSS)
- (五)SQLite数据存储与读取
- 快速入门Maxwell基本操作流程(3D部分)
- 使用python的scapy库,提供一个发送nbns询问包的一个示例代码
- 用Python+Mysql+MDUI实现的数据库增删改查列表操作及单,多文件上传实例
- 大IPD之——学习华为让业务主管成为人力资源管理的第一责任人(十六)