WGAN提出Wasserstein距离取代原始GAN的JS散度衡量两分布之间距离,使模型更加稳定并消除了mode collapse问题。关于WGAN的介绍,建议参考以下博客:

令人拍案叫绝的WassersteinGAN
GAN是怎么工作的
WGAN和GAN直观区别和优劣

这次依然是使用cifar数据集生成马的彩色图片,上期采用DCGAN实现,关于数据集的读取和生成模型的验证请参考DCGAN教程:

https://blog.csdn.net/Ephemeroptera/article/details/89019873)

这期我们使用更稳定的WGAN训练,下面给出WGAN框架:

"""
-------------------------------------------------------生死看淡,不服就GAN-------------------------------------------------------------------------
PROJECT: CIFAR10_WGAN
Author: Ephemeroptera
Date:2019-3-19
QQ:605686962"""
"""
WGAN说明:相比较原始GAN,WGAN提出以下改进:(1)判决器不再表示判决分数,而是表现最优Wasserstein距离,因此去掉sigmoid(2)损失函数去掉log(3)判别器采用权值区间截断,满足lipschitz连续(4)优化器建议使用基于动量的优化器,可以采用RMSPropOptimizer
"""
# 导入包
import numpy as np
import tensorflow as tf
import pickle
import TFRecordTools
import time############################################### 设置参数 ####################################################################################real_shape = [-1,32,32,3] # 真实样本尺寸
data_total = 5000 # 真实样本个数
batch_size = 64 # 批大小
noise_size = 128 # 噪声维度
max_iters = 50000 #的最大迭代次数
learning_rate = 5e-5 # 学习率
CRITIC_NUM = 5 # 每次迭代判别器训练次数
CLIP = [-0.1,0.1]# 判别器权值截断区间############################################# 定义生成器和判别器 ############################################################################## 定义生成器(32x32图片)
def Generator_DC_32x32(z, channel, is_train=True):""":param z: 噪声信号,tensor类型:param channnel: 生成图片的通道数:param is_train: 是否为训练状态,该参数主要用于作为batch_normalization方法中的参数使用(训练时候开启)"""# 训练时生成器不允许复用with tf.variable_scope("generator", reuse=(not is_train)):# layer1: noise_dim --> 4*4*512 --> 4x4x512 -->BN+relulayer1 = tf.layers.dense(z, 4 * 4 * 512)layer1 = tf.reshape(layer1, [-1, 4, 4, 512])layer1 = tf.layers.batch_normalization(layer1, training=is_train,)layer1 = tf.nn.relu(layer1)# layer1 = tf.nn.dropout(layer1, keep_prob=0.8)# dropout# layer2: deconv(ks=3x3,s=2,padding=same):4x4x512 --> 8x8x256 --> BN+relulayer2 = tf.layers.conv2d_transpose(layer1, 256, 3, strides=2, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))layer2 = tf.layers.batch_normalization(layer2, training=is_train)layer2 = tf.nn.relu(layer2)# layer2 = tf.nn.dropout(layer2, keep_prob=0.8)# dropout# layer3: deconv(ks=3x3,s=2,padding=same):8x8x256 --> 16x16x128 --> BN+relulayer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))layer3 = tf.layers.batch_normalization(layer3, training=is_train)layer3 = tf.nn.relu(layer3)# layer3 = tf.nn.dropout(layer3, keep_prob=0.8)# dropout# layer4: deconv(ks=3x3,s=2,padding=same):16x16x128 --> 32x32x64--> BN+relulayer4 = tf.layers.conv2d_transpose(layer3, 64, 3, strides=2, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))layer4 = tf.layers.batch_normalization(layer4, training=is_train)layer4 = tf.nn.relu(layer4)# layer4 = tf.nn.dropout(layer3, keep_prob=0.8)# dropout# logits: deconv(ks=3x3,s=2,padding=same):32x32x64 --> 32x32x3logits = tf.layers.conv2d_transpose(layer4, channel, 3, strides=1, padding='same',kernel_initializer=tf.random_normal_initializer(0, 0.02),bias_initializer=tf.random_normal_initializer(0, 0.02))# outputsoutputs = tf.tanh(logits)return logits,outputs# 定义判别器(32x32)
def Discriminator_DC_32x32(inputs_img, reuse=False, GAN = False,GP= False,alpha=0.2):"""@param inputs_img: 输入图片,tensor类型@param reuse:判别器复用@param GP: 使用WGAN-GP时关闭BN@param alpha: Leaky ReLU系数"""with tf.variable_scope("discriminator", reuse=reuse):# layer1: conv(ks=3x3,s=2,padding=same)+lrelu -->32x32x3 to 16x16x128layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')if GP is False:layer1 = tf.layers.batch_normalization(layer1, training=True)layer1 = tf.nn.leaky_relu(layer1,alpha=alpha)# layer1 = tf.nn.dropout(layer1, keep_prob=0.8)# layer2: conv(ks=3x3,s=2,padding=same)+BN+lrelu -->16x16x128 to 8x8x256layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')if GP is False:layer2 = tf.layers.batch_normalization(layer2, training=True)layer2 = tf.nn.leaky_relu(layer2, alpha=alpha)# layer2 = tf.nn.dropout(layer2, keep_prob=0.8)# layer3: conv(ks=3x3,s=2,padding=same)+BN+lrelu -->8x8x256 to 4x4x512layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')if GP is False:layer3 = tf.layers.batch_normalization(layer3, training=True)layer3 = tf.nn.leaky_relu(layer3, alpha=alpha)layer3 = tf.reshape(layer3, [-1, 4*4* 512])# layer3 = tf.nn.dropout(layer2, keep_prob=0.8)# logits,output:logits = tf.layers.dense(layer3, 1)"WGAN:去除sigmoid"if GAN:outputs = Noneelse:outputs = tf.sigmoid(logits)return logits, outputs############################################## 定义计算图(网络) ########################################################----------------------输入----------------inputs_real = tf.placeholder(tf.float32, [None, real_shape[1], real_shape[2], real_shape[3]], name='inputs_real') # 真实样本输入
inputs_noise = tf.placeholder(tf.float32, [None, noise_size], name='inputs_noise') # 生成样本输入#-------------------生成和判别--------------
# 生成样本
_,g_outputs = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=True) # 训练生成器
_,g_test = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=False) # 测试生成器
# 判别样本
d_logits_real, _ = Discriminator_DC_32x32(inputs_real,GAN=True) #识别真样本
d_logits_fake, _ = Discriminator_DC_32x32(g_outputs,GAN=True,reuse=True) #识别假样本#------------定义原始GAN的损失函数--------------
"WGAN:损失函数去log,采用Wasserstein距离形式"
# 生成器
g_loss = tf.reduce_mean(-d_logits_fake)
# 判别器
d_loss = tf.reduce_mean(d_logits_fake - d_logits_real)#-------------------训练模型-----------------
# 分别获取生成器和判别器的变量空间
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]# Optimizer
'WGAN:不建议使用基于动量的优化器'
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):# 保证分布白化先完成g_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)d_train_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
"WGAN:判别器权值区间截断,满足Lipschitz连续"
# clip
d_clip_opt = [tf.assign(var, tf.clip_by_value(var, CLIP[0], CLIP[1])) for var in d_vars]############################################# 调用TFRecord读取数据 ###################################################### 读取TFR,不打乱文件顺序,指定数据类型,开启多线程
[data,label] = TFRecordTools.ReadFromTFRecord(sameName= r'.\TFR\class7-*',isShuffle= False,datatype= tf.float64,labeltype= tf.int32,isMultithreading= True)
# 批量处理,送入队列数据,指定数据大小,打乱数据项,设置批次大小64
[data_batch,label_batch] = TFRecordTools.DataBatch(data,label,dataSize= 32*32*3,labelSize= 1,isShuffle= True,batchSize= 64)############################################### 迭代 #################################################################### 存储训练过程中生成日志
GenLog = []
# 存储loss
losses = []
# 保存生成器变量(仅保存生成器模型,保存最近5个)
saver = tf.train.Saver(var_list=[var for var in tf.trainable_variables()if var.name.startswith("generator")],max_to_keep=5)
# 定义批预处理
def batch_preprocess(data_batch):# 提取批数据batch = sess.run(data_batch)# 整理成RGB(Nx32x32x3)batch_images = np.reshape(batch, [-1, 3, 32, 32]).transpose((0, 2, 3, 1))  # (-1,32,32,3)# scale to -1, 1batch_images = batch_images * 2 - 1return  batch_images# 生成相关目录保存生成信息
def GEN_DIR():import osif not os.path.isdir('ckpt'):print('文件夹ckpt未创建,现在在当前目录下创建..')os.mkdir('ckpt')if not os.path.isdir('trainLog'):print('文件夹ckpt未创建,现在在当前目录下创建..')os.mkdir('trainLog')# 开启会话
with tf.Session() as sess:# 生成相关目录GEN_DIR()# 初始化变量init = (tf.global_variables_initializer(), tf.local_variables_initializer())sess.run(init)# 开启协调器coord = tf.train.Coordinator()# 启动线程threads = tf.train.start_queue_runners(sess=sess, coord=coord)time_start = time.time() # 开始计时for steps in range(max_iters):steps += 1# 判别器重复训练设置if steps < 25 or steps % 500 == 0:critic_num = 100else:critic_num = CRITIC_NUM# 重复训练判别器for i in range(CRITIC_NUM):batch_images = batch_preprocess(data_batch)  # imagesbatch_noise = np.random.normal(size=(batch_size, noise_size))  # noise(normal)_ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images,inputs_noise: batch_noise})sess.run(d_clip_opt)# 训练生成器batch_images = batch_preprocess(data_batch)  # imagesbatch_noise = np.random.normal(size=(batch_size, noise_size))  # noise(normal)_ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images,inputs_noise: batch_noise})#  记录训练信息if steps % 5 == 1:# (1)记录损失函数train_loss_d = d_loss.eval({inputs_real: batch_images,inputs_noise: batch_noise})train_loss_g = g_loss.eval({inputs_real: batch_images,inputs_noise: batch_noise})losses.append([train_loss_d, train_loss_g,steps])# (2)记录生成样本batch_noise = np.random.normal(size=(batch_size, noise_size))gen_samples = sess.run(g_test, feed_dict={inputs_noise: batch_noise})genLog = (gen_samples[0:11] + 1) / 2  # 恢复颜色空间(取10张)GenLog.append(genLog)# (3)打印信息print('step {}...'.format(steps),"Discriminator Loss: {:.4f}...".format(train_loss_d),"Generator Loss: {:.4f}...".format(train_loss_g))# (4)保存生成模型if steps % 300 ==0:saver.save(sess, './ckpt/generator.ckpt', global_step=steps)# 关闭线程coord.request_stop()coord.join(threads)#计时结束:
time_end = time.time()
print('迭代结束,耗时:%.2f秒'%(time_end-time_start))# 保存信息
#  保存loss记录
with open('./trainLog/loss_variation.loss', 'wb') as l:losses = np.array(losses)pickle.dump(losses,l)print('保存loss信息..')# 保存生成日志
with open('./trainLog/GenLog.log', 'wb') as g:pickle.dump(GenLog, g)print('保存GenLog信息..')

