搭建全连接GAN网络

#*************************************** 生死看淡,不服就GAN **************************************************************
"""
PROJECT:MNIST_GAN_MLP
Author:Ephemeroptera
Date:2018-4-24
QQ:605686962
Reference:' improved_wgan_training-master': <https://github.com/igul222/improved_wgan_training>'Zardinality/WGAN-tensorflow':<https://github.com/Zardinality/WGAN-tensorflow>'NELSONZHAO/zhihu':<https://github.com/NELSONZHAO/zhihu>
"""# import dependency
import tensorflow as tf
import numpy as np
import pickle
import visualization
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from threading import Thread
from time import sleep
import time
import cv2# import MNIST dataset
mnist_dir = r'../mnist_dataset'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(mnist_dir)#------------------------------------------------ define moudle related -------------------------------------------------# define generator
def Generator_MLP(latents,out_dim,reuse=False):uints = 128with tf.variable_scope("generator", reuse=reuse):# dense0dense0 = tf.layers.dense(latents,uints,activation=tf.nn.leaky_relu,name='dense0')# dropoutdropout = tf.layers.dropout(dense0,rate=0.2,name='dropout')# dense1logits = tf.layers.dense(dropout, out_dim,name='dense1')# outputoutputs = tf.tanh(logits,name='outputs')return logits,outputs# define discriminator
def Discriminator_MLP(input,out_dim,reuse=False):uints = 128with tf.variable_scope("discriminator", reuse=reuse):# dense0dense0 = tf.layers.dense(input, uints, activation=tf.nn.leaky_relu, name='dense0',kernel_initializer=tf.random_normal_initializer(0,0.1))# dense1logits = tf.layers.dense(dense0, out_dim, name='dense1',kernel_initializer=tf.random_normal_initializer(0,0.1))# outputoutputs = tf.sigmoid(logits, name='outputs')return logits, outputs# counting total to vars
def COUNT_VARS(vars):total_para = 0for variable in vars:# get each shape of varsshape = variable.get_shape()variable_para = 1for dim in shape:variable_para *= dim.valuetotal_para += variable_parareturn total_para# display paras infomation
def ShowParasList(paras):p = open('./trainLog/Paras.txt', 'w')p.writelines(['vars_total: %d'%COUNT_VARS(paras),'\n'])for variable in paras:p.writelines([variable.name, str(variable.get_shape()),'\n'])print(variable.name, variable.get_shape())p.close()# build related dirs
def GEN_DIR():if not os.path.isdir('ckpt'):print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')os.mkdir('ckpt')if not os.path.isdir('trainLog'):print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')os.mkdir('trainLog')#---------------------------------------------- build graph -------------------------------------------------------------
# hyper-parameters
latents_dim = 128
img_dim = 28*28
smooth = 0.1
learn_rate = 0.001# define input
latents = tf.placeholder(shape=[None,latents_dim],dtype=tf.float32,name='latents')
input_real = tf.placeholder(shape=[None,img_dim],dtype=tf.float32,name='input_real')# get output of G,D
_, g_outputs = Generator_MLP(latents,img_dim,reuse=False)
d_logits_real, d_outputs_real = Discriminator_MLP(input_real,1,reuse=False)
d_logits_fake, d_outputs_fake = Discriminator_MLP(g_outputs,1,reuse=True)# define loss
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,labels=tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_logits_fake)) * (1 - smooth))
# gradient descent
train_vars = tf.trainable_variables()
ShowParasList(train_vars) # display
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(g_loss, var_list=g_vars)#------------------------------------------------ iterations --------------------------====----------------------------GEN_DIR()
max_iters = 20000
batch_size = 64
critic_n = 5
GenLog = []
Losses = []
saver = tf.train.Saver(var_list=g_vars)# recording training info
def SavingRecords():global Lossesglobal GenLog# saving Losseswith open('./trainLog/loss_variation.loss', 'wb') as l:losses = np.array(Losses)pickle.dump(losses, l)print('saving Losses sucessfully!')# saving 生成样本with open('./trainLog/GenLog.log', 'wb') as g:GenLog = np.array(GenLog)pickle.dump(GenLog, g)print('saving GenLog sucessfully!')# define training
def training():with tf.Session() as sess:sess.run(tf.global_variables_initializer())time_start = time.time()  # gofor steps in range(max_iters+1):# 获取数据集data_batch = mnist.train.next_batch(batch_size)[0]# ops.SHOW('real',data_batch[0].reshape([28,28,1]))data_batch = data_batch * 2 - 1data_batch = data_batch.astype(np.float32)z = np.random.normal(0, 1, size=[batch_size, latents_dim]).astype(np.float32)# 训练discriminatorfor n in range(critic_n):sess.run(d_train_opt, feed_dict={input_real: data_batch, latents: z})# 训练Generatorsess.run(g_train_opt, feed_dict={latents: z})# recording training_lossestrain_loss_d = sess.run(d_loss, feed_dict={input_real: data_batch, latents: z})train_loss_g = sess.run(g_loss, feed_dict={latents: z})info = [steps, train_loss_d, train_loss_g]# recording training_productsgen_sanmpes = sess.run(g_outputs, feed_dict={latents: z})visualization.CV2_BATCH_SHOW((np.reshape(gen_sanmpes[0:9], [-1, 28, 28, 1]) + 1) / 2, 1, 3, 3, delay=1)print('iters::%d/%d..Discriminator_loss:%.3f..Generator_loss:%.3f..' % (steps, max_iters, train_loss_d, train_loss_g))if steps % 5 == 0:Losses.append(info)GenLog.append(gen_sanmpes)if steps % 1000 == 0 and steps > 0:saver.save(sess, './ckpt/generator.ckpt', global_step=steps)if steps == max_iters:# cv2.destroyAllWindows()# setup a thread to saving the training infosleep(3)thread1 = Thread(target=SavingRecords,args=())thread1.start()yield info#------------------------------------------------- ANIMATION ----------------------------------------------------------
# ANIMATION
"""
note: in this code , we will see the runtime-variation of G,D losses
"""
iters = []
dloss = []
gloss = []
fig = plt.figure()
ax1 = fig.add_subplot(2,1,1,xlim=(0, max_iters), ylim=(-1, 1))
ax2 = fig.add_subplot(2,1,2,xlim=(0, max_iters), ylim=(-20, 20))
ax1.set_title('discriminator_loss')
ax2.set_title('generator_loss')
line1, = ax1.plot([], [], color='red',lw=1,label='discriminator')
line2, = ax2.plot([], [],color='blue', lw=1,label='generator')
fig.tight_layout()def init():line1.set_data([], [])line2.set_data([], [])return line1,line2def update(info):iters.append(info[0])dloss.append(info[1])gloss.append(info[2])line1.set_data(iters, dloss)line2.set_data(iters, gloss)return line1, line2ani = FuncAnimation(fig, update, frames=training,init_func=init, blit=True,interval=1,repeat=False)
plt.show()

