『教程』Batch Normalization 层介绍

基础知识

下面有莫凡的对于批处理的解释:

fc_mean,fc_var = tf.nn.moments(Wx_plus_b,axes=[0],# 想要 normalize 的维度, [0] 代表 batch 维度# 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度
)
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
epsilon = 0.001
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b,fc_mean,fc_var,shift,scale,epsilon)
# 上面那一步, 在做如下事情:
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
# Wx_plus_b = Wx_plus_b * scale + shift

tf.contrib.layers.batch_norm:封装好的批处理类

class batch_norm():'''batch normalization层'''def __init__(self, epsilon=1e-5,momentum=0.9, name='batch_norm'):'''初始化:param epsilon:    防零极小值:param momentum:   滑动平均参数:param name:       节点名称'''with tf.variable_scope(name):self.epsilon = epsilonself.momentum = momentumself.name = namedef __call__(self, x, train=True):# 一个封装了的会在内部调用batch_normalization进行正则化的高级接口return tf.contrib.layers.batch_norm(x,decay=self.momentum,        # 滑动平均参数updates_collections=None,epsilon=self.epsilon,scale=True,is_training=train,          # 影响滑动平均scope=self.name)

1.

Note: when training, the moving_mean and moving_variance need to be updated.
    By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
    need to be added as a dependency to the `train_op`. For example:
    
    ```python
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)
    ```
    
    One can set updates_collections=None to force the updates in place, but that
    can have a speed penalty, especially in distributed settings.

2.

is_training: Whether or not the layer is in training mode. In training mode
        it would accumulate the statistics of the moments into `moving_mean` and
        `moving_variance` using an exponential moving average with the given
        `decay`. When it is not in training mode then it would use the values of
        the `moving_mean` and the `moving_variance`.

tf.nn.batch_normalization:原始接口封装使用

实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装,这个类又进行了一次封装(主要是制订了一部分默认参数),实际操作时可以仅仅使用tf.contrib.layers.batch_norm函数,它已经足够方便了。

添加了滑动平均处理之后,也就是不使用封装,直接使用tf.nn.moments和tf.nn.batch_normalization实现的batch_norm函数:

def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):with tf.variable_scope(scope):# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)# gamma = tf.get_variable(name='gamma', shape=[n_out],#                         initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')ema = tf.train.ExponentialMovingAverage(decay=decay)def mean_var_with_update():ema_apply_op = ema.apply([batch_mean,batch_var])with tf.control_dependencies([ema_apply_op]):return tf.identity(batch_mean),tf.identity(batch_var)# identity之后会把Variable转换为Tensor并入图中,# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制mean,var = tf.cond(phase_train,mean_var_with_update,lambda: (ema.average(batch_mean),ema.average(batch_var)))normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)return normed

另一种将滑动平均展开了的方式,

def batch_norm(x, size, training, decay=0.999):beta = tf.Variable(tf.zeros([size]), name='beta')scale = tf.Variable(tf.ones([size]), name='scale')pop_mean = tf.Variable(tf.zeros([size]))pop_var = tf.Variable(tf.ones([size]))epsilon = 1e-3batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))def batch_statistics():with tf.control_dependencies([train_mean, train_var]):return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')def population_statistics():return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')return tf.cond(training, batch_statistics, population_statistics)

注, tf.cond:流程控制,参数一True,则执行参数二的函数,否则执行参数三函数。

『TensorFlow』批处理类相关推荐

  1. 『TensorFlow』专题汇总

    TensorFlow函数查询 『TensorFlow』0.x_&_1.x版本框架改动汇总 『TensorFlow』函数查询列表_数值计算 『TensorFlow』函数查询列表_张量属性调整 『 ...

  2. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  3. 『TensorFlow』模型保存和载入方法汇总

    一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 参数名称 功能说明 默认值 var_list Saver中存储变 ...

  4. 『TensorFlow』模型载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  5. python 动漫卡通人物图片大全,『TensorFlow』DCGAN生成动漫人物头像_下

    一.计算图效果以及实际代码实现 计算图效果 实际模型实现 相关介绍移步我的github项目. 二.生成器与判别器设计 生成器 相关参量, 噪声向量z维度:100 标签向量y维度:10(如果有的话) 生 ...

  6. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

  7. 『TensorFlow』函数查询列表_张量属性调整

    博客园 首页 新随笔 新文章 联系 订阅 管理 『TensorFlow』函数查询列表_张量属性调整 数据类型转换Casting 操作 描述 tf.string_to_number (string_te ...

  8. 『TensorFlow』第七弹_保存载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

  9. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

最新文章

  1. labview 随笔记录
  2. linux 下网络流量监控
  3. 机器学习笔记:线性规划,梯度下降
  4. [LeetCode]*105.Construct Binary Tree from Preorder and Inorder Traversal
  5. 在 Red HatAS4下添加网卡驱动!!
  6. 牛逼!Python的类和对象(长文系列第⑤篇)
  7. ScanTailor-ScanTailor 强大的多方位的满足处理扫描图片的需求
  8. c 形参 可变 入门
  9. linux安装.AppImage后缀安装包
  10. Java 初始化块
  11. OSI七(八)层结构 TCP/IP 4层结构
  12. #565. 「LibreOJ Round #10」mathematican 的二进制(期望 + 分治NTT)
  13. ubuntu安装ulipad
  14. 硬件电路设计基础知识
  15. 中学生怎样才能合理使用计算机,浅析中学生计算机的使用
  16. 解决java.lang.IllegalArgumentException: Scrapped or attached views may not be recycled. isScrap:false
  17. iOS 查看Realm数据库表
  18. 第四章第十节数据资产盘点-形成数据资产目录
  19. 解决PHP7中微信(小程序)mcrypt_module_open() 无法使用的解决方法
  20. Java架构师学习路线图

热门文章

  1. gitee创建ssh公钥
  2. java web接收tcp_Java多线程实现TCP网络Socket编程(C/S通信)
  3. java的equals什么作用_java当中equals函数的作用小结
  4. tek示波器软件_给示波器以云空间,泰克发布突破性的数据协同软件TekDrive
  5. php采到的数据自动修改入库,基于PHP的简单采集数据入库程序【续篇】_php实例...
  6. mc1.8.1怎么局域网java_同一台电脑同时装jdk1.8和jdk1.7
  7. 计算机组成原理在线实验,《计算机组成原理》实验.doc
  8. html头部协议,TCP/IP协议头部结构体
  9. Centos7 ifconfig这个命令没找到的解决方法
  10. voip和rtc_WebRTC与VoIP的对比