使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上
文章目录
- 1 数据集描述
- 2 GPU设置
- 3 设置Dataset类
- 4 设置辨别器类
- 5 辅助函数与辅助类
1 数据集描述
此项目使用的是著名的celebA(CelebFaces Attribute)数据集。其包含10,177个名人身份的202,599张人脸图片,每张图片都做好了特征标记,包含人脸bbox标注框、5个人脸特征点坐标以及40个属性标记,数据由香港中文大学开放提供(不包含商业用途的使用)。
在实际训练前,已经将数据处理成了HDF5的数据集格式。使用h5py处理HDF5数据集可以提供很多方便,使得数据处理更加高效、灵活、可扩展,显著提升训练过程的文件读取速度。可以使用h5py包自行对数据进行处理,也可直接下载我已经处理好的HDF5数据格式。
如需了解更多h5py相关知识,可以查看HDF5补充知识。
2 GPU设置
前面几篇博客的内容,都是对手写数字这个数据集的处理,CPU还能吃得消。这次数据输入明显增加,需要使用GPU处理数据。如电脑无NAVIDIA独显,建议使用Google Colab执行代码,Colab提供了免费的GPU算力。
if torch.cuda.is_available():torch.set_default_tensor_type(torch.cuda.FloatTensor)print("using cuda:", torch.cuda.get_device_name(0))passdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
这段代码的作用是,如果当前设备有可用的CUDA,则将默认的张量类型设置为CUDA浮点张量并输出使用的CUDA设备的名称。然后,它将设备设置为CUDA设备(如果有)或CPU。
3 设置Dataset类
基于面向对象编程的基本原则,我们建立一个Dataset类,使类具有数据读取、获取指定索引的数据与绘制指定索引的图像,具体代码如下:
class CelebADataset(Dataset):def __init__(self, file):self.file_object = h5py.File(file, 'r')self.dataset = self.file_object['img_align_celeba']passdef __len__(self):return len(self.dataset)def __getitem__(self, index):if index >= len(self.dataset):raise IndexError()img = numpy.array(self.dataset[str(index) + '.jpg'])return torch.cuda.FloatTensor(img) / 255.0def plot_image(self, index):plt.imshow(numpy.array(self.dataset[str(index) + '.jpg']), interpolation='nearest')plt.show()
在获取指定索引对应的数据时,如果指定数大于索引的最大值,我们命令程序返回一个IndexError()错误,以便于快速查找问题所在。
为了理解这一数据类,我们对类进行使用:
celeba_dataset = CelebADataset('文件地址.h5py')
这里创建了一个名为celeba_dataset
的CelebADataset
类,并传入了文件的所在路径file
。在__init__中,使用h5py.File
方法读取路经所在的文件。
celeba_dataset.plot_image(66)
绘制数据集中66.jpg图形。如果前面代码正确,此处将绘制出数据集中的人脸头像。如果为绘制出图形并产生报错,考虑路径是否有误以及数据格式是否正确。
4 设置辨别器类
本项目的核心类为鉴别器类与生成器类,下面开始编写鉴别器类。首先建立神经网络框架:
class Discriminator(nn.Module):def __init__(self):# 父类继承super().__init__()# 神经网络定义self.model = nn.Sequential(View(218 * 178 * 3),nn.Linear(3 * 218 * 178, 100),nn.LeakyReLU(),nn.LayerNorm(100),nn.Linear(100, 1),nn.Sigmoid())# 创建损失函数self.loss_function = nn.BCELoss()# 创建优化器self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)# 初始化计数器self.counter = 0self.progress = []
这段代码定义了一个名为Discriminator的类,继承了PyTorch中nn.Module类。在__init__函数中,通过nn.Sequential定义了一个神经网络模型,包括三个线性层,两个激活函数,一个归一化层。一开始的View(218*178*3)
是新代码。它的作用是将大小为(218, 178, 3) 的三维图像张量重塑成一个长度为218×178×3的一维张量。基于自上而下的编程习惯,我们会在后面对View进行定义。
在此基础上,定义了损失函数nn.BCELoss()和优化器Adam,并定义了一个计数器和一个存储进度的列表。
class Discriminator(nn.Module):def forward(self, inputs):# simply run modelreturn self.model(inputs)def train(self, inputs, targets):# calculate the output of the networkoutputs = self.forward(inputs)# calculate lossloss = self.loss_function(outputs, targets)# increase counter and accumulate error every 10self.counter += 1if (self.counter % 10 == 0):self.progress.append(loss.item())if (self.counter % 1000 == 0):print("counter = ", self.counter)# 梯度归零,向后传递,优化执行self.optimiser.zero_grad()loss.backward()self.optimiser.step()def plot_progress(self):df = pandas.DataFrame(self.progress, columns=['loss'])df.plot(ylim=(0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
接下来定义forward功能,train功能,plot_progress功能。在forward()函数中,它只是让模型对输入数据进行前向传播并返回网络的输出。在train()函数中,它使用输入数据和目标数据来计算网络的损失,并使用优化器来更新网络的参数。最后,plot_progress()函数可以用来绘制训练进度。以上类方法与手写字体识别博文中的定义完全相同,如有需要可找到对应博文查看。
5 辅助函数与辅助类
class View(nn.Module):def __init__(self, shape):super().__init__()self.shape = shape, # 逗号不是多打的,代表这是元组def forward(self, x):return x.view(*self.shape)
在前面定义鉴别器类时,我们已经使用了View,此处对View进行补充定义。在 forward 方法中,它对输入的 x 应用了 view 方法,并将 shape 属性作为参数传入。这个模型的作用是将输入的张量的形状调整为 shape 属性所指定的形状。
def generate_random_image(size):random_data = torch.rand(size)return random_datadef generate_random_seed(size):random_data = torch.randn(size)return random_data
以上两个随机张量生成器,其作用与手写数字识别中的作用完全相同,在此不再赘述。后续在使用时也会再进行介绍。
截至目前,我们已经建立好了模型所必需的鉴别器类与Dataset类。下一篇会讲解最重要的鉴别器类以及对模型的训练与使用。
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上相关推荐
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下
文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别
文章目录 1 生成对抗网络基本概念 2 生成对抗网络建模 2.1 建立MnistDataset类 2.2 建立鉴别器 2.3 测试鉴别器 2.4 Mnist生成器制作 3 模型的训练 4 模型表现的判 ...
- 使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上
阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读. 相关资料链接为:使用PyTorch构建GAN生成对抗 本次训练代码使用了 ...
- 使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类
文章目录 1 数据准备 2 数据预览 3 简单神经网络创建 3.1 设计网络结构 3.2 损失函数相关设置 3.3 向网络传递信息 3.4 定义训练函数train 4 函数汇总 1 数据准备 神经网络 ...
- pytorch基于GAN生成对抗网络的数据集扩充
文章目录 前言 一.GAN基本原理 1.结构图 2.目标函数 二.实现 1.实现流程图 2.实例 2.1采集少量原始数据 2.2GAN模型训练(注意修改图片路径) 2.3用训练好的模型扩充数据集(生成 ...
- 深度学习(九) GAN 生成对抗网络 理论部分
GAN 生成对抗网络 理论部分 前言 一.Pixel RNN 1.图片的生成模型 2.Pixel RNN 3.Pixel CNN 二.VAE(Variational Autoencoder) 1.VA ...
- 深度学习 GAN生成对抗网络-1010格式数据生成简单案例
一.前言 本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络. 二.GAN概念 生成对抗网络(Generative Adversarial Networks ...
- PyTorch 实现 GAN 生成式对抗网络 含代码
GAN 网络结构 GAN 公式的理解 简单线性 GAN 代码如下 卷积 GAN 代码如下 Ref 网络结构 GAN 公式的理解 minGmaxDV(D,G)=Ex∼Pdata(x)[logD(x)]+ ...
- 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例
1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...
最新文章
- 力扣(LeetCode)刷题,简单题(第18期)
- pytorch模型转onnx-量化rknn(bisenet)
- 低头是家和月光,抬头是车和远方
- MSChart使用导航之开发
- 【Q】之防火墙的SNAT DANT原理应用
- python 创建子类_python创建子类的方法分析
- [Windows Phone] 如何在 Windows Phone 应用程式制作市集搜寻
- 18、Windows API 图形用户界面(2)
- Django SCRF跨站点请求伪造
- Photoshop快捷键和技巧大全
- 苹果往事:乔布斯和 iPod 的诞生
- 使用idea中JD-Intellij插件反编译
- 电脑消除,彻底清除电脑垃圾,让电脑运行更流畅!
- markDown简单使用说明
- R包中文文本挖掘chinese.misc的中文说明
- 四选一多路开关电路描述
- 计算机上如何保存ico格式,PS怎么保存ico格式
- SQL 中判断条件的先后顺序,会引起索引失效么?
- Gdal关于CAD转SHP格式
- Java编程随机发红包,红包随机算法Java实现
热门文章
- php 汉字到html乱码怎么办,html网页乱码问题以及解决办法总结
- Linux网络管理—brctl命令
- 用c++模拟直线插补和圆弧插补
- 四分之一车辆垂向动力学模型
- VirtualLab基础实验教程-5.泊松亮斑
- 题解 P4466 [国家集训队]和与积
- pxe装机dhcp获取不到_Windows – PXE启动无法通过DHCP获取IP地址,但DHCP在操作系统启动时有效...
- Ubuntu安装dos2unix工具
- Android 反编译初探-基础篇
- css设置按钮样式_使用CSS设置按钮样式的快速指南