我们来研究一下生成式对抗网络 GAN,并且用 TensorFlow 代码实现。

自从 Ian Goodfellow 在 14 年发表了 论文 Generative Adversarial Nets 以来,生成式对抗网络 GAN 广受关注,加上学界大牛 Yann Lecun 在 Quora 答题时曾说,他最激动的深度学习进展是生成式对抗网络,使得 GAN 成为近年来在机器学习领域的新宠,可以说,研究机器学习的人,不懂 GAN,简直都不好意思出门。

下面我们来简单介绍一下生成式对抗网络,主要介绍三篇论文:1)Generative Adversarial Networks;2)Conditional Generative Adversarial Nets;3)Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks。

首先来看下第一篇论文,了解一下 GAN 的过程和原理:

GAN 启发自博弈论中的二人零和博弈(two-player game),GAN 模型中的两位博弈方分别由生成式模型(generative model)和判别式模型(discriminative model)充当。生成模型 G 捕捉样本数据的分布,用服从某一分布(均匀分布,高斯分布等)的噪声 z 生成一个类似真实训练数据的样本,追求效果是越像真实样本越好;判别模型 D 是一个二分类器,估计一个样本来自于训练数据(而非生成数据)的概率,如果样本来自于真实的训练数据,D 输出大概率,否则,D 输出小概率。可以做如下类比:生成网络 G 好比假币制造团伙,专门制造假币,判别网络 D 好比警察,专门检测使用的货币是真币还是假币,G 的目标是想方设法生成和真币一样的货币,使得 D 判别不出来,D 的目标是想方设法检测出

来 G 生成的假币。如图所示:

在训练的过程中固定一方,更新另一方的网络权重,交替迭代,在这个过程中,双方都极力优化自己的网络,从而形成竞争对抗,直到双方达到一个动态的平衡(纳什均衡),此时生成模型 G 恢复了训练数据的分布(造出了和真实数据一模一样的样本),判别模型再也判别不出来结果,准确率为 50%,约等于乱猜。

上述过程可以表述为如下公式:

当固定生成网络 G 的时候,对于判别网络 D 的优化,可以这样理解:输入来自于真实数据,D 优化网络结构使自己输出 1,输入来自于生成数据,D 优化网络结构使自己输出 0;当固定判别网络 D 的时候,G 优化自己的网络使自己输出尽可能和真实数据一样的样本,并且使得生成的样本经过 D 的判别之后,D 输出高概率。

第一篇文章,在 MNIST 手写数据集上生成的结果如下图:

最右边的一列是真实样本的图像,前面五列是生成网络生成的样本图像,可以看到生成的样本还是很像真实样本的,只是和真实样本属于不同的类,类别是随机的。

第二篇文章想法很简单,就是给 GAN 加上条件,让生成的样本符合我们的预期,这个条件可以是类别标签(例如 MNIST 手写数据集的类别标签),也可以是其他的多模态信息(例如对图像的描述语言)等。用公式表示就是:

式子中的 y 是所加的条件,结构图如下:

生成结果如下图:

图中所加的条件 y 是类别标签。

第三篇文章,简称(DCGAN),在实际中是代码使用率最高的一篇文章,本系列文的代码也是这篇文章代码的初级版本,它优化了网络结构,加入了 conv,batch_norm 等层,使得网络更容易训练,网络结构如下:

可以有加条件和不加条件两种网络,论文还做了好多试验,展示了这个网络在各种数据集上的结果。有兴趣同学可以去看论文,此文我们只从代码的角度理解去理解它。

参考文献:

1. http://blog.csdn.net/solomon1558/article/details/52549409

前面我们了解了 GAN 的原理,下面我们就来用 TensorFlow 搭建 GAN(严格说来是 DCGAN,如无特别说明,本系列文章所说的 GAN 均指 DCGAN),如前面所说,GAN 分为有约束条件的 GAN,和不加约束条件的GAN,我们先来搭建一个简单的 MNIST 数据集上加约束条件的 GAN。

首先下载数据:在  /home/your_name/TensorFlow/DCGAN/ 下建立文件夹 data/mnist,从 http://yann.lecun.com/exdb/mnist/ 网站上下载 mnist 数据集 train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz 到 mnist 文件夹下得到四个 .gz 文件。

数据下载好之后,在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 read_data.py 读取数据,输入如下代码:

