批量标准化(batch normalization简称BN)主要是为了克服当神经网络层数加深而导致难以训练而诞生的。当深度神经网络随着网络深度加深,训练起来会越来越困难,收敛速度会很慢,还会产生梯度消失问题(vanishing gradient problem)。

在统计机器学习领域中有一个ICS(Internal Covariate Shift)理论:源域(source domain)和目标域(target domain)的数据分布是一致的。也就是训练数据和测试数据满足相同的分布,这是通过训练数据获得的模型在测试数据上有一个好的效果的保证。

Covariate Shift是指训练数据的样本和测试数据的样本分布不一致时,训练获取的模型无法很好的泛化。它是分布不一致假设之下的一个分支问题,也就是指源域和目标域的条件概率是一致的,但是其边缘概率不同。对于神经网络而言,神经网络的各层输出,在经过了层内操作后,各层输出分布会随着输入分布的变化而变化,而且差异会随着网络的深度增加而加大,但是每一层随指向的样本标记是不会改变的。

解决Covariate Shift问题可以通过对训练样本和测试样本的比例对训练样本做一个矫正,通过批量标准化来标准化某些层或所有层的输入,从而固定每层输入信号的均值与方差。

一、批量标准化的实现

批量标准化是在激活函数之前,对z=wx+b做标准化,使得输出结果满足标准的正态分布,即均值为0,方差为1。让每一层的输入有一个稳定的分布便于网络的训练。

二、批量标准化的优点

1、加大探索的步长,加快模型收敛的速度

2、更容易跳出局部最小值

3、破坏原来的数据分布,在一定程度上可以缓解过拟合。

当遇到神经网络收敛速度很慢或梯度爆炸等无法训练的情况时,可以尝试使用批量标准化来解决问题。

三、TensorFlow的批量标准化实例

1、tf.nn.moments(x,axes,shift=None,name=None,keep_dims=False)

函数介绍:计算x的均值和方差

参数介绍:

  • x:需要计算均值和方差的tensor
  • axes:指定求解x某个维度上的均值和方差,如果x是一维tensor,则axes=[0]
  • name:用于计算均值和方差操作的名称
  • keep_dims:是否产生与输入相同相同维度的结果
    z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)#计算z的均值和方差#计算列的均值和方差z_mean_col,z_var_col = tf.nn.moments(z,axes=[0])#[1.5 1.5 1.5 1.5 1.5] [0.25 0.25 0.25 0.25 0.25]#计算行的均值和方差z_mean_row,z_var_row = tf.nn.moments(z,axes=[1])#等价于axes=[-1],-1表示最后一维#[1. 2.] [0. 0.]#计算整个数组的均值和方差z_mean,z_var = tf.nn.moments(z,axes=[0,1])#1.5 0.25sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)print(sess.run(z))print(sess.run(z_mean_col),sess.run(z_var_col))print(sess.run(z_mean_row),sess.run(z_var_row))print(sess.run(z_mean),sess.run(z_var))

2、tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)

函数介绍:计算batch normalization

参数介绍:

  • x:输入的tensor,具有任意的维度
  • mean:输入tensor的均值
  • variance:输入tensor的方差
  • offset:偏置tensor,初始化为1
  • scale:比例tensor,初始化为0
  • variance_epsilon:一个接近于0的数,避免除以0
    z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)#计算z的均值和方差z_mean,z_var = tf.nn.moments(z,axes=[0,1])scale = tf.Variable(tf.ones([2,5]))shift = tf.Variable(tf.zeros([2,5]))#计算batch normalizationz_bath_norm = tf.nn.batch_normalization(z,z_mean,z_var,shift,scale,variance_epsilon=0.001)sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)print(sess.run(z))print(sess.run(z_bath_norm))

TensorFlow的batch_normalization相关推荐

  1. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  2. tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)

    tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...

  3. Tensorflow笔记(基础):批处理(batch_normalization)

    CODE # - * - coding: utf - 8 -*- # # 作者:田丰(FontTian) # 创建时间:'2017/8/2' # 邮箱:fonttian@Gmaill.com # CS ...

  4. Tensorflow中的tf.layers.batch_normalization()用法

    使用tf.layers.batch_normalization()需要三步: 在卷积层将激活函数设置为None. 使用batch_normalization. 使用激活函数激活. 需要特别注意的是:在 ...

  5. 【目标检测】(8) ASPP改进加强特征提取模块,附Tensorflow完整代码

    各位同学好,最近想改进一下YOLOV4的SPP加强特征提取模块,看到很多论文中都使用语义分割中的ASPP模块来改进,今天用Tensorflow复现一下代码. YOLOV4的主干网络代码可见我上一篇文章 ...

  6. 【神经网络】(17) EfficientNet 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 复现 EfficientNet 卷积神经网络模型. EfficientNet 的网络结构和 MobileNetV3 比较相似,建议大家在学 ...

  7. 【神经网络】(16) MobileNetV3 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 MobileNetV3 轻量化网络模型. MobileNetV3 做了如下改动(1)更新了V2中的逆转残差结构:(2)使用NAS搜索 ...

  8. 【神经网络】(15) Xception 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 Xception 神经网络模型. 在前面章节中,我已经介绍了很多种轻量化卷积神经网络模型,感兴趣的可以看一下:https://blo ...

  9. 【神经网络】(14) MnasNet 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 复现谷歌轻量化神经网络 MnasNet  通常而言,移动端(手机)和终端(安防监控.无人驾驶)上的设备计算能力有限,无法搭载庞大的神经网络 ...

最新文章

  1. JavaScript基本知识
  2. unity 灯笼_如何创建将自己拼成文字的漂亮灯笼
  3. IEEE 发布年终总结,AI 奇迹不再是故事
  4. 极光推送配置(Android Studio),亲测有效
  5. 海量数据处理——位图法bitmap
  6. 从零开始学python网络爬虫-从零开始学Python 三(网络爬虫)
  7. Shape Context
  8. [Bzoj2243][SDOI2011]染色(线段树树剖)
  9. 一个Linux下C线程池的实现(转)
  10. 北京师范大学新生入学计算机考试内容,北京师范大学
  11. linux内存操作--ioremap和mmap
  12. IDEA下Springcloud框架搭建(一)之服务注册与发现
  13. hadoop的学习之一
  14. ‘catkin_make‘ is currently not installed问题修复
  15. 使用Python3将BT种子转磁力链接
  16. 高通平台驱动常见问题
  17. 零信任嵌入式安全沙箱技术,企业应用软件的技术底座
  18. 尺缩钟慢之动钟变慢——思想实验推导狭义相对论(七)
  19. oa系统打不开只能重启服务器,oa系统打不开怎么办-oa系统打不开的解决方法 - 河东软件园...
  20. PowerPC VxWorks BSP分析(4.3)——BSP定制

热门文章

  1. ps界面为啥突然变大了_5个一劳永逸的Ps设置,让Ps用起来更轻松
  2. 2021云南曲靖富源区高考成绩查询,云南富源第一中学2021年录取分数线
  3. C语言strtok函数使用实例以及注意事项
  4. 十四条令PHP初学者头疼问题大总结
  5. 十四条令PHP初学者头疼问题大总结(1)
  6. C学习笔记-字符串处理函数
  7. 国内服务器的提供商有哪些
  8. 十一个顶级的Git 客户端,绝对很实用!
  9. 软件构造:防御式拷贝(Defensive Copying)
  10. 小飞鱼通达二开 企业微信与通达OA的另一种集成方式(图文)