本实例在MNIST数据集上使用AEGAN模型进行特征压缩及重建,并且加入标签信息loss实现AC-GAN网络。其中D和G都是通过卷积网络实现。

实例描述

  使用InfoGAN网络,在其基础上添加自编码网络,将InfoGAN的参数固定,训练反向生成器(自编码网络中的编码器),并将生成的模型用于MNIST数据集样本重建,得到相似的样本。

1.添加反向生成器

  添加反向生成器inversegenerator函数。该函数的功能是将图片生成特征吗,其结构与判别器相似,均为生成器的反向操作,即两个卷积层加上两个全连接层。

#反向生成器定义,结构与判别器类似
def inversegenerator(x):reuse = len([t for t in tf.global_variables() if t.name.startswith('inversegenerator')]) > 0with tf.variable_scope('inversegenerator', reuse=reuse):#两个卷积x = tf.reshape(x, shape=[-1, 28, 28, 1])x = slim.conv2d(x, num_outputs = 64, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)x = slim.conv2d(x, num_outputs=128, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)#两个全连接x = slim.flatten(x)        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn = leaky_relu)z = slim.fully_connected(shared_tensor, num_outputs=50, activation_fn = leaky_relu)return z

2.添加自编码网络代码

  自编码网络输入不是真实图片,而是生成器生成的图片generator(z),通过inversegenerator来压缩特征,生成与生成器输入噪声一样维度,然后将生成器当做自编码中的解码器重建出原始生成的图片。
  将自编码还原的图片与GAN生成生成的输入图片进行平方差计算,得到自编码的损失值loss_ae。

z_con = tf.random_normal((batch_size, con_dim))      #2列
z_rand = tf.random_normal((batch_size, rand_dim))    #38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth = classes_dim), z_con, z_rand])#50列
#生成器生成的图片gen模拟数据
gen = generator(z)
genout= tf.squeeze(gen, -1)#自编码网络
aelearning_rate =0.01
igen = generator(inversegenerator(generator(z)))    #生成器生成的模拟数据-反生成原始数据-再生成图片的数据
loss_ae = tf.reduce_mean(tf.pow(gen - igen, 2))#输出
igenout = generator(inversegenerator(x))

3.添加自编码网络的训练参数列表,定义优化器

  自编码网络的训练参数与前面的GAN几乎一样,使用MonitoredTrainingSession来管理检查点文件,定义global_step。定义train_ae优化器,并将global_step放入优化器中。

# 获得各个网络中各自的训练参数
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
ae_vars =  [var for var in t_vars if 'inversegenerator' in var.name]gen_global_step = tf.Variable(0, trainable=False)
global_step = tf.train.get_or_create_global_step()#使用MonitoredTrainingSession,必须有train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d + loss_c + loss_con, var_list = d_vars, global_step = global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g + loss_c + loss_con, var_list = g_vars, global_step = gen_global_step)
train_ae = tf.train.AdamOptimizer(aelearning_rate).minimize(loss_ae, var_list = ae_vars, global_step = global_step)training_GANepochs = 3   #训练GAN迭代3次数据集
training_aeepochs = 6    #训练AE迭代3次数据集(从3开始到6)
display_step = 1

  本例需要训练GAN和AE两个网络,使用MonitoredTrainingSession管理后只能由有一个global_step,于是将global_step分段来管理两个网络的训练。每次迭代训练都会遍历整个数据集,先GAN迭代3次,在让AE迭代3测。

4.起动session依次训练GAN和AE网络

  使用MonitoredTrainingSession创建session。令程序每2分钟保存一次检查点文件。

