批量归一化

在对神经网络的优化方法中,有一种使用十分广泛的方法——批量归一化,使得神经网络的识别准确度得到了极大的提升。

在网络的前向计算过程中,当输出的数据不再同一分布时,可能会使得loss的值非常大,使得网络无法进行计算。产生梯度爆炸的原因是因为网络的内部协变量转移,即正向传播的不同层参数会将反向训练计算时参照的数据样本分布改变。批量归一化的目的,就是要最大限度地保证每次的正向传播输出在同一分布上,这样反向计算时参照的数据样本分布就会与正向计算时的数据分布一样了,保证分布的统一。

了解了原理,批量正则化的做法就会变得简单,即将每一层运算出来的数据都归一化成均值为0方差为1的标准高斯分布。这样就会在保留样本分布特征的同时,又消除层与层间的分布差异。在实际的应用中,批量归一化的收敛非常快,并且有很强的泛化能力,在一些情况下,完全可以代替前面的正则化,dropout。

批量归一化的定义

在TensorFlow中有自带的BN函数定义:

tf.nn.batch_normalization(x,

maen,

variance,

offset,

scale,

variance_epsilon)

各个参数的含义如下:

x:代表输入

mean:代表样本的均值

variance:代表方差

offset:代表偏移量,即相加一个转化值,通常是用激活函数来做。

scale:代表缩放,即乘以一个转化值,同理,一般是1

variance_epsilon:为了避免分母是0的情况,给分母加一个极小值。

要使用这个函数,还需要另外的一个函数的配合:tf.nn.moments(),由此函数来计算均值和方差,然后就可以使用BN了,给函数的定义如下:

tf.nn.moments(x, axes, name, keep_dims=False),axes指定那个轴求均值和方差。

为了更好的效果,我们使用平滑指数衰减的方法来优化每次的均值和方差,这里可以使用

tf.train.ExponentialMovingAverage()函数,它的作用是让上一次的值对本次的值有一个衰减后的影响,从而使的每次的值连起来后会相对平滑一下。

批量归一化的简单用法

下面介绍具体的用法,在使用的时候需要引入头文件。

from tensorflow.contrib.layers.python.layers import batch_norm

函数的定义如下:

batch_norm(inputs,

decay,

center,

scale,

epsilon,

activation_fn,

param_initializers=None,

param_regularizers=None,

updates_collections=ops.GraphKeys.UPDATE_OPS,

is_training=True,

reuse=None,

variables_collections=None,

outputs_collections=None,

trainable=True,

batch_weights=None,

fused=False,

data_format=DATA_FORMAT_NHWC,

zero_debias_moving_mean=False,

scope=None,

renorm=False,

renorm_clipping=None,

renorm_decay=0.99)

各参数的具体含义如下:

inputs:输入

decay:移动平均值的衰减速度,使用的是平滑指数衰减的方法更新均值方差,一般会设置0.9,值太小会导致更新太快,值太大会导致几乎没有衰减,容易出现过拟合。

scale:是否进行变换,通过乘以一个gamma值进行缩放,我们常习惯在BN后面接一个线性变化,如relu。

epsilon:为了避免分母为0,给分母加上一个极小值,一般默认。

is_training:当为True时,代表训练过程,这时会不断更新样本集的均值和方差,当测试时,要设置为False,这样就会使用训练样本的均值和方差。

updates_collections:在训练时,提供一种内置的均值方差更新机制,即通过图中的tf.GraphKeys.UPDATE_OPS变量来更新。但它是在每次当前批次训练完成后才更新均值和方差,这样导致当前数据总是使用前一次的均值和方差,没有得到最新的值,所以一般设置为None,让均值和方差及时更新,但在性能上稍慢。

reuse:支持变量共享。

具体的代码如下:

x = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3])

y= tf.placeholder(dtype=tf.float32, shape=[None, 10])

train=tf.Variable(tf.constant(False))

x_images= tf.reshape(x, [-1, 32, 32, 3])def batch_norm_layer(value, train=False, name='batch_norm'):if train is notFalse:return batch_norm(value, decay=0.9, updates_collections=None, is_training=True)else:return batch_norm(value, decay=0.9, updates_collections=None, is_training=False)

w_conv1= init_cnn.weight_variable([3, 3, 3, 64]) #[-1, 32, 32, 3]

b_conv1 = init_cnn.bias_variable([64])

h_conv1= tf.nn.relu(batch_norm_layer((init_cnn.conv2d(x_images, w_conv1) +b_conv1), train))

h_pool1=init_cnn.max_pool_2x2(h_conv1)

w_conv2= init_cnn.weight_variable([3, 3, 64, 64]) #[-1, 16, 16, 64]

b_conv2 = init_cnn.bias_variable([64])

h_conv2= tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool1, w_conv2) +b_conv2), train))

h_pool2=init_cnn.max_pool_2x2(h_conv2)

w_conv3= init_cnn.weight_variable([3, 3, 64, 32]) #[-1, 18, 8, 32]

b_conv3 = init_cnn.bias_variable([32])

h_conv3= tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool2, w_conv3) +b_conv3), train))

h_pool3=init_cnn.max_pool_2x2(h_conv3)

w_conv4= init_cnn.weight_variable([3, 3, 32, 16]) #[-1, 18, 8, 32]

b_conv4 = init_cnn.bias_variable([16])

h_conv4= tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool3, w_conv4) +b_conv4), train))

h_pool4=init_cnn.max_pool_2x2(h_conv4)

w_conv5= init_cnn.weight_variable([3, 3, 16, 10]) #[-1, 4, 4, 16]

