包含从头开始构建Autoencoders模型的完整代码。

(关注“我爱计算机视觉”公众号,一个有价值有深度的公众号~)
在本教程中,我们一起来探索一个非监督学习神经网络——Autoencoders(自动编码器)。
自动编码器是用于在输出层再现输入数据的深度神经网络,所以输出层中的神经元的数量与输入层中的神经元的数量完全相同。
如下图所示:

该图像展示了典型的深度自动编码器的结构。自动编码器网络结构的目标是在输出层创建输入的表示,使得两者尽可能接近(相似)。 但是,自动编码器的实际使用是用来得到具有最低数据丢失量的输入数据的压缩版本。 在机器学习项目中的作用类似于主成分分析( Principle Component Analysis,PCA),PCA的作用是在有大量属性的数据集上训练模型时找到最佳和最相关属性。

自动编码器以类似的方式工作。 其编码器部分将输入数据压缩,确保重要数据不会丢失,但数据的整体大小会显著减小。 这个概念称为降维( Dimensionality Reduction)。
降维的缺点是,压缩数据是一个黑盒子,即我们无法确定其压缩后数据中的数据结构的具体含义。 比如,假设我们有一个包含5个参数的数据集,我们在这些数据上训练一个自动编码器。 编码器不会为了获得更好的表示而省略某些参数,它会将参数融合在一起(压缩后的变量时综合变量)以创建压缩版本,使得参数更少(比如从5个压缩到3个)。
自动编码器有两个部分,即编码器和解码器。

编码器压缩输入数据,而解码器则基于压缩表示的数据反过来恢复数据的未压缩版本,以尽可能准确地创建输入的重建。

我们将使用Tensorflow的layers API创建自动编码器神经网络,并在mnist数据集上对其进行测试。

首先,我们导入相关的Python库,并读入mnist数据集。 如果数据集存在于本地计算机上,那么它将自动读取,否则将通过运行以下命令自动下载。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.layers import fully_connectedmnist=input_data.read_data_sets("/MNIST_data/",one_hot=True)

接下来,我们为方便起见创建一些常量,并事先声明我们的激活函数。 mnist数据集中的图像大小为28×28像素,即784像素,我们将其压缩为196像素。 当然你也可以更进一步缩小像素大小。 但是,压缩太多可能会导致自动编码器丢失信息。

num_inputs=784    #28x28 pixels
num_hid1=392
num_hid2=196
num_hid3=num_hid1
num_output=num_inputs
lr=0.01
actf=tf.nn.relu

现在,我们为每一层的weights 和 biases创建变量。 然后,我们使用先前声明的激活函数创建layer。

X=tf.placeholder(tf.float32,shape=[None,num_inputs])
initializer=tf.variance_scaling_initializer()w1=tf.Variable(initializer([num_inputs,num_hid1]),dtype=tf.float32)
w2=tf.Variable(initializer([num_hid1,num_hid2]),dtype=tf.float32)
w3=tf.Variable(initializer([num_hid2,num_hid3]),dtype=tf.float32)
w4=tf.Variable(initializer([num_hid3,num_output]),dtype=tf.float32)b1=tf.Variable(tf.zeros(num_hid1))
b2=tf.Variable(tf.zeros(num_hid2))
b3=tf.Variable(tf.zeros(num_hid3))
b4=tf.Variable(tf.zeros(num_output))hid_layer1=actf(tf.matmul(X,w1)+b1)
hid_layer2=actf(tf.matmul(hid_layer1,w2)+b2)
hid_layer3=actf(tf.matmul(hid_layer2,w3)+b3)
output_layer=actf(tf.matmul(hid_layer3,w4)+b4)

在一般情况下,TensorFlow的工程通常不使用tf.variance_scaling_initializer()。 但是,我们在这里使用它是因为我们正在处理不断变化大小的输入。 因此,placeholder张量形状(placeholder用于输入批处理)根据输入大小的形状调整自身,从而防止我们遇到任何维度错误。 通过简单地将先前带有相关weights 和 biases 的隐藏层作为输入输送到激活函数(ReLu)中来创建后面的隐藏层。

我们将使用RMSE损失函数用于此神经网络并将其传递给Adam优化器。你也可以替换这些来获得更多结果。

loss=tf.reduce_mean(tf.square(output_layer-X))
optimizer=tf.train.AdamOptimizer(lr)
train=optimizer.minimize(loss)
init=tf.global_variables_initializer()

现在,我们定义epochs和batch size并运行session。 我们使用mnist类的mnist.train.next_batch()来获取每个新的batch。 此外,我们将在每个epoch之后输出训练loss以监控其训练。

num_epoch=5
batch_size=150
num_test_images=10
with tf.Session() as sess:sess.run(init)    for epoch in range(num_epoch):num_batches=mnist.train.num_examples//batch_size        for iteration in range(num_batches):X_batch,y_batch=mnist.train.next_batch(batch_size)sess.run(train,feed_dict={X:X_batch})train_loss=loss.eval(feed_dict={X:X_batch})print("epoch {} loss {}".format(epoch,train_loss))

最后,我们将编写一个小的绘图函数来绘制原始图像和重建图,以查看我们训练得到的模型的工作情况。

results=output_layer.eval(feed_dict={X:mnist.test.images[:num_test_images]})    #Comparing original images with reconstructionsf,a=plt.subplots(2,10,figsize=(20,4))    for i in range(num_test_images):a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))a[1][i].imshow(np.reshape(results[i],(28,28)))

