本文主要带领读者了解生成对抗神经网络(GAN),并使用提供的face数据集训练网络

GAN 入门

自 2014 年 Ian Goodfellow 的《生成对抗网络(Generative Adversarial Networks)》论文发表以来,GAN 的进展突飞猛进,生成结果也越来越具有照片真实感。
就在三年前,Ian Goodfellow 在 reddit 上回答 GAN 是否可以应用在文本领域的问题时,还认为 GAN 不能扩展到文本领域。

“由于 GAN 定义在实值数据上,因此 GAN 不能应用于 NLP。
GAN 的工作原理是训练一个生成网络,输出合成数据,然后利用判别网络判别合成数据。判别网络根据合成数据输出的梯度告诉你该如何对合成数据进行微调,使其更真实。
因此只有当合成数据是基于连续数字时,才能对其进行微调。如果是基于离散的数字,就没有办法做微小的改变。
例如,如果输出像素值为 1.0 的图像,则下一步可以将该像素值更改为 1.0001。
但如果输出单词‘penguin’,不能在下一步直接将其更改为‘penguin+.001’,因为没有‘penguin+.001’这样的单词。你必须从‘penguin’直接转变到‘ostrich’。
由于所有的 NLP 都是基于离散的值,如单词、字符或字节,所以目前还没有人知道该如何将 GAN 应用于 NLP。”

但是现在,GAN 已经可用于生成各种内容,包括图像、视频、音频和文本。这些输出的合成数据既可以用于训练其他的模型,也可以用于创建一些有趣的项目。

GAN 原理

GAN 由两个神经网络组成,一个是合成新样本的生成器,另一个是对比训练样本与生成样本的判别器。判别器的目标是区分“真实”和“虚假”的输入(对样本来自模型分布还是真实分布进行分类)。这些样本可以是图像、视频、音频片段和文本。

为了合成这些新的样本,生成器的输入为随机噪声,然后尝试从训练数据中学习到的分布中生成真实的图像。
判别器网络(卷积神经网络)输出相对于合成数据的梯度,其中包含着如何改变合成数据以使其更具真实感的信息。最终生成器收敛,它可以生成符合真实数据分布的样本,而判别器无法区分生成数据和真实数据。
ok,接下来我们就来实现一下

准备阶段

下载数据集
数据集,笔者这里已经为大家提供了,链接如下:
链接: https://pan.baidu.com/s/15wFZAANvr8gajiVY_1mI0A
提取码: c9vy
解压数据集
将下载好的数据集解压,放在工程目录下

加载数据集
加载数据集的代码,笔者这里直接提供给大家了,下面只是展示部分代码,文末会提供完整项目的代码链接

import multiprocessing
import tensorflow as tf
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):@tf.functiondef _map_fn(img):img = tf.image.resize(img, [resize, resize])img = tf.clip_by_value(img, 0, 255)img = img / 127.5 - 1return imgdataset = disk_image_batch_dataset(img_paths,batch_size,drop_remainder=drop_remainder,map_fn=_map_fn,shuffle=shuffle,repeat=repeat)img_shape = (resize, resize, 3)len_dataset = len(img_paths) // batch_sizereturn dataset, img_shape, len_dataset
def batch_dataset(dataset,batch_size,drop_remainder=True,n_prefetch_batch=1,filter_fn=None,map_fn=None,n_map_threads=None,filter_after_map=False,shuffle=True,shuffle_buffer_size=None,repeat=None):

构建网络
搭建Generator,Generator包含两个部分,init部分和前向传播的call部分,代码如下

class Generator(keras.Model):def __init__(self):super(Generator, self).__init__()# z:[b,100]-->[b,3*3*512]-->[b,3,3,512]-->[b,64,64,3]self.fc=keras.layers.Dense(3*3*512)self.conv1=keras.layers.Conv2DTranspose(256,3,3,'valid')  # 反卷积self.bn1=keras.layers.BatchNormalization()self.conv2=keras.layers.Conv2DTranspose(128,5,2,'valid')self.bn2=keras.layers.BatchNormalization()self.conv3=keras.layers.Conv2DTranspose(3,4,3,'valid')def call(self, inputs, training=None, mask=None):# [z,100]-->[z,3*3*512]x=self.fc(inputs)x=tf.reshape(x,[-1,3,3,512])x=tf.nn.leaky_relu(x)x=tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))x=self.conv3(x)x=tf.tanh(x)return x

