在我们之前的文章中,我们学习了如何构造一个简单的 GAN 来生成 MNIST 手写图片。对于图像问题,卷积神经网络相比于简单地全连接的神经网络更具优势,因此,我们这一节我们将继续深入 GAN,通过融合卷积神经网络来对我们的 GAN 进行改进,实现一个深度卷积 GAN。如果还没有亲手实践过 GAN 的小伙伴可以先去学习一下上一篇专栏:生成对抗网络(GAN)之 MNIST 数据生成。

专栏中的所有代码都在我的 GitHub中,欢迎 star 与 fork。

本次代码在 NELSONZHAO/zhihu/dcgan,里面包含了两个文件:

  • dcgan_mnist:基于 MNIST 手写数据集构造深度卷积 GAN 模型

  • dcgan_cifar:基于 CIFAR 数据集构造深度卷积 GAN 模型

本文主要以 MNIST 为例进行介绍,两者在本质上没有差别,只在细微的参数上有所调整。由于穷学生资源有限,没有对模型增加迭代次数,也没有构造更深的模型。并且也没有选取像素很高的图像,高像素非常消耗计算量。本节只是一个抛砖引玉的作用,让大家了解 DCGAN 的结构,如果有资源的小伙伴可以自己去尝试其他更清晰的图片以及更深的结构,相信会取得很不错的结果。

工具

  • Python3

  • TensorFlow 1.0

  • Jupyter notebook

正文

整个正文部分将包括以下部分:

- 数据加载

- 模型输入

- Generator

- Discriminator

- Loss

- Optimizer

- 训练模型

- 可视化

数据加载

数据加载部分采用 TensorFlow 中的 input_data 接口来进行加载。关于加载细节在前面的文章中已经写了很多次啦,相信看过我文章的小伙伴对 MNIST 加载也非常熟悉,这里不再赘述。

模型输入

在 GAN 中,我们的输入包括两部分,一个是真实图片,它将直接输入给 discriminator 来获得一个判别结果;另一个是随机噪声,随机噪声将作为 generator 来生成图片的材料,generator 再将生成图片传递给 discriminator 获得一个判别结果。

上面的函数定义了输入图片与噪声图片两个 tensor。

Generator

生成器接收一个噪声信号,基于该信号生成一个图片输入给判别器。在上一篇专栏文章生成对抗网络(GAN)之 MNIST 数据生成中,我们的生成器是一个全连接层的神经网络,而本节我们将生成器改造为包含卷积结构的网络,使其更加适合处理图片输入。整个生成器结构如下:

我们采用了 transposed convolution 将我们的噪声图片转换为了一个与输入图片具有相同 shape 的生成图像。我们来看一下具体的实现代码:

上面的代码是整个生成器的实现细节,里面包含了一些 trick,我们来一步步地看一下。

首先我们通过一个全连接层将输入的噪声图像转换成了一个 1 x 4*4*512 的结构,再将其 reshape 成一个 [batch_size, 4, 4, 512] 的形状,至此我们其实完成了第一步的转换。接下来我们使用了一个对加速收敛及提高卷积神经网络性能中非常有效的方法——加入 BN(batch normalization),它的思想是归一化当前层输入,使它们的均值为 0 和方差为 1,类似于我们归一化网络输入的方法。它的好处在于可以加速收敛,并且加入 BN 的卷积神经网络受权重初始化影响非常小,具有非常好的稳定性,对于提升卷积性能有很好的效果。关于 batch normalization,我会在后面专栏中进行一个详细的介绍。

完成 BN 后,我们使用 Leaky ReLU 作为激活函数,在上一篇专栏中我们已经提过这个函数,这里不再赘述。最后加入 dropout 正则化。剩下的 transposed convolution 结构层与之类似,只不过在最后一层中,我们不采用 BN,直接采用 tanh 激活函数输出生成的图片。