with tf.train.MonitoredTrainingSession(checkpoint_dir='log/aecheckpoints',save_checkpoint_secs  =120) as sess:total_batch = int(mnist.train.num_examples/batch_size)print("ae_global_step.eval(session=sess)",global_step.eval(session=sess),int(global_step.eval(session=sess)/total_batch))for epoch in range( int(global_step.eval(session=sess)/total_batch),training_GANepochs):avg_cost = 0.# 遍历全部数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据feeds = {x: batch_xs, y: batch_ys}# Fit training using batch datal_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step],feeds)l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step],feeds)# 显示训练中的详细信息if epoch % display_step == 0:print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc),l_gen)print("GAN完成!")# 测试print ("Result:", loss_d.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]},session = sess), loss_g.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]},session = sess))



  从图中可以看出,InfoGAN只会生成属于原始数据分布的图片,而AEGAN会生成与原始图片更相近的图片。
  这种网络有压缩特征与重建两部分用途,重建样本常常用于处理图像的恢复与重建,还可以将重建的模拟数据保存起来空充数据集,也可以应用在超分辨率重建部分;

实例99:使用AEGAN对MNIST数据集压缩特征及重建相关推荐

  1. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

  2. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...

  3. MNIST数据集的gist特征提取(含全部实例代码下载地址)

    这些天处理图像检索的一些benchmark数据集,今天处理了MNIST数据集,并对其进行了特征的提取.我的方法可能不一定是最优,但是按照这样的步骤来做,得到了我最后想要的特征数据结果.需要的朋友可以参 ...

  4. 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%

    基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...

  5. pytorch dropout_PyTorch初探MNIST数据集

    前言: 本文主要描述了如何使用现在热度和关注度比较高的Pytorch(深度学习框架)构建一个简单的卷积神经网络,并对MNIST数据集进行了训练和测试.MNIST数据集是一个28*28的手写数字图片集合 ...

  6. pytorch保存准确率_初学Pytorch:MNIST数据集训练详解

    前言 本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试.针对过程中的每个步骤都尽可能的给出了详尽的解释. ...

  7. 使用mnist数据集_使用MNIST数据集上的t分布随机邻居嵌入(t-SNE)进行降维

    使用mnist数据集 It is easy for us to visualize two or three dimensional data, but once it goes beyond thr ...

  8. python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu ...

  9. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  10. [转载] 卷积神经网络做mnist数据集识别

    参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...

最新文章

  1. c语言的图像拼接,OpenCV实现多图像拼接成一张大图分享!
  2. 在C#中读取枚举值的描述属性
  3. 工程师进阶之路(二)
  4. [转载]常用正则表达式
  5. .net读写配置文件
  6. intellij关联本地的maven的repository
  7. 如何提升github的clone速度(简单粗暴,亲测有效)
  8. 0018计算机基础知识,0018 0019计算机应用基础上机试题
  9. 最小二乘支持向量机的分析与改进及Python实现
  10. linux下C语言简单实现线程池
  11. 项目管理系统Redmine安装
  12. java php cms_内容管理系统的开发策略研究——以PHP CMS、Node.js CMS、Java CMS为例
  13. 关于在Webservice里使用LinqToSQL遇到一对多关系的父子表中子表需要ToList输出泛型而产生循环引用错误的解决办法!...
  14. 《编写有效用例》阅读笔记04
  15. HDU 2089:不要62(数位DP)
  16. 你是否还在写try-catch-finally?来使用try-with-resources优雅地关闭流吧
  17. python开发酷q插件gui_Python酷Q应用开发
  18. Bjui前端框架文档链接
  19. 杜邦线改成焊线_做杜邦线(假)教程
  20. 2021-2022-2 ACM集训队每周程序设计竞赛(1) - 问题 B: 蹩脚两轮车 - 题解

热门文章

  1. keras中的K.gradients()函数
  2. bzoj 2827 千山鸟飞绝
  3. FileZilla查看站点隐藏文件
  4. excel高级筛选怎么用_Excel表格自动筛选的9个高级用法
  5. 双机热备系统的方案与软件浅析
  6. 战争英雄、同性恋和计算机科学的奠基人
  7. Unity3D中玩家的移动方式,三大类型,八种方式
  8. 以CRM系统为案例讲解数据分析(重要性介绍及分析方法)
  9. 团队作业8----第二次项目冲刺(Beta阶段) 第四天
  10. 知乎spark与hadoop讨论