自编码网络能够自学习样本特征的网络,属于无监督学习模型的网络,可以从无标注的数据中学习特征,它可以给出比原始数据更好的特征描述,具有较强的特征学习能力。

主要的网络结构就是高维特征样本---》编码成---》低维特征---》解码回---》高维特征,下面以MNIST数据集为示例进行演示:

  1. import tensorflow as tf
  2. #导入数据集合
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. mnist = input_data.read_data_sets('/data/',one_hot=True)
  5. #整体流程,原始图片像素28*28-784
  6. #784-》256-》128-》128-》256-》784
  7. learning_rate = 0.01
  8. n_hidden_1 = 256     #第一层256个结点
  9. n_hidden_2 = 128     #第二层128个结点
  10. n_input = 784
  11. x = tf.placeholder('float',[None,n_input])
  12. y = x
  13. weights = {
  14. 'encoder_h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),
  15. 'encoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
  16. 'decoder_h1':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
  17. 'decoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_input])),
  18. }
  19. biases = {
  20. 'encoder_b1':tf.Variable(tf.zeros([n_hidden_1])),
  21. 'encoder_b2':tf.Variable(tf.zeros([n_hidden_2])),
  22. 'decoder_b1':tf.Variable(tf.zeros([n_hidden_1])),
  23. 'decoder_b2':tf.Variable(tf.zeros([n_input])),
  24. }
  25. def encoder(x):
  26. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_h1']),biases['encoder_b1']))
  27. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['encoder_h2']),biases['encoder_b2']))
  28. return layer_2
  29. def decoder(x):
  30. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_h1']),biases['decoder_b1']))
  31. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,weights['decoder_h2']),biases['decoder_b2']))
  32. return layer_2
  33. pred = decoder(encoder(x))
  34. cost = tf.reduce_mean(tf.pow(y-pred,2))
  35. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
  36. training_epochs = 20  #共迭代20次
  37. batch_size = 256      #每次取256个样本
  38. display_step = 5      #迭代5次输出一次信息
  39. #启动会话
  40. with tf.Session() as sess:
  41. sess.run(tf.global_variables_initializer())
  42. total_batch = int(mnist.train.num_examples/batch_size)
  43. #开始训练
  44. for epoch in range(training_epochs):
  45. for i in range(total_batch):
  46. batch_xs,batch_ys = mnist.train.next_batch(batch_size)#取数据
  47. _,c = sess.run([optimizer,cost],feed_dict={x:batch_xs})#训练模型
  48. if epoch % display_step == 0:#输出日志信息
  49. print("Epoch:",'%4d' % (epoch+1),'cost=',"{:.9f}".format(c))
  50. print('Training Finished!')
  51. correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
  52. accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))
  53. print('Accuracy:',1-accuracy.eval({x:mnist.test.images,y:mnist.test.images}))

Tensorflow实现MNIST数据自编码(1)相关推荐

  1. Tensorflow实现MNIST数据自编码(3)

    前面自编码(1)和自编码(2)是针对高维数据维数进行降低维数角度改进模型,但是还需要让这些特征具有抗干扰能力,输入的特征数据受到干扰时,生成特征依然不会怎么变化,使自动编码器具有更好的泛化能力 #en ...

  2. Tensorflow实现MNIST数据自编码(2)

    对自编码(1)进行改进,(1)中实现的网络是用2个编码层,2个解码层,现在对它进行添加编码层和解码层分别为4层 原始数据784-->256-->64-->16-->2 #enc ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  4. 使用Tensorflow操作MNIST数据

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...

  5. TensorFlow学习笔记(十四)TensorFLow 用mnist数据做classification

    之前的例子,给的都是tf来做regression,也就是回归问题,现在用tf来做一个classification的处理,也就是分类问题. 这里用的数据集是mnist数据. 代码: "&quo ...

  6. Tensorflow实战之下载MNIST数据,自动分成train, validation和test三个数据集

    TensorFlow 实战Google深度学习框架 第2版 ,郑泽宇之P96.下载MNIST数据,自动分成train, validation和test三个数据集,源码如下: #!/usr/bin/en ...

  7. 【TensorFlow】MNIST手写数字识别

    MNIST   MNIST是一个非常简单的机器视觉数据集.如图,它由几万字28像素×28像素的手写数字组成,这些图片只包含灰度值信息.我们的任务是对这些手写数字的图片进行分类,转成0~9一共10类. ...

  8. 机器学习入门案例简单理解——Tensorflow之MNIST解析

    深度学习简单介绍 首先要简单区别几个概念:人工智能,机器学习,深度学习,神经网络.这几个词应该是出现的最为频繁的,但是他们有什么区别呢? 人工智能:人类通过直觉可以解决的问题,如:自然语言理解,图像识 ...

  9. TF之DCGAN:基于TF利用DCGAN测试MNIST数据集并进行生成过程全记录

    TF之DCGAN:基于TF利用DCGAN测试MNIST数据集并进行生成 目录 测试结果 测试过程全记录 测试结果 train_00_0099 train_00_0799 train_00_0899 t ...

最新文章

  1. 游戏基础体验研究:玩家想要什么样的美术品质?
  2. C#远程调用技术WebService葵花宝典
  3. linux 线程管理、同步机制等
  4. SAS在金融中的应用七
  5. 俄罗斯电力公司T Plus完成25MW光伏电站
  6. python遇到天猫反爬虫_用Python爬取天猫评价-我的新游戏
  7. 高等数学第七版下册 同济大学数学系 编 课后答案 习题解析
  8. x2分布临界值表(卡方分布)
  9. 2016年APP推广应该怎么做?
  10. WPS VBA遇到的问题小记
  11. eoj 3279 爱狗狗的两个dalao(dfs)
  12. js 数组按奇偶拆分_js数组拆分问题
  13. 笔记本不显示计算机储存盘,电脑开机黑屏只有鼠标箭头的解决办法 电脑硬盘无法分区怎么办...
  14. Ubuntu 14.04.2 系统无线网络不稳定问题
  15. MySQL Deamon少量解读
  16. Linux-vim设置
  17. 深圳市数字经济指数发布:数字经济蓬勃发展,数字用户深度渗透
  18. SAP AS ABAP 7.52 SP04, Developer Edition 免费下载
  19. 【零基础 快速学Java】韩顺平 p104-147 流程控制:顺序、分支、循环、跳转 控制语句 (if、for、while、dowhile、break、continue、return)
  20. 用R语言画一朵玫瑰花

热门文章

  1. rabbitmq 集群数据存储与单点故障
  2. scrapy proxy and user_agent
  3. mysql proxy 读写分离 1
  4. 10个不太为人所知的,但实用的PHP函数
  5. 使用 IntraWeb (31) - IntraWeb 的 Xml 操作使用的是 NativeXml
  6. 如何提高代码质量:代码复查
  7. Ubuntu用命令行发邮件mutt,报警发短信通知
  8. 使用jquery+json实现ajax的方法
  9. GdiPlus[28]: IGPPen: 建立复合画笔
  10. live messenger与稀疏文件—Sparse File Bit