搭建Discriminator,同上

class Discriminator(keras.Model):def __init__(self):super(Discriminator, self).__init__()# [b,64,64,3]-->[b,1]self.conv1=keras.layers.Conv2D(64,5,3,'valid')self.conv2=keras.layers.Conv2D(128,5,3,'valid')self.bn2=keras.layers.BatchNormalization()self.conv3=keras.layers.Conv2D(256,5,3,'valid')self.bn3=keras.layers.BatchNormalization()# [b,h,w,c]-->[b,-1]self.flatten=keras.layers.Flatten()# [b,-1]-->[b,1]self.fc=keras.layers.Dense(1)def call(self, inputs, training=None, mask=None):x=tf.nn.leaky_relu(self.conv1(inputs))x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))x=tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))x=self.flatten(x)logits=self.fc(x)return logits

训练GAN
定义相关数据,包括epoch,lr等等
这些数据可以自定义,笔者这里就不改动了

  z_dim = 100epochs = 50000batch_size = 512learning_rate = 0.0002is_training = True

加载数据

 img_path=glob.glob(r'E:\python_pro\TF2.0\GAN\faces\*.jpg')dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)

可以打印查看数据集信息:

(512, 64, 64, 3), (64, 64, 3)
(512, 64, 64, 3) ,1.0, -1.0

定义优化器,注意我们在开始训练时,需要新建训练GAN图片的文件,为查看数据提供持久化依据

    for epoch in range(epochs):batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)batch_x = next(db_iter)# train Dwith tf.GradientTape() as tape:d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)grads = tape.gradient(d_loss, discriminator.trainable_variables)d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))with tf.GradientTape() as tape:g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)grads = tape.gradient(g_loss, generator.trainable_variables)g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))if epoch % 100 == 0:print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))z = tf.random.uniform([100, z_dim])fake_image = generator(z, training=False)img_path = os.path.join('GAN_IMAGE', 'gan%d.png'%epoch)save_result(fake_image.numpy(), 10, img_path, color_mode='P')

训练结果

接下来我们来看看,训练的效果图,注意,GAN的训练过程是非常非常非常慢的,大概训练十几个小时,才能有个比较好的效果,有的数据集甚至会训练几天之久,这个随数据集的大小和对最终效果的要求来定的。笔者这个数据集比较的简单,只是给大家做演示,好了,废话就不过多的说了,上图




上述分别是训练了100epoch、500、1500、4000的效果图,可以看到随着训练的次数增加,效果因为越来越好了

总结

大家在训练GAN时,还是需要一个好一些的GPU显卡才行,这样可以体验GPU给我们带来的加速效果。这样会使得训练的速度大大加快。
笔者水平有限,如有表述不准确的地方还请谅解,有错误的地方欢迎大家批评指正。
最后还是希望大家动手实践实践,共同进步。
最终的代码链接:https://github.com/huzixuan1/TF_2.0/tree/master/GAN

