【已解决】使用keras对resnet, inception3进行fine-tune出现训练集准确率很高但验证集很低的问题(BN)
最近用keras跑基于resnet50,inception3的一些迁移学习的实验,遇到一些问题。通过查看github和博客发现是由于BN层导致的,国外已经有人总结并提了一个PR(虽然并没有被merge到Keras官方库中),并写了一篇博客,也看到知乎有人翻译了一遍:Keras的BN你真的冻结对了吗
当保存模型后再加载模型去预测时发现与直接预测结果不一致也可能是BN层的问题。
总结:
keras中通常用trainable这个参数来控制某一层的权重是否更新,例如trainable可以控制BN中的是否变化。
TF为后端时,BN有一个参数是training,控制归一化时用的是当前Batch的均值和方差(训练模式)还是移动均值和方差(测试模式),这个参数由Keras的K.learning_phase控制。若只设置trainable是不会影响BN的training参数。
冻结时某一层时,我们希望这一层的状态和预训练模型中的状态一致
我们通常希望训练和测试时网络中的配置一致,但BN训练和测试时的配置是不一样的,而frozen这个行为放大了这种不一致,导致精度下降。训练时用了新数据集的均值和方差去做归一化,测试时用了旧数据集的移动均值和方差去做归一化
为了让训练和测试尽量一致,避免精度下降,有两种方案,一种是在测试时也用旧数据集的移动均值和方差
另一种方案是在训练时也只用旧数据集的移动均值和方差,这是Keras作者fchollet在GitHub issue里回复的方案:在定义模型时,手动将training参数设为False(可以通过显式设置BN的training参数,或者通过设置learning_phase来隐式改变training参数),我觉得其实这种workaround还是挺好用的,而且也更符合frozen的意图,即:
显式设置:
x = BatchNormalization()(y, training=False)
隐式设置:
# Set up inference-mode baseK.set_learning_phase(0)inputs = Input(...)x = layer1(...)(inputs)x = layer2(...)(x)...x = layerN(...)(x)# Add training-mode layersK.set_learning_phase(1)x = layerNp1(...)(x)x = layerNp2(...)(x)
不可否认的是,默认的Frozen的BN的行为在迁移学习中确实是有training这个坑存在的,个人认为fchollet的修复方法更简单一点,并且这种方式达到的效果和使用预训练网络提取特征,单独训练分类层达到的效果是一致的,当你真的想要冻结BN层的时候,这种方式更符合冻结的这个动机;但在测试时使用新数据集的移动均值和方差一定程度上也是一种domain adaption。
译文:
虽然Keras节省了我们很多编码时间,但Keras中BN层的默认行为非常怪异,坑了我(此处及后续的“我”均指原文作者)很多次。Keras的默认行为随着时间发生过许多的变化,但仍然有很多问题以至于现在Keras的GitHub上还挂着几个相关的issue。在这篇文章中,我会构建一个案例来说明为什么Keras的BN层对迁移学习并不友好,并给出对Keras BN层的一个修复补丁,以及修复后的实验效果。
1. Introduction
这一节我会简要介绍迁移学习和BN层,以及learning_phase的工作原理,Keras BN层在各个版本中的变化。如果你已经了解过这些知识,可以直接跳到第二节(译者注:1.3和1.4跟这个问题还是比较相关的,不全是背景)。
1.1 迁移学习在深度学习中非常重要
深度学习在过去广受诟病,原因之一就是它需要太多的训练数据了。解决这个限制的方法之一就是迁移学习。
假设你现在要训练一个分类器来解决猫狗二分类问题,其实并不需要几百万张猫猫狗狗的图片。你可以只对预训练模型顶部的几层卷积层进行微调。因为预训练模型是用图像数据训练的,底层卷积层可以识别线条,边缘或者其他有用的模式作为特征使用,所以可以用预训练模型的权重作为一个很好的初始化值,或者只对模型的一部分用自己数据进行训练。
Keras包含多种预训练模型,并且很容易Fine-tune,更多细节可以查阅Keras官方文档。
1.2 Batch Normalization是个啥
BN在2014年由Loffe和Szegedy提出,通过将前一层的输出进行标准化解决梯度消失问题,并减小了训练达到收敛所需的迭代次数,从而减少训练时间,使得训练更深的网络成为可能。具体原理请看原论文,简单来说,BN将每一层的输入减去其在Batch中的均值,除以它的标准差,得到标准化的输入,此外,BN也会为每个单元学习两个因子来还原输入。从下图可以看到加了BN之后Loss下降更快,最后能达到的效果也更好。
1.3 Keras中的learning_phase是啥
网络中有些层在训练时和推导时的行为是不同的。最重要的两个例子就是BN和Dropout层。对BN层,训练时我们需要用mini batch的均值和方差来缩放输入。在推导时,我们用训练时统计到的累计均值和方差对推导的mini batch进行缩放。
Keras用learning_phase机制来告诉模型当前的所处的模式。假如用户没有手工指定的话,使用fit()时,网络默认将learning_phase设为1,表示训练模式。在预测时,比如调用predict()和evaluate()方法或者在fit()的验证步骤中,网络将learning_phase设为0,表示测试模式。用户可以静态地,在model或tensor添加到一个graph中之前,将learning_phase设为某个值(虽然官方不推荐手动设置),设置后,learning_phase就不可以修改了。
1.4 不同版本中的Keras是如何实现BN的
Keras中的BN训练时统计当前Batch的均值和方差进行归一化,并且使用移动平均法累计均值和方差,给测试集用于归一化。
Keras中BN的行为变过几次,但最重要的变更发生在2.1.3这个版本。2.1.3之前,当BN被冻结时(trainable=False),它仍然会更新mini batch的移动均值和方差,并用于测试,造成用户的困扰(一副没有冻结住的样子)。
这种设计是错误的。考虑Conv1-Bn-Conv2-Conv3这样的结构,如果BN层被冻结住了,应该无事发生才对。当Conv2处于冻结状态时,如果我们部分更新了BN,那么Conv2不能适应更新过的mini-batch的移动均值和方差,导致错误率上升。
在2.1.3及之后,当BN层被设为trainable=False时,Keras中不再更新mini batch的移动均值和方差,测试时使用的是预训练模型中的移动均值和方差,从而达到冻结的效果, But is that enough? Not if you are using Transfer Learning.
2. 问题描述与解决方案
我会介绍问题的根源以及解决方案(一个Keras补丁)的技术实现。同时我也会提供一些样例来说明打补丁前后模型的准确率变化。
2.1 问题描述
2.1.3版本后,当Keras中BN层冻结时,在训练中会用mini batch的均值和方差统计值以执行归一化。我认为更好的方式应该是使用训练中得到的移动均值和方差(译者注:这样不就退回2.1.3之前的做法了)。原因和2.1.3的修复原因相同,由于冻结的BN的后续层没有得到正确的训练,使用mini batch的均值和方差统计值会导致较差的结果。
假设你没有足够的数据训练一个视觉模型,你准备用一个预训练Keras模型来Fine-tune。但你没法保证新数据集在每一层的均值和方差与旧数据集的统计值的相似性。注意哦,在当前的版本中,不管你的BN有没有冻结,训练时都会用mini-batch的均值和方差统计值进行批归一化,而在测试时你也会用移动均值方差进行归一化。因此,如果你冻结了底层并微调顶层,顶层均值和方差会偏向新数据集,而推导时,底层会使用旧数据集的统计值进行归一化,导致顶层接收到不同程度的归一化的数据。
如上图所示,假设我们从Conv K+1层开始微调模型,冻结左边1到k层。训练中,1到K层中的BN层会用训练集的mini batch统计值来做归一化,然而,由于每个BN的均值和方差与旧数据集不一定接近,在Relu处的丢弃的数据量与旧数据集会有很大区别,导致后续K+1层接收到的输入和旧数据集的输入范围差别很大,后续K+1层的初始权重不能恰当处理这种输入,导致精度下降。尽管网络在训练中可以通过对K+1层的权重调节来适应这种变化,但在测试模式下,Keras会用预训练数据集的均值和方差,改变K+1层的输入分布,导致较差的结果。
2.2 如何检查你是否受到了这个问题的影响
分别将learning_phase这个变量设置为1或0进行预测,如果结果有显著的差别,说明你中招了。不过learning_phase这个参数通常不建议手工指定,learning_phase不会改变已经编译后的模型的状态,所以最好是新建一个干净的session,在定义graph中的变量之前指定learning_phase。
检查AUC和ACC,如果acc只有50%但auc接近1(并且测试和训练表现有明显不同),很可能是BN迷之缩放的锅。类似的,在回归问题上你可以比较MSE和Spearman‘s correlation来检查。
2.3 如何修复
如果BN在测试时真的锁住了,这个问题就能真正解决。实现上,需要用trainable这个标签来真正控制BN的行为,而不仅是用learning_phase来控制。具体实现在GitHub上。
主要是通过安装补丁:作者提供了三个版本的补丁,安装自己需要的版本就可以
pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn
或者
pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@fork/keras2.2.4
用了这个补丁之后,BN冻结后,在训练时它不会使用mini batch均值方差统计值进行归一化,而会使用在训练中学习到的统计值,避免归一化的突变导致准确率的下降**。如果BN没有冻结,它也会继续使用训练集中得到的统计值。**
原文:
By applying the above fix, when a BN layer is frozen it will no longer use the mini-batch statistics but instead use the ones learned during training. As a result, there will be no discrepancy between training and test modes which leads to increased accuracy. Obviously when the BN layer is not frozen, it will continue using the mini-batch statistics during training.
2.4 评估这个补丁的影响
虽然这个补丁是最近才写好的,但其中的思想已经在各种各样的workaround中验证过了。这些workaround包括:将模型分成两部分,一部分冻结,一部分不冻结,冻结部分只过一遍提取特征,训练时只训练不冻结的部分。为了增加说服力,我会给出一些例子来展示这个补丁的真实影响。
- 我会用一小块数据来刻意过拟合模型,用相同的数据来训练和验证模型,那么在训练集和验证集上都应该达到接近100%的准确率。
- 如果验证的准确率低于训练准确率,说明当前的BN实现在推导中是有问题的。
- 预处理在generator之外进行,因为keras2.1.5中有一个相关的bug,在2.1.6中修复了。
- 在推导时使用不同的learning_phase设置,如果两种设置下准确率不同,说明确实中招了。
代码如下:
import numpy as np
from keras.datasets import cifar10
from scipy.misc import imresizefrom keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.models import Model, load_model
from keras.layers import Dense, Flatten
from keras import backend as Kseed = 42
epochs = 10
records_per_class = 100# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()def filter_resize(category):# We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.return [preprocess_input(imresize(img, (224,224)).astype('float')) for img in x[y.flatten()==category][:records_per_class]]x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)for layer in model.layers[:140]:layer.trainable = Falsefor layer in model.layers[140:]:layer.trainable = Truemodel.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), epochs=epochs, validation_data=ImageDataGenerator().flow(x, y, seed=42))# Store the model on disk
model.save('tmp.h5')# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate_generator(ImageDataGenerator().flow(x, y, seed=42)))
输出如下:
Epoch 10/10
1/7 [===>..........................] - ETA: 3s - loss: 0.0354 - acc: 1.0000
2/7 [=======>......................] - ETA: 3s - loss: 0.0381 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0354 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0828 - acc: 0.9688
5/7 [====================>.........] - ETA: 1s - loss: 0.0791 - acc: 0.9750
6/7 [========================>.....] - ETA: 0s - loss: 0.0794 - acc: 0.9792
7/7 [==============================] - 8s 1s/step - loss: 0.0704 - acc: 0.9838 - val_loss: 0.3615 - val_acc: 0.8600DYNAMIC LEARNING_PHASE
[0.3614931714534759, 0.86]STATIC LEARNING_PHASE = 0
[0.3614931714534759, 0.86]STATIC LEARNING_PHASE = 1
[0.025861846953630446, 1.0]
如上文所述,验证集准确率确实要差一些。
训练完成后,我们做了三个实验,DYNAMIC LEARNING_PHASE是默认操作,由Keras内部机制动态决定learning_phase,static两种是手工指定learning_phase,分为设为0和1.当learning_phase设为1时,验证集的效果提升了,因为模型正是使用训练集的均值和方差统计值来训练的,而这些统计值与冻结的BN中存储的值不同,冻结的BN中存储的是预训练数据集的均值和方差,不会在训练中更新,会在测试中使用。这种BN的行为不一致性导致了推导时准确率下降。
加了补丁后的效果:
Epoch 10/10
1/7 [===>..........................] - ETA: 4s - loss: 0.0251 - acc: 1.0000
2/7 [=======>......................] - ETA: 2s - loss: 0.0228 - acc: 1.0000
3/7 [===========>..................] - ETA: 2s - loss: 0.0217 - acc: 1.0000
4/7 [================>.............] - ETA: 1s - loss: 0.0249 - acc: 1.0000
5/7 [====================>.........] - ETA: 1s - loss: 0.0244 - acc: 1.0000
6/7 [========================>.....] - ETA: 0s - loss: 0.0239 - acc: 1.0000
7/7 [==============================] - 9s 1s/step - loss: 0.0290 - acc: 1.0000 - val_loss: 0.0127 - val_acc: 1.0000DYNAMIC LEARNING_PHASE
[0.012697912137955427, 1.0]STATIC LEARNING_PHASE = 0
[0.012697912137955427, 1.0]STATIC LEARNING_PHASE = 1
[0.01744014158844948, 1.0]
模型收敛得更快,改变learning_phase也不再影响模型的准确率了,因为现在BN都会使用训练集的均值和方差进行归一化。
2.5 这个修复在真实数据集上表现如何
我们用Keras预训练的ResNet50,在CIFAR10上开展实验,只训练分类层10个epoch,以及139层以后5个epoch。没有用补丁的时候准确率为87.44%,用了之后准确率为92.36%,提升了5个点。
2.6 其他层是否也要做类似的修复呢?
Dropout在训练时和测试时的表现也不同,但Dropout是用来避免过拟合的,如果在训练时也将其冻结在测试模式,Dropout就没用了,所以Dropout被frozen时,我们还是让它保持能够随机丢弃单元的现状吧。
参考文献:
https://zhuanlan.zhihu.com/p/56225304
http://blog.datumbox.com/the-batch-normalization-layer-of-keras-is-broken/
【已解决】使用keras对resnet, inception3进行fine-tune出现训练集准确率很高但验证集很低的问题(BN)相关推荐
- 使用resnet, inception3进行fine-tune出现训练集准确率很高但验证集很低的问题
向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程 公众号:datayx 最近用keras跑基于resnet50,inception3的一些迁移学习的实验,遇到一些问题 ...
- keras训练模型,训练集的准确率很高,但是测试集准确率很低的原因
今天在测试模型时发现一个问题,keras训练模型,训练集准确率很高,测试集准确率很低,因此记录一下希望能帮助大家也避坑: 首先keras本身不同的版本都有些不同的或大或小的bug,包括之前也困扰过我的 ...
- [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码
环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...
- 已解决Using TensorFlow backend.
已解决(tensorflow .keras导入报错)FutureWarning: Conversion of the second ard. In future, it will be treated ...
- 新遇到的系统编译问题!已解决!
今天做C语言题真的遇到了很多麻烦.先是很慢很慢的做了几道程序设计. 然后在结构体这块又遇到了系统编译错误. 不过经过丰哥和百度的指导,已解决,如果大家有类似问题,可以参考. ------------- ...
- npm缺少css-loader,/style-compiler,stylus-loader问题,npm没有权限无法全局更新问题【已解决】
npm缺少css-loader,/style-compiler,stylus-loader问题,npm没有权限无法全局更新问题[已解决] 参考文章: (1)npm缺少css-loader,/style ...
- Myeclipse中导入项目后java类中汉字注释出现乱码问题(已解决)
Myeclipse中导入项目后java类中汉字注释出现乱码问题(已解决) 参考文章: (1)Myeclipse中导入项目后java类中汉字注释出现乱码问题(已解决) (2)https://www.cn ...
- 【已解决】关于SQL2008 “不允许保存更改。您所做的更改要求删除并重新创建以下表。您对无法重新创建的标进行了更改或者启用了‘阻止保存要求重新创建表的更改’” 解决方案
[已解决]关于SQL2008 "不允许保存更改.您所做的更改要求删除并重新创建以下表.您对无法重新创建的标进行了更改或者启用了'阻止保存要求重新创建表的更改'" 解决方案 参考文章 ...
- 关于div的滚动条滚动到底部,内容显示不全的问题。(已解决)
关于div的滚动条滚动到底部,内容显示不全的问题.(已解决) 参考文章: (1)关于div的滚动条滚动到底部,内容显示不全的问题.(已解决) (2)https://www.cnblogs.com/th ...
最新文章
- 二、如何读入图片、显示图像?
- 三个轻量级WebServer--lighttpd,thttpd,shttpd介绍
- 《敏捷迭代开发:管理者指南》—第2章2.5节渐进开发和自适应开发
- LeetCode 152. Maximum Product Subarray--动态规划--C++,Python解法
- Myisamchk小工具使用手册
- c#大圣之路笔记——c# SqlDataReader和SqlDataAdapter区别
- This document is opened by another project error message
- python如何输入多行数据合并_Python如何将多行数据合并成一行|python如何实现excle数据合并...
- python iterable对象_一篇文章看懂 Python iterable,
- android 使用4大组件的源码,Android Jetpack架构组件之 Paging(使用、源码篇)
- jQuery调用WebService详解
- C++例4.11 求两个或三个正整数中的最大数,用带有默认参数的函数实现。
- 如何利用Arcmap模型构建器处理NC格式数据
- java网站渗透测试_如何进行Web渗透测试
- 文华财经指标公式大全,通达信指标加密破解DLL加密防破解技术方法
- python实现视频ai换脸_Python如何实现AI换脸功能 Python实现AI换脸功能代码
- 君莫笑系列视频学习(5)(终)
- unity 移动开发优化二 图形优化,脚本优化概述
- python从列表中随机抽取n个元素
- Android P 如何挂载system镜像到根目录
热门文章
- MindMaster思维导图及亿图图示会员 优惠活动
- 2021年危险化学品生产单位安全生产管理人员考试题及危险化学品生产单位安全生产管理人员最新解析
- 按快捷键进不去bios问题解决
- LeetCode题解(1425):带限制的子序列和(Python)
- 在线图片处理工具大全!ps可以下岗了。
- qiankun微前端应用间通信实现
- Vue项目安装XLSX成功后,生成项目报错:“export ‘default‘ (imported as ‘XLSX‘) was not found in ‘xlsx‘
- 谷歌搜索 site命令 指定网站搜索
- Node.js 模块化的操作,简单明了的代码帮助你明白后端的实现和前端之前的交互,及解决跨域等问题
- Linux命令手动清除缓存