b_conv5 = init_cnn.bias_variable([10])

h_conv5= tf.nn.relu(batch_norm_layer((init_cnn.conv2d(h_pool4, w_conv5) +b_conv5), train))

h_pool5= init_cnn.avg_pool_4x4(h_conv5) #[-1, 4, 4, 10]

y_pool= tf.reshape(h_pool5, shape=[-1, 10])

cross_entropy= tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pool))

optimizer= tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

加上了BN层之后,识别的准确率显著的得到了提升,并且计算速度也是飞起。

tensorflow 数据归一化_TensorFlow——批量归一化操作相关推荐

  1. 梯度消失、梯度爆炸、过拟合问题之神经网络应对方案:数据预处理、批量归一化、非饱和激活函数、梯度缩放和梯度裁剪、权重初始化、提前终止、集成学习、l1l2、Dropout

    数据预处理.批量归一化Batch Normalization.非饱和激活函数.梯度缩放(Gradient Scaling)和梯度裁剪(Gradient Clipping).权重初始化(Xavier+H ...

  2. Lecture6:激活函数、权值初始化、数据预处理、批量归一化、超参数选择

    目录 1.最小梯度下降(Mini-batch SGD) 2.激活函数 2.1 sigmoid 2.2 tanh 2.3 ReLU 2.4 Leaky ReLU 2.5 ELU 2.6 最大输出神经元 ...

  3. 深度学习 --- 优化入门四(Batch Normalization(批量归一化)一)

    前几节我们详细的探讨了,梯度下降存在的问题和优化方法,本节将介绍在数据处理方面很重要的优化手段即批量归一化(批量归一化). 批量归一化(Batch Normalization)并不能算作是一种最优化算 ...

  4. 【Pytorch神经网络理论篇】 16 过拟合问题的优化技巧(三):批量归一化

    1 批量归一化理论 1.1 批量归一化原理 1.2 批量归一化定义 将每一层运算出来的数据归一化成均值为0.方差为1的标准高斯分布.这样就会在保留样本的分布特征,又消除了层与层间的分布差异. 在实际应 ...

  5. (pytorch-深度学习)批量归一化

    批量归一化 批量归一化(batch normalization)层能让较深的神经网络的训练变得更加容易 通常来说,数据标准化预处理对于浅层模型就足够有效了.随着模型训练的进行,当每层中参数更新时,靠近 ...

  6. 深度学习入门(三十二)卷积神经网络——BN批量归一化

    深度学习入门(三十二)卷积神经网络--BN批量归一化 前言 批量归一化batch normalization 课件 批量归一化 批量归一化层 批量归一化在做什么? 总结 教材 1 训练深层网络 2 批 ...

  7. layui框架数据表格的批量删除

    layui框架数据表格的批量删除操作 此文献为layui框架的数据表格的批量删除,批量删除顾名思义就是把大量的数据进行删除操作 由于有点项目数据繁多,如果在要删除的时候一个一个的删除的话,就会很麻烦. ...

  8. TensorFlow数据归一化

    TensorFlow数据归一化 1. tf.nn.l2_normalize     - l2_normalize(x, dim, epsilon=1e-12,name=None)     - outp ...

  9. Batch Normalization批量归一化

    深度学习捷报连连.声名鹊起,随机梯度下降成了训练深度网络的主流方法.尽管随机梯度下降法对于训练深度网络简单高效,但是它有个毛病,就是需要我们人为的去选择参数,比如学习率.参数初始化.权重衰减系数.Dr ...

最新文章

  1. acctmod-ftp.sh
  2. 初学python还是swift-请问零基础学习python 和swift哪个更好入门呢?
  3. iOS开发-Xcode入门ObjC程序
  4. 你在看Netflix,Netflix也在看你
  5. 一份完整的问卷模板_一份完整市场推广策划方案模板
  6. c# 配置文件App.config操作类库
  7. c语言斐波那契数列递归数组,C语言数据结构学习:递归之斐波那契数列
  8. redis value多大会影响性能_Redis 最常见面试问题
  9. PC批量转换网易ncm音乐
  10. win10 pip install talib一直安装失败
  11. python存储16bit和32bit图像
  12. centOS6添加开机启动
  13. hortonworks_具有在IBM POWER8上运行的Hortonworks Data Platform(HDP)的SAS软件
  14. 《算法图解》系列笔记(七)—— 狄克斯特拉算法
  15. cpu,寄存器,控制器,运算器
  16. 未来属于智能,智能存在未在每个角落-称重
  17. 初中计算机考试wps文字,初中信息技术WPS表格测试题.docx
  18. 3.计算机的应用领域及其发展趋势是什么,计算机应用的现状及其发展趋势
  19. python爬虫-获取腾讯视频的弹幕
  20. 微信小程序开发解决按钮大小问题

热门文章

  1. 米奇emoji_一些常用的 Emoji 符号(可直接复制)
  2. php自动提交百度收录,wordpress站点如何自动提交百度收录
  3. 安卓9开机 bootanimation.zip_小米手机如何从安卓10退回安卓9系统,不丢失传感器,不变砖...
  4. ZooKeeper学习第七期--ZooKeeper一致性原理(转)
  5. 理解 Generator 的执行
  6. Java8 HashMap之tableSizeFor
  7. 三表联查,这是我目前写过的最长的sql语句,嗯嗯,果然遇到问题才能让我更快成长,更复杂的语句也有了一些心得了...
  8. oracle修改用户密码
  9. list numpy array tensor转换
  10. 【数据结构】BFS 代码模板