实验结果

1.损失函数变化曲线

2.生成日志

3.验证生成器

生死看淡,不服就GAN(四)---- 用全连层GAN生成MNIST手写体相关推荐

  1. 友商逼急 雷急跳墙:生死看淡 不服就干

    友商逼急    雷急跳墙:生死看淡 不服就干 短短一个小时的红米Note7手机产品发布会,雷军怼了友商8次:甚至在媒体群访环节,雷军也抑制不住愤怒之情,提到友商面色铁青,以至于有人说,这次发布会的雷军 ...

  2. 雷军推红米Redmi独立品牌喊话友商:生死看淡 不服就干

    雷帝网 雷建平 1月10日报道 小米今日在北京召开独立品牌红米Redmi发布会,并发布该品牌首款产品Redmi Note 7. 作为首款产品,Redmi Note 7坚持"死磕性价比&quo ...

  3. 生死看淡,不服就GAN(五)----用DCGAN生成MNIST手写体

    搭建DCGAN网络 #*************************************** 生死看淡,不服就GAN ************************************* ...

  4. 生死看淡 不服就干!雷军这次真的被逼急了

    来源 | 网易科技 作者 | 崔玉贤 短短一个小时的红米Note7手机产品发布会,雷军怼了友商8次:甚至在媒体群访环节,雷军也抑制不住愤怒之情,提到友商面色铁青,以至于有人说,这次发布会的雷军不像&q ...

  5. Redmi K40系列要做旗舰“焊门员”:生死看淡 不服就焊

    经过了一段时间的密集预热,根据此前官宣的消息,全新的Redmi K40系列旗舰将于2月25日也就是明天正式发布.而随着发布会进入最后的倒计时,Redmi官方的预热行动也进入了最后的冲刺阶段.近日Red ...

  6. 生死看淡,不服就GAN(六)----用DCGAN生成马的彩色图片

    1. 首先我们需要的一组真实样本集来自cifar10,因此先制作一个读取cifar10的脚本. """ --------------------------------- ...

  7. 生死看淡,不服就GAN(八)----WGAN的改进版本WGAN-GP

    WGAN-GP是针对WGAN的存在的问题提出来的,WGAN在真实的实验过程中依旧存在着训练困难.收敛速度慢的 问题,相比较传统GAN在实验上提升不是很明显.WGAN-GP在文章中指出了WGAN存在问题 ...

  8. 生死看淡,不服就GAN(七)----用更稳定的生成模型WGAN生成cifar

    WGAN提出Wasserstein距离取代原始GAN的JS散度衡量两分布之间距离,使模型更加稳定并消除了mode collapse问题.关于WGAN的介绍,建议参考以下博客: 令人拍案叫绝的Wasse ...

  9. 雷军的100亿计划:不服就干,生死看淡

    图片来自小米官网 整理 | 琥珀 出品 | AI 科技大本营 1 月 10 日,红米品牌正式独立. 11 日,雷军在小米年会上宣布,2019 年,小米将正式启动"手机+AIoT"双 ...

最新文章

  1. R语言使用pROC包绘制ROC曲线、获取最优阈值(threshold)及最优阈值对应的置信区间
  2. robotframework接口测试(二)—post request
  3. x的平方加y平加xy的java语言_JAVA语言及网络编程-中国大学mooc-题库零氪
  4. 如何datagrid分页保持每页先前选择的checkbox的状态?
  5. C++ STL vector的容量
  6. 一道Python面试题,设置一个动态变量名
  7. 常用函数式接口-Function
  8. git tag打标签常用命令
  9. Deepin添加PPA显示没有公钥签名
  10. Python【每日一问】27
  11. 多个微服务的接口依赖如何测试_一文看懂微服务
  12. 【Mybatis】Mybatis三大组件之ResultSetHandler
  13. excel mac 水晶球_水晶球软件使用crystalball.pptx
  14. 伺服速度控制模式接线图_伺服驱动器控制模式的接线及其注意事项
  15. Base16,Base32,Base64编码的介绍
  16. 触动精灵mysql用法_基于Lua语言的触动精灵脚本开发
  17. excel退出打印预览快捷键?
  18. 正则表达式 '^[a-zA-Z0-9''-'\s]{1,30}$' 代表什么意思?
  19. ECharts之横向柱状图二
  20. 支理解SVM的三层境界

热门文章

  1. mysql设置允许远程访问
  2. (接上)将txt中的一组时间转换为简化儒略日的小工具
  3. 安装完黑苹果之后该做的事情
  4. 什么是迭代式项目开发
  5. 不相交轮换的乘积怎么求_怎么样将一个轮换分解成不相交的轮换的乘积
  6. 冰桶挑战引来了百度搜索冰桶算法
  7. hash函数的简单介绍
  8. AppNode面板安装搭建教程
  9. 【07】概率图推断之信念传播
  10. Java面向对象05:创建对象的内存分析成员变量和局部变量的内存分析