Tensorflow2.0实战之GAN相关推荐

  1. 神经网络与深度学习——TensorFlow2.0实战(笔记)(二)(安装TensorFlow2.0)

    创建环境并激活 conda create --name tensorflow2.0 python==3.7 activate tensorflow2.0 安装相关软件包(conda命令或pip命令2选 ...

  2. 神经网络与深度学习——TensorFlow2.0实战(笔记)(二)(开发环境介绍)

    开发环境介绍 Python3 1.结构清晰,简单易学 2.丰富的标准库 3.强大的的第三方生态系统 4.开源.开放体系 5.高可扩展性:胶水语言 6.高可扩展性:胶水语言 7.解释型语言,实现复杂算法 ...

  3. TensorFlow2.0实战: 入门到进阶深度学习

    TensorFlow2.0 入门到进阶 课程以Tensorflow2.0框架为主体,以图像分类.房价预测.文本分类等项目为依托,讲解Tensorflow框架的使用方法,同时学习到相关的深度学习/机器学 ...

  4. 笔记3:Tensorflow2.0实战之MNSIT数据集

    最近Tensorflow相继推出了alpha和beta两个版本,这两个都属于tensorflow2.0版本:早听说新版做了很大的革新,今天就来用一下看看 这里还是使用MNSIT数据集进行测试 导入必要 ...

  5. Tensorflow2.0实战练习之猫狗数据集(包含自定义训练和迁移学习)

    最近在学习使用Tenforflow2.0,写下这篇文章,用来帮助和我一样的初学者,文章中如果存在某些问题,还希望各位指出. 目录 数据集介绍 数据处理及增强 VGG模型介绍 模型搭建 训练及结果展示 ...

  6. 【深度学习与tensorflow2.0实战】(网易云课堂)13-GAN

    本文目录 GAN原理 纳什均衡-D.G EM距离 GAN实战 **gan.py** dataset.py GAN原理 Having Fun ▪ https://reiinakano.github.io ...

  7. 神经网络与深度学习——TensorFlow2.0实战(笔记)(六)(Matplotlib绘图基础<折线图和柱状图>python)

    折线图(Line Chart): 散点图的基础上,将相邻的点用线段相连接 plot()函数 #折线图:在散点图的基础上将相邻两个点链接 #描述变量变化的趋势 #plot(x,y,color,marke ...

  8. 神经网络与深度学习——TensorFlow2.0实战(笔记)(五)(Matplotlib绘图基础<散点图>python)

    散点图(Scatter): 是数据点在直角坐标系中的分布图 scatter() 函数 marker参数--数据点样式 添加文字--text() 函数 坐标轴设置 增加图例 绘制标准正态分布的散点图步骤 ...

  9. 神经网络与深度学习——TensorFlow2.0实战(笔记)(五)(Matplotlib绘图基础<1>python)

    数据可视化 数据分析阶段:理解和洞察数据之间的关系 算法调试阶段:发现问题,优化算法 项目总结阶段:展示项目成果 Matplotlib: 第三方库,可以快速方便地生成高质量的图表 安装Matplotl ...

最新文章

  1. memcache和memcached安装
  2. 如果提高声音测距的分辨率?
  3. Linux内核 sysctl.conf 优化设置
  4. 一个小厂前端 Leader 如何筛选候选人?
  5. apache camel_发掘Apache Camel的力量
  6. flash动画设计期末作业_「2019年下学期」第二十五二十六节:期末作品三-吉祥物设计...
  7. mysql 101_MySQL 调优/优化的 101 个建议!
  8. Oracle数据库ORA-00942: 表或视图不存在的问题
  9. 用JavaScript简单编程——基础篇
  10. 蚂蚁金服OceanBase“击败”甲骨文?呵呵!
  11. 项目成本管理---控制成本
  12. Ubuntu Emacs Fcitx 中文输入法设置
  13. cad转pdf怎么变成黑白?
  14. 尺缩钟慢之动钟变慢——思想实验推导狭义相对论(七)
  15. hdu4747-线段树
  16. sql server 2008 千万条数据分页查询
  17. RRT、RRTConnect、RRT*——Matlab算法
  18. 使用Java和FFempeg批量转码B站缓存下来的列表视频,成MP4格式
  19. 电脑设备中PCI简易通讯控制器驱动显示黄色感叹号图标怎么办【申明:来源于网络】
  20. QTimer定时器问题分析

热门文章

  1. Java基础-控制流程-5. 中断控制流程语句
  2. 软考(6)--数据库
  3. VI3之vCenterServer配置的备份与还原
  4. ViT (Vision Transformer) ----LSTM网络
  5. ios 自定义拍照页面_30分钟搞定iOS自定义相机
  6. python 数组比较大小_python – 比较两个不同长度的numpy数组
  7. kubernetes安装_在 Kubernetes 上安装 Gitlab CI Runner
  8. gms签名不一致_电子签名拍照-多媒体互动装置介绍「振邦视界」
  9. 大部分Java程序员都会忽略的几个问题,你中招没?
  10. 必做作业3:原型化系统