上海站 | 高性能计算之GPU CUDA培训

4月13-15日

三天密集式学习 快速带你晋级
阅读全文
>

正文共4488个字,4张图,预计阅读时间12分钟。

tensorflow中关于BN(Batch Normalization)的函数主要有两个,分别是:

  • tf.nn.moments

  • tf.nn.batch_normalization

关于这两个函数,官方API中有详细的说明,具体的细节可以点链接查看,关于BN的介绍可以参考这篇论文(https://arxiv.org/abs/1502.03167),我来说说自己的理解。

不得不吐槽一下,tensorflow的官方API很少给例子,太不人性化了,人家numpy做的就比tensorflow强。

对了,moments函数的计算结果一般作为batch_normalization的部分输入!这就是两个函数的关系,下面展开介绍!

tf.nn.moments函数

官方的输入定义如下:

def moments(x, axes, name=None, keep_dims=False)

解释如下:

x 可以理解为我们输出的数据,形如 [batchsize, height, width, kernels]

axes 表示在哪个维度上求解,是个list,例如 [0, 1, 2]

name 就是个名字,不多解释

keep_dims 是否保持维度,不多解释

这个函数的输出有两个,用官方的话说就是:

Two Tensor objects: mean and variance.

解释如下:

  • mean 就是均值啦

  • variance 就是方差啦

关于这个函数的最基本的知识就介绍完了,但依然没明白这函数到底是干啥的,下面通过几个例子来说明:

1、计算2×3维向量的mean和variance,程序节选如下:

img = tf.Variable(tf.random_normal([2, 3])) axis = list(range(len(img.get_shape()) - 1)) mean, variance = tf.nn.moments(img, axis)

输出的结果如下:

img = [[ 0.69495416  2.08983064 -1.08764684]         [ 0.31431156 -0.98923939 -0.34656194]] mean =  [ 0.50463283  0.55029559 -0.71710438] variance =  [ 0.0362222   2.37016821  0.13730171]

有了例子和结果,就很好理解了,moments函数就是在 [0] 维度上求了个均值和方差,对于axis这个参数的理解,可以参考这里。

另外,针对2×3大小的矩阵,axis还可以这么理解,若axis = [0],那么我们2×3的小矩阵可以理解成是一个包含了2个长度为3的一维向量,然后就是求这两个向量的均值和方差啦!多个向量的均值、方差计算请自行脑补。

当然了,这个例子只是一个最简单的例子,如果换做求形如“[batchsize, height, width, kernels]”数据的mean和variance呢?接下来来简单分析一下。

2、计算卷积神经网络某层的的mean和variance

假定我们需要计算数据的形状是 [batchsize, height, width, kernels],熟悉CNN的都知道,这个在tensorflow中太常见了,例程序如下:

img = tf.Variable(tf.random_normal([128, 32, 32, 64])) axis = list(range(len(img.get_shape()) - 1)) mean, variance = tf.nn.moments(img, axis)

形如[128, 32, 32, 64]的数据在CNN的中间层非常常见,那么,为了给出一个直观的认识,这个函数的输出结果如下,可能输出的数字比较多。。。

mean =  [ -1.58071518e-03   9.46253538e-04   9.92774963e-04  -2.57909298e-04             4.31227684e-03   2.85443664e-03  -3.51431966e-03  -2.95847654e-04            -1.57856941e-03  -7.36653805e-04  -3.81006300e-03   1.95848942e-03            -2.19231844e-03   1.88898295e-04   3.09050083e-03   1.28045678e-04            -5.45501709e-04  -7.49588013e-04   3.41436267e-03   4.55856323e-04             1.21808052e-03   1.71916187e-03   2.33578682e-03  -9.98377800e-04             1.01172924e-03  -3.25803459e-03   1.98090076e-03  -9.53197479e-04             3.37207317e-03   6.27857447e-03  -2.22939253e-03  -1.75476074e-04             1.82938576e-03   2.28643417e-03  -2.59208679e-03  -1.05714798e-03            -1.82652473e-03   4.51803207e-05  -1.38700008e-03   1.88308954e-03            -3.67999077e-03  -4.22883034e-03   8.54551792e-04  -1.30176544e-04            -1.02388859e-03   3.15248966e-03  -1.00244582e-03  -3.58343124e-04             9.68813896e-04  -3.17507982e-03  -2.61783600e-03  -5.57708740e-03            -3.49491835e-04   7.54106045e-03  -9.98616219e-04   5.13806939e-04             1.08468533e-03   1.58560276e-03  -2.76589394e-03  -1.18827820e-03            -4.92024422e-03   3.14301252e-03   9.12249088e-04  -1.98567938e-03] variance =  [ 1.00330877  1.00071466  1.00299144  1.00269675  0.99600208  0.99615276                0.9968518   1.00154674  0.99785519  0.99120021  1.00565553  0.99633628                0.99637395  0.99959981  0.99702841  0.99686354  1.00210547  1.00151515                1.00124979  1.00289011  1.0019592   0.99810153  1.00296855  1.0040164                1.00397885  0.99348587  0.99743217  0.99921477  1.00718474  1.00182319                1.00461221  1.00222814  1.00570309  0.99897575  1.00203466  1.0002507                1.00139284  1.0015136   1.00439298  0.99371535  1.00209546  1.00239146                0.99446201  1.00200033  1.00330424  0.99965429  0.99676734  0.99974728                0.99562836  1.00447667  0.9969337   1.0026046   0.99110448  1.00229466                1.00264072  0.99483615  1.00260413  1.0050714   1.00082493  1.00062656                1.0020628   1.00507069  1.00343442  0.99490905]

然后我解释一下这些数字到底是怎么来的,可能对于2×3这么大的矩阵,理解起来比较容易,但是对于 [128, 32, 32, 64] 这样的4维矩阵,理解就有点困难了。

其实很简单,可以这么理解,一个batch里的128个图,经过一个64 kernels卷积层处理,得到了128×64个图,再针对每一个kernel所对应的128个图,求它们所有像素的mean和variance,因为总共有64个kernels,输出的结果就是一个一维长度64的数组啦!

手画示意图太丑了,我重新画了一个!

计算mean和variance

tf.nn.batch_normalization函数

官方对函数输入的定义是:

def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None):