在上面的 transposed convolution 中,很多小伙伴肯定会对每一层 size 的变化疑惑,在这里来讲一下在 TensorFlow 中如何来计算每一层 feature map 的 size。首先,在卷积神经网络中,假如我们使用一个 k x k 的 filter 对 m x m x d 的图片进行卷积操作,strides 为 s,在 TensorFlow 中,当我们设置 padding='same'时,卷积以后的每一个 feature map 的 height 和 width 为;当设置 padding='valid'时,每一个 feature map 的 height 和 width 为。那么反过来,如果我们想要进行 transposed convolution 操作,比如将 7 x 7 的形状变为 14 x 14,那么此时,我们可以设置 padding='same',strides=2 即可,与 filter 的 size 没有关系;而如果将 4 x 4 变为 7 x 7 的话,当设置 padding='valid'时,即,此时 s=1,k=4 即可实现我们的目标。

上面的代码中我也标注了每一步 shape 的变化。

Discriminator

Discriminator 接收一个图片,输出一个判别结果(概率)。其实 Discriminator 完全可以看做一个包含卷积神经网络的图片二分类器。结构如下:

实现代码如下:

上面代码其实就是一个简单的卷积神经网络图像识别问题,最终返回 logits(用来计算 loss)与 outputs。这里没有加入池化层的原因在于图片本身经过多层卷积以后已经非常小了,并且我们加入了 batch normalization 加速了训练,并不需要通过 max pooling 来进行特征提取加速训练。

Loss Function

Loss 部分分别计算 Generator 的 loss 与 Discriminator 的 loss,和之前一样,我们加入 label smoothing 防止过拟合,增强泛化能力。

Optimizer

GAN 中实际包含了两个神经网络,因此对于这两个神经网络要分开进行优化。代码如下:

这里的 Optimizer 和我们之前不同,由于我们使用了 TensorFlow 中的 batch normalization 函数,这个函数中有很多 trick 要注意。首先我们要知道,batch normalization 在训练阶段与非训练阶段的计算方式是有差别的,这也是为什么我们在使用 batch normalization 过程中需要指定 training 这个参数。上面使用 tf.control_dependencies 是为了保证在训练阶段能够一直更新 moving averages。具体参考 A Gentle Guide to Using Batch Normalization in Tensorflow - Rui Shu。

训练

到此为止,我们就完成了深度卷积 GAN 的构造,接着我们可以对我们的 GAN 来进行训练,并且定义一些辅助函数来可视化迭代的结果。代码太长就不放上来了,可以直接去我的 GitHub 下载。

我这里只设置了 5 轮 epochs,每隔 100 个 batch 打印一次结果,每一行代表同一个 epoch 下的 25 张图:

我们可以看出仅仅经过了少部分的迭代就已经生成非常清晰的手写数字,并且训练速度是非常快的。

上面的图是最后几次迭代的结果。我们可以回顾一下上一篇的一个简单的全连接层的 GAN,收敛速度明显不如深度卷积 GAN。

总结

到此为止,我们学习了一个深度卷积 GAN,并且看到相比于之前简单的 GAN 来说,深度卷积 GAN 的性能更加优秀。当然除了 MNST 数据集以外,小伙伴儿们还可以尝试很多其他图片,比如我们之前用到过的 CIFAR 数据集,我在这里也实现了一个 CIFAR 数据集的图片生成,我只选取了马的图片进行训练:

刚开始训练时:

训练 50 个 epochs:

这里我只设置了 50 次迭代,可以看到最后已经生成了非常明显的马的图像,可见深度卷积 GAN 的优势。

我的 GitHub:NELSONZHAO (Nelson Zhao)

上面包含了我的专栏中所有的代码实现,欢迎 star,欢迎 fork。