本次迭代实验结果如下:

最后一次生成样本

训练过程生成日志

损失函数

生成器验证

生死看淡,不服就GAN(七)----用更稳定的生成模型WGAN生成cifar相关推荐

  1. 友商逼急 雷急跳墙:生死看淡 不服就干

    友商逼急    雷急跳墙:生死看淡 不服就干 短短一个小时的红米Note7手机产品发布会,雷军怼了友商8次:甚至在媒体群访环节,雷军也抑制不住愤怒之情,提到友商面色铁青,以至于有人说,这次发布会的雷军 ...

  2. 雷军推红米Redmi独立品牌喊话友商:生死看淡 不服就干

    雷帝网 雷建平 1月10日报道 小米今日在北京召开独立品牌红米Redmi发布会,并发布该品牌首款产品Redmi Note 7. 作为首款产品,Redmi Note 7坚持"死磕性价比&quo ...

  3. 生死看淡 不服就干!雷军这次真的被逼急了

    来源 | 网易科技 作者 | 崔玉贤 短短一个小时的红米Note7手机产品发布会,雷军怼了友商8次:甚至在媒体群访环节,雷军也抑制不住愤怒之情,提到友商面色铁青,以至于有人说,这次发布会的雷军不像&q ...

  4. Redmi K40系列要做旗舰“焊门员”:生死看淡 不服就焊

    经过了一段时间的密集预热,根据此前官宣的消息,全新的Redmi K40系列旗舰将于2月25日也就是明天正式发布.而随着发布会进入最后的倒计时,Redmi官方的预热行动也进入了最后的冲刺阶段.近日Red ...

  5. 生死看淡,不服就GAN(六)----用DCGAN生成马的彩色图片

    1. 首先我们需要的一组真实样本集来自cifar10,因此先制作一个读取cifar10的脚本. """ --------------------------------- ...

  6. 生死看淡,不服就GAN(五)----用DCGAN生成MNIST手写体

    搭建DCGAN网络 #*************************************** 生死看淡,不服就GAN ************************************* ...

  7. 生死看淡,不服就GAN(八)----WGAN的改进版本WGAN-GP

    WGAN-GP是针对WGAN的存在的问题提出来的,WGAN在真实的实验过程中依旧存在着训练困难.收敛速度慢的 问题,相比较传统GAN在实验上提升不是很明显.WGAN-GP在文章中指出了WGAN存在问题 ...

  8. 生死看淡,不服就GAN(四)---- 用全连层GAN生成MNIST手写体

    搭建全连接GAN网络 #*************************************** 生死看淡,不服就GAN ************************************ ...

  9. 生死看淡,不服就干,小米终于迎来了久违的大幅反弹

    市调机构Canalys公布的2019年四季度的数据显示,小米以23%的增速高居全球前五大手机企业第一名,.这也是它在2019年四个季度当中唯一的一个季度取得正增长. 市调机构IDC给出的2019年一季 ...

最新文章

  1. Redis主从复制下的工作原理
  2. ubuntu12.4上安装minigui3.0.12
  3. 自己写的一个报表,研究SAP CRM ibase保存问题
  4. 明明可以靠技术吃饭,现在却非要出来当编剧!
  5. query的list()和iterate()区别 面试题
  6. Knative Serving 之路由管理和 Ingress
  7. regsvr32.exe
  8. java 蓝桥杯算法训练 寂寞的数(题解)
  9. Android App界面和流畅度优化
  10. vmware硬件兼容官方查询地址
  11. WPF 第三方控件学习使用——可停靠布局控件(AvalonDock)
  12. fun的用法c语言,fun的用法_fun的用法
  13. 周立功DTU+温度传感器,ZWS物联网平台尝试
  14. html简单登录页面代码
  15. 【python】只保留字符串中的英文字母
  16. 用react-custom-scrollbars插件美化 滚动条
  17. Pr零基础入门指南笔记四
  18. Android hook微信 apk 实时获取微信聊天消息记录
  19. pta 7-1 走楼梯升级版(递归)
  20. 数据仓库十大主题;TeraData金融数据模型

热门文章

  1. 背包型动态规划——零钱兑换
  2. 基于OpenGL的Android系统视频转换功能实现
  3. 获取星期--蔡勒公式
  4. lambda-view: JS源码阅读工具
  5. 2005-11-11
  6. 利用Photoshop对证件照换底且抠出头发丝
  7. fanyibishe
  8. poj1061青蛙的约会
  9. [爱情智慧]爱作的女人,最后都不怎么好!学会述情才能婚姻幸福!
  10. 道家修真分哪几个境界?