在MNIST数据集上使用InfoGAN网络模型生活才能模拟数据,并且加入标签的loss函数同时实现AC-GAN网络。其中D和G都是卷积函数来生成,相当于在DCGAN基础上的InfoGAN例子。

实例描述

通过使用InfoGAN网络学习MNIST数据特征,生成以假乱真的MNIST的模拟样本,并发现内部潜在的特征信息。

1.引入头文件并加载MNIST数据

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slimfrom tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/")#, one_hot=True)

2.网络结构介绍

建立两个噪声数据(一般噪声和隐含信息)与label结合放到生成器中,生成模拟样本,然后将模拟样本和真实样本分别输入到判别器中,生成判别结果、重构的隐含信息,以及样本标签。
在优化时,让判别器对真实的样本判别结果为1、对模拟数据的判别结果为0来做损失值计算(loss);对生成器让判别结果过为1来做损失值计算。

3.定义生成器和判别器

先从模拟噪声数据来恢复样本,生成器采用反卷积函数,这里通过两个全连接+两个反卷积模拟样本的生成,并且每一层都有BN处理。

def generator(x):reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0##定义生成器的变量域with tf.variable_scope('generator', reuse = reuse):#全连接层1024x = slim.fully_connected(x, 1024)x = slim.batch_norm(x, activation_fn=tf.nn.relu)#全连接成7*7*128x = slim.fully_connected(x, 7*7*128)x = slim.batch_norm(x, activation_fn=tf.nn.relu)x = tf.reshape(x, [-1, 7, 7, 128])#反卷积层,生成64个x = slim.conv2d_transpose(x, 64, kernel_size=[4,4], stride=2, activation_fn = None)x = slim.batch_norm(x, activation_fn = tf.nn.relu)#反卷积生成一个z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)return zdef leaky_relu(x):return tf.where(tf.greater(x, 0), x, 0.01 * x)def discriminator(x, num_classes=10, num_cont=2):reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0with tf.variable_scope('discriminator', reuse=reuse):x = tf.reshape(x, shape=[-1, 28, 28, 1])#两次卷积x = slim.conv2d(x, num_outputs = 64, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)x = slim.conv2d(x, num_outputs=128, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)x = slim.flatten(x)     #将x扁平化,指保留批次#全连接shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn = leaky_relu)#全连接recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn = leaky_relu)#1维输出层disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=None)disc = tf.squeeze(disc, -1)     #删除维度为1#10维分类recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)#2维输出层recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)return disc, recog_cat, recog_cont

如果判别器输入的是真正的样本,同样也要经过两次卷积,两次全连接层,生成的数据可以分别连接不同的输出层产生不同的结果,其中1维的输出层产生的判别结果为1或是0,10维的输出层产生分类结果,2维输出层产生隐含维度信息。

4.定义网络模型

一般噪声维度为38,应节点为z_rand;隐含信息维度为2,应节点为z_con,二者都是符合标准高斯分布的随机数。将他们与one_hot转换后的标签连接在一起放到生成器中。

batch_size = 10   # 获取样本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2
rand_dim = 38
n_input  = 784x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.int32, [None])z_con = tf.random_normal((batch_size, con_dim))#2列
z_rand = tf.random_normal((batch_size, rand_dim))#38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth = classes_dim), z_con, z_rand])#50列
gen = generator(z)
genout= tf.squeeze(gen, -1)y_real = tf.ones(batch_size) #真
y_fake = tf.zeros(batch_size)#假# 判别器
disc_real, class_real, _ = discriminator(x)
disc_fake, class_fake, con_fake = discriminator(gen)
pred_class = tf.argmax(class_fake, dimension=1)

5.定义损失函数与优化器

  判别器中,判别结果loss有两个:真实输入结果与模拟输入结果,将两者结合在一起生成loss_d。生成器的loss为自己输出的模拟数据,让其在判断器中为真,生成loss_g。
  定义网络中共有的loss:真实标签与输入真实样本判断的标签、真实的标签与输入模拟熟数据判别的标签、隐含信息的重构误差。
  这里也应用了一个技巧:将判别器的学习率设置的很小,将生成器的学习率设置的大一些。让生成器更快的进化速度来模拟真实数据。

