文章目录

  • 1 测试鉴别器
  • 2 建立生成器
  • 3 测试生成器
  • 4 训练生成器
  • 5 使用生成器
  • 6 内存查看

上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类。
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上
接下来,我们测试一下鉴别器是否可以正常工作,并建立生成器。

1 测试鉴别器

# 数据类建立
celeba_dataset = CelebADataset(r'F:\学习\AI\对抗网络\face-data\celeba_aligned_small.h5py')
celeba_dataset.plot_image(66)# 鉴别器类建立
D = Discriminator()
D.to(device)for image_data_tensor in celeba_dataset:# real dataD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# fake dataD.train(generate_random_image((218,178,3)), torch.cuda.FloatTensor([0.0]))

此处我们调用了两个类,一个是celeba_dataset(Dataset)类,一个是D(Discriminator)类。两个类在博文的上篇中完成了定义。此处分别使用real数据与fake数据对模型进行训练。fake数据使用的是随机生成的不规则像素点,real数据使用的是真是人脸数据。
在使用GPU的情况,此处预计会消耗5分钟左右。
训练完成后,可以绘制损失值的变化以查看训练效果。

D.plot_progress()
plt.show()

2 建立生成器

生成器与鉴别器高度类似,仅网络的结构和训练部分略有不同。
网格结构选取的是输入层为100个节点,中间层为单层结构,包含3*10*10个节点,输出层为3 * 218 * 178。输出层是完全根据照片的像素格式来确定的,输入层与中间层可以根据经验进行修改与优化。各层之间均采用全连接的连接方式。相关部分的代码如下:

class Generator(nn.Module):def __init__(self):# 父类继承super().__init__()# 定义神经网络self.model = nn.Sequential(nn.Linear(100, 3 * 10 * 10),nn.LeakyReLU(),nn.LayerNorm(3 * 10 * 10),nn.Linear(3 * 10 * 10, 3 * 218 * 178),nn.Sigmoid(),View((218, 178, 3)))

