转载网址:如果侵权,联系我删除

https://www.cnblogs.com/hrlnw/p/7227447.html

https://www.cnblogs.com/eilearn/p/9780696.html

https://www.cnblogs.com/stingsl/p/6428694.html

神经网络学习过程本质就是为了学习数据分布,一旦训练数据与测试数据的分布不同,那么网络的泛化能力也大大降低;另外一方面,一旦每批训练数据的分布各不相同(batch 梯度下降),那么网络就要在每次迭代都去学习适应不同的分布,这样将会大大降低网络的训练速度,这也正是为什么我们需要对数据都要做一个归一化预处理的原因。

对于深度网络的训练是一个复杂的过程,只要网络的前面几层发生微小的改变,那么后面几层就会被累积放大下去。一旦网络某一层的输入数据的分布发生改变,那么这一层网络就需要去适应学习这个新的数据分布,所以如果训练过程中,训练数据的分布一直在发生变化,那么将会影响网络的训练速度。

我们知道网络一旦train起来,那么参数就要发生更新,除了输入层的数据外(因为输入层数据,我们已经人为的为每个样本归一化),后面网络每一层的输入数据分布是一直在发生变化的,因为在训练的时候,前面层训练参数的更新将导致后面层输入数据分布的变化。以网络第二层为例:网络的第二层输入,是由第一层的参数和input计算得到的,而第一层的参数在整个训练过程中一直在变化,因此必然会引起后面每一层输入数据分布的改变。我们把网络中间层在训练过程中,数据分布的改变称之为:“Internal  Covariate Shift”。Paper所提出的算法,就是要解决在训练过程中,中间层数据分布发生改变的情况,于是就有了Batch  Normalization,这个牛逼算法的诞生。

1.原理

公式如下:

y=γ(x-μ)/σ+β

其中x是输入,y是输出,μ是均值,σ是方差,γ和β是缩放(scale)、偏移(offset)系数。

一般来讲,这些参数都是基于channel来做的,比如输入x是一个16*32*32*128(NWHC格式)的feature map,那么上述参数都是128维的向量。其中γ和β是可有可无的,有的话,就是一个可以学习的参数(参与前向后向),没有的话,就简化成y=(x-μ)/σ。而μ和σ,在训练的时候,使用的是batch内的统计值,测试/预测的时候,采用的是训练时计算出的滑动平均值。

2.tensorflow中使用

tensorflow中batch normalization的实现主要有下面三个:

tf.nn.batch_normalization

tf.layers.batch_normalization

tf.contrib.layers.batch_norm

封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因为在tensorflow官网的解释比较详细。我平时多使用tf.layers.batch_normalization,因此下面的步骤都是基于这个。

3.训练

训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops):train_op = optimizer.minimize(loss)

4.测试

测试时需要注意一点,输入参数training=False,其他就没了

5.预测

预测时比较特别,因为这一步一般都是从checkpoint文件中读取模型参数,然后做预测。一般来说,保存checkpoint的时候,不会把所有模型参数都保存下来,因为一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),如下:

var_list = tf.trainable_variables()
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

但使用了batch_normalization,γ和β是可训练参数没错,μ和σ不是,它们仅仅是通过滑动平均计算出的,如果按照上面的方法保存模型,在读取模型预测时,会报错找不到μ和σ。更诡异的是,利用tf.moving_average_variables()也没法获取bn层中的μ和σ(也可能是我用法不对),不过好在所有的参数都在tf.global_variables()中,因此可以这么写:

var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

按照上述写法,即可把μ和σ保存下来,读取模型预测时也不会报错,当然输入参数training=False还是要的。

注意上面有个不严谨的地方,因为我的网络结构中只有bn层包含moving_mean和moving_variance,因此只根据这两个字符串做了过滤,如果你的网络结构中其他层也有这两个参数,但你不需要保存,建议使用诸如bn/moving_mean的字符串进行过滤。

 

2018.4.22更新