# 判别器 loss
#判断真实输入为真
loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))
#判断模拟输入为假
loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))
loss_d = (loss_d_r + loss_d_f) / 2# 生成器loss
#将生成器的结果不断趋近真
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_real))#  共有的loss
#模拟数据标签识别
loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
#真实数据标签识别
loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
loss_c =(loss_cf + loss_cr) / 2# 隐含信息loss
loss_con =tf.reduce_mean(tf.square(con_fake-z_con))# 获得各个网络中各自的训练参数
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]disc_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d + loss_c + loss_con, var_list = d_vars, global_step = disc_global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g + loss_c + loss_con, var_list = g_vars, global_step = gen_global_step)

  AC-GAN就是将loss_cr加入到loss_c中。没有loss_cr,会损失真实分类与模拟数据之间对应关系。

6.开始训练与测试

with tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(training_epochs):avg_cost = 0.total_batch = int(mnist.train.num_examples/batch_size)# 遍历全部数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据feeds = {x: batch_xs, y: batch_ys}# Fit training using batch datal_disc, _, l_d_step = sess.run([loss_d, train_disc, disc_global_step],feeds)l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step],feeds)# 显示训练中的详细信息if epoch % display_step == 0:print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc),l_gen)print("完成!")# 测试print ("Result:", loss_d.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]}), loss_g.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]}))

7.可视化

  可视化部分会生活舱呢个两个图片:原样本对应的模拟数据图片、利用隐含信息生成的模拟样本图片。

  • 原样本与对应的模拟数据图片会将对应的分类、预测分类、隐含信息一起打印出来。
  • 利用隐含信息生成的模拟样本图片会在整个[0,1]空间里均匀抽样,与样本的标签混合在一起,生成模拟数据。
 # 根据图片模拟生成图片show_num = 10gensimple,d_class,inputx,inputy,con_out = sess.run([genout,pred_class,x,y,con_fake], feed_dict={x: mnist.test.images[:batch_size],y: mnist.test.labels[:batch_size]})f, a = plt.subplots(2, 10, figsize=(10, 2))for i in range(show_num):a[0][i].imshow(np.reshape(inputx[i], (28, 28)))a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))print("d_class",d_class[i],"inputy",inputy[i],"con_out",con_out[i])plt.draw()plt.show()  my_con=tf.placeholder(tf.float32, [batch_size,2])myz = tf.concat(axis=1, values=[tf.one_hot(y, depth = classes_dim), my_con, z_rand])mygen = generator(myz)mygenout= tf.squeeze(mygen, -1) my_con1 = np.ones([10,2])a = np.linspace(0.0001, 0.99999, 10)y_input= np.ones([10])figure = np.zeros((28 * 10, 28 * 10))my_rand = tf.random_normal((10, rand_dim))for i in range(10):for j in range(10):my_con1[j][0]=a[i]my_con1[j][1]=a[j]y_input[j] = jmygenoutv =  sess.run(mygenout,feed_dict={y:y_input,my_con:my_con1})for jj in range(10):digit = mygenoutv[jj].reshape(28, 28)figure[i * 28: (i + 1) * 28,jj * 28: (jj + 1) * 28] = digitplt.figure(figsize=(10, 10))plt.imshow(figure, cmap='Greys_r')plt.show()

  从上面的结果中,可以很容易观察到,除了可控的类别信息外,隐含信息找某些维度具有非常显著的语义信息。两个维度可能与倾斜和粗细程度。

