以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了。在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在《实战Google深度学习框架》第二版这本书P166里只是提了一句,没有做出解答。

  书中说训练时和测试时使用的参数is_training都为True,然后给出了一个链接供参考。本人刚开始使用时也是按照书中的做法没有改动,后来从保存后的checkpoint中加载模型做预测时出了问题:当改变需要预测数据的batchsize时预测的label也跟着变,这意味着checkpoint里面没有保存训练中BN层的参数,使用的BN层参数还是从需要预测的数据中计算而来的。这显然会出问题,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。

  那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。

  所以需要在训练时把BN层的参数保存下来,然后在预测时加载,参考几位大佬的博客,有了以下训练时添加的代码:

 1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)2 with tf.control_dependencies(update_ops):3         train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)4 5 # 设置保存模型6 var_list = tf.trainable_variables()7 g_list = tf.global_variables()8 bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]9 bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
10 var_list += bn_moving_vars
11 saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

这样就可以在预测时从checkpoint文件加载BN层的参数并设置is_training=False。

最后要说的是,虽然这么做可以解决这个问题,但也可以利用预测数据来计算BN层的参数,不是说一定要保存训练时的参数,两种方案可以作为超参数来调节使用,看哪种方法的结果更好。此外是否使用保存的BN层μ和σ参数可以考虑一下test时候是单样本测试还是一组样本测试,一组样本测试时候可以重新计算μ和σ不使用保存的μ和σ(总体而言train参数在test时是true和false酌情测试选定)

感谢几位大佬的博客解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html

https://blog.csdn.net/huowa9077/article/details/79696755------未尝试 https://blog.csdn.net/zaf0516/article/details/89958962---未尝试感觉不太合理

Tensorflow训练和预测中的BN层的坑(转载)相关推荐

  1. Tensorflow训练和预测中的BN层的坑(转)-训练和测试差异性巨大

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  2. caffe中的batchNorm层(caffe 中为什么bn层要和scale层一起使用)

    caffe中的batchNorm层 链接: http://blog.csdn.net/wfei101/article/details/78449680 caffe 中为什么bn层要和scale层一起使 ...

  3. tensorflow中的BN层实现

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.keras imp ...

  4. 模型参数无法更新的原因:训练、预测中加入了print函数

    模型参数无法更新的问题排查以及解决 注释掉结构的方法 排查出错误 最终排查 进一步排查错误 loss的数值一致??? 进一步排查问题来源:预处理之中的标签处理出现错误!!! 灵感:model.trai ...

  5. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用隐藏层

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_dataINPUT_NODE = 784 # ...

  6. pytorch 模型中的bn层一键转化为同步bn(syncbn)

    pytorch 将模型中的所有BatchNorm2d layer转换为SyncBatchNorm layer: (单机多卡设置下) import torch.distributed as distdi ...

  7. 使用PaddleFluid和TensorFlow训练序列标注模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  8. 狠补基础-数学+算法角度讲解卷积层,激活函数,池化层,Dropout层,BN层,全链接层

    狠补基础-数学+算法角度讲解卷积层,激活函数,池化层,Dropout层,BN层,全链接层 在这篇文章中您将会从数学和算法两个角度去重新温习一下卷积层,激活函数,池化层,Dropout层,BN层,全链接 ...

  9. 网络骨架:Backbone(神经网络基本组成——BN层、全连接层)

    BN层 为了追求更高的性能,卷积网络被设计得越来越深,然而网络却变得难以训练收敛与调参.原因在于,浅层参数的微弱变化经过多层线性变化与激活函数后会被放大,改变了每一层的输入分布,造成深层的网络需要不断 ...

最新文章

  1. 三位数除以两位数竖式计算没有余数_二年级数学第三十课:有余数的除法 例4 试商...
  2. 阿里算法工程师公开机器学习路线,你的路走对了吗?
  3. ***客户端出现“无法完成连接尝试”的解决方法
  4. Unity 2017 Game Optimization 读书笔记 Dynamic Graphics (5) Shader优化
  5. [mybatis]缓存_缓存有关的设置以及属性
  6. 通过DOS命令nslookup查域名DNS服务器
  7. 吴裕雄 python 机器学习——多项式贝叶斯分类器MultinomialNB模型
  8. Linux下10 个最酷的 Linux 单行命令(转载)
  9. vue element-ui只有一条信息时默认选中按钮,且不能取消,多条信息时可以手动选择
  10. 【干货】神经网络初始化trick:大神何凯明教你如何训练网络!
  11. ASP.NET vNext MVC 6 电商网站开发实战
  12. 前端技术—CSS常用代码大全
  13. 近世代数——Part2 群:基础与子群 课后习题
  14. 完美世界国际版不用外挂多开的方法
  15. Vue使用Emoji表情
  16. 3dmax如何删除多余的时间帧
  17. sql server 无法为该请求检索数据
  18. 如何使用微信编辑器排版微信公众号内容?
  19. QQ互联第三方登录多应用用户登录打通
  20. 数据分析:基于Pandas的全球自然灾害分析与可视化

热门文章

  1. 数字后端基本概念-合集
  2. EXT--表单AJax提交后台,返回前端数据格式的转换
  3. OpenSuSE 网络配置
  4. 【转】CCScale9Sprite和CCControlButton
  5. java的cxf的maven_Maven+Spirng+Mybatis+CXF搭建WebService服务
  6. 你知道怎么离线安装全局 node 模块吗?
  7. 再让大家清爽一下,给加班的oscer们,哈
  8. CDN加速下载VSCode-1.57.1
  9. java类型转换 float类型转换_Java类型转换 – float(和long)到int
  10. 游戏经济系统分析:通货与交易