tensorflow中batch normalization的用法

网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下:

1.原理

公式如下:

y=γ(x-μ)/σ+β

其中x是输入,y是输出,μ是均值,σ是方差,γ和β是缩放(scale)、偏移(offset)系数。

一般来讲,这些参数都是基于channel来做的,比如输入x是一个16*32*32*128(NWHC格式)的feature map,那么上述参数都是128维的向量。其中γ和β是可有可无的,有的话,就是一个可以学习的参数(参与前向后向),没有的话,就简化成y=(x-μ)/σ。而μ和σ,在训练的时候,使用的是batch内的统计值,测试/预测的时候,采用的是训练时计算出的滑动平均值。

2.tensorflow中使用

tensorflow中batch normalization的实现主要有下面三个:

tf.nn.batch_normalization

tf.layers.batch_normalization

tf.contrib.layers.batch_norm

封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因为在tensorflow官网的解释比较详细。我平时多使用tf.layers.batch_normalization,因此下面的步骤都是基于这个。

3.训练

训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops):train_op = optimizer.minimize(loss)

4.测试

测试时需要注意一点,输入参数training=False,其他就没了

5.预测

预测时比较特别,因为这一步一般都是从checkpoint文件中读取模型参数,然后做预测。一般来说,保存checkpoint的时候,不会把所有模型参数都保存下来,因为一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),如下:

var_list = tf.trainable_variables()
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

但使用了batch_normalization,γ和β是可训练参数没错,μ和σ不是,它们仅仅是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到μ和σ。更诡异的是,利用tf.moving_average_variables()也没法获取bn层中的μ和σ(也可能是我用法不对),不过好在所有的参数都在tf.global_variables()中,因此可以这么写:

var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

按照上述写法,即可把μ和σ保存下来,读取模型预测时也不会报错,当然输入参数training=False还是要的。

注意上面有个不严谨的地方,因为我的网络结构中只有bn层包含moving_mean和moving_variance,因此只根据这两个字符串做了过滤,如果你的网络结构中其他层也有这两个参数,但你不需要保存,建议使用诸如bn/moving_mean的字符串进行过滤。

 

2018.4.22更新

提供一个基于mnist的示例,供大家参考。包含两个文件,分别用于train/test。注意bn_train.py文件的51-61行,仅保存了网络中的可训练变量和bn层利用统计得到的mean和var。注意示例中需要下载mnist数据集,要保持电脑可以联网。

tensorflow学习笔记1:batch normalization 用法相关推荐

  1. 【深度学习笔记】Batch Normalization 以及其如何解决梯度消失问题

    前言 Batch Normalization作为最近一年来DL的重要成果,已经广泛被证明其有效性和重要性.目前几乎已经成为DL的标配了,任何有志于学习DL的同学们朋友们雷迪斯俺的詹特曼们都应该好好学一 ...

  2. tensorflow学习笔记(七):CNN手写体(MNIST)识别

    文章目录 一.CNN简介 二.主要函数 三.CNN的手写体识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.CNN简介 一般的卷积神经网络由以下几个层组成:卷积层,池化层,非线性激活函数 ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  4. tensorflow学习笔记1

    tensorflow学习笔记1 本文主要记录我在慕课上观看北大曹建老师的<人工智能实践:Tensorflow笔记>,链接:https://www.icourse163.org/course ...

  5. tensorflow学习笔记(八):LSTM手写体(MNIST)识别

    文章目录 一.LSTM简介 二.主要函数 三.LSTM手写体(MNIST)识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.LSTM简介 LSTM是一种特殊的RNN,很好的解决了RNN中 ...

  6. 【TensorFlow学习笔记】完美解决 pip3 install tensorflow 没有models库,读取PTB数据

    安装tensorflow 我使用的是最最最简单的容易的 pip3 install <TensorFlow学习笔记> 一. 安装win10下python3.6的tensorflow的CPU版 ...

  7. TensorFlow学习笔记02:使用tf.data读取和保存数据文件

    TensorFlow学习笔记02:使用tf.data读取和保存数据文件 使用`tf.data`读取和写入数据文件 读取和写入csv文件 写入csv文件 读取csv文件 读取和保存TFRecord文件 ...

  8. [TensorFlow 学习笔记-04]卷积函数之tf.nn.conv2d

    [版权说明] TensorFlow 学习笔记参考: 李嘉璇 著 TensorFlow技术解析与实战 黄文坚 唐源 著 TensorFlow实战郑泽宇  顾思宇 著 TensorFlow实战Google ...

  9. TensorFlow学习笔记:Retrain Inception_v3(一)

    转:http://www.jianshu.com/p/613c3b08faea 0. 概要 最新的物体识别模型可能含有数百万个参数,将耗费几周的时间去完全训练.因此我们采用迁移学习的方法,在已经训练好 ...

  10. Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题

    Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 参考文章: (1)Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 (2)http ...

最新文章

  1. 我的Python分析成长之路8
  2. 我就想要个两年1024徽章~!
  3. SAP云平台和SAP传统Netweaver系统互联的技术方式
  4. DPM 2012 SP1---安装并部署DPM 2012 SP1服务器
  5. Vivado MMCM IP核接口信号介绍
  6. 12306 被质疑过度获取用户隐私;直播答题外挂横行;阿里云辟谣称绝不做虚拟货币 | 一周业界事
  7. 数据库设计系列[04]组织结构加入权限系统
  8. 800万像素3倍光变 奥林巴斯FE280降价
  9. Ps 初学者教程,如何在图片中创建双重曝光效果?
  10. oracle 取时间的日期函数,Oracle日期函数简介
  11. 前端实现H5制作海报
  12. 计算机管理员绩效指标,网络管理员绩效kpi考核标准..doc
  13. 第8章 资源管理调度框架YARN
  14. ISP模块之色彩增强算法--HSV空间Saturation通道调整
  15. word自动编号与文字间距太大怎么办
  16. 大写金额换算器iOS版源代码
  17. GUI界的大战: QT VS GTK
  18. 产生按指数分布的随机数----摘自csdn
  19. 【C库函数】strlen函数详解
  20. Python-(生成由0到9组成的n位数字)

热门文章

  1. android message to iphone,这款应用可以将苹果的iMessage带到安卓系统
  2. access ea 可以联网吗_EA自家Origin平台高级会员Origin Access Premier现已上线
  3. sap 双计量单位_SAP双计量单位配置指南CUNI.doc
  4. php 开启 pathinfo,Nginx + php-fpm 开启 PATH_INFO 模式
  5. magento mysql_解决Magento环境Mysql经常挂掉的问题
  6. 传入oracle中的日期类型,Oracle中的日期类型及相关函数
  7. 如何根据iframe内嵌页面调整iframe高宽续篇
  8. 四 Lync Server 2013 部署指南-前端部署(2)
  9. 英国大概率退出欧盟!
  10. 将Nginx添加到系统服务(使其可使用service命令控制)