关于这几个参数,可以参考这篇论文和这个博客,我这里就直接给出一个公式的截图了,如下:

晦涩难懂的公式

官方对参数的解释如下

官方的解释

这一堆参数里面,我们已经知道x、mean、variance这三个,那offset和scale呢??答案是:这两个参数貌似是需要训练的,其中offset一般初始化为0,scale初始化为1,另外offset、scale的shape与mean相同。

variance_epsilon这个参数设为一个很小的数就行,比如0.001。

但是,我这里要但是一下!BN在神经网络进行training和testing的时候,所用的mean、variance是不一样的!这个博客里已经说明了,但具体怎么操作的呢?我们看下面的代码:

update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY) update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance) mean, variance = control_flow_ops.cond(['is_training'], lambda: (mean, variance), lambda: (moving_mean, moving_variance))

看不懂没关系,这段代码的意思就是计算moving mean(滑动平均)、moving variance(滑动方差),然后利用 (moving_mean, moving_variance) 进行网络测试。

关于BN的完整实现,在Ryan Dahl的repository里有,名字叫做tensorflow-resnet(https://github.com/ry/tensorflow-resnet),可以自行查看。

原文链接:https://www.jianshu.com/p/0312e04e4e83

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org

请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看


LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础

谈谈Tensorflow的Batch Normalization相关推荐

  1. tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)

    tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...

  2. tensorflow中batch normalization的用法

    转载网址:如果侵权,联系我删除 https://www.cnblogs.com/hrlnw/p/7227447.html https://www.cnblogs.com/eilearn/p/97806 ...

  3. 谈Tensorflow的Batch Normalization

    tensorflow中关于BN(Batch Normalization)的函数主要有两个,分别是: tf.nn.moments tf.nn.batch_normalization 关于这两个函数,官方 ...

  4. 黑猿大叔-译文 | TensorFlow实现Batch Normalization

    正文共10537个字,8张图,预计阅读时间27分钟. 原文:Implementing Batch Normalization in Tensorflow(https://r2rt.com/implem ...

  5. tensorflow没有这个参数_解决TensorFlow中Batch Normalization参数没有保存的问题

    batch normalization的坑我真的是踩到要吐了,几个月前就踩了一次,看了网上好多资料,虽然跑通了但是当时没记录下来,结果这次又遇到了.时隔几个月,已经忘得差不多了,结果又花了半天重新踩了 ...

  6. batch normalization详解

    1.引入BN的原因 1.加快模型的收敛速度 2.在一定程度上缓解了深度网络中的"梯度弥散"问题,从而使得训练深层网络模型更加容易和稳定. 3.对每一批数据进行归一化.这个数据是可以 ...

  7. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  8. 3.1 Tensorflow: 批标准化(Batch Normalization)

    ##BN 简介 背景 批标准化(Batch Normalization )简称BN算法,是为了克服神经网络层数加深导致难以训练而诞生的一个算法.根据ICS理论,当训练集的样本数据和目标样本集分布不一致 ...

  9. 深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方,

    用pytorch做dropout和BN时需要注意的地方 pytorch做dropout: 就是train的时候使用dropout,训练的时候不使用dropout, pytorch里面是通过net.ev ...

最新文章

  1. linux jna调用so动态库
  2. .Net Discovery系列之四 深入理解.Net垃圾收集机制(下)
  3. 51 NOD 1363 最小公倍数之和 (欧拉函数思维应用)
  4. P3527 [POI2011]MET-Meteors 整体二分 + 树状数组
  5. tar打包时排除一些文件或者目录
  6. ArrayList 类方法toArray的一点疑惑
  7. 英文原始文本的读取与处理
  8. 基于C++和QT实现的简单数独游戏软件
  9. R软件和RStudio的入门介绍
  10. 微信小程序emoji表情输入框制作
  11. 可以删除电脑文件的c语言程序,win7电脑c盘都有哪些文件可以删除
  12. 【RAM IP】RAM IP核简介及实验
  13. 运鸿蒙之息 行祈者之意,祈禳之禳关度煞科
  14. Halting Problem图灵机问题
  15. 使用MCU SPI访问具有非标准SPI接口ADC的方法
  16. idea关于找不到包的问题,比如:Java:程序包org.springframework.beans.factory.annotation不存在
  17. Word2Vec算法和源码分析完整版
  18. 【genius_platform软件平台开发】第八十二讲:ARM Neon指令集一(ARM NEON Intrinsics, SIMD运算, 优化心得)
  19. 使用SOLIDWORKS验证光线模拟
  20. 如何取消PDF武侠小说中的密码

热门文章

  1. jpa 托管_java – jpa非托管实体
  2. 西门子plm_西门子PLM副总裁:NX,智能的CAD平台
  3. disable path length limit_通过Antsword看绕过disable_functions
  4. 南师大附中2021高考成绩查询,2021高考倒计时,你有一份师大附中专属回忆录待查收~...
  5. math库是python语言的数学模块_Python 数学模块(Math)
  6. 升级无法登录_JeeSite v4.2.2 发布,代码生成增强、Boot 2.3、短信登录、性能提升...
  7. php 读取excel转数组中,thinkphp5使用PHPExcel读取excel csv到数组
  8. [翻译]基于ASP.NET的NumericTextBox控件[Carol]
  9. Linux性能监测(系统监测统计命令详解)
  10. 4-Ubuntu—终端下重启与关机