正文共10537个字,8张图,预计阅读时间27分钟。

原文:Implementing Batch Normalization in Tensorflow(https://r2rt.com/implementing-batch-normalization-in-tensorflow.html)
来源:R2RT

译者注:本文基于一个最基础的全连接网络,演示如何构建Batch Norm层、如何训练以及如何正确进行测试,玩转这份示例代码是理解Batch Norm的最好方式。

文中代码可在jupyter notebook环境下运行:

  • nn_withBN.ipynb(https://github.com/EthanYuan/TensorFlow-Zero-to-N/blob/master/TF1_4/nn_withBN.ipynb),

  • nn_withBN_ok.ipynb(https://github.com/EthanYuan/TensorFlow-Zero-to-N/blob/master/TF1_4/nn_withBN_ok.ipynb)

批标准化,是Sergey Ioffe和Christian Szegedy在2015年3月的论文BN2015(https://arxiv.org/pdf/1502.03167v3.pdf)中提出的一种简单、高效的改善神经网络性能的方法。论文BN2015中,Ioffe和Szegedy指出批标准化不仅能应用更高的学习率、具有正则化器的效用,还能将训练速度提升14倍之多。本文将基于TensorFlow来实现批标准化。

问题的提出

批标准化所要解决的问题是:模型参数在学习阶段的变化,会使每个隐藏层输出的分布也发生改变。这意味着靠后的层要在训练过程中去适应这些变化。

批标准化的概念

为了解决这个问题,论文BN2015提出了批标准化,即在训练时作用于每个神经元激活函数(比如sigmoid或者ReLU函数)的输入,使得基于每个批次的训练样本,激活函数的输入都能满足均值为0,方差为1的分布。对于激活函数σ(Wx+b),应用批标准化后变为σ(BN(Wx+b)),其中BN代表批标准化。

批标准化公式

对一批数据中的某个数值进行标准化,做法是先减去整批数据的均值,然后除以整批数据的标准差√(σ2+ε)。注意小的常量ε加到方差中是为了防止除零。给定一个数值xi,一个初始的批标准化公式如下:

上面的公式中,批标准化对激活函数的输入约束为正态分布,但是这样一来限制了网络层的表达能力。为此,可以通过乘以一个新的比例参数γ,并加上一个新的位移参数β,来让网络撤销批标准化变换。γ和β都是可学习参数。

加入γ和β后得到下面最终的批标准化公式:

基于TensorFlow实现批标准化

我们将把批标准化加进一个有两个隐藏层、每层包含100个神经元的全连接神经网络,并展示与论文BN2015中图1(b)和(c)类似的实验结果。

需要注意,此时该网络还不适合在测试期使用。后面的“模型预测”一节中将会阐释其中的原因,并给出修复版本。

Imports,config

import numpy as np, tensorflow as tf, tqdm

from tensorflow.examples.tutorials.mnist                       import input_data

import matplotlib.pyplot as plt %matplotlib inline mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


# Generate predetermined random weights so the networks are similarly initialized

w1_initial = np.random.normal(size=(784,100)).astype(np.float32) w2_initial = np.random.normal(size=(100,100)).astype(np.float32) w3_initial = np.random.normal(size=(100,10)).astype(np.float32)

# Small epsilon value for the BN transform

epsilon = 1e-3

Building the  graph

# Placeholders

x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, 10])


# Layer 1 without BNw1 = tf.Variable(w1_initial) b1 = tf.Variable(tf.zeros([100])) z1 = tf.matmul(x,w1)+b1 l1 = tf.nn.sigmoid(z1)

下面是经过批标准化的第一层:

# Layer 1 with BN

w1_BN = tf.Variable(w1_initial)

# Note that pre-batch normalization bias is ommitted. The effect of this bias would be

# eliminated when subtracting the batch mean. Instead, the role of the bias is performed

# by the new beta variable. See Section 3.2 of the BN2015 paper.

z1_BN = tf.matmul(x,w1_BN)

# Calculate batch mean and variance

batch_mean1, batch_var1 = tf.nn.moments(z1_BN,[0])

# Apply the initial batch normalizing transform

z1_hat = (z1_BN - batch_mean1) / tf.sqrt(batch_var1 + epsilon)

# Create two new parameters, scale and beta (shift)

scale1 = tf.Variable(tf.ones([100]))

beta1 = tf.Variable(tf.zeros([100]))

# Scale and shift to obtain the final output of the batch normalization

# this value is fed into the activation function (here a sigmoid)

BN1 = scale1 * z1_hat + beta1

l1_BN = tf.nn.sigmoid(BN1)


# Layer 2 without BNw2 = tf.Variable(w2_initial) b2 = tf.Variable(tf.zeros([100])) z2 = tf.matmul(l1,w2)+b2 l2 = tf.nn.sigmoid(z2)

TensorFlow提供了tf.nn.batch_normalization,我用它定义了下面的第二层。这与上面第一层的代码行为是一样的。查阅开源代码在这里(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn_impl.py#L911)。

# Layer 2 with BN, using Tensorflows built-in BN function

w2_BN = tf.Variable(w2_initial) z2_BN = tf.matmul(l1_BN,w2_BN) batch_mean2, batch_var2 = tf.nn.moments(z2_BN,[0]) scale2 = tf.Variable(tf.ones([100])) beta2 = tf.Variable(tf.zeros([100])) BN2 = tf.nn.batch_normalization(z2_BN,batch_mean2,batch_var2,beta2,scale2,epsilon) l2_BN = tf.nn.sigmoid(BN2)


# Softmaxw3 = tf.Variable(w3_initial) b3 = tf.Variable(tf.zeros([10])) y  = tf.nn.softmax(tf.matmul(l2,w3)+b3) w3_BN = tf.Variable(w3_initial) b3_BN = tf.Variable(tf.zeros([10])) y_BN  = tf.nn.softmax(tf.matmul(l2_BN,w3_BN)+b3_BN)


# Loss, optimizer and predictions cross_entropy = -tf.reduce_sum(y_*tf.log(y)) cross_entropy_BN = -tf.reduce_sum(y_*tf.log(y_BN)) train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) train_step_BN = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy_BN) correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) correct_prediction_BN = tf.equal(tf.arg_max(y_BN,1),tf.arg_max(y_,1)) accuracy_BN = tf.reduce_mean(tf.cast(correct_prediction_BN,tf.float32))

training the network

zs, BNs, acc, acc_BN = [], [], [], [] sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer())for i in tqdm.tqdm(range(40000)):    batch = mnist.train.next_batch(60)    train_step.run(feed_dict={x: batch[0], y_: batch[1]})    train_step_BN.run(feed_dict={x: batch[0], y_: batch[1]})    if i % 50 is 0:        res = sess.run([accuracy,accuracy_BN,z2,BN2],feed_dict={x: mnist.test.images, y_: mnist.test.labels})        acc.append(res[0])        acc_BN.append(res[1])        zs.append(np.mean(res[2],axis=0)) # record the mean value of z2 over the entire test set        BNs.append(np.mean(res[3],axis=0)) # record the mean value of BN2 over the entire test setzs, BNs, acc, acc_BN = np.array(zs), np.array(BNs), np.array(acc), np.array(acc_BN)

速度和精度的提升

如下所示,应用批标准化后,精度和训练速度均有可观的改善。论文BN2015中的图2显示,批标准化对于其他网络架构也同样具有重要作用。

fig, ax = plt.subplots() ax.plot(range(0,len(acc)*50,50),acc, label='Without BN') ax.plot(range(0,len(acc)*50,50),acc_BN, label='With BN') ax.set_xlabel('Training steps') ax.set_ylabel('Accuracy') ax.set_ylim([0.8,1]) ax.set_title('Batch Normalization Accuracy') ax.legend(loc=4) plt.show()

激活函数输入的时间序列图示

下面是网络第2层的前5个神经元的sigmoid激活函数输入随时间的分布情况。批标准化在消除输入的方差/噪声上具有显著的效果。

fig, axes = plt.subplots(5, 2, figsize=(6,12)) fig.tight_layout()for i, ax in enumerate(axes):    ax[0].set_title("Without BN")    ax[1].set_title("With BN")    ax[0].plot(zs[:,i])    ax[1].plot(BNs[:,i])

模型预测

使用批标准化模型进行预测时,使用批量样本自身的均值和方差会适得其反。想象一下单个样本进入我们训练的模型会发生什么?激活函数的输入将永远为零(因为我们做的是均值为0的标准化),而且无论输入是什么,我们总得到相同的结果。

验证如下:

predictions = [] correct = 0for i in range(100):    pred, corr = sess.run([tf.arg_max(y_BN,1), accuracy_BN],                         feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]})    correct += corr    predictions.append(pred[0])print("PREDICTIONS:", predictions)print("ACCURACY:", correct/100)


