主要内容
1.自编码器的TensorFlow实现代码(详细代码注释)
2.该实现中的函数总结

平台:
1.windows 10 64位
2.Anaconda3-4.2.0-Windows-x86_64.exe (当时TF还不支持python3.6,又懒得在高版本的anaconda下配置多个Python环境,于是装了一个3-4.2.0(默认装python3.5),建议装anaconda3的最新版本,TF1.2.0版本已经支持python3.6!)
3.TensorFlow1.1.0

老样子,先贴代码:

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 20 12:59:16 2017@author: ASUS
"""
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data# 定义Xavier初始化函数  tf.random_uniform产生均匀分布
def xavier_init(fan_in, fan_out, constant = 1):low = -constant * np.sqrt(6.0 / (fan_in + fan_out))high = constant * np.sqrt(6.0 / (fan_in + fan_out))return tf.random_uniform((fan_in, fan_out), minval = low, maxval = high, dtype = tf.float32)
# 定义自编码器类
class AdditiveGaussianNoiseAutoEncoder(object):def __init__(self, n_input, n_hidden, transfer_function = tf.nn.softplus,optimizer = tf.train.AdamOptimizer(), scale = 0.1):self.n_input = n_inputself.n_hidden = n_hiddenself.transfer = transfer_functionself.scale = tf.placeholder(tf.float32)self.training_scale = scale network_weights = self._initialize_weights()self.weights = network_weights# 定义网络结构 # x 为输入,因此要用placeholder进行“占位符”操作self.x = tf.placeholder(tf.float32, [None, self.n_input])# hidden 是隐含层 ,此自编码器只含一个隐层# self.x + scale * tf.random_normal((n_input,)) 是加噪声,scale是噪声系数#  hidden = f(w*x1+b), f 是激活函数 #  weights w1和b1 分别表示第一层的权值、偏置self.hidden = self.transfer(tf.add(tf.matmul(self.x + scale * tf.random_normal((n_input,)),self.weights['w1']), self.weights['b1']))# 定义重构self.reconstruction = tf.add(tf.matmul(self.hidden,self.weights['w2']),self.weights['b2'])# 定义 平方误差为cost # tf.pow()是计算幂 2.0则表示计算平方, tf.subtract是对应元素相减self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0))# 定义优化器 对损失 self.cost进行优化self.optimizer = optimizer.minimize(self.cost)# 全局参数初始化init = tf.global_variables_initializer()# 创建会话 sessself.sess = tf.Session()self.sess.run(init)# 定义权值初始化函数,AE的权值存放在一个字典里# w1采用xavier初始化,其余设置为全0 def _initialize_weights(self):all_weights = dict()all_weights['w1'] = tf.Variable(xavier_init(self.n_input, self.n_hidden))all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden],dtype = tf.float32))all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden,self.n_input], dtype = tf.float32))all_weights['b2'] = tf.Variable(tf.zeros([self.n_input],dtype = tf.float32))return all_weights # 定义 执行一步训练的函数def partial_fit(self, X):cost, opt = self.sess.run((self.cost, self.optimizer),feed_dict = {self.x: X, self.scale: self.training_scale})return cost# 定义计算总的cost,在评测AE时用到# 只让Session执行一个计算图节点 self.costdef calc_total_cost(self, X):return self.sess.run(self.cost, feed_dict = {self.x : X, self.scale: self.training_scale})# 定义transform函数,# 作用是 返回AE隐含层的输出结果# 目的是 提供一个接口来获取抽象后的特征def transform(self, X):return self.sess.run(self.hidden, feed_dic = {self.scale: self.training_scale})# 定义generate函数,将隐含层的输出作为输入,# 通过重建层(reconstruction)来复原原始数据def generate(self, hidden = None):if hidden is None:hidden = np.random.normal(size = self.weights['b1'])return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})# 定义重构函数# 包括抽象特征的提取和 通过抽象特征来复原原始数据def reconstruct(self, X):return self.sess.run(self.reconstruction, fedd_dict = {self.x: X, self.scale: self.traning_scale})# getWeights获取 隐含层权重def getWeights(self):return self.sess.run(self.weights['w1'])def getBiases(self):return self.sess.run(self.weights['b1'])mnist = input_data.read_data_sets('MNIST_data', one_hot = True)# 定义函数 对 mnist数据进行标准化 (减均值,除以标准差)
# 利用skleran里的 StandardScaler类
def standard_scale(X_train, X_test):preprocessor = prep.StandardScaler().fit(X_train)X_train = preprocessor.transform(X_train)X_test = preprocessor.transform(X_test)return X_train, X_test# 定义函数 获取随机block数据
def get_random_block_from_data(data, batch_size):start_index = np.random.randint(0, len(data) - batch_size )return data[start_index:(start_index + batch_size)]# 数据标准化
X_train, X_test = standard_scale(mnist.train.images, mnist.test.images)# 设置基本参数
n_samples = int(mnist.train.num_examples)
training_epochs = 20
batch_size = 128
display_step = 1# 创建AGN(Additive Gaussian Noise,加性高斯噪声)自编码器实例
autoencoder = AdditiveGaussianNoiseAutoEncoder(n_input = 784,n_hidden = 200,transfer_function = tf.nn.softplus,optimizer = tf.train.AdamOptimizer(learning_rate = 0.001),scale = 0.01)# 迭代训练
for epoch in range(training_epochs):avg_cost = 0.total_batch = int(n_samples / batch_size)for i in range(total_batch):batch_xs = get_random_block_from_data(X_train, batch_size)cost = autoencoder.partial_fit(batch_xs)avg_cost += cost / n_samples * batch_sizeif epoch % display_step ==0 :print('Epoch: ', '%04d' % (epoch+1), 'Cost = ','{:.9f}'.format(avg_cost))print('Total cost: ' + str(autoencoder.calc_total_cost(X_test)))
此代码主要实现了一个去噪自编码器,噪声采用的AGN(Additive Gaussian Noise,加性高斯噪声)。

自编码器的介绍可以查看如下链接:
http://ufldl.stanford.edu/wiki/index.php/Autoencoders_and_Sparsity
简单讲,自编码器就是对原始数据利用神经网络进行编码,这里的码其实就是隐含层的输出,通过BP算法对神经网络的权值进行修改,最终通过这些权值对原始数据做运算(运算则是编码过程),得到输出(输出则是编好的码) ,而有编码,就有解码。解码呢,就是将编好的码还原成原始数据,这里的还原方法,同样是采用神经网络的某一层。

其中用到Xavier初始化方法是2010年 Xavier提出的,有兴趣可拜读:
《Understanding the Difficult of Training Deep Feedforward Neural Networks》

其中用到的函数总结(续上篇):
1. sess = tf.InteractiveSession() 将sess注册为默认的session
2. tf.placeholder() , Placeholder是输入数据的地方,也称为占位符,通俗的理解就是给输入数据(此例中的图片x)和真实标签(y_)提供一个入口,或者是存放地。(个人理解,可能不太正确,后期对TF有深入认识的话再回来改~~)
3. tf.Variable() Variable是用来存储模型参数,与存储数据的tensor不同,tensor一旦使用掉就消失
4. tf.matmul() 矩阵相乘函数
5. tf.reduce_mean 和tf.reduce_sum 是缩减维度的计算均值,以及缩减维度的求和
6. tf.argmax() 是寻找tensor中值最大的元素的序号 ,此例中用来判断类别
7. tf.cast() 用于数据类型转换
————————————–我是分割线(一)———————————–

  1. tf.random_uniform 生成均匀分布的随机数
  2. tf.train.AdamOptimizer() 创建优化器,优化方法为Adam(adaptive moment estimation,Adam优化方法根据损失函数对每个参数的梯度的一阶矩估计和二阶矩估计动态调整针对于每个参数的学习速率)
  3. tf.placeholder “占位符”,只要是对网络的输入,都需要用这个函数这个进行“初始化”
  4. tf.random_normal 生成正态分布
  5. tf.add 和 tf.matmul 数据的相加 、相乘
  6. tf.reduce_sum 缩减维度的求和
  7. tf.pow 求幂函数
  8. tf.subtract 数据的相减
  9. tf.global_variables_initializer 定义全局参数初始化
  10. tf.Session 创建会话.
  11. tf.Variable 创建变量,是用来存储模型参数的变量。是有别于模型的输入数据的
  12. tf.train.AdamOptimizer (learning_rate = 0.001) 采用Adam进行优化,学习率为 0.001

【TensorFlow-windows】(二) 实现一个去噪自编码器相关推荐

  1. 【theano-windows】学习笔记十三——去噪自编码器

    前言 上一章节学习了卷积的写法,主要注意的是其实现在theano.tensor.nnet和theano.sandbox.cuda.dnn中都有对应函数实现, 这一节就进入到无监督或者称为半监督的网络构 ...

  2. ICLR 2020:从去噪自编码器到生成模型

    作者丨苏剑林 单位丨追一科技 研究方向丨NLP,神经网络 个人主页丨kexue.fm 在我看来,几大顶会之中,ICLR 的论文通常是最有意思的,因为它们的选题和风格基本上都比较轻松活泼.天马行空,让人 ...

  3. tensorflow 卷积、反卷积形式的去噪自编码器

    tensorflow 卷积.反卷积形式的去噪自编码器 对于去噪自编码器,网上好多都是利用全连接神经网络进行构建,我自己写了一个卷积.反卷积形式的去噪自编码器,其中的参数调优如果有兴趣的话,可以自行修改 ...

  4. TensorFlow实现去噪自编码器(Denoising Autoencoder)

    TensorFlow实现去噪自编码器(Denoising Autoencoder) 去噪自编码器(Denoising Autoencoder, DAE) DAE模型架构 DAE实现 数据预处理 模型构 ...

  5. tensorflow学习笔记二——建立一个简单的神经网络拟合二次函数

    tensorflow学习笔记二--建立一个简单的神经网络 2016-09-23 16:04 2973人阅读 评论(2) 收藏 举报  分类: tensorflow(4)  目录(?)[+] 本笔记目的 ...

  6. Scikit-Learn TensorFlow机器学习实用指南(二):一个完整的机器学习项目【上】

    机器学习实用指南(二):一个完整的机器学习项目[上] 作者:LeonG 本文参考自:<Hands-On Machine Learning with Scikit-Learn & Tens ...

  7. [自编码器:理论+代码]:自编码器、栈式自编码器、欠完备自编码器、稀疏自编码器、去噪自编码器、卷积自编码器

    写在前面 因为时间原因本文有些图片自己没有画,来自网络的图片我尽量注出原链接,但是有的链接已经记不得了,如果有使用到您的图片,请联系我,必注释. 自编码器及其变形很多,本篇博客目前主要基于普通自编码器 ...

  8. 【AI实战】快速掌握TensorFlow(二):计算图、会话

    2019独角兽企业重金招聘Python工程师标准>>> 在前面的文章中,我们已经完成了AI基础环境的搭建(见文章:Ubuntu + Anaconda + TensorFlow + G ...

  9. 【theano-windows】学习笔记十四——堆叠去噪自编码器

    前言 前面已经学习了softmax,多层感知器,CNN,AE,dAE,接下来可以仿照多层感知器的方法去堆叠自编码器 国际惯例,参考文献: Stacked Denoising Autoencoders ...

最新文章

  1. php跨域访问java,案例:PHP Ajax 跨域最佳解决方案
  2. ubuntu 16.04 更新后搜狗输入法无法输入中文的问题
  3. Winpcap 中sockaddr_storage问题收藏
  4. 【NLP】巧借“他山之石”,生成信息量大、可读性强且稳定的摘要
  5. 【LeetCode笔记】543. 二叉树的直径(Java、dfs、二叉树)
  6. Python 根据文件绝对路径删除文件
  7. IPv6網絡開發范例
  8. VMware 中的操作系统切换模式后总是连接不上互联网可能的问题之一
  9. paip.提升用户体验---c++ qt自定义窗体(1)---标题栏的绘制
  10. Kali Linux全网最细安装教程
  11. 自动语音呼叫中心系统
  12. pioneer软件VoLTE测试步骤,世纪鼎利Pioneer连接移动平台进行VoLTE测试操作说明综述...
  13. Power BI中字体使用微软雅黑
  14. linux无字幕打开文件,解决SMPLAYER无画面/无字幕
  15. ANSNP中线安防 安科瑞 时丽花
  16. 怎么把所有图片变成一样的大小
  17. cdn cfdn是什么_P2P+CDN=PCDN
  18. 人工智能学习培训哪家好
  19. 真香!用Python检测和识别车牌(附代码)
  20. 【Bio】基础生物学 - 基因 gene

热门文章

  1. 对软件工程的问题及个别软件的分析
  2. 用Barcode生成条形码图片
  3. CJ20N被删除物料的历史记录
  4. java swing container_Java Swing GUI学习(一)
  5. 操作数据库pymysql
  6. ISE创建Microblaze软核(三)
  7. Hibernate读书笔记---继承映射
  8. 是不是Cookie让禁用了,Session就一定不能用了呢
  9. WMPLib.WindowsMediaPlayer 的用法
  10. 没有调用save或update方法,却有sql语句执行