tensorflow学习笔记1:batch normalization 用法
tensorflow中batch normalization的用法
网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下:
1.原理
公式如下:
y=γ(x-μ)/σ+β
其中x是输入,y是输出,μ是均值,σ是方差,γ和β是缩放(scale)、偏移(offset)系数。
一般来讲,这些参数都是基于channel来做的,比如输入x是一个16*32*32*128(NWHC格式)的feature map,那么上述参数都是128维的向量。其中γ和β是可有可无的,有的话,就是一个可以学习的参数(参与前向后向),没有的话,就简化成y=(x-μ)/σ。而μ和σ,在训练的时候,使用的是batch内的统计值,测试/预测的时候,采用的是训练时计算出的滑动平均值。
2.tensorflow中使用
tensorflow中batch normalization的实现主要有下面三个:
tf.nn.batch_normalization
tf.layers.batch_normalization
tf.contrib.layers.batch_norm
封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因为在tensorflow官网的解释比较详细。我平时多使用tf.layers.batch_normalization,因此下面的步骤都是基于这个。
3.训练
训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加
update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops):train_op = optimizer.minimize(loss)
4.测试
测试时需要注意一点,输入参数training=False,其他就没了
5.预测
预测时比较特别,因为这一步一般都是从checkpoint文件中读取模型参数,然后做预测。一般来说,保存checkpoint的时候,不会把所有模型参数都保存下来,因为一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),如下:
var_list = tf.trainable_variables() saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
但使用了batch_normalization,γ和β是可训练参数没错,μ和σ不是,它们仅仅是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到μ和σ。更诡异的是,利用tf.moving_average_variables()也没法获取bn层中的μ和σ(也可能是我用法不对),不过好在所有的参数都在
tf.global_variables()中,因此可以这么写:
var_list = tf.trainable_variables() g_list = tf.global_variables() bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] var_list += bn_moving_vars saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
按照上述写法,即可把μ和σ保存下来,读取模型预测时也不会报错,当然输入参数training=False还是要的。
注意上面有个不严谨的地方,因为我的网络结构中只有bn层包含moving_mean和moving_variance,因此只根据这两个字符串做了过滤,如果你的网络结构中其他层也有这两个参数,但你不需要保存,建议使用诸如bn/moving_mean的字符串进行过滤。
2018.4.22更新
提供一个基于mnist的示例,供大家参考。包含两个文件,分别用于train/test。注意bn_train.py文件的51-61行,仅保存了网络中的可训练变量和bn层利用统计得到的mean和var。注意示例中需要下载mnist数据集,要保持电脑可以联网。
tensorflow学习笔记1:batch normalization 用法相关推荐
- 【深度学习笔记】Batch Normalization 以及其如何解决梯度消失问题
前言 Batch Normalization作为最近一年来DL的重要成果,已经广泛被证明其有效性和重要性.目前几乎已经成为DL的标配了,任何有志于学习DL的同学们朋友们雷迪斯俺的詹特曼们都应该好好学一 ...
- tensorflow学习笔记(七):CNN手写体(MNIST)识别
文章目录 一.CNN简介 二.主要函数 三.CNN的手写体识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.CNN简介 一般的卷积神经网络由以下几个层组成:卷积层,池化层,非线性激活函数 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- tensorflow学习笔记1
tensorflow学习笔记1 本文主要记录我在慕课上观看北大曹建老师的<人工智能实践:Tensorflow笔记>,链接:https://www.icourse163.org/course ...
- tensorflow学习笔记(八):LSTM手写体(MNIST)识别
文章目录 一.LSTM简介 二.主要函数 三.LSTM手写体(MNIST)识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.LSTM简介 LSTM是一种特殊的RNN,很好的解决了RNN中 ...
- 【TensorFlow学习笔记】完美解决 pip3 install tensorflow 没有models库,读取PTB数据
安装tensorflow 我使用的是最最最简单的容易的 pip3 install <TensorFlow学习笔记> 一. 安装win10下python3.6的tensorflow的CPU版 ...
- TensorFlow学习笔记02:使用tf.data读取和保存数据文件
TensorFlow学习笔记02:使用tf.data读取和保存数据文件 使用`tf.data`读取和写入数据文件 读取和写入csv文件 写入csv文件 读取csv文件 读取和保存TFRecord文件 ...
- [TensorFlow 学习笔记-04]卷积函数之tf.nn.conv2d
[版权说明] TensorFlow 学习笔记参考: 李嘉璇 著 TensorFlow技术解析与实战 黄文坚 唐源 著 TensorFlow实战郑泽宇 顾思宇 著 TensorFlow实战Google ...
- TensorFlow学习笔记:Retrain Inception_v3(一)
转:http://www.jianshu.com/p/613c3b08faea 0. 概要 最新的物体识别模型可能含有数百万个参数,将耗费几周的时间去完全训练.因此我们采用迁移学习的方法,在已经训练好 ...
- Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题
Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 参考文章: (1)Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 (2)http ...
最新文章
- 我的Python分析成长之路8
- 我就想要个两年1024徽章~!
- SAP云平台和SAP传统Netweaver系统互联的技术方式
- DPM 2012 SP1---安装并部署DPM 2012 SP1服务器
- Vivado MMCM IP核接口信号介绍
- 12306 被质疑过度获取用户隐私;直播答题外挂横行;阿里云辟谣称绝不做虚拟货币 | 一周业界事
- 数据库设计系列[04]组织结构加入权限系统
- 800万像素3倍光变 奥林巴斯FE280降价
- Ps 初学者教程,如何在图片中创建双重曝光效果?
- oracle 取时间的日期函数,Oracle日期函数简介
- 前端实现H5制作海报
- 计算机管理员绩效指标,网络管理员绩效kpi考核标准..doc
- 第8章 资源管理调度框架YARN
- ISP模块之色彩增强算法--HSV空间Saturation通道调整
- word自动编号与文字间距太大怎么办
- 大写金额换算器iOS版源代码
- GUI界的大战: QT VS GTK
- 产生按指数分布的随机数----摘自csdn
- 【C库函数】strlen函数详解
- Python-(生成由0到9组成的n位数字)
热门文章
- android message to iphone,这款应用可以将苹果的iMessage带到安卓系统
- access ea 可以联网吗_EA自家Origin平台高级会员Origin Access Premier现已上线
- sap 双计量单位_SAP双计量单位配置指南CUNI.doc
- php 开启 pathinfo,Nginx + php-fpm 开启 PATH_INFO 模式
- magento mysql_解决Magento环境Mysql经常挂掉的问题
- 传入oracle中的日期类型,Oracle中的日期类型及相关函数
- 如何根据iframe内嵌页面调整iframe高宽续篇
- 四 Lync Server 2013 部署指南-前端部署(2)
- 英国大概率退出欧盟!
- 将Nginx添加到系统服务(使其可使用service命令控制)