以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了。在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在《实战Google深度学习框架》第二版这本书P166里只是提了一句,没有做出解答。

  书中说训练时和测试时使用的参数is_training都为True,然后给出了一个链接供参考。本人刚开始使用时也是按照书中的做法没有改动,后来从保存后的checkpoint中加载模型做预测时出了问题:当改变需要预测数据的batchsize时预测的label也跟着变,这意味着checkpoint里面没有保存训练中BN层的参数,使用的BN层参数还是从需要预测的数据中计算而来的。这显然会出问题,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。

  那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。

  所以需要在训练时把BN层的参数保存下来,然后在预测时加载,参考几位大佬的博客,有了以下训练时添加的代码:

![复制代码](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9jb21tb24uY25ibG9ncy5jb20vaW1hZ2VzL2NvcHljb2RlLmdpZg)
 1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)2 with tf.control_dependencies(update_ops):3         train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)4 5 # 设置保存模型6 var_list = tf.trainable_variables()7 g_list = tf.global_variables()8 bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]9 bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
10 var_list += bn_moving_vars
11 saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
![复制代码](https://imgconvert.csdnimg.cn/aHR0cHM6Ly9jb21tb24uY25ibG9ncy5jb20vaW1hZ2VzL2NvcHljb2RlLmdpZg)

这样就可以在预测时从checkpoint文件加载BN层的参数并设置is_training=False。

最后要说的是,虽然这么做可以解决这个问题,但也可以利用预测数据来计算BN层的参数,不是说一定要保存训练时的参数,两种方案可以作为超参数来调节使用,看哪种方法的结果更好。

感谢几位大佬的博客解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html

Tensorflow训练和预测中的BN层的坑(转)-训练和测试差异性巨大相关推荐

  1. Tensorflow训练和预测中的BN层的坑(转载)

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  2. caffe中的batchNorm层(caffe 中为什么bn层要和scale层一起使用)

    caffe中的batchNorm层 链接: http://blog.csdn.net/wfei101/article/details/78449680 caffe 中为什么bn层要和scale层一起使 ...

  3. tensorflow中的BN层实现

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.keras imp ...

  4. 模型参数无法更新的原因:训练、预测中加入了print函数

    模型参数无法更新的问题排查以及解决 注释掉结构的方法 排查出错误 最终排查 进一步排查错误 loss的数值一致??? 进一步排查问题来源:预处理之中的标签处理出现错误!!! 灵感:model.trai ...

  5. pytorch 模型中的bn层一键转化为同步bn(syncbn)

    pytorch 将模型中的所有BatchNorm2d layer转换为SyncBatchNorm layer: (单机多卡设置下) import torch.distributed as distdi ...

  6. 狠补基础-数学+算法角度讲解卷积层,激活函数,池化层,Dropout层,BN层,全链接层

    狠补基础-数学+算法角度讲解卷积层,激活函数,池化层,Dropout层,BN层,全链接层 在这篇文章中您将会从数学和算法两个角度去重新温习一下卷积层,激活函数,池化层,Dropout层,BN层,全链接 ...

  7. 网络骨架:Backbone(神经网络基本组成——BN层、全连接层)

    BN层 为了追求更高的性能,卷积网络被设计得越来越深,然而网络却变得难以训练收敛与调参.原因在于,浅层参数的微弱变化经过多层线性变化与激活函数后会被放大,改变了每一层的输入分布,造成深层的网络需要不断 ...

  8. 深度学习 | BN层原理浅谈

    深度学习 | BN层原理浅谈 文章目录 深度学习 | BN层原理浅谈 一. 背景 二. BN层作用 三. 计算原理 四. 注意事项 为什么BN层一般用在线性层和卷积层的后面,而不是放在激活函数后 为什 ...

  9. 模型压缩(一)通道剪枝-BN层

    论文:https://arxiv.org/pdf/1708.06519.pdf BN层中缩放因子γ与卷积层中的每个通道关联起来.在训练过程中对这些比例因子进行稀疏正则化,以自动识别不重要的通道.缩放因 ...

最新文章

  1. 使用eBPFBCC提取内核网络流量信息
  2. 安利 10 个 Intellij IDEA 实用插件
  3. 【html+css练习】小白使用html+css模拟音乐播放器构造了网页音乐播放器--1
  4. linux中利用脚本编写数组,shell脚本编程之数组
  5. Android开发实战二之Hello Android实例
  6. html5 开发工具_前端HTML5开发工具有哪些呢?
  7. 常见数据类型的手机二维码生成与识别格式参考
  8. Home Assistant系列 -- 设置界面语言与地理位置
  9. 手把手教你学习DSP_硬件设计
  10. Steam上传游戏包体的三种方法
  11. Linux中有关文件权限的详解
  12. 区别samtools faid产生的.fai文件功能和bwa index 产生的四个文件的功能
  13. 【Python 实战基础】Pandas如何从股票数据找出收盘价最低行
  14. 万字长文!推荐一款日志切割神器
  15. Mysql 计算当前日期是本月第几周:一个自定义算法
  16. IntelliJ idea (最新版)激活方法
  17. WEB API新增整理(三)
  18. 页面自动化之 selenium(一) 自动签到与签退
  19. 集成产品开发,不让你的产品变现脚踩西瓜皮
  20. 天啊!你居然还不知道如何防止缓存击穿?用布隆过滤器啊!!!

热门文章

  1. ahjesus自定义隐式转换和显示转换
  2. lync登录时一直停留在登录界面
  3. 根据共享文件夹的权限进行自动映射网络驱动器
  4. XP---VS05---部署个人网站初学者工具包---方案
  5. mysql5.7.11升级_MySQL升级从5.6.18到5.7.11
  6. 为什么红黑树查询快_为什么工程中都喜欢用红黑树,而不是其他平衡二叉查找树呢?...
  7. arcgis图层叠加不匹配
  8. 用C#读取数码相片的EXIF信息(一)
  9. android:id=@android:id/list,Logcat错误 - 内容必须有一个ListView的id属性是'android.R.id.list'...
  10. git 改local branch名字_最好的Git分支管理教程