WGAN-GP是针对WGAN的存在的问题提出来的,WGAN在真实的实验过程中依旧存在着训练困难、收敛速度慢的 问题,相比较传统GAN在实验上提升不是很明显。WGAN-GP在文章中指出了WGAN存在问题的原因,那就是WGAN在处理Lipschitz限制条件时直接采用了 weight clipping。相关讲解请参考

WGAN-GP的介绍

同往期一样,依然以生成cifar数据集中马的彩色图片为例,关于cifar数据集的读取和生成器模型的验证请参考第6期:

用DCGAN生成马的彩色图片

下面给出WGAN-GP框架

"""
-------------------------------------------------------生死看淡,不服就GAN-------------------------------------------------------------------------
PROJECT: CIFAR10_WGAN-GP
Author: Ephemeroptera
Date:2019-3-19
QQ:605686962"""
"""
WGAN说明:相比较WGAN,WGAN-GP提出以下改进:(1)用对判别器梯度惩罚取代WGAN的判决器权值区间截断(2)判别器取消BN操作(3)优化器使用ADAM
"""
# 导入包
import numpy as np
import tensorflow as tf
import pickle
import TFRecordTools
import time############################################### 设置参数 ####################################################################################real_shape = [-1,32,32,3] # 真实样本尺寸
data_total = 5000 # 真实样本个数
batch_size = 64 # 批大小
noise_size = 128 # 噪声维度
max_iters = 50000 #的最大迭代次数
learning_rate = 5e-5 # 学习率
beta1 = 0.5# ADAM参数1
beta2 = 0.9# ADAM参数2
CRITIC_NUM = 5 # 每次迭代判别器训练次数
lam = 10 #梯度惩罚权重############################################# 定义生成器和判别器 ############################################################################## 定义生成器(32x32图片)
def Generator_DC_32x32(z, channel, is_train=True):""":param z: 噪声信号,tensor类型:param channnel: 生成图片的通道数:param is_train: 是否为训练状态,该参数主要用于作为batch_normalization方法中的参数使用(训练时候开启)"""# 训练时生成器不允许复用with tf.variable_scope("generator", reuse=(not is_train)):# layer1: noise_dim --> 4*4*512 --> 4x4x512 -->BN+relulayer1 = tf.layers.dense(z, 4 * 4 * 512)layer1 = tf.reshape(layer1, [-1, 4, 4, 512])layer1 = tf.layers.batch_normalization(layer1, training=is_train,)layer1 = tf.nn.relu(layer1)# layer1 = tf.nn.dropout(layer1, keep_prob=0.8)# dropout# layer2: deconv(ks=3x3,s=2,padding=same):4x4x512 --> 8x8x256 --> BN+relulayer2 = tf.layers.conv2d_transpose(layer1, 256, 3, strides=2, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))layer2 = tf.layers.batch_normalization(layer2, training=is_train)layer2 = tf.nn.relu(layer2)# layer2 = tf.nn.dropout(layer2, keep_prob=0.8)# dropout# layer3: deconv(ks=3x3,s=2,padding=same):8x8x256 --> 16x16x128 --> BN+relulayer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))layer3 = tf.layers.batch_normalization(layer3, training=is_train)layer3 = tf.nn.relu(layer3)# layer3 = tf.nn.dropout(layer3, keep_prob=0.8)# dropout# layer4: deconv(ks=3x3,s=2,padding=same):16x16x128 --> 32x32x64--> BN+relulayer4 = tf.layers.conv2d_transpose(layer3, 64, 3, strides=2, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))layer4 = tf.layers.batch_normalization(layer4, training=is_train)layer4 = tf.nn.relu(layer4)# layer4 = tf.nn.dropout(layer3, keep_prob=0.8)# dropout# logits: deconv(ks=3x3,s=2,padding=same):32x32x64 --> 32x32x3logits = tf.layers.conv2d_transpose(layer4, channel, 3, strides=1, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))# outputsoutputs = tf.tanh(logits)return logits,outputs# 定义判别器(32x32)
def Discriminator_DC_32x32(inputs_img, reuse=False, GAN = False,GP= False,alpha=0.2):"""@param inputs_img: 输入图片,tensor类型@param reuse:判别器复用@param GP: 使用WGAN-GP时关闭BN@param alpha: Leaky ReLU系数"""with tf.variable_scope("discriminator", reuse=reuse):# layer1: conv(ks=3x3,s=2,padding=same)+lrelu -->32x32x3 to 16x16x128layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')if GP is False:layer1 = tf.layers.batch_normalization(layer1, training=True)layer1 = tf.nn.leaky_relu(layer1,alpha=alpha)# layer1 = tf.nn.dropout(layer1, keep_prob=0.8)# layer2: conv(ks=3x3,s=2,padding=same)+BN+lrelu -->16x16x128 to 8x8x256layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')if GP is False:layer2 = tf.layers.batch_normalization(layer2, training=True)layer2 = tf.nn.leaky_relu(layer2, alpha=alpha)# layer2 = tf.nn.dropout(layer2, keep_prob=0.8)# layer3: conv(ks=3x3,s=2,padding=same)+BN+lrelu -->8x8x256 to 4x4x512layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')if GP is False:layer3 = tf.layers.batch_normalization(layer3, training=True)layer3 = tf.nn.leaky_relu(layer3, alpha=alpha)layer3 = tf.reshape(layer3, [-1, 4*4* 512])# layer3 = tf.nn.dropout(layer2, keep_prob=0.8)# logits,output:logits = tf.layers.dense(layer3, 1)"WGAN:去除sigmoid"if GAN:outputs = Noneelse:outputs = tf.sigmoid(logits)return logits, outputs############################################## 定义计算图(网络) ########################################################----------------------输入----------------inputs_real = tf.placeholder(tf.float32, [None, real_shape[1], real_shape[2], real_shape[3]], name='inputs_real') # 真实样本输入
inputs_noise = tf.placeholder(tf.float32, [None, noise_size], name='inputs_noise') # 生成样本输入#-------------------生成和判别--------------
# 生成样本
_,g_outputs = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=True) # 训练生成器
_,g_test = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=False) # 测试生成器
# 判别样本
'WGAN-GP:判别器废除批归一化'
d_logits_real, _ = Discriminator_DC_32x32(inputs_real,GAN=True,GP=True) #识别真样本
d_logits_fake, _ = Discriminator_DC_32x32(g_outputs,GAN=True,GP=True,reuse=True) #识别假样本#------------定义原始GAN的损失函数--------------
"WGAN:损失函数去log,采用Wasserstein距离形式"
# 生成器
g_loss = tf.reduce_mean(-d_logits_fake)
# 判别器
d_loss = tf.reduce_mean(d_logits_fake - d_logits_real)
'WGAN-GP:加入判别器梯度惩罚项'
# 判别器梯度惩罚项
alpha_dist = tf.contrib.distributions.Uniform(low=0., high=1.)#获取[0,1]之间正态分布
alpha = alpha_dist.sample((batch_size, 1, 1, 1))
interpolated = inputs_real + alpha*(g_outputs-inputs_real)# 对真实样本和生成样本之间插值
inte_logit,_ = Discriminator_DC_32x32(interpolated, GAN=True,GP=True,reuse=True)
gradients = tf.gradients(inte_logit, [interpolated,])[0]# 求得判别器梯度
grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))
gradient_penalty = tf.reduce_mean((grad_l2-1)**2) # 定义惩罚项
# 加入d_loss
d_loss+=gradient_penalty*lam
#-------------------训练模型-----------------
# 分别获取生成器和判别器的变量空间
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]# Optimizer
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):# 保证分布白化先完成g_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1, beta2=beta2).minimize(g_loss, var_list=g_vars)d_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1, beta2=beta2).minimize(d_loss, var_list=d_vars)
############################################# 调用TFRecord读取数据 ###################################################### 读取TFR,不打乱文件顺序,指定数据类型,开启多线程
[data,label] = TFRecordTools.ReadFromTFRecord(sameName= r'.\TFR\class7-*',isShuffle= False,datatype= tf.float64,labeltype= tf.int32,isMultithreading= True)
# 批量处理,送入队列数据,指定数据大小,打乱数据项,设置批次大小64
[data_batch,label_batch] = TFRecordTools.DataBatch(data,label,dataSize= 32*32*3,labelSize= 1,isShuffle= True,batchSize= 64)############################################### 迭代 #################################################################### 存储训练过程中生成日志
GenLog = []
# 存储loss
losses = []
# 保存生成器变量(仅保存生成器模型,保存最近5个)
saver = tf.train.Saver(var_list=[var for var in tf.trainable_variables()if var.name.startswith("generator")],max_to_keep=5)
# 定义批预处理
def batch_preprocess(data_batch):# 提取批数据batch = sess.run(data_batch)# 整理成RGB(Nx32x32x3)batch_images = np.reshape(batch, [-1, 3, 32, 32]).transpose((0, 2, 3, 1))  # (-1,32,32,3)# scale to -1, 1batch_images = batch_images * 2 - 1return  batch_images# 生成相关目录保存生成信息
def GEN_DIR():import osif not os.path.isdir('ckpt'):print('文件夹ckpt未创建,现在在当前目录下创建..')os.mkdir('ckpt')if not os.path.isdir('trainLog'):print('文件夹ckpt未创建,现在在当前目录下创建..')os.mkdir('trainLog')# 开启会话
with tf.Session() as sess:# 生成相关目录GEN_DIR()# 初始化变量init = (tf.global_variables_initializer(), tf.local_variables_initializer())sess.run(init)# 开启协调器coord = tf.train.Coordinator()# 启动线程threads = tf.train.start_queue_runners(sess=sess, coord=coord)time_start = time.time() # 开始计时for steps in range(max_iters):steps += 1# 判别器重复训练设置if steps < 25 or steps % 500 == 0:critic_num = 100else:critic_num = CRITIC_NUM# 重复训练判别器for i in range(CRITIC_NUM):batch_images = batch_preprocess(data_batch)  # imagesbatch_noise = np.random.normal(size=(batch_size, noise_size))  # noise(normal)_ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images,inputs_noise: batch_noise})# 训练生成器batch_images = batch_preprocess(data_batch)  # imagesbatch_noise = np.random.normal(size=(batch_size, noise_size))  # noise(normal)_ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images,inputs_noise: batch_noise})#  记录训练信息if steps % 5 == 1:# (1)记录损失函数train_loss_d = d_loss.eval({inputs_real: batch_images,inputs_noise: batch_noise})train_loss_g = g_loss.eval({inputs_real: batch_images,inputs_noise: batch_noise})losses.append([train_loss_d, train_loss_g,steps])# (2)记录生成样本batch_noise = np.random.normal(size=(batch_size, noise_size))gen_samples = sess.run(g_test, feed_dict={inputs_noise: batch_noise})genLog = (gen_samples[0:11] + 1) / 2  # 恢复颜色空间(取10张)GenLog.append(genLog)# (3)打印信息print('step {}...'.format(steps),"Discriminator Loss: {:.4f}...".format(train_loss_d),"Generator Loss: {:.4f}...".format(train_loss_g))# (4)保存生成模型if steps % 300 ==0:saver.save(sess, './ckpt/generator.ckpt', global_step=steps)# 关闭线程coord.request_stop()coord.join(threads)#计时结束:
time_end = time.time()
print('迭代结束,耗时:%.2f秒'%(time_end-time_start))# 保存信息
#  保存loss记录
with open('./trainLog/loss_variation.loss', 'wb') as l:losses = np.array(losses)pickle.dump(losses,l)print('保存loss信息..')# 保存生成日志
with open('./trainLog/GenLog.log', 'wb') as g:pickle.dump(GenLog, g)print('保存GenLog信息..')