import osimport numpy as npdef read_data():    # 数据目录data_dir = '/home/your_name/TensorFlow/DCGAN/data/mnist'      # 打开训练数据    fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))        # 转化成 numpy 数组loaded = np.fromfile(file=fd,dtype=np.uint8)        # 根据 mnist 官网描述的数据格式,图像像素从 16 字节开始trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)   

    # 训练 labelfd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))loaded = np.fromfile(file=fd,dtype=np.uint8)trY = loaded[8:].reshape((60000)).astype(np.float)    

    # 测试数据fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))loaded = np.fromfile(file=fd,dtype=np.uint8)teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)    

    # 测试 labelfd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))loaded = np.fromfile(file=fd,dtype=np.uint8)teY = loaded[8:].reshape((10000)).astype(np.float)trY = np.asarray(trY)teY = np.asarray(teY)    # 由于生成网络由服从某一分布的噪声生成图片,不需要测试集,# 所以把训练和测试两部分数据合并X = np.concatenate((trX, teX), axis=0)y = np.concatenate((trY, teY), axis=0)    # 打乱排序seed = 547np.random.seed(seed)np.random.shuffle(X)np.random.seed(seed)np.random.shuffle(y)    # 这里,y_vec 表示对网络所加的约束条件,这个条件是类别标签,# 可以看到,y_vec 实际就是对 y 的独热编码,关于什么是独热编码,# 请参考 http://www.cnblogs.com/Charles-Wan/p/6207039.htmly_vec = np.zeros((len(y), 10), dtype=np.float)        for i, label in enumerate(y):         y_vec[i,y[i]] = 1.0    return X/255., y_vec

这里顺便说明一下,由于 MNIST 数据总体占得内存不大(可以看下载的文件,最大的一个 45M 左右,)所以这样读取数据是允许的,一般情况下,数据特别庞大的时候,建议把数据转化成 tfrecords,用 TensorFlow 标准的数据读取格式,这样能带来比较高的效率。

然后,定义一些基本的操作层,例如卷积,池化,全连接等层,在 /home/your_name/TensorFlow/DCGAN/ 新建文件 ops.py,输入如下代码:

import tensorflow as tffrom tensorflow.contrib.layers.python.layers import batch_norm as batch_norm

# 常数偏置def bias(name, shape, bias_start = 0.0, trainable = True):dtype = tf.float32var = tf.get_variable(name, shape, tf.float32, trainable = trainable, initializer = tf.constant_initializer(bias_start, dtype = dtype))    return var

# 随机权重def weight(name, shape, stddev = 0.02, trainable = True):dtype = tf.float32var = tf.get_variable(name, shape, tf.float32, trainable = trainable, initializer = tf.random_normal_initializer(stddev = stddev, dtype = dtype))    

return var

# 全连接层def fully_connected(value, output_shape, name = 'fully_connected', with_w = False):shape = value.get_shape().as_list()with tf.variable_scope(name):weights = weight('weights', [shape[1], output_shape], 0.02)biases = bias('biases', [output_shape], 0.0)        if with_w:        return tf.matmul(value, weights) + biases, weights, biases     else:        return tf.matmul(value, weights) + biases

# Leaky-ReLu 层def lrelu(x, leak=0.2, name = 'lrelu'):with tf.variable_scope(name):        return tf.maximum(x, leak*x, name = name)        # ReLu 层def relu(value, name = 'relu'):with tf.variable_scope(name):        return tf.nn.relu(value)    # 解卷积层def deconv2d(value, output_shape, k_h = 5, k_w = 5, strides =[1, 2, 2, 1], name = 'deconv2d', with_w = False):with tf.variable_scope(name):weights = weight('weights', [k_h, k_w, output_shape[-1], value.get_shape()[-1]])deconv = tf.nn.conv2d_transpose(value, weights, output_shape, strides = strides)biases = bias('biases', [output_shape[-1]])deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())         if with_w:             return deconv, weights, biases         else:            return deconv            # 卷积层            def conv2d(value, output_dim, k_h = 5, k_w = 5, strides =[1, 2, 2, 1], name = 'conv2d'):with tf.variable_scope(name):weights = weight('weights', [k_h, k_w, value.get_shape()[-1], output_dim])conv = tf.nn.conv2d(value, weights, strides = strides, padding = 'SAME')biases = bias('biases', [output_dim])conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())        return conv

# 把约束条件串联到 feature mapdef conv_cond_concat(value, cond, name = 'concat'):    # 把张量的维度形状转化成 Python 的 listvalue_shapes = value.get_shape().as_list()cond_shapes = cond.get_shape().as_list()    # 在第三个维度上(feature map 维度上)把条件和输入串联起来,# 条件会被预先设为四维张量的形式,假设输入为 [64, 32, 32, 32] 维的张量,# 条件为 [64, 32, 32, 10] 维的张量,那么输出就是一个 [64, 32, 32, 42] 维张量    with tf.variable_scope(name):        return tf.concat(3, [value, cond * tf.ones(value_shapes[0:3] + cond_shapes[3:])])  # BN 层,这里我们直接用官方的 BN 层。        def batch_norm_layer(value, is_train = True, name = 'batch_norm'):with tf.variable_scope(name) as scope:        if is_train:        return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True, is_training = is_train, updates_collections = None, scope = scope)         else:            return batch_norm(value, decay = 0.9, epsilon = 1e-5, scale = True, is_training = is_train, reuse = True, updates_collections = None, scope = scope)

