使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下
文章目录
- 1 测试鉴别器
- 2 建立生成器
- 3 测试生成器
- 4 训练生成器
- 5 使用生成器
- 6 内存查看
上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类。
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上
接下来,我们测试一下鉴别器是否可以正常工作,并建立生成器。
1 测试鉴别器
# 数据类建立
celeba_dataset = CelebADataset(r'F:\学习\AI\对抗网络\face-data\celeba_aligned_small.h5py')
celeba_dataset.plot_image(66)# 鉴别器类建立
D = Discriminator()
D.to(device)for image_data_tensor in celeba_dataset:# real dataD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# fake dataD.train(generate_random_image((218,178,3)), torch.cuda.FloatTensor([0.0]))
此处我们调用了两个类,一个是celeba_dataset(Dataset)类,一个是D(Discriminator)类。两个类在博文的上篇中完成了定义。此处分别使用real数据与fake数据对模型进行训练。fake数据使用的是随机生成的不规则像素点,real数据使用的是真是人脸数据。
在使用GPU的情况,此处预计会消耗5分钟左右。
训练完成后,可以绘制损失值的变化以查看训练效果。
D.plot_progress()
plt.show()
2 建立生成器
生成器与鉴别器高度类似,仅网络的结构和训练部分略有不同。
网格结构选取的是输入层为100个节点,中间层为单层结构,包含3*10*10
个节点,输出层为3 * 218 * 178
。输出层是完全根据照片的像素格式来确定的,输入层与中间层可以根据经验进行修改与优化。各层之间均采用全连接的连接方式。相关部分的代码如下:
class Generator(nn.Module):def __init__(self):# 父类继承super().__init__()# 定义神经网络self.model = nn.Sequential(nn.Linear(100, 3 * 10 * 10),nn.LeakyReLU(),nn.LayerNorm(3 * 10 * 10),nn.Linear(3 * 10 * 10, 3 * 218 * 178),nn.Sigmoid(),View((218, 178, 3)))
在进行损失计算时,我们将鉴别器的返回值作为实际输出,将torch.cuda.FloatTensor([1.0]
作为目标输出,来计算损失。相关比分的代码如下:
class Generator(nn.Module):def train(self, D, inputs, targets):# 计算输出g_output = self.forward(inputs)# 将输出传至鉴别器d_output = D.forward(g_output)# 计算损失loss = D.loss_function(d_output, targets)
对于生成器的完整代码,也将在文末进行提供。
3 测试生成器
未经训练的生成器,应该具备生成类似雪花马赛克的随机图像能力。下面建立了一个生成器类,并用未经训练的生成器直接输出图像。
G = Generator()
G.to(device)output = G.forward(generate_random_seed(100))
img = output.detach().cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()
如果代码运行正常,应得到类似下面的图象。
4 训练生成器
训练时,对数据集进行遍历,并且依次执行下面三步:
- 使用真实照片数据,对鉴别器进行训练,期望的鉴别器输出值为1;
- 使用生成器输出的fake数据,对鉴别器进行训练,期望的鉴别器输出值为0;
- 使用鉴别器的返回值,训练生成器,生成器所希望的鉴别器输出为1
具体代码如下:
for image_data_tensor in celeba_dataset:# train discriminator on trueD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# train discriminator on false# use detach() so gradients in G are not calculatedD.train(G.forward(generate_random_seed(100)).detach(), torch.cuda.FloatTensor([0.0]))# train generatorG.train(D, generate_random_seed(100), torch.cuda.FloatTensor([1.0]))
在训练后,可以分别查看鉴别器与生成器的损失变化曲线。
D.plot_progress()
G.plot_progress()
下图为鉴别器损失值变化曲线
下图为生成器损失值变化曲线
5 使用生成器
6 内存查看
最后可以查看一下本次训练的内存使用情况
(1)分配给张量的当前内存(输出单位是GB)
torch.cuda.memory_allocated(device) / (1024*1024*1024)
我的输出结果为:0.6999950408935547
(2)分配给张量的总内存(输出单位是GB)
torch.cuda.max_memory_allocated(device) / (1024*1024*1024)
我的输出结果为:0.962151050567627
(3)内存消耗汇总
print(torch.cuda.memory_summary(device, abbreviated=True))
输出结果如下:
|===========================================================================|
| PyTorch CUDA memory summary, device ID 0 |
|---------------------------------------------------------------------------|
| CUDA OOMs: 0 | cudaMalloc retries: 0 |
|===========================================================================|
| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
|---------------------------------------------------------------------------|
| Allocated memory | 733998 KB | 985 MB | 14018 GB | 14017 GB |
|---------------------------------------------------------------------------|
| Active memory | 733998 KB | 985 MB | 14018 GB | 14017 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory | 1086 MB | 1086 MB | 1086 MB | 0 B |
|---------------------------------------------------------------------------|
| Non-releasable memory | 9426 KB | 12685 KB | 353393 MB | 353383 MB |
|---------------------------------------------------------------------------|
| Allocations | 68 | 87 | 2580 K | 2580 K |
|---------------------------------------------------------------------------|
| Active allocs | 68 | 87 | 2580 K | 2580 K |
|---------------------------------------------------------------------------|
| GPU reserved segments | 15 | 15 | 15 | 0 |
|---------------------------------------------------------------------------|
| Non-releasable allocs | 11 | 14 | 1410 K | 1410 K |
|---------------------------------------------------------------------------|
| Oversize allocations | 0 | 0 | 0 | 0 |
|---------------------------------------------------------------------------|
| Oversize GPU segments | 0 | 0 | 0 | 0 |
|===========================================================================|
代码文件:博客附件代码
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下相关推荐
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上
文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...
- 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别
文章目录 1 生成对抗网络基本概念 2 生成对抗网络建模 2.1 建立MnistDataset类 2.2 建立鉴别器 2.3 测试鉴别器 2.4 Mnist生成器制作 3 模型的训练 4 模型表现的判 ...
- 使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上
阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读. 相关资料链接为:使用PyTorch构建GAN生成对抗 本次训练代码使用了 ...
- 使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类
文章目录 1 数据准备 2 数据预览 3 简单神经网络创建 3.1 设计网络结构 3.2 损失函数相关设置 3.3 向网络传递信息 3.4 定义训练函数train 4 函数汇总 1 数据准备 神经网络 ...
- pytorch基于GAN生成对抗网络的数据集扩充
文章目录 前言 一.GAN基本原理 1.结构图 2.目标函数 二.实现 1.实现流程图 2.实例 2.1采集少量原始数据 2.2GAN模型训练(注意修改图片路径) 2.3用训练好的模型扩充数据集(生成 ...
- 深度学习(九) GAN 生成对抗网络 理论部分
GAN 生成对抗网络 理论部分 前言 一.Pixel RNN 1.图片的生成模型 2.Pixel RNN 3.Pixel CNN 二.VAE(Variational Autoencoder) 1.VA ...
- 深度学习 GAN生成对抗网络-1010格式数据生成简单案例
一.前言 本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络. 二.GAN概念 生成对抗网络(Generative Adversarial Networks ...
- PyTorch 实现 GAN 生成式对抗网络 含代码
GAN 网络结构 GAN 公式的理解 简单线性 GAN 代码如下 卷积 GAN 代码如下 Ref 网络结构 GAN 公式的理解 minGmaxDV(D,G)=Ex∼Pdata(x)[logD(x)]+ ...
- 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例
1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...
最新文章
- 启动任务管理器命令符,doc命令
- 计算机科学与技术研究目的,计算机科学与技术专业培养目标分析
- 双十一变身大型奥数竞赛现场?数学不好的你请转场
- ajax 输入不为空,ajax POST响应为空
- 微服务架构会和分布式单体架构高度重合吗
- unity创建项目报错:解决sentinel key not found (h0007) Unity
- Android RecyclerView网格布局动画
- 动态规划-hdoj-4832-百度之星2014初赛第二场
- 556. 下一个更大元素 III
- 如何理解图像的概率分布?为什么N(0,1)的向量可以Gen图片?
- CSS + JavaScript 实现八卦太极图
- 别让生活 耗尽了你的耐心和向往 你还有诗和远方...
- 【工具使用】GPU的各项参数说明
- vue动态调节背景图片
- 东华理工大学计算机网络期末考试试卷,东华理工大学计算机网络计算题
- Java反射之Filed(类中的属性对象)
- matlab2018a课后答案,[2018年最新整理]matlab习题及答案.doc
- 安卓:三分钟实现物流配送页面(时间轴效果)
- mbedtls 连接 阿里云物联网
- 阴阳师人数最多的服务器,说说阴阳师中玩家最少的几个鬼区服务器
热门文章
- Python第三方库导出、导入、离线安装
- QGIS中安装Python第三方库
- 华为HCIE证大数据证书社会认可度怎么样?
- linux 查看网卡 万兆,Linux查看网卡是千兆还是万兆网卡
- 区块链+社交=颠覆性的革新?
- 【chatGPT】01 数组、二维数组在不同语言中的存储方式
- PHP等于==和恒等于===的区别
- matlab三角函数降次,Matlab实验-传递函数表示方法.ppt
- php中使用视频流的方式播放视频。
- JAVA java学习(2)——————java下载安装与环境配置