文章目录

  • 什么是自编码器
  • 损失函数的设计

什么是自编码器

自编码器就是将原始数据进行编码,进行降低维度,发现数据之间的规律的过程。
数据降维比如mnist的图片为28*28像素,将图片向量化之后的得到一个长度为784的向量。在网络训练的过程中,网络只用到该向量之中少量元素,其中的大部分元素对于网络来说是没有用的,自编码器通过无监督学习来提取有用的信息,对于手写数字图片来说可能是颜色为黑色的像素点,将图片中的很大一部分白色像素舍弃,只提取对网络有用的信息,到达降低数据维度的目的。

作用

  1. 先对无标注的数据进行自编码器的训练,然后将有标注的数据第一步输入自编码器,自编码器的输出输入到神经网络之中,达到提升网络输出精度的目的。

  2. 用于神经网络权重的初始化。

基本形式如下,其中f(x)为编码函数,g(x)为解码函数
x → f ( x ) h → g ( x ) x ′ x\overset{f(x)}{\rightarrow}h\overset{g(x)}{\rightarrow}x' x→f(x)h→g(x)x′
训练的过程,我们采用下面形式的约束
x ≈ x ′ x\approx x' x≈x′
即设计一个损失函数,让编码器的输入和输出尽可能相似
其中
x ′ = g ( f ( x ) ) x'=g(f(x)) x′=g(f(x))
但自编码器学习的目的不是学习上面这个恒等函数,这可能导致过拟合,我们要自编码器学习的是将稀疏数据到稠密的一个映射,而不是输入和输出完全相等,这样提取的特征才能为后面的网络所用。

损失函数的设计

以mnist数据集为例,设输入x的数据是长度为784的一维向量,编码后h是长度为196的一维向量,解码后x’是长度为784的一维向量。

