生成对抗网络GANs(Generative Adversarial Nets

from datetime import datetime
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tensorflow.examples.tutorials.mnist import input_dataBATCH_SIZE = 128
LEARNING_RATE = 1e-4
Z_DIM = 100
IMAGE_W = 28
IMAGE_H = 28
model_dir = 'model_gan'x_in = tf.placeholder(tf.float32, shape=[None, 784])def load_mnist():return input_data.read_data_sets("./MNIST_data", one_hot=True)mnist = load_mnist()def get_W_b(input_dim, output_dim, name):W = tf.Variable(tf.random_normal([input_dim, output_dim], stddev=0.02), name=name.replace('_b', ''))b = tf.Variable(tf.zeros([output_dim], tf.float32), name=name.replace('_W', ''))return W, btmp = 256class GAN(object):def __init__(self, lr=LEARNING_RATE, batch_size=BATCH_SIZE, z_dim=Z_DIM):self.lr = lrself.batch_size = batch_sizeself.z_dim = z_dim# 生成器的权重self.gen_W1, self.gen_b1 = get_W_b(z_dim, tmp, 'gen_W_b_1')self.gen_W2, self.gen_b2 = get_W_b(tmp, IMAGE_H * IMAGE_W, 'gen_W_b_2')# 判别器的权重self.discrim_W1, self.discrim_b1 = get_W_b(IMAGE_H * IMAGE_W, tmp, 'discrim_W_b_1')self.discrim_W2, self.discrim_b2 = get_W_b(tmp, 1, 'discrim_W_b_2')# 判别器def discriminator(self, x):d_h1 = tf.nn.relu(tf.add(tf.matmul(x, self.discrim_W1), self.discrim_b1))d_h2 = tf.add(tf.matmul(d_h1, self.discrim_W2), self.discrim_b2)return tf.nn.sigmoid(d_h2)# 生成器def generator(self, z):g_h1 = tf.nn.relu(tf.add(tf.matmul(z, self.gen_W1), self.gen_b1))g_h2 = tf.add(tf.matmul(g_h1, self.gen_W2), self.gen_b2)return tf.nn.sigmoid(g_h2)# 建立模型def build_model(self):z_sample = np.random.uniform(-1., 1., size=[self.batch_size, self.z_dim]).astype('float32')g_image = self.generator(z_sample)d_real = self.discriminator(x_in)d_fake = self.discriminator(g_image)d_cost = -tf.reduce_mean(tf.log(d_real) + tf.log(1. - d_fake))g_cost = -tf.reduce_mean(tf.log(d_fake))return d_cost, g_cost, tf.reduce_mean(d_real), tf.reduce_mean(d_fake)# 画图
def plot_grid(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(IMAGE_H, IMAGE_W), cmap='Greys_r')return fig# 训练
def train():with tf.Session() as sess:gan = GAN()discrim_vars = list(filter(lambda x: x.name.startswith('discrim'), tf.trainable_variables()))gen_vars = list(filter(lambda x: x.name.startswith('gen'), tf.trainable_variables()))d_cost, g_cost, d_real, d_fake = gan.build_model()optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)d_opt = optimizer.minimize(d_cost, var_list=discrim_vars)g_opt = optimizer.minimize(g_cost, var_list=gen_vars)saver = tf.train.Saver()checkpoint = tf.train.latest_checkpoint(model_dir)if checkpoint:saver.restore(sess, checkpoint)  # 从模型中读取数据print('checkpoint: {}'.format(checkpoint))else:# 变量初始化sess.run(tf.global_variables_initializer())print("Started training {}".format(datetime.now().isoformat()[11:]))plot_index = 0for step in range(100000):batch_x, _ = mnist.train.next_batch(BATCH_SIZE)sess.run(d_opt, feed_dict={x_in: batch_x})sess.run(g_opt, feed_dict={x_in: batch_x})# 每1000个step保存一次图片if step % 1000 == 0:batch_x, _ = mnist.train.next_batch(BATCH_SIZE)d_cost_, d_real_, d_fake_ = sess.run([d_cost, d_real, d_fake], feed_dict={x_in: batch_x})g_cost_ = sess.run(g_cost, feed_dict={x_in: batch_x})print("step:{} Discriminator Loss {} Generator loss {}  d_real:{}  d_feak:{}".format(step, d_cost_,g_cost_, d_real_,d_fake_))z_sample = np.random.uniform(-1., 1., size=[16, Z_DIM]).astype('float32')g_image = sess.run(gan.generator(z_sample))fig = plot_grid(g_image)plt.savefig('D:\project\生成对抗网络\img\{}.png'.format(str(plot_index).zfill(4)), bbox_inches='tight')plot_index += 1plt.close(fig)# 保存模型saver.save(sess, "{}/model_gan.model".format(model_dir), global_step=step)print("Ended training {}".format(datetime.now().isoformat()[11:]))if __name__ == "__main__":train()