在进行损失计算时,我们将鉴别器的返回值作为实际输出,将torch.cuda.FloatTensor([1.0]作为目标输出,来计算损失。相关比分的代码如下:

class Generator(nn.Module):def train(self, D, inputs, targets):# 计算输出g_output = self.forward(inputs)# 将输出传至鉴别器d_output = D.forward(g_output)# 计算损失loss = D.loss_function(d_output, targets)

对于生成器的完整代码,也将在文末进行提供。

3 测试生成器

未经训练的生成器,应该具备生成类似雪花马赛克的随机图像能力。下面建立了一个生成器类,并用未经训练的生成器直接输出图像。

G = Generator()
G.to(device)output = G.forward(generate_random_seed(100))
img = output.detach().cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()

如果代码运行正常,应得到类似下面的图象。

4 训练生成器

训练时,对数据集进行遍历,并且依次执行下面三步:

  1. 使用真实照片数据,对鉴别器进行训练,期望的鉴别器输出值为1;
  2. 使用生成器输出的fake数据,对鉴别器进行训练,期望的鉴别器输出值为0;
  3. 使用鉴别器的返回值,训练生成器,生成器所希望的鉴别器输出为1
    具体代码如下:
for image_data_tensor in celeba_dataset:# train discriminator on trueD.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))# train discriminator on false# use detach() so gradients in G are not calculatedD.train(G.forward(generate_random_seed(100)).detach(), torch.cuda.FloatTensor([0.0]))# train generatorG.train(D, generate_random_seed(100), torch.cuda.FloatTensor([1.0]))

在训练后,可以分别查看鉴别器与生成器的损失变化曲线。

D.plot_progress()
G.plot_progress()

下图为鉴别器损失值变化曲线

下图为生成器损失值变化曲线

5 使用生成器

6 内存查看

最后可以查看一下本次训练的内存使用情况
(1)分配给张量的当前内存(输出单位是GB)

torch.cuda.memory_allocated(device) / (1024*1024*1024)

我的输出结果为:0.6999950408935547
(2)分配给张量的总内存(输出单位是GB)

torch.cuda.max_memory_allocated(device) / (1024*1024*1024)

我的输出结果为:0.962151050567627
(3)内存消耗汇总

print(torch.cuda.memory_summary(device, abbreviated=True))

输出结果如下:

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| Active memory         |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    1086 MB |    1086 MB |    1086 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |    9426 KB |   12685 KB |  353393 MB |  353383 MB |
|---------------------------------------------------------------------------|
| Allocations           |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| Active allocs         |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |      15    |      15    |      15    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      11    |      14    |    1410 K  |    1410 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

代码文件:博客附件代码

使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下相关推荐

  1. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

    文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...

  2. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

    文章目录 1 生成对抗网络基本概念 2 生成对抗网络建模 2.1 建立MnistDataset类 2.2 建立鉴别器 2.3 测试鉴别器 2.4 Mnist生成器制作 3 模型的训练 4 模型表现的判 ...

  3. 使用PyTorch构建卷积GAN源码(详细步骤讲解+注释版) 02人脸图片生成 上

    阅读提示:本篇文章的代码为在普通GAN代码上实现人脸图片生成的修改,文章内容仅包含修改内容,全部代码讲解需结合下面的文章阅读. 相关资料链接为:使用PyTorch构建GAN生成对抗 本次训练代码使用了 ...

  4. 使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类

    文章目录 1 数据准备 2 数据预览 3 简单神经网络创建 3.1 设计网络结构 3.2 损失函数相关设置 3.3 向网络传递信息 3.4 定义训练函数train 4 函数汇总 1 数据准备 神经网络 ...

  5. pytorch基于GAN生成对抗网络的数据集扩充

    文章目录 前言 一.GAN基本原理 1.结构图 2.目标函数 二.实现 1.实现流程图 2.实例 2.1采集少量原始数据 2.2GAN模型训练(注意修改图片路径) 2.3用训练好的模型扩充数据集(生成 ...

  6. 深度学习(九) GAN 生成对抗网络 理论部分

    GAN 生成对抗网络 理论部分 前言 一.Pixel RNN 1.图片的生成模型 2.Pixel RNN 3.Pixel CNN 二.VAE(Variational Autoencoder) 1.VA ...

  7. 深度学习 GAN生成对抗网络-1010格式数据生成简单案例

    一.前言 本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络. 二.GAN概念 生成对抗网络(Generative Adversarial Networks ...

  8. PyTorch 实现 GAN 生成式对抗网络 含代码

    GAN 网络结构 GAN 公式的理解 简单线性 GAN 代码如下 卷积 GAN 代码如下 Ref 网络结构 GAN 公式的理解 minGmaxDV(D,G)=Ex∼Pdata(x)[logD(x)]+ ...

  9. 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例

    1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...

最新文章

  1. 启动任务管理器命令符,doc命令
  2. 计算机科学与技术研究目的,计算机科学与技术专业培养目标分析
  3. 双十一变身大型奥数竞赛现场?数学不好的你请转场
  4. ajax 输入不为空,ajax POST响应为空
  5. 微服务架构会和分布式单体架构高度重合吗
  6. unity创建项目报错:解决sentinel key not found (h0007) Unity
  7. Android RecyclerView网格布局动画
  8. 动态规划-hdoj-4832-百度之星2014初赛第二场
  9. 556. 下一个更大元素 III
  10. 如何理解图像的概率分布?为什么N(0,1)的向量可以Gen图片?
  11. CSS + JavaScript 实现八卦太极图
  12. 别让生活 耗尽了你的耐心和向往 你还有诗和远方...
  13. 【工具使用】GPU的各项参数说明
  14. vue动态调节背景图片
  15. 东华理工大学计算机网络期末考试试卷,东华理工大学计算机网络计算题
  16. Java反射之Filed(类中的属性对象)
  17. matlab2018a课后答案,[2018年最新整理]matlab习题及答案.doc
  18. 安卓:三分钟实现物流配送页面(时间轴效果)
  19. mbedtls 连接 阿里云物联网
  20. 阴阳师人数最多的服务器,说说阴阳师中玩家最少的几个鬼区服务器

热门文章

  1. Python第三方库导出、导入、离线安装
  2. QGIS中安装Python第三方库
  3. 华为HCIE证大数据证书社会认可度怎么样?
  4. linux 查看网卡 万兆,Linux查看网卡是千兆还是万兆网卡
  5. 区块链+社交=颠覆性的革新?
  6. 【chatGPT】01 数组、二维数组在不同语言中的存储方式
  7. PHP等于==和恒等于===的区别
  8. matlab三角函数降次,Matlab实验-传递函数表示方法.ppt
  9. php中使用视频流的方式播放视频。
  10. JAVA java学习(2)——————java下载安装与环境配置