在这里,我们可以看到重建并不完美,但非常接近原始图像。 注意上图中,2的重建看起来像是3,这是由于压缩时信息丢失造成的。

我们可以通过超参数调整来改进自动编码器模型,并且还可以通过在GPU上运行训练来提高速度。

获取完整代码,请访问:
https://github.com/Tathagatd96/Deep-Autoencoder-using-Tensorflow

原文链接:
https://towardsdatascience.com/deep-autoencoders-using-tensorflow-c68f075fd1a3

点击阅读原文可以在www.52cv.net查看本文。

欢迎转发,让更多人看到。

更多精彩推荐:

Kaggle新上比赛:空客公司卫星图像船体分割

重磅推荐!日立开源语义分割数据集标注工具Semantic Segmentation Editor

计算机视觉研究入门全指南

开源OCR文字识别软件Calamari

TensorFlow深度自动编码器入门实践相关推荐

  1. Tensorflow深度学习入门(1)——Tensorflow环境搭建

    Tensorflow深度学习入门--环境搭建 自测以下的环境搭建方式是行得通的,目前我用的就是这些 1.        下载安装Ubuntu 14.04 虚拟机 https://github.com/ ...

  2. [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98%+

    [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98.8%+ 我们在博文,使用CNN做Kaggle比赛手写数字识别准确率99%+,在此基础之 ...

  3. [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

    [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...

  4. TensorFlow深度学习应用实践

    http://product.dangdang.com/25207334.html 内 容 简 介 本书总的指导思想是在掌握深度学习的基本知识和特性的基础上,培养使用TensorFlow进行实际编程以 ...

  5. 面向隐私AI的TensorFlow深度定制化实践

    作者 | Rosetta团队 出品 | AI科技大本营(ID:rgznai100) 之前我们整体上介绍了基于深度学习框架开发隐私 AI 框架中的工程挑战和可行解决方案.在这一篇文章中,我们进一步结合 ...

  6. 人工神经网络理论、设计及应用_TensorFlow深度学习应用实践:教你如何掌握深度学习模型及应用...

    前言 通过TensorFlow图像处理,全面掌握深度学习模型及应用. 全面深入讲解反馈神经网络和卷积神经网络理论体系. 结合深度学习实际案例的实现,掌握TensorFlow程序设计方法和技巧. 着重深 ...

  7. tensorflow 语义slam_研究《视觉SLAM十四讲从理论到实践第2版》PDF代码+《OpenCV+TensorFlow深度学习与计算机视觉实战》PDF代码笔记...

    我们知道随着人工神经网络和深度学习的发展,通过模拟视觉所构建的卷积神经网络模型在图像识别和分类上取得了非常好的效果,借助于深度学习技术的发展,使用人工智能去处理常规劳动,理解语音语义,帮助医学诊断和支 ...

  8. 深度学习算法实践(基于Theano和TensorFlow)

    深度学习算法实践(基于Theano和TensorFlow) 闫涛 周琦 著 ISBN:9787121337932 包装:平装 开本:16开 用纸:胶版纸 正文语种:中文 出版社:电子工业出版社 出版时 ...

  9. envi 文件 生成mat_JVM 内存分析工具 MAT 的深度讲解与实践——入门篇

    1. MAT 工具简介 MAT(全名:Memory Analyzer Tool),是一款快速便捷且功能强大丰富的 JVM 堆内存离线分析工具.其通过展现 JVM 异常时所记录的运行时堆转储快照(Hea ...

最新文章

  1. Android 之小技巧
  2. python 文件不存在时才能写入,读写模式xt
  3. dw指向html的根路径,dreamweaver中绝对、文档相对和站点根目录相对路径区分
  4. 查看表字段信息 sql,mysql,oracle
  5. Linux之atime,ctime,mtime的区别
  6. docker-machine 下载iso慢的问题
  7. mac git 自动补全
  8. c语言字符雨动画代码,c语言实现数字雨
  9. 《念奴娇·赤壁怀古》古词鉴赏
  10. 华为鸿蒙爆出惊天骗局,华为鸿蒙系统爆出惊天骗局!
  11. 悟空互动:如何让百度更快的收录网站,试试快速收录提交入口!
  12. Inheritance with Jackson
  13. JS面试题汇总(六)
  14. AI医学影像千亿长坡,“医疗AI第一股”鹰瞳科技为何能滚起雪球?
  15. 了解CSS属性font-kerning,font-smoothing,font-variant
  16. 图片验证码不显示解决方案
  17. N71005-第五周
  18. (一)、音视频相关名词
  19. 渡一教育公开课web前端开发JavaScript精英课学习笔记(七)对象和包装类
  20. 'NULL' undeclared错误

热门文章

  1. css 下边框 90%,css怎么设置下边框
  2. ios xcode文件前缀_IOS Xcode开发中 文件名的后缀名m,mm,cpp,h区别
  3. c语言练习书,谁有C语言入门的练习题?
  4. encoding python3_关于 Python3 的编码
  5. html隐藏并失效,如果元素开始隐藏,css过渡将不起作用
  6. python画苹果标志图片_Mac生成APP图标和启动图的脚本
  7. java页面间面向对象的方法面试题_JAVA面向对象面试题带答案(墙裂推荐)
  8. linux配置rsync服务器
  9. python两个一维数组合并_python:16.合并两个排序的链表
  10. php网页审批权限设置,Linux下ThinkPHP网站目录权限设置