提供一个基于mnist的示例,供大家参考。包含两个文件,分别用于train/test。注意bn_train.py文件的51-61行,仅保存了网络中的可训练变量和bn层利用统计得到的mean和var。注意示例中需要下载mnist数据集,要保持电脑可以联网。

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_datatf.logging.set_verbosity(tf.logging.INFO)if __name__ == '__main__':mnist = input_data.read_data_sets('mnist', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])image = tf.reshape(x, [-1, 28, 28, 1])conv1 = tf.layers.conv2d(image, filters=32, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv1')bn1 = tf.layers.batch_normalization(conv1, training=True, name='bn1')pool1 = tf.layers.max_pooling2d(bn1, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool1')conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv2')bn2 = tf.layers.batch_normalization(conv2, training=True, name='bn2')pool2 = tf.layers.max_pooling2d(bn2, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool2')flatten_layer = tf.contrib.layers.flatten(pool2, 'flatten_layer')weights = tf.get_variable(shape=[flatten_layer.shape[-1], 10], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1), name='fc_weights')biases = tf.get_variable(shape=[10], dtype=tf.float32,initializer=tf.constant_initializer(0.0), name='fc_biases')logit_output = tf.nn.bias_add(tf.matmul(flatten_layer, weights), biases, name='logit_output')cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logit_output))pred_label = tf.argmax(logit_output, 1)label = tf.argmax(y_, 1)accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_label, label), tf.float32))update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)global_step = tf.get_variable('global_step', [], dtype=tf.int32,initializer=tf.constant_initializer(0), trainable=False)learning_rate = tf.train.exponential_decay(learning_rate=0.1, global_step=global_step, decay_steps=5000,decay_rate=0.1, staircase=True)opt = tf.train.AdadeltaOptimizer(learning_rate=learning_rate, name='optimizer')with tf.control_dependencies(update_ops):grads = opt.compute_gradients(cross_entropy)train_op = opt.apply_gradients(grads, global_step=global_step)tf_config = tf.ConfigProto()tf_config.gpu_options.allow_growth = Truetf_config.allow_soft_placement = Truesess = tf.InteractiveSession(config=tf_config)sess.run(tf.global_variables_initializer())# only save trainable and bn variablesvar_list = tf.trainable_variables()if global_step is not None:var_list.append(global_step)g_list = tf.global_variables()bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]var_list += bn_moving_varssaver = tf.train.Saver(var_list=var_list,max_to_keep=5)# save all variables# saver = tf.train.Saver(max_to_keep=5)if tf.train.latest_checkpoint('ckpts') is not None:saver.restore(sess, tf.train.latest_checkpoint('ckpts'))train_loops = 10000for i in range(train_loops):batch_xs, batch_ys = mnist.train.next_batch(32)_, step, loss, acc = sess.run([train_op, global_step, cross_entropy, accuracy],feed_dict={x: batch_xs, y_: batch_ys})if step % 100 == 0:  # print training infolog_str = 'step:%d \t loss:%.6f \t acc:%.6f' % (step, loss, acc)tf.logging.info(log_str)if step % 1000 == 0:  # save current modelsave_path = os.path.join('ckpts', 'mnist-model.ckpt')saver.save(sess, save_path, global_step=step)sess.close()
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datatf.logging.set_verbosity(tf.logging.INFO)if __name__ == '__main__':mnist = input_data.read_data_sets('mnist', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])image = tf.reshape(x, [-1, 28, 28, 1])conv1 = tf.layers.conv2d(image, filters=32, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv1')bn1 = tf.layers.batch_normalization(conv1, training=False, name='bn1')pool1 = tf.layers.max_pooling2d(bn1, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool1')conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[3, 3], strides=[1, 1], padding='same',activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),name='conv2')bn2 = tf.layers.batch_normalization(conv2, training=False, name='bn2')pool2 = tf.layers.max_pooling2d(bn2, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool2')flatten_layer = tf.contrib.layers.flatten(pool2, 'flatten_layer')weights = tf.get_variable(shape=[flatten_layer.shape[-1], 10], dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.1), name='fc_weights')biases = tf.get_variable(shape=[10], dtype=tf.float32,initializer=tf.constant_initializer(0.0), name='fc_biases')logit_output = tf.nn.bias_add(tf.matmul(flatten_layer, weights), biases, name='logit_output')cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logit_output))pred_label = tf.argmax(logit_output, 1)label = tf.argmax(y_, 1)accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_label, label), tf.float32))tf_config = tf.ConfigProto()tf_config.gpu_options.allow_growth = Truetf_config.allow_soft_placement = Truesess = tf.InteractiveSession(config=tf_config)saver = tf.train.Saver()if tf.train.latest_checkpoint('ckpts') is not None:saver.restore(sess, tf.train.latest_checkpoint('ckpts'))else:assert 'can not find checkpoint folder path!'loss, acc = sess.run([cross_entropy,accuracy],feed_dict={x: mnist.test.images,y_: mnist.test.labels})log_str = 'loss:%.6f \t acc:%.6f' % (loss, acc)tf.logging.info(log_str)sess.close()

tensorflow中batch normalization的用法相关推荐

  1. tensorflow没有这个参数_解决TensorFlow中Batch Normalization参数没有保存的问题

    batch normalization的坑我真的是踩到要吐了,几个月前就踩了一次,看了网上好多资料,虽然跑通了但是当时没记录下来,结果这次又遇到了.时隔几个月,已经忘得差不多了,结果又花了半天重新踩了 ...

  2. tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)

    tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...

  3. 谈Tensorflow的Batch Normalization

    tensorflow中关于BN(Batch Normalization)的函数主要有两个,分别是: tf.nn.moments tf.nn.batch_normalization 关于这两个函数,官方 ...

  4. 谈谈Tensorflow的Batch Normalization

    上海站 | 高性能计算之GPU CUDA培训 4月13-15日 三天密集式学习 快速带你晋级 阅读全文 > 正文共4488个字,4张图,预计阅读时间12分钟. tensorflow中关于BN(B ...

  5. 黑猿大叔-译文 | TensorFlow实现Batch Normalization

    正文共10537个字,8张图,预计阅读时间27分钟. 原文:Implementing Batch Normalization in Tensorflow(https://r2rt.com/implem ...

  6. 五个角度解释深度学习中 Batch Normalization为什么效果好?

    https://www.toutiao.com/a6699953853724361220/ 深度学习模型中使用Batch Normalization通常会让模型得到更好表现,其中原因到底有哪些呢?本篇 ...

  7. tensorflow中的eval的用法

    tensorflow中eval()的用法 做人工智能实验的过程中遇到这样一段代码不是很明白是什么意思: 查阅资料后明白了tensorflow中eval的用法: with tf.Session() as ...

  8. Tensorflow中placeholder函数的用法

    文章目录 简介 实现 简介 在代码层面上,每一个tensor值在graph上都是一个op,当我们将train数据分成一个个minibatch然后传入网络上进行训练时,每一个minibatch都将是一个 ...

  9. batch normalization详解

    1.引入BN的原因 1.加快模型的收敛速度 2.在一定程度上缓解了深度网络中的"梯度弥散"问题,从而使得训练深层网络模型更加容易和稳定. 3.对每一批数据进行归一化.这个数据是可以 ...

最新文章

  1. 吃了这些数据集和模型,跟 AI 学跳舞,做 TensorFlowBoys
  2. 正则表达式练习 Regex Golf
  3. Angular cli 发布自定义组件
  4. Flutter 以Dialog Activity形式展现
  5. mybatis学习(45):开启二级缓存
  6. python判断字典,列表,元组为空的方法。
  7. 【UI设计师必备】完美色彩搭配超级实用GUI的素材
  8. 图像算法三:【图像增强--空间域】图像平滑、中值滤波、图像锐化
  9. c语言解三元一次方程组_一次二次反比例,一山更比一山高?二次函数三大解析式详解...
  10. 远程执行python脚本_python 远程执行服务器上的脚本
  11. JAVA中读取配置文件以及修改配置文件
  12. 花呗的24期利息计算器_花呗24期怎么算利息怎么算(花呗借600024期要还多少利息?)...
  13. php网站 视频马赛克,如何给视频加马赛克 菜鸟也能学会的视频加马赛克解决方案...
  14. 【转载】CSS常用英文字体介绍
  15. dvi接口引脚定义_为什么越来越多人用RS232接口,却还分不清DB9、DB25的引脚定义?...
  16. 高中计算机教室标语,高中教室标语精华
  17. 网址 URL 最后的斜杠 / 是作甚的?
  18. 当年“你说什么,我都能实现”的软件公司,后来都是怎么死的?
  19. linux mint卸载桌面环境,在Linux Mint 19/Ubuntu 18.04系统上安装Deepin桌面环境的方法
  20. 成都二手房长啥样 —— 基于链家数据

热门文章

  1. python变量定义问题_python 定义n个变量方法 (变量声明自动化)
  2. 最新蚂蚁金服Java面试题:Docker+秒杀设计+RocketMQ+亿级数据设计
  3. 1.1 字符串的旋转+1.2 字符串的包含
  4. 华为算法工程师-2020届实习招聘题
  5. 中文电子病例命名实体识别项目
  6. python replace()
  7. pipelineDB学习笔记-2. Stream (流)
  8. yum list失败
  9. mysql 修复表和优化表
  10. jQuery 结构分析