均方误差(MSE) 用来衡量输入数据和输出数据的相似度。表示如下
L ( x , g ( f ( x ) ) = E x ∼ p d a t a ∥ x − g ( f ( x ) ) ∥ 2 L(x,g(f(x))=E_{x\sim~p_{data}}\left\|x-g(f(x))\right\|^2 L(x,g(f(x))=Ex∼ pdata​​∥x−g(f(x))∥2
在学习的过程中,均方误差可能变得很小,这样会导过拟合,而我们期望的是一个泛化能力很强的编码器,所以我们加如L1正则化相对熵(KLD)来抑制过拟合。
L1正则化仅仅作用于编码,因为我们关心的是编码的过程,解码器只是方便显示编码之后的结果。
Ω ( h ) = λ ∑ i ∣ h i ∣ \Omega(h)=\lambda\sum\nolimits_{i}\left|h_i\right| Ω(h)=λ∑i​∣hi​∣
h i h_i hi​是第 i i i个神经元的激励值。

相对熵KL Divergence,KLD)。首先定义隐层神经元j的平均活跃度$ {\hat{\rho}}_j$
ρ ^ j = 1 N ∑ i = 1 N h j ( i ) \hat{\rho}_j=\frac{1}{N}\sum\nolimits_{i=1}^{N}{h}^{(i)}_j ρ^​j​=N1​∑i=1N​hj(i)​
对 i i i求和表示对训练集(N个样本)的所有输入取均值,j代表第几个隐层神经元,激励值通常在 0 ∼ 1 0\sim1 0∼1之间。对稀疏性的约束就是令神经元的平均活跃度接近稀疏性系数 ρ : ρ ^ j ≈ ρ \rho:\hat{\rho}_j\approx\rho ρ:ρ^​j​≈ρ这个系数通常取接近0的值,从而约束隐层神经元的活跃程度,可以把 ρ \rho ρ理解成某个神经元被激活的概率。

为了实现这个约束,需要在损失函数中添加一个损失项
∑ j = 1 M K L ( ρ ∥ ρ ^ j ) = ∑ j = 1 M [ ρ ⋅ log ⁡ p ρ ^ j + ( 1 − ρ ) ⋅ log ⁡ 1 − p 1 − ρ ^ j ] \sum\nolimits_{j=1}^{M}KL(\rho\parallel\hat{\rho}_j)=\sum\nolimits_{j=1}^{M}[\rho\cdot\log\frac{p}{\hat{\rho}_j}+(1-\rho)\cdot\log\frac{1-p}{1-\hat{\rho}_j}] ∑j=1M​KL(ρ∥ρ^​j​)=∑j=1M​[ρ⋅logρ^​j​p​+(1−ρ)⋅log1−ρ^​j​1−p​]
KLD是一种衡量两个分布之间差异的方法,式中的 ρ \rho ρ和 ρ ^ j \hat{\rho}_j ρ^​j​分别表示期望和实际的隐层神经元的输出两点分布(两点分别代表饱和和睡眠)的均值和期望。

两种损失函数可定义为下面版本:

L1:
L ( x , g ( f ( x ) ) = E x ∼ p d a t a ∥ x − g ( f ( x ) ) ∥ 2 + λ ∑ i ∣ h i ∣ L(x,g(f(x))=E_{x\sim~p_{data}}\left\|x-g(f(x))\right\|^2+\lambda\sum\nolimits_{i}\left|h_i\right| L(x,g(f(x))=Ex∼ pdata​​∥x−g(f(x))∥2+λ∑i​∣hi​∣
KLD:
L ( x , g ( f ( x ) ) = E x ∼ p d a t a ∥ x − g ( f ( x ) ) ∥ 2 + β ∑ j = 1 M [ ρ ⋅ log ⁡ p ρ ^ j + ( 1 − ρ ) ⋅ log ⁡ 1 − p 1 − ρ ^ j ] L(x,g(f(x))=E_{x\sim~p_{data}}\left\|x-g(f(x))\right\|^2+\beta\sum_{j=1}^{M}[\rho\cdot\log\frac{p}{\hat{\rho}_j}+(1-\rho)\cdot\log\frac{1-p}{1-\hat{\rho}_j}] L(x,g(f(x))=Ex∼ pdata​​∥x−g(f(x))∥2+βj=1∑M​[ρ⋅logρ^​j​p​+(1−ρ)⋅log1−ρ^​j​1−p​]

两种损失函数该如何选择

隐层激活函数类型 重构层激活函数类型 MSE L1 KLD
Sigmoid Sigmoid True False True
Relu Softplus True True False

隐层使用Sigmoid时,隐层输出值在(0,1)之间,可用来计算KLD。隐层使用Relu时,隐层的输出值在 [ 0 , + ∞ ) [0,+\infty) [0,+∞),不能使用KLD。

实现代码如下:

# coding:utf-8import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
import numpy as np
import matplotlib.pylab as pltlearning_rate = 0.0001
lambda_l2_w = 0.01
n_epochs = 100
batch_size = 128
print_interval = 200hidden_size = 196
input_size = 784
image_width = 28
model = 'sigmoid'X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(path='./data/')x = tf.placeholder(tf.float32,shape=[None,784],name='x')print('Build Network')
if model=='relu':net = InputLayer(x, name='input')net = DenseLayer(net,n_units=hidden_size,act=tf.nn.relu, name='relu1')encode_img = net.outputsrecon_layer1 = DenseLayer(net,n_units=input_size,act=tf.nn.softplus,name='recon_layer1')if model=='sigmoid':net = InputLayer(x, name='input')net = DenseLayer(net,n_units=hidden_size,act=tf.nn.sigmoid, name='sigmoid1')encode_img = net.outputsrecon_layer1 = DenseLayer(net,n_units=input_size,act=tf.nn.sigmoid,name='recon_layer1')y = recon_layer1.outputs
train_params = recon_layer1.all_params[-4:]mse = tf.reduce_sum(tf.squared_difference(y,x),1)
mse = tf.reduce_mean(mse)# w1 和 w2 采用L2正则化
L2_w = tf.contrib.layers.l2_regularizer(lambda_l2_w)(train_params[0])+\tf.contrib.layers.l2_regularizer(lambda_l2_w)(train_params[2])# 稀疏性约束
activation_out = recon_layer1.all_layers[-2]
L1_a = 0.001 * tf.reduce_mean(activation_out)# 相对熵(KLD)
beta = 0.5
rho = 0.15
p_hat = tf.reduce_mean(activation_out,0)
KLD = beta * tf.reduce_sum(rho * tf.log(tf.divide(rho,p_hat)) +(1-rho) * tf.log((1-rho)/(tf.subtract(float(1),p_hat))))# 联合损失函数
if model=='sigmoid':cost = mse + L2_w + KLD
if model=='relu':cost = mse + L2_w + L1_a# 定义优化器
train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost)
saver = tf.train.Saver()# 模型训练
total_batch = X_train.shape[0] // batch_sizewith tf.Session() as sess:sess.run(tf.global_variables_initializer())for epoch in range(n_epochs):avg_cost = 0for i in range(total_batch):batch_x,batch_y =X_train[i*batch_size:(i+1)*batch_size],y_train[i*batch_size:(i+1)*batch_size]batch_x = np.array(batch_x).astype(np.float32)batch_cost, _ = sess.run([cost, train_op], feed_dict={x: batch_x})if not i % print_interval:print('Minibatch: %03d | Cost:  %.3f' % (i + 1, batch_cost))print('Epoch:   %03d | AvgCost:  %.3f' % (epoch + 1, avg_cost / i + 1))saver.save(sess,save_path='./model/3-101.ckpt')# 恢复参数
n_images=15
fig,axes=plt.subplots(nrows=2,ncols=n_images,sharex=True,sharey=True,figsize=(20,2.5))test_images = X_test[:n_images]with tf.Session() as sess:# 加载训练好的模型saver.restore(sess,save_path='./model/3-101.ckpt')# 获取重构参数decoded = sess.run(recon_layer1.outputs,feed_dict={x:test_images})# 恢复编码器的权重参数if model=='relu':weights = sess.run(tl.layers.get_variables_with_name('relu1/W:0',False,True))if model=='sigmoid':weights = sess.run(tl.layers.get_variables_with_name('sigmoid1/W:0',False,True))# 获取解码器的权重参数recon_weights = sess.run(tl.layers.get_variables_with_name('recon_layer1/W:0',False,True))recon_bias = sess.run(tl.layers.get_variables_with_name('recon_layer1/b:0',False,True))for i in range(n_images):for ax,img in zip(axes,[test_images,decoded]):ax[i].imshow(img[i].reshape(image_width,image_width),cmap='binary')plt.show()

实验结果:

MSE

MSE+L2

MSE+L2+KLD

总结:

自编码器通过监督学习来发现数据集内部特征,提取有用信息,达到降维的目的。在现实中存在大量无标注的数据,先用这部分数据训练一个自编码器,在神经网络训练过程中,先将数据喂入自编码器,自编码器输出的结果再喂入神经网络进行训练。通过这种操作达到提升训练效果的目的。

这个自编码器不是说输出数据和输入数据的相似度越高越好,相似度太高可能出现过拟合的情况,我们所希望的是自编码器对同一类型的数据都具有编码能力,即要求自编码器有很强的泛化能力。为了达到这个目的,在实验中对编码器增加正则化和KLD惩罚项,使得自编码器学习到稀疏性特征。需要注意的是两个版本的损失函数的使用场景各不相同。

参考《一起玩转TensorLayer》

自编码器的原理及实现相关推荐

  1. 11旋转编码器原理_旋转编码器的原理是什么?增量式编码器和绝对式编码器有什么区别?...

    先给出结论,最重要的区别在于:增量式编码器没有记忆,断电重启必须回到参考零位,才能找到需要的位置,而绝对式编码器,有记忆,断电重启不用回到零位,即可知道目标所在的位置. 接下来细说一下,主要包含如下的 ...

  2. 光电编码器的原理及应用场合_旋转式光电编码器工作原理及在视觉检测中的使用...

    一.光电编码器工作原理 光电编码器,是一种通过光电转换将输出轴上的机械几何位移量转换成脉冲或数字量的传感器.这是目前应用最多的传感器,光电编码器是由光栅盘和光电检测装置组成.光栅盘是在一定直径的圆板上 ...

  3. 编码器类型原理知识汇总(增量式/绝对式/绝对值)

    编码器以信号原理来分,有增量式编码器(SPC)和绝对式编码器(APC). 绝对式编码器可以记录编码器在一个绝对坐标系上的位置,而增量式编码器可以输出编码器从预定义的起始位置发生的增量变化. 增量式编码 ...

  4. 增量式编码器工作原理超详细图解

    旋转编码器是由光栅盘(又叫分度码盘)和光电检测装置(又叫接收器)组成.光栅盘是在一定直径的圆板上等分地开通若干个长方形孔.由于光栅盘与电机同轴,电机旋转时,光栅盘与电机同速旋转,发光二极管垂直照射光栅 ...

  5. 增量式(相对式)编码器与绝对式编码器工作原理

    增量式(相对式)编码器与绝对式编码器工作原理 增量式编码器工作原理 绝对式编码器工作原理 根据检测原理,编码器可分为光学式.磁式.感应式和电容式.根据其刻度方法及信号输出形式,可分为增量式.绝对式以及 ...

  6. 增量式旋转编码器工作原理

    增量式旋转编码器工作原理 增量式旋转编码器通过内部两个光敏接受管转化其角度码盘的时序和相位关系,得到其角度码盘角度位移量增加(正方向)或减少(负方向).在接合数字电路特别是单片机后,增量式旋转编码器在 ...

  7. 绝对值编码器工作原理是什么?单圈/多圈绝对值编码器有何区别?

    在前两篇文章中,小编对增量式编码器以及绝对式编码器有所阐述.为增进大家对编码器的认识,本文将对绝对值编码器予以介绍.通过本文,你将了解到什么是绝对值编码器.绝对值编码器的工作原理以及单圈/多圈绝对值编 ...

  8. 光电编码器的原理及应用场合_【技术浅析】编码器原理在数控系统维修中的应用...

    摘要:本文分析了编码器工作原理及其在数控系统中的应用,  结合维修工作中常见的机床零点丢失故障案例,找出有效的解决方法.      关键词:编码器 FANUC 数控系统 参考点 目前数控机床采用日本 ...

  9. 简述旋转编码器的工作原理_什么是编码器,编码器工作原理介绍

    点击上方蓝色字体 机械菌 关注我们,涨知识涨见识就在这里. 正文开始 编码器(encoder)是将信号(如比特流)或数据进行编制.转换为可用以通讯.传输和存储的信号形式的设备.编码器把角位移或直线位移 ...

  10. VAE(变分自编码器)原理简介

    一.技术背景 变分自编码器(VAE)是一种深度生成模型,可以用于从高维数据中提取潜在的低维表示,并用于生成新的样本数据.自编码器(Autoencoder)是深度学习领域中常用的一种无监督学习方法,其基 ...

最新文章

  1. getline及读文件总结
  2. 【学习笔记】SAP Fiori相关概念介绍
  3. 保存的图数据丢失_自从用了这2个功能,再也没有担心过文档丢失
  4. Java线程池ThreadPoolExecutor
  5. 深入Redis客户端(redis客户端属性、redis缓冲区、关闭redis客户端)
  6. 【华为云技术分享】看得见的安心,一手掌握华为云DRS迁移进度
  7. d3d 渲染遇到的几个问题
  8. 曹 雷 : 证券基金经营机构如何理解科技是投资而非投入
  9. Unity3D 渲染管线全流程解析
  10. 亚马逊Alexa Connect Kit(ACK)
  11. Spring 事务扩展机制 TransactionSynchronization
  12. 3D 文件格式 - 对应厂商
  13. http://www.jdon.com/
  14. 2012,三星势必问鼎中原
  15. 客户端在线更新-QT
  16. 成都敏之澳电商:拼多多商家怎么看店铺是否降权导?
  17. Android拦截电话与短信(电话拒接/短信拒收)
  18. IT技术外包公司值得去吗? | 关于 ICC Contractor 你应该知道的!
  19. 网络和网路互联的设计
  20. uvc摄像头代码解析6

热门文章

  1. CSDN联合BSV发布首个区块链开发工程师能力认证
  2. C语言:二维数组:求平均数
  3. JDBC连接数据库模板
  4. PCF8951(AD-DA)
  5. 随身车联网——车联网生态新物种(附发布会视频)
  6. 【flask高级】从源码深入理解flask路由之endpoint
  7. soot基础 -- 相关数据结构SootClass,SootMethod,SootBody,Unit的进一步说明
  8. 多层循环给数组添加元素重复添加问题
  9. db2 replace函数的用法_48R软件数据的基本处理之删除重复数据(duplicated()、unique()、distinct()函数)...
  10. java程序步骤_java编写程序的步骤是什么?java编写程序步骤实例讲解