生成对抗网络GANs相关推荐

  1. 生成对抗网络GANs理解(附代码)

    生成对抗网络GANs理解(附代码) 原文地址: http://blog.csdn.net/sxf1061926959/article/details/54630462 生成模型和判别模型 理解对抗网络 ...

  2. 判别两棵树是否相等 设计算法_一文看懂生成对抗网络 - GANs?(附:10种典型算法+13种应用)...

    生成对抗网络 – GANs 是最近2年很热门的一种无监督算法,他能生成出非常逼真的照片,图像甚至视频.我们手机里的照片处理软件中就会使用到它. 本文将详细介绍生成对抗网络 – GANs 的设计初衷.基 ...

  3. 掌握生成对抗网络(GANs),召唤专属二次元老婆(老公)不是梦

    全文共6706字,预计学习时长12分钟或更长 近日,<狮子王>热映,其逼真的外形,几乎可以以假乱真,让观众不禁大呼:awsl,这也太真实了吧! 实体模型.CGI动画.实景拍摄.VR等技术娴 ...

  4. 一文读懂生成对抗网络GANs(附学习资源)

    原文标题:AnIntuitive Introduction to Generative Adversarial Networks 作者:KeshavDhandhania.ArashDelijani 翻 ...

  5. 【干货】生成对抗网络GANs算法在医学图像领域应用总结

    Goodfellow等人,介绍了生成对抗网络(GAN)以模拟数据分布.由于与两个基本属性相关的原因,GAN可以合成真实图像. GAN是一种无监督的训练方法,可以通过类似于人类学习图像特征的方式获取信息 ...

  6. 小说生成对抗网络GANs

    1.前言 机器学习可以分为两类模型,一种是判别模型(discrimination model),给定一个输入,通过模型判别输入的类别.另一种是生成模型(generative model),给定输入,通 ...

  7. Nikolai Yakovenko大佬:深度学习的下一个热点:生成对抗网络(GANs)将改变世界

    生成式对抗网络-简称GANs-将成为深度学习的下一个热点,它将改变我们认知世界的方式. 准确来讲,对抗式训练为指导人工智能完成复杂任务提供了一个全新的思路,某种意义上他们(人工智能)将学习如何成为一个 ...

  8. 深度学习的下一个热点:生成对抗网络(GANs)将改变世界

    本文作者 Nikolai Yakovenko 毕业于哥伦比亚大学,目前是 Google 的工程师,致力于构建人工智能系统,专注于语言处理.文本分类.解析与生成. 生成式对抗网络-简称GANs-将成为深 ...

  9. 训练生成对抗网络的一些需要关注的问题

    Utkarsh Desai 2020-09-22 Tuesday ➤ 01导读 生成对抗网络是个好东西,不过训练比较麻烦,这里有一些技巧和陷阱,分享给大家. 生成对抗网络(GANs)是当前深度学习研究 ...

最新文章

  1. 美多商城之支付(支付宝介绍)
  2. 那个一年发了4篇Cell的研究生,后来怎么样了?
  3. 软考自查:多媒体基础知识
  4. curl抓取页面时遇到重定向的解决方法(转)
  5. 十个简单好用的设计技巧[SM]
  6. js计算工时,去周末,设置上下班时间
  7. Java IO: 其他字节流(上)
  8. ROS笔记(6) ROS通讯机制
  9. 如何为项目和产品提供资源——优化工作时间、激励团队和预算
  10. 【python】yaml文件操作
  11. 心电图系统服务器与存储系统,心电图网络信息化管理系统
  12. 201771010112罗松《面向对象程序设计(java)》第二周学习总结
  13. Gradle配置及同一应用不同版本配置不同资源文件,不同签名,包名进行打包
  14. cobbler简单入门
  15. 学习日志-勉励自己-自律
  16. Python学习a1——背景及基础
  17. 细谈围城---我的启示录
  18. 沙漠求生十五选五实验
  19. 如何将excel.xls文件批量转换成.xlsx格式
  20. 微信朋友圈 html 字体颜色,微信朋友圈怎么发文字,朋友圈字体颜色可以改吗?...

热门文章

  1. 软件许可协议怎么写?
  2. 深入解析AAVE智能合约:计算和利率
  3. 反向交易: 圣杯还是危险的假象
  4. Python暴力破解wifi密码,你看了你也行
  5. 01组团队项目-Beta冲刺-3/5
  6. 关于hadoop访问8088端口显示只有一个节点
  7. stata 入门(双重差分模型)
  8. 【计算机存储单位】字,字节,字符
  9. 如何利用无线路由接收无线信号
  10. msec 腾讯CICD程序框架发布