PREDICTIONS: [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]ACCURACY: 0.02

我们的模型总是输出8,在MNIST的前100个样本中8实际上只有2个,所以精度只有2%。

修改模型的测试期行为

为了修复这个问题,我们需要将批均值和批方差替换成全局均值和全局方差。详见论文BN2015的3.1节。但是这会造成,上面的模型想正确的工作,就只能一次性的将测试集所有样本进行预测,因为这样才能算出理想的全局均值和全局方差。

为了使批标准化模型适用于测试,我们需要在测试前的每一步批标准化操作时,都对全局均值和全局方差进行估算,然后才能在做预测时使用这些值。和我们需要批标准化的原因一样(激活输入的均值和方差在训练时会发生变化),估算全局均值和方差最好在其依赖的权重更新完成后,但是同时进行也不算特别糟,因为权重在训练快结束时就收敛了。

现在,为了基于TensorFlow来实现修复,我们要写一个batch_norm_wrapper函数,来封装激活输入。这个函数会将全局均值和方差作为tf.Variables来存储,并在做标准化时决定采用批统计还是全局统计。为此,需要一个is_training标记。当is_training == True,我们就要在训练期学习全局均值和方差。代码骨架如下:

def batch_norm_wrapper(inputs, is_training):    ...    pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)    pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)    if is_training:        mean, var = tf.nn.moments(inputs,[0])        ...        # learn pop_mean and pop_var here        ...        return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, scale, epsilon)    else:        return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon)