结果展示

最后一次生成样本

训练期间生成日志

损失函数

验证生成器

生死看淡,不服就GAN(八)----WGAN的改进版本WGAN-GP相关推荐

  1. 雷军推红米Redmi独立品牌喊话友商:生死看淡 不服就干

    雷帝网 雷建平 1月10日报道 小米今日在北京召开独立品牌红米Redmi发布会,并发布该品牌首款产品Redmi Note 7. 作为首款产品,Redmi Note 7坚持"死磕性价比&quo ...

  2. 友商逼急 雷急跳墙:生死看淡 不服就干

    友商逼急    雷急跳墙:生死看淡 不服就干 短短一个小时的红米Note7手机产品发布会,雷军怼了友商8次:甚至在媒体群访环节,雷军也抑制不住愤怒之情,提到友商面色铁青,以至于有人说,这次发布会的雷军 ...

  3. 生死看淡 不服就干!雷军这次真的被逼急了

    来源 | 网易科技 作者 | 崔玉贤 短短一个小时的红米Note7手机产品发布会,雷军怼了友商8次:甚至在媒体群访环节,雷军也抑制不住愤怒之情,提到友商面色铁青,以至于有人说,这次发布会的雷军不像&q ...

  4. Redmi K40系列要做旗舰“焊门员”:生死看淡 不服就焊

    经过了一段时间的密集预热,根据此前官宣的消息,全新的Redmi K40系列旗舰将于2月25日也就是明天正式发布.而随着发布会进入最后的倒计时,Redmi官方的预热行动也进入了最后的冲刺阶段.近日Red ...

  5. 生死看淡,不服就GAN(七)----用更稳定的生成模型WGAN生成cifar

    WGAN提出Wasserstein距离取代原始GAN的JS散度衡量两分布之间距离,使模型更加稳定并消除了mode collapse问题.关于WGAN的介绍,建议参考以下博客: 令人拍案叫绝的Wasse ...

  6. 生死看淡,不服就GAN(六)----用DCGAN生成马的彩色图片

    1. 首先我们需要的一组真实样本集来自cifar10,因此先制作一个读取cifar10的脚本. """ --------------------------------- ...

  7. 生死看淡,不服就GAN(五)----用DCGAN生成MNIST手写体

    搭建DCGAN网络 #*************************************** 生死看淡,不服就GAN ************************************* ...

  8. 生死看淡,不服就GAN(四)---- 用全连层GAN生成MNIST手写体

    搭建全连接GAN网络 #*************************************** 生死看淡,不服就GAN ************************************ ...

  9. 雷军的100亿计划:不服就干,生死看淡

    图片来自小米官网 整理 | 琥珀 出品 | AI 科技大本营 1 月 10 日,红米品牌正式独立. 11 日,雷军在小米年会上宣布,2019 年,小米将正式启动"手机+AIoT"双 ...

最新文章

  1. Scala’s parallel collections
  2. 第十六届全国大学生智能车提问与回复 |7月10日
  3. 苹果签名分发系统需要什么配置的服务器呢,苹果/IOS超级签名分发系统
  4. LeetCode_111.二叉树的最小深度
  5. LINK : fatal error LNK1123: 转换到 COFF 期间失败: 文件无效或损坏
  6. 深度学习(三)——Autoencoder, 词向量
  7. Java笔记 —— 继承
  8. android京东流式布局,京东移动端首页流式布局
  9. 【银联支付】php接入银联支付
  10. centos 切换终端_centos进入不同终端的几种方法
  11. LCD显示屏和OLED显示屏的区别
  12. R语言ggplot2 | 如何自定义facet分面的坐标轴范围
  13. 使用certbot自动续签ssl证书
  14. 基于SpringBoot的健身房管理系统
  15. 利用逻辑分析仪测定单片机延时函数时间
  16. Linux UDP下C语言实现TFTP协议客户端
  17. ipa文件的安装方法
  18. GD32F303调试小记(二)之SPI(软件SPI、硬件SPI、硬件SPI+DMA)
  19. 建筑学计算机出图报告,建筑系举办计算机生成图像技术工作坊
  20. 银行信用卡评分模型(一)| python

热门文章

  1. python程序设计实用教程清华大学出版社_清华大学出版社-图书详情-《Python程序设计简明教程》...
  2. opencv小游戏(05):小车的运动
  3. ZOJ 3755 - Mines (状压DP)
  4. 《SolidWorks 2014中文版完全自学手册》——导读
  5. 【MQTT】SpringBoot整合MQTT(EMQX)
  6. 做直播能有多赚钱,Python告诉你
  7. Shell一句话根据进程名杀死进程
  8. Android:InflateException: Binary XML file line #12: Error inflating class null
  9. 玩个游戏好难 Win10我的世界(Minecraft)下载
  10. 基于虹软人证核验 2.0 Android SDK开发集成入门