batch_norm 里的 decay 指的是滑动平均的 decay,epsilon 作用是加到分母 variance 上避免分母为零,scale 是个布尔变量,如果为真值 True, 结果要乘以 gamma,否则 gamma 不使用,is_train 也是布尔变量,为真值代表训练过程,否则代表测试过程(在 BN 层中,训练过程和测试过程是不同的,具体请参考论文:https://arxiv.org/abs/1502.03167)。关于 batch_norm 的其他的参数,请看参考文献2。

参考文献:

1. https://github.com/carpedm20/DCGAN-tensorflow

2. https://github.com/tensorflow/tensorflow/blob/b826b79718e3e93148c3545e7aa3f90891744cc0/tensorflow/contrib/layers/python/layers/layers.py#L100

如何用 TensorFlow 实现生成式对抗网络(GAN)相关推荐

  1. 简述生成式对抗网络 GAN

    本文主要阐述了对生成式对抗网络的理解,首先谈到了什么是对抗样本,以及它与对抗网络的关系,然后解释了对抗网络的每个组成部分,再结合算法流程和代码实现来解释具体是如何实现并执行这个算法的,最后通过给出一个 ...

  2. 深度学习之生成式对抗网络 GAN(Generative Adversarial Networks)

    一.GAN介绍 生成式对抗网络GAN(Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.它源于2014年发表的论文:& ...

  3. 王飞跃教授:生成式对抗网络GAN的研究进展与展望

    本次汇报的主要内容包括GAN的提出背景.GAN的理论与实现模型.发展以及我们所做的工作,即GAN与平行智能.  生成式对抗网络GAN GAN是Goodfellow在2014年提出来的一种思想,是一种比 ...

  4. 生成式对抗网络GAN(一)—基于python实现

    基于python实现生成式对抗网络GAN 构建和训练一个生成对抗网络(GAN) ,使其可以生成数字(0-9)的手写图像. 学习目标 从零开始构建GAN的生成器和判别器. 创建GAN的生成器和判别器的损 ...

  5. 《生成式对抗网络GAN的研究进展与展望》论文笔记

    本文主要是对论文:王坤峰, 苟超, 段艳杰, 林懿伦, 郑心湖, 王飞跃. 生成式对抗网络GAN的研究进展与展望. 自动化学报, 2017, 43(3): 321-332. 进行总结. 相关博客地址: ...

  6. 深度学习之生成式对抗网络GAN

    一.GAN介绍 生成式对抗网络GAN(Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块 ...

  7. 生成式对抗网络(GAN, Generaitive Adversarial Networks)总结

    最近要做有关图像生成的工作-也是小白,今天简单学习一些有关GAN的基础知识,很浅,入个门,大神勿喷. GAN目前确实是在深度学习领域最热门,最有前景的方向之一.近几年有关于GAN的论文非常非常之多,从 ...

  8. 生成式对抗网络GAN模型搭建

    生成式对抗网络GAN模型搭建 目录 一.理论部分 1.GAN基本原理介绍 2.对KL散度的理解 3.模块导入命令 二.编程实现 1.加载所需要的模块和库,设定展示图片函数以及其他对图像预处理函数 1) ...

  9. 利用Tensorflow构建生成对抗网络GAN以生成数据

    使用生成对抗网络(GAN)生成数据 本文主要内容 介绍了自动编码器的基本原理 比较了生成模型与自动编码器的区别 描述了GAN模型的网络结构 分析了GAN模型的目标核函数以及训练过程 介绍了利用Goog ...

最新文章

  1. Kali Linux搜索软件包
  2. linux 编译zbar
  3. 利用HTML5 canvas合并图片并解决Filaed to execute 'toDataURL' on 'HTMLCanvasElement'异常
  4. centos 6.5 安装 lamp 后mysql不能启动_Lamp的搭建--centos6.5下安装mysql
  5. Leetcode怎么调试java代码,在Clion上调试LeetCode代码
  6. void init(void) 分析 ! \linux-1.0\init\main.c
  7. 阿里云服务器配置开发环境第五章:Centos7.3切换为iptables防火墙
  8. 高清壁纸:60款可爱的圣诞节电脑桌面壁纸《下篇》
  9. SpringBoot实战教程(5)| 整合Freemaker
  10. 学习一个Vue模板项目
  11. c语言中一些公用的方法
  12. 第三届阿里云磐久智维算法大赛——GRU BaseLine
  13. 计算机打印机出现副本1,打印机提示Administrator的1个文档被挂起
  14. Debain查看ip地址
  15. Bat shell 脚本相关查询记录
  16. JAVA后端如何保证业务操作的幂等性
  17. QQ.阿里旺旺.淘宝.在线网页链接代码及详解
  18. 【转载】MLC(Multi-Label Classification) 多标签分类
  19. 被“投机之王”奉为交易核心的时间要素到底是什么?
  20. 数据库——ODBC连接

热门文章

  1. 编写一个函数进行左移或右移的位运算
  2. java字符串复制空值_Java脚本:去除字符串中空值
  3. vue data数据修改_史上最强vue总结,万字长文
  4. Shell test 命令
  5. mongoose 实用 API 总结
  6. 生产环境 Apache 和 php 配置优化(一)
  7. 转:JAVA常见错误处理方法 和 JVM内存结构
  8. 高性能集群软件Keepalived之基础知识篇
  9. Win8/Win8.1值得做的十多项优化方法
  10. CentOS 5.6 修改国内网易163高速源