注意变量节点声明了 trainable = False,因为我们将要自行更新它们,而不是让最优化器来更新。

在训练期间,一个计算全局均值和方差的方法是指数平滑法,它很简单,且避免了额外的工作,我们应用如下:

decay = 0.999 # use numbers closer to 1 if you have more data

train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))

train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))

最后,我们需要解决如何调用这些训练期操作。为了完全可控,你可以把它们加入到一个graph collection(可以看看下面链接的TensorFlow源码),但是简单起见,我们将会在每次计算批均值和批方差时都调用它们。为此,当is_training为True时,我们把它们作为依赖加入了batch_norm_wrapper的返回值中。最终的batch_norm_wrapper函数如下:

# this is a simpler version of Tensorflow's 'official' version. See:

# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/layers.py#L102

def batch_norm_wrapper(inputs, is_training, decay = 0.999):    scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))    beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))    pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)    pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)

if is_training:        batch_mean, batch_var = tf.nn.moments(inputs,[0])        train_mean = tf.assign(pop_mean,       pop_mean * decay + batch_mean * (1 - decay))        train_var = tf.assign(pop_var,     pop_var * decay + batch_var * (1 - decay))

with tf.control_dependencies([train_mean, train_var]):

return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, scale, epsilon)

else:

return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon)

实现正常测试

现在为了证明修复后的代码可以正常测试,我们使用batch_norm_wrapper重新构建模型。注意,我们不仅要在训练时做一次构建,在测试时还要重新做一次构建,所以我们写了一个build_graph函数(实际的模型对象往往也是这么封装的):

def build_graph(is_training):    # Placeholders    x = tf.placeholder(tf.float32, shape=[None, 784])    y_ = tf.placeholder(tf.float32, shape=[None, 10])

# Layer 1    w1 = tf.Variable(w1_initial)    z1 = tf.matmul(x,w1)    bn1 = batch_norm_wrapper(z1, is_training)    l1 = tf.nn.sigmoid(bn1)

#Layer 2    w2 = tf.Variable(w2_initial)    z2 = tf.matmul(l1,w2)    bn2 = batch_norm_wrapper(z2, is_training)    l2 = tf.nn.sigmoid(bn2)

# Softmax    w3 = tf.Variable(w3_initial)    b3 = tf.Variable(tf.zeros([10]))    y  = tf.nn.softmax(tf.matmul(l2, w3))

# Loss, Optimizer and Predictions    cross_entropy = -tf.reduce_sum(y_*tf.log(y))    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)    correct_prediction = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

return (x, y_), train_step, accuracy, y, tf.train.Saver()


#Build training graph, train and save the trained modelsess.close() tf.reset_default_graph() (x, y_), train_step, accuracy, _, saver = build_graph(is_training=True) acc = []with tf.Session() as sess:    sess.run(tf.global_variables_initializer())

for i in tqdm.tqdm(range(10000)):        batch = mnist.train.next_batch(60)        train_step.run(feed_dict={x: batch[0], y_: batch[1]})

if i % 50 is 0:    res = sess.run([accuracy],feed_dict={x: mnist.test.images, y_: mnist.test.labels})    acc.append(res[0])    saved_model = saver.save(sess, './temp-bn-save') print("Final accuracy:", acc[-1])


Final accuracy: 0.9721

现在应该一切正常了,我们重复上面的实验:

tf.reset_default_graph() (x, y_), _, accuracy, y, saver = build_graph(is_training=False) predictions = [] correct = 0with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    saver.restore(sess, './temp-bn-save')

for i in range(100):        pred, corr = sess.run([tf.arg_max(y,1), accuracy],      feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]})        correct += corr        predictions.append(pred[0]) print("PREDICTIONS:", predictions) print("ACCURACY:", correct/100)