用GAN来做图像生成,这是最好的方法相关推荐

  1. 【每周CV论文推荐】初学基于GAN的三维图像生成有哪些经典论文需要阅读

    欢迎来到<每周CV论文推荐>.在这个专栏里,还是本着有三AI一贯的原则,专注于让大家能够系统性完成学习,所以我们推荐的文章也必定是同一主题的. 当前二维图像生成领域的发展已经非常成熟,但是 ...

  2. GAN|在图像生成领域里,GAN这一大家族是如何生根发芽的

    作者:思源 本文经机器之心(微信公众号:almosthuman2014)授权转载,禁止二次转载 图像生成领域的 SOTA 排名涉及非常多的数据集与度量方法,我们并不能直观展示不同 GAN 的发展路线. ...

  3. 今晚直播 | 旷视研究院王毅:用于条件图像生成的注意力归一化

    「PW Live」是 PaperWeekly 的学术直播间,旨在帮助更多的青年学者宣传其最新科研成果.我们一直认为,单向地输出知识并不是一个最好的方式,而有效地反馈和交流可能会让知识的传播更加有意义, ...

  4. 直播预告 | 旷视研究院王毅:用于条件图像生成的注意力归一化

    「PW Live」是 PaperWeekly 的学术直播间,旨在帮助更多的青年学者宣传其最新科研成果.我们一直认为,单向地输出知识并不是一个最好的方式,而有效地反馈和交流可能会让知识的传播更加有意义, ...

  5. CVPR 2020 Oral | 妙笔生花新境界,语义级别多模态图像生成

    GAN已经成为图像生成的有力工具,现今GAN已经不再局限于生成以假乱真的图,而是向着更加灵活可操控的方向发展. 今天向大家介绍的CVPR 2020 的文章出自华中科技大学白翔老师组,特别要提醒,文中有 ...

  6. 图像生成 - 使用BigGAN在Imagenet数据集上生成高质量图像。

    图像生成是计算机视觉领域中的重要问题,其目的是生成具有高质量和真实感的图像.最近,Google提出的BigGAN方法在图像生成任务上取得了巨大的成功,可以生成高分辨率和高质量的图像.在本文中,我们将介 ...

  7. CV之IG:图像生成(Image Generation)的简介、使用方法、案例应用之详细攻略

    CV之IG:图像生成(Image Generation)的简介.使用方法.案例应用之详细攻略 目录 图像生成(Image Generation)的简介 图像生成(Image Generation)的使 ...

  8. 在图像生成领域里,GAN这一大家族是如何生根发芽的

    作者:思源 生成对抗网络这一 ML 新成员目前已经枝繁叶茂了,截止今年 5 月份,目前 GAN 至少有 300+的论文与变体.而本文尝试借助机器之心 SOTA 项目梳理生成对抗网络的架构与损失函数发展 ...

  9. 图像生成王者不是GAN?扩散模型最近有点火:靠加入类别条件,效果直达SOTA

    博雯 发自 凹非寺 量子位 报道 | 公众号 QbitAI OpenAI刚刚推出的年末新作GLIDE,又让扩散模型小火了一把. 这个基于扩散模型的文本图像生成大模型参数规模更小,但生成的图像质量却更高 ...

最新文章

  1. Session 过期问题处理
  2. Python + Selenium 练习篇 - 获取页面所有邮箱
  3. MySQL bin-log 日志清理方式
  4. matplotlib 标签_Python可视化matplotlibamp;seborn14热图heatmap
  5. Go支持自定义数据类型:使用type来定义,类似于数据类型的一个别名
  6. 杭电1869六度分离
  7. Java导出Excel数据错乱
  8. 基于单片机的水温控制系统设计
  9. matlab sil,丰田使用高精度发动机模型和SIL+M前置开发发动机控制系统
  10. c语言int转为dint,【转】IQMATH使用
  11. In Search of the Holy Grail 寻找圣杯 中文翻译
  12. jquery ZeroClipboard实现黏贴板功能,兼容所有浏览器
  13. C语言 crc32校验算法原理,CRC循环冗余校验的实现原理
  14. Ghidra Java API报NoClassDefFoundError的解决方法
  15. 微信小程序页面跳转,url传参参数丢失问题
  16. 打开Word文档的时候提示 “安全警告 宏已被禁用”
  17. 连接共享打印机提示【操作失败,错误为0x0000011b】
  18. php5.3不能连接mssql数据库的解决方法
  19. 【rtthread番外】第三篇:套接字抽象层SAL
  20. python 神经网络可以输出连续值_dqn 神经网络输出

热门文章

  1. 是否能被3,5,7同时整除(3.4)(Java)
  2. 【c语言】hello
  3. 【c语言】蓝桥杯基础练习 01字串
  4. 启程 - 《每日五分钟搞定大数据》
  5. CentOS 7.4 安装 MySQL 5.6.40 完美教程
  6. Top100论文导读:深入理解卷积神经网络CNN(Part Ⅱ)
  7. 移动4G打造排污视频监控系统助力咸宁环保建设
  8. Java8 - 接口默认方法
  9. linux应用程序安装与管理
  10. iPhone开发四剑客之《iPhone开发秘籍》