实例99:使用AEGAN对MNIST数据集压缩特征及重建
本实例在MNIST数据集上使用AEGAN模型进行特征压缩及重建,并且加入标签信息loss实现AC-GAN网络。其中D和G都是通过卷积网络实现。
实例描述
使用InfoGAN网络,在其基础上添加自编码网络,将InfoGAN的参数固定,训练反向生成器(自编码网络中的编码器),并将生成的模型用于MNIST数据集样本重建,得到相似的样本。
1.添加反向生成器
添加反向生成器inversegenerator函数。该函数的功能是将图片生成特征吗,其结构与判别器相似,均为生成器的反向操作,即两个卷积层加上两个全连接层。
#反向生成器定义,结构与判别器类似
def inversegenerator(x):reuse = len([t for t in tf.global_variables() if t.name.startswith('inversegenerator')]) > 0with tf.variable_scope('inversegenerator', reuse=reuse):#两个卷积x = tf.reshape(x, shape=[-1, 28, 28, 1])x = slim.conv2d(x, num_outputs = 64, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)x = slim.conv2d(x, num_outputs=128, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)#两个全连接x = slim.flatten(x) shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn = leaky_relu)z = slim.fully_connected(shared_tensor, num_outputs=50, activation_fn = leaky_relu)return z
2.添加自编码网络代码
自编码网络输入不是真实图片,而是生成器生成的图片generator(z),通过inversegenerator来压缩特征,生成与生成器输入噪声一样维度,然后将生成器当做自编码中的解码器重建出原始生成的图片。
将自编码还原的图片与GAN生成生成的输入图片进行平方差计算,得到自编码的损失值loss_ae。
z_con = tf.random_normal((batch_size, con_dim)) #2列
z_rand = tf.random_normal((batch_size, rand_dim)) #38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth = classes_dim), z_con, z_rand])#50列
#生成器生成的图片gen模拟数据
gen = generator(z)
genout= tf.squeeze(gen, -1)#自编码网络
aelearning_rate =0.01
igen = generator(inversegenerator(generator(z))) #生成器生成的模拟数据-反生成原始数据-再生成图片的数据
loss_ae = tf.reduce_mean(tf.pow(gen - igen, 2))#输出
igenout = generator(inversegenerator(x))
3.添加自编码网络的训练参数列表,定义优化器
自编码网络的训练参数与前面的GAN几乎一样,使用MonitoredTrainingSession来管理检查点文件,定义global_step。定义train_ae优化器,并将global_step放入优化器中。
# 获得各个网络中各自的训练参数
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
ae_vars = [var for var in t_vars if 'inversegenerator' in var.name]gen_global_step = tf.Variable(0, trainable=False)
global_step = tf.train.get_or_create_global_step()#使用MonitoredTrainingSession,必须有train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d + loss_c + loss_con, var_list = d_vars, global_step = global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g + loss_c + loss_con, var_list = g_vars, global_step = gen_global_step)
train_ae = tf.train.AdamOptimizer(aelearning_rate).minimize(loss_ae, var_list = ae_vars, global_step = global_step)training_GANepochs = 3 #训练GAN迭代3次数据集
training_aeepochs = 6 #训练AE迭代3次数据集(从3开始到6)
display_step = 1
本例需要训练GAN和AE两个网络,使用MonitoredTrainingSession管理后只能由有一个global_step,于是将global_step分段来管理两个网络的训练。每次迭代训练都会遍历整个数据集,先GAN迭代3次,在让AE迭代3测。
4.起动session依次训练GAN和AE网络
使用MonitoredTrainingSession创建session。令程序每2分钟保存一次检查点文件。
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/aecheckpoints',save_checkpoint_secs =120) as sess:total_batch = int(mnist.train.num_examples/batch_size)print("ae_global_step.eval(session=sess)",global_step.eval(session=sess),int(global_step.eval(session=sess)/total_batch))for epoch in range( int(global_step.eval(session=sess)/total_batch),training_GANepochs):avg_cost = 0.# 遍历全部数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据feeds = {x: batch_xs, y: batch_ys}# Fit training using batch datal_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step],feeds)l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step],feeds)# 显示训练中的详细信息if epoch % display_step == 0:print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc),l_gen)print("GAN完成!")# 测试print ("Result:", loss_d.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]},session = sess), loss_g.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]},session = sess))
从图中可以看出,InfoGAN只会生成属于原始数据分布的图片,而AEGAN会生成与原始图片更相近的图片。
这种网络有压缩特征与重建两部分用途,重建样本常常用于处理图像的恢复与重建,还可以将重建的模拟数据保存起来空充数据集,也可以应用在超分辨率重建部分;
实例99:使用AEGAN对MNIST数据集压缩特征及重建相关推荐
- DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)
DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...
- DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)
DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...
- MNIST数据集的gist特征提取(含全部实例代码下载地址)
这些天处理图像检索的一些benchmark数据集,今天处理了MNIST数据集,并对其进行了特征的提取.我的方法可能不一定是最优,但是按照这样的步骤来做,得到了我最后想要的特征数据结果.需要的朋友可以参 ...
- 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%
基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...
- pytorch dropout_PyTorch初探MNIST数据集
前言: 本文主要描述了如何使用现在热度和关注度比较高的Pytorch(深度学习框架)构建一个简单的卷积神经网络,并对MNIST数据集进行了训练和测试.MNIST数据集是一个28*28的手写数字图片集合 ...
- pytorch保存准确率_初学Pytorch:MNIST数据集训练详解
前言 本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试.针对过程中的每个步骤都尽可能的给出了详尽的解释. ...
- 使用mnist数据集_使用MNIST数据集上的t分布随机邻居嵌入(t-SNE)进行降维
使用mnist数据集 It is easy for us to visualize two or three dimensional data, but once it goes beyond thr ...
- python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解
关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu ...
- 基于tensorflow+RNN的MNIST数据集手写数字分类
2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...
- [转载] 卷积神经网络做mnist数据集识别
参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...
最新文章
- c语言的图像拼接,OpenCV实现多图像拼接成一张大图分享!
- 在C#中读取枚举值的描述属性
- 工程师进阶之路(二)
- [转载]常用正则表达式
- .net读写配置文件
- intellij关联本地的maven的repository
- 如何提升github的clone速度(简单粗暴,亲测有效)
- 0018计算机基础知识,0018 0019计算机应用基础上机试题
- 最小二乘支持向量机的分析与改进及Python实现
- linux下C语言简单实现线程池
- 项目管理系统Redmine安装
- java php cms_内容管理系统的开发策略研究——以PHP CMS、Node.js CMS、Java CMS为例
- 关于在Webservice里使用LinqToSQL遇到一对多关系的父子表中子表需要ToList输出泛型而产生循环引用错误的解决办法!...
- 《编写有效用例》阅读笔记04
- HDU 2089:不要62(数位DP)
- 你是否还在写try-catch-finally?来使用try-with-resources优雅地关闭流吧
- python开发酷q插件gui_Python酷Q应用开发
- Bjui前端框架文档链接
- 杜邦线改成焊线_做杜邦线(假)教程
- 2021-2022-2 ACM集训队每周程序设计竞赛(1) - 问题 B: 蹩脚两轮车 - 题解