PREDICTIONS: [7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5, 4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2, 4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0, 2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4, 1, 7, 6, 9]ACCURACY: 0.99

原文链接:https://www.jianshu.com/p/b2d2f3c7bfc7

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org

请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看


LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础

黑猿大叔-译文 | TensorFlow实现Batch Normalization相关推荐

  1. tensorflow 的 Batch Normalization 实现(tf.nn.moments、tf.nn.batch_normalization)

    tensorflow 在实现 Batch Normalization(各个网络层输出的归一化)时,主要用到以下两个 api: tf.nn.moments(x, axes, name=None, kee ...

  2. tensorflow中batch normalization的用法

    转载网址:如果侵权,联系我删除 https://www.cnblogs.com/hrlnw/p/7227447.html https://www.cnblogs.com/eilearn/p/97806 ...

  3. 谈Tensorflow的Batch Normalization

    tensorflow中关于BN(Batch Normalization)的函数主要有两个,分别是: tf.nn.moments tf.nn.batch_normalization 关于这两个函数,官方 ...

  4. 谈谈Tensorflow的Batch Normalization

    上海站 | 高性能计算之GPU CUDA培训 4月13-15日 三天密集式学习 快速带你晋级 阅读全文 > 正文共4488个字,4张图,预计阅读时间12分钟. tensorflow中关于BN(B ...

  5. tensorflow没有这个参数_解决TensorFlow中Batch Normalization参数没有保存的问题

    batch normalization的坑我真的是踩到要吐了,几个月前就踩了一次,看了网上好多资料,虽然跑通了但是当时没记录下来,结果这次又遇到了.时隔几个月,已经忘得差不多了,结果又花了半天重新踩了 ...

  6. batch normalization详解

    1.引入BN的原因 1.加快模型的收敛速度 2.在一定程度上缓解了深度网络中的"梯度弥散"问题,从而使得训练深层网络模型更加容易和稳定. 3.对每一批数据进行归一化.这个数据是可以 ...

  7. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  8. 3.1 Tensorflow: 批标准化(Batch Normalization)

    ##BN 简介 背景 批标准化(Batch Normalization )简称BN算法,是为了克服神经网络层数加深导致难以训练而诞生的一个算法.根据ICS理论,当训练集的样本数据和目标样本集分布不一致 ...

  9. 深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方,

    用pytorch做dropout和BN时需要注意的地方 pytorch做dropout: 就是train的时候使用dropout,训练的时候不使用dropout, pytorch里面是通过net.ev ...

最新文章

  1. AngularJs Cookie 的使用
  2. 添加别名_ssh别名免密登陆服务器
  3. 【bzoj4009】[HNOI2015]接水果 DFS序+树上倍增+整体二分+树状数组
  4. python六十四课——高阶函数练习题(一)
  5. [Swift]LeetCode86. 分隔链表 | Partition List
  6. flask mysql分页,Flask分页的实现方法
  7. 安全技术可以采用计算机安全,2017年计算机三级《信息安全技术》习题
  8. nginx配置多个二级子域名
  9. SSH远程连接:简单的连接
  10. 安装maven过程并配置IDEA的全过程
  11. PMP项目管理学习心得分享
  12. php和vue实现智商在线测试题
  13. netty权威指南 微云_Netty权威指南 第2版.pdf
  14. 华为查看mpls的命令_华为BGP基本命令
  15. 基于 CIM 的智慧社区总体框架
  16. 离散数学-----自然数系统
  17. 华为GAUSS数据库常用的单行操作函数介绍
  18. whale 帷幄:营销自动化saas系统 saas营销系统是什么意思
  19. php ean13,php生成EAN_13标准条形码实例_PHP教程
  20. 如何安装Python中numpy,在DOS验证下一步步解决安装问题(DOS下从python的验证到pip验证到Numpy安装成功)

热门文章

  1. c语言常用库函数使用方法,c语言常用库函数使用方法及用途
  2. centos7 安装 php-fpm_centos7中如何安装 php-fpm(nginx)
  3. oracle恢复RAC到单机
  4. 编写python扩展模块_《深度剖析CPython解释器》27. 使用Python/C API编写扩展模块:编写扩展模块的整体流程...
  5. 生成注释_java基础- Java编程规范与注释
  6. 什么叫matlab仿真,【图片】求助帖:哪位matlab大神能告诉我这个仿真这能得出什么结论呢_matlab吧_百度贴吧...
  7. java dochain,Java filter中的chain.doFilter详解
  8. 表达式* ptr ++和++ * ptr是否相同?
  9. 工作笔记-关于工具函数的编写问题
  10. Qt组件中的双缓冲无闪烁绘图