实例88:构建InfoGAN生成MNIST模拟数据相关推荐

  1. 【Pytorch神经网络实战案例】17 带W散度的WGAN-div模型生成Fashon-MNST模拟数据

    1 WGAN-div 简介 W散度的损失函数GAN-dv模型使用了W散度来替换W距离的计算方式,将原有的真假样本采样操作换为基于分布层面的计算. 2 代码实现 在WGAN-gp的基础上稍加改动来实现, ...

  2. 【Pytorch神经网络实战案例】15 WGAN-gp模型生成Fashon-MNST模拟数据

    1 WGAN-gp模型生成模拟数据案例说明 使用WGAN-gp模型模拟Fashion-MNIST数据的生成,会使用到WGAN-gp模型.深度卷积GAN(DeepConvolutional GAN,DC ...

  3. 如何快速在oracle内生成数据,[Oracle]快速生成大量模拟数据的方法

    快速生成大量模拟数据的方法: create table TEST(id integer, TEST_NUMBER NUMBER(18,6)); insert into TEST select i+j, ...

  4. 【Pytorch神经网络实战案例】13 构建变分自编码神经网络模型生成Fashon-MNST模拟数据

    1 变分自编码神经网络生成模拟数据案例说明 变分自编码里面真正的公式只有一个KL散度. 1.1 变分自编码神经网络模型介绍 主要由以下三个部分构成: 1.1.1 编码器 由两层全连接神经网络组成,第一 ...

  5. lazy-mock ,一个生成后端模拟数据的懒人工具

    lazy-mock   lazy-mock 是基于koa2构建的,使用lowdb持久化数据到JSON文件.只需要简单的配置就可以实现和json-server差不多的功能,但是比json-server更 ...

  6. 推荐一个生成后端模拟数据的懒人工具:lazy-mock

    点击上方蓝色"程序猿DD",选择"设为星标" 回复"资源"获取独家整理的学习资料! 作者 | 若邪 来源 | https://juejin. ...

  7. 【Pytorch神经网络实战案例】14 构建条件变分自编码神经网络模型生成可控Fashon-MNST模拟数据

    1 条件变分自编码神经网络生成模拟数据案例说明 在实际应用中,条件变分自编码神经网络的应用会更为广泛一些,因为它使得模型输出的模拟数据可控,即可以指定模型输出鞋子或者上衣. 1.1 案例描述 在变分自 ...

  8. node.js 生成文件_如何使用Node.js在几秒钟内生成模拟数据

    node.js 生成文件 介绍 (Introduction) In most of the applications, you need to have some static JSON data w ...

  9. 【Pytorch神经网络实战案例】16 条件WGAN模型生成可控Fashon-MNST模拟数据

    1 条件GAN前置知识 条件GAN也可以使GAN所生成的数据可控,使模型变得实用, 1.1 实验描述 搭建条件GAN模型,实现向模型中输入标签,并使其生成与标签类别对应的模拟数据的功能,基于WGAN- ...

  10. doceker模拟数据的生成

    需求的计算,用sparkcore+sparkSQL离线数据源:离线数据源解析动作表中的一条行为对应着用户的一次行为(点击,搜索,下单或者购买) 先用mocker模拟数据//一共有100个用户,有重复v ...

最新文章

  1. C++ 类的内存分布
  2. Jquery mobile技术咖们走进来瞧瞧吧
  3. 聊天机器人中的深度学习技术(引言)
  4. 作为程序员,你吃过哪些数学的亏?
  5. 虚拟软驱影像文件制作程序下载路径:http://download.csdn.net/source/738137
  6. mysql 存储引擎版本_mysql不同版本和存储引擎选型的验证
  7. JS 中通过对象关联实现『继承』
  8. Debug在中Eclipse的应用
  9. qt中dll缺失以及无法启动程序的正确解决方法
  10. poj Gone Fishing 枚举加贪心 当初做的很纠结啊!!终于A了,与大家分享一下经验
  11. RAW-socket
  12. Android studio3.5读取项目资源文件的图片
  13. kali工具 -- setoolkit(克隆网站及利用)
  14. 亚马逊获20亿美元信用额度:有助新业务投资
  15. SQL 一条SQL语句 统计 各班总人数、男女各总人数 、该班级男女 比例
  16. VMware安装Centos7_64位系统安装步骤
  17. 读写锁,为什么要用读写锁;
  18. [Swift]LeetCode120. 三角形最小路径和 | Triangle
  19. 自己对win10虚拟内存的理解,不一定对
  20. JPEG图像压缩算法流程详解

热门文章

  1. 初识STM32F407芯片
  2. apache camel_使用Apache Camel开始使用REST服务
  3. java偏向锁_Java锁事之偏向锁
  4. Navicat Premiumx64 使用注册机激活
  5. 什么是ipo表,ipo图,hipo图
  6. 计算机考研复试【英语面试题汇总】
  7. PS打开PSD文档服务器未响应,ps打不开psd文件的解决方法
  8. CxImage 使用报错解决办法
  9. 架构师之路:星环大数据架构师的培训心得
  10. jflash合并bin文件