本项目数据集来自kaggle竞赛,地址:

https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

数据的训练集放在train文件夹下,测试集放在test文件夹下,其中train文件夹下的图片命名方式为cat.0.jpg,一直到cat.12499.jpg,然后是dog.0.jpg直到dog.12499.jpg,共25000张图片。测试集图片命名格式为1.jpg~12500.jpg共12500张图片。我们需要用训练集对模型进行训练,然后在测试集上“考试”,提交kaggle查看考试结果。

图片载入

这里介绍两种图片的载入方式。

第一种方法将所有图片加载成ndarray格式,在本系列第一讲就介绍过了,这里再讲一种简单些的处理方法。

import cv2
import numpy as np
from tqdm import tqdmdef load_train(n, img_size):X = np.zeros((n, img_size, img_size, 3), dtype=np.uint8)y = np.zeros((n, 2), dtype=np.uint8)for i in tqdm(range(n//2)): #tqdm给载入过程增加了进度条X[i] = cv2.resize(cv2.cvtColor(cv2.imread('train/cat.%d.jpg' % i), cv2.COLOR_BGR2RGB), (img_size, img_size)) #读入图片 + 转成RGB + resizeX[i+n//2] = cv2.resize(cv2.cvtColor(cv2.imread('train/dog.%d.jpg' % i), cv2.COLOR_BGR2RGB), (img_size, img_size))y[:n//2, 0] = 1 #one-hot编码0为(1,0),1为(0,1)y[n//2:, 1] = 1return X,ydef load_test(n,img_size):X = np.zeros((n, img_size, img_size, 3), dtype=np.uint8)for i in tqdm(range(n)):#test图片从1.jpg开始X[i] = cv2.resize(cv2.cvtColor(cv2.imread('test/%d.jpg' % (i+1)), cv2.COLOR_BGR2RGB), (img_size, img_size)) return X

第二种载入方式比较简单,借用Keras的ImageDataGeneratorflow_from_directory函数,直接从目录生成数据生成器,很方便。在使用之前需要按照要求布置图片集,将不同类别的图片放到不同的文件夹,具体来讲,train/cat下放所有猫的图片,train/dog下放所有狗的图片。

from keras.preprocessing.image import ImageDataGeneratorbatch_size = 16
gen = ImageDataGenerator() #实例化
train_generator = gen.flow_from_directory("train", image_size, shuffle=False, batch_size=batch_size)
test_generator = gen.flow_from_directory("test", image_size, shuffle=False, batch_size=batch_size, class_mode=None)

这里要注意,测试集由于没有label,生成test_generator的函数需加参数class_mode=None

迁移学习模型

上篇文章讨论过图片归一化预处理时,ndarray的dtype会变成float,内存占用是uint8格式的4倍,如果直接全部载入可能会造成内存错误OOM。所以我们选择在模型中进行预处理。

首先第一步,选择我们要使用的预训练模型,这里以ResNet50为例,看keras是如何进行迁移学习的。

from keras.applications import *base_model = ResNet50(input_tensor=inputs, weights='imagenet', include_top=False)

这里解释一下,keras将一些表现比较好的预训练模型做进了库里,我们可以直接用函数调用。其中input_tensor需传入一个tensor,weights可以选择None也就是只加载整个模型不加载权重,一般在训练集图片基本与'imagenet'中的class无关时我们选择从头训练模型,这里我们要借用其权重所以weights='imagenet'include_top表示是否去掉最后的全连接层,由于原模型有1000个类别而我们只有2个类别,所以需要去掉然后自己搭建最终的全连接层。

这里我们用预训练模型提取训练集的特征向量来进行预测,具体做法就是使用模型和权重让训练集正向传播,在最后一层后面(ResNet50去掉全连接最后一层为AveragePooling )进行全局平均池化(gap),得到特征向量。

这次需要搭建一个到gap的模型,并把预处理函数放入模型内。

from keras.models import *
from keras.layers import *input_tensor = Input((224, 224, 3))
inputs = input_tensor
x = Lambda(resnet50.preprocess_input)(inputs) #preprocess_input函数因预训练模型而异
base_model = ResNet50(input_tensor=x, weights='imagenet', include_top=False)
x = base_model(x)
outputs = GlobalAveragePooling2D()(x)
model = Model(inputs, outputs)

看下此时的模型结构

得到了模型,接下来正向传播得到GAP层的特征向量。

对于第一种图片载入方法,有以下两种方法提取特征向量

先得到训练数据和测试数据

X_train, y_train = load_train(25000,224)
X_test = load_test(12500,224)

第一种方法,直接predict。

train_features = model.predict(X_train)
test_features = model.predict(X_test)

第二种采用ImageDataGenerator中的flow函数,用生成器的方式,此种方式可以进行数据增强。

datagen = ImageDataGenerator()
train_generator = datagen.flow(X_train,batch_size=16, shuffle=False)
test_generator = datagen.flow(X_test,batch_size=16, shuffle=False)train_features = model.predict_generator(train_generator,  verbose=1)
test_features = model.predict_generator(test_generator,  verbose=1)

采用第二种图片载入方法,即使用flow_from_directory函数得到数据生成器,提取特征向量:

train_features = model.predict_generator(train_generator,  verbose=1)
test_features = model.predict_generator(test_generator,  verbose=1)

至此,我们提取出了本数据集的特征向量(bottleneck features),训练集特征向量的shape为(25000,2048),测试集的为(12500,2048)。

有一点需要注意,提取特征向量进行预测的做法只有在数据集与'imagenet'高度类似的情况下才可以进行,即只对特征向量到分类层的全连接进行训练,前面模型的层全部冻结。当然,为了进一步提高精度,可以用训练集在预训练模型imagenet权重的基础上继续训练,得到更好的特征向量来预测,这是提高方向。

以此时的特征向量作为训练集来进行预测

X_train = train_features
X_test = test_features

搭建最后的全连接层并进行训练

from keras.callbacks import ModelCheckpointinputs = Input((X_train.shape[1:]))
x = Dropout(0.4)(inputs)
x = Dense(2, activation='softmax')(x)
model = Model(inputs, x)from sklearn.model_selection import train_test_split
X_train,X_val,y_train,y_val = train_test_split(X_train,y_train,test_size=0.2,random_state=0)checkpointer = ModelCheckpoint(filepath='weights.best.hdf5', verbose=1, save_best_only=True) #保存最好模型权重
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])epochs = 10history = model.fit(X_train, y_train,validation_data=(X_val, y_val),epochs=epochs,callbacks=[checkpointer],verbose=1)

多模型融合

既然我们提取特征向量只训练最后的全连接层,那么是不是可以将多个模型的特征向量相串接进行融合呢?深度学习从不嫌特征多,只怕特征差。Garbage in,garbage out。

考虑到不同的预训练模型需要不同的输入图片尺寸和预处理函数,这里使用函数统一处理:

import h5py
from keras.applications import *
from keras.models import *
from keras.layers import *
from keras.preprocessing.image import ImageDataGeneratordef FeatureeExtract(MODEL,img_size,func=None)    inputs = Input((img_size,img_size,3)) #实例化一个tensorx = inputsx = Lambda(func)(x) #增加预处理函数层base_model = MODEL(input_tensor=x, weights='imagenet', include_top=False) model = Model(base_model.input, GlobalAveragePooling2D()(base_model.output))datagen = ImageDataGenerator() #后续可考虑数据增强train_generator = datagen.flow(X_train,batch_size=16, shuffle=False)test_generator = datagen.flow(X_test,batch_size=16, shuffle=False)train_features = model.predict_generator(train_generator,  verbose=1) test_features = model.predict_generator(test_generator,  verbose=1)# 保存bottleneck特征with h5py.File('%s_data.h5'%MODEL.__name__) as h:h.create_dataset("train",data = train_features)h.create_dataset("test",data = test_features)    h.create_dataset('label',data = y_train)

在keras文档中的预处理函数,根据在imagenet数据集上的预测准确率,排行前三的是InceptionResNetV2、Xception、InceptionV3,考虑用这三个模型进行融合。

FeatureExtract(InceptionResNetV2, 299, inception_resnet_v2.preprocess_input)
FeatureExtract(Xception, 299, xception.preprocess_input)
FeatureExtract(InceptionV3, 299, inception_v3.preprocess_input)

将提取出的特征向量分别保存为h5文件储存,以便我们复现现在的结果,现在提取出来并串接在一起。

X_train = []
X_test = []
for filename in ['InceptionResNetV2_data.h5','InceptionV3_data.h5','Xception_data.h5']:with h5py.File(filename,'r') as h:X_train.append(np.array(h['train']))X_test.append(np.array(h['test']))y_train = np.array(h['label'])
X_train = np.concatenate(X_train,axis=1)#将三个模型得到的X_train拼接
X_test = np.concatenate(X_test,axis=1)from sklearn.utils import shuffle
np.random.seed(10)  #可以设置随机种子,以后每次打乱都是一样的。
X_train,y_train = shuffle(X_train,y_train)

然后跟上面单模型一样进行训练即可。

10个epoch内的训练曲线,效果很好,验证准确率达到了99.82%,可以对测试集进行预测并提交kaggle进行测试查看成绩了。

model.load_weights('weights.best.hdf5')y_pred = model.predict(X_test)
y_pred = y_pred.ravel() #y_pred变为一维
y_pred = y_pred.clip(min=0.005,max=0.995) #把预测结果clip到[0.005,0.995]
Num = []
imgs = os.listdir("test/")
for i in range(len(imgs)):Num.append(int(imgs[i].split('.')[0]))import pandas as pd
df = pd.DataFrame({'label':y_pred},index=Num)
df.sort_index(inplace=True) #对DataFrame以index排序
df2 = pd.DataFrame({'id':np.arange(1,nb_test+1),'label':df['label']})
df2.to_csv('submit.csv',index=None)

这里有个小trick,由于训练集与验证集同分布,而与测试集分布略有差异,得到的预测结果需要做一个clip,将[0,1]clip到[0.005,0.995]。原因是最终loss的计算方式特性:

由上式logloss计算方法可以看出,在某个样本结果预测正确的时候logloss为0,这当然很好,但如果样本预测错误,实际为0预测为1(或者实际为1预测为0),其logloss为+∞,这当然是无法接受的。由此将预测结果限制在[0.005,0.995],就算极端的预测错误其单个logloss也只有2,对整个大局的影响不大。

最终提交kaggle的得分达到了0.03798,在public leaderboard上排名第8。

后续工作

  • 实际上在做本项目的时候我进行异常值处理,剔除了一些异常的图片。这些图片里有不是猫狗标成猫狗的,也有是猫狗但是太小无法识别的,也有一些其他情况。由于篇幅原因不在这里展开,只说下做法。利用预训练模型的imagenet权重对所有图片运行预测,然后对照imagenet中猫和狗的类别index,如果预测的top30里都没有猫和狗,把它判定为异常图片。可以用多个预训练模型得到的异常图片进行并集,最终剔除了200多张。
  • 考虑到25000张训练集图片已经足够,并且训练过程中内存占用不小,再加上现阶段已经可以达成基准模型的要求,尚未考虑数据增强。实际上可使用ImageDataGenerator对图片进行数据增强,即对图片采用旋转角度、上下左右平移等操作生成新的图片以扩充训练集。
  • 提取特征向量之前先用训练集在模型'imagenet'权重的基础上再训练,找到正确率最高的特征向量,然后再进行多模型融合。
  • 本项目融合了三个预训练模型,还可考虑更多更好模型的融合。

相关阅读:

stawary:Keras做图片分类(一):图片的导入与处理​zhuanlan.zhihu.com

stawary:Keras做图片分类(二):图片的分批读取和数据增强​zhuanlan.zhihu.com

stawary:Keras做图片分类(三):Keras CNN模型cifar-10实战​zhuanlan.zhihu.com

——————————————————————————————————————

喜欢我的文章,或者希望了解更多机器学习、人工智能等相关知识、动态的小伙伴可以关注公众号,并进交流群。这里有一群志同道合的小伙伴可以一起交流学习、转行、打比赛等诸多经验。

keras提取模型中的某一层_Keras做图片分类(四):迁移学习--猫狗大战实战相关推荐

  1. keras提取模型中的某一层_keras获得某一层或者某层权重的输出实例

    一个例子: print("Loading vgg19 weights...") vgg_model = VGG19(include_top=False, weights='imag ...

  2. keras提取模型中的某一层_keras K.function获取某层的输出操作

    如下所示: from keras import backend as K from keras.models import load_model models = load_model('models ...

  3. keras提取模型中的某一层_Tensorflow笔记:高级封装——Keras

    前言 之前在<Tensorflow笔记:高级封装--tf.Estimator>中介绍了Tensorflow的一种高级封装,本文介绍另一种高级封装Keras.Keras的特点就是两个字--简 ...

  4. python模型保存save_浅谈keras保存模型中的save()和save_weights()区别

    今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别. 我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5.同样是h5文件用save ...

  5. SSL协议工作在OSI模型中的哪一层?

    首先我们来看看什么是SSL协议(引申出TLS): SSL(Secure Sockets Layer 安全套接层),及其继任者传输层安全(Transport Layer Security,TLS)是为网 ...

  6. OSI的七层模型,网线,网卡,集线器,交换机,路由器分别工作在七层模型中的哪一层?

    OSI七层网络模型由下至上为1至7层,分别为物理层(Physical layer),数据链路层(Data link layer),网络层(Network layer),传输层(Transport la ...

  7. keras迁移学习猫狗大战-Vgg16

    VGG16模型 转载vgg16-Bubbliiiing VGG是由Simonyan 和Zisserman在文献<Very Deep Convolutional Networks for Larg ...

  8. 贝叶斯文本分类python_scikit_learn 中朴素贝叶斯算法做文本分类的 实践总结

    朴素贝叶斯算法对于分类非常高效 想了解的可以参考这篇博文:贝叶斯从浅入深详细解析,详细例子解释 - zwan0518的专栏 - 博客频道 - CSDN.NET贝叶斯从浅入深 先来做个小小总结说明 在这 ...

  9. keras从入门到放弃(十七)使用预训练网络VGG迁移学习

    VGG16网络是13层卷积层,运算起来非常的忙,如果使用CPU基本跑不了 import keras from keras import layers import numpy as np import ...

最新文章

  1. SQL中where与having的区别
  2. 【组合数学】指数生成函数 ( 指数生成函数性质 | 指数生成函数求解多重集排列 )
  3. windows 技巧篇-查看文件夹被那个进程占用,文件夹占用解除方法
  4. 3种python调用其他脚本的方法,你还知道其他的方法吗?
  5. TCP/IP 7.2 OSPF 虚链路
  6. 遍历一个数据去掉最后一个元素的样式
  7. JavaScript创建Element元素/标签的工具/方法
  8. Python 检测字符串开始值String.StartsWith 方法
  9. python 函数对象(函数式编程 lambda、map、filter、reduce)、闭包(closure)
  10. MAC下利用Github 、hexo、 多说、百度统计 建立个人博客指南
  11. matlab 定义离散函数,matlab离散点拟合函数
  12. 学习web渗透测试国内、国外在线网站
  13. 什么是IDC?数据中心该如何选择?
  14. Leetcode_128_Longest Consecutive Sequence
  15. 你需要了解的群体重测序都在这里(一)
  16. 【管理者】精读德鲁克教授《卓有成效的管理者》(一)
  17. IDEA无法启动Tomcat显示[localhost-startStop-1] org.apache.catalina.startup.HostConfig.deployDirector
  18. 关于三星SSD的固态优化
  19. 这个国外大学生的作弊神器,竟是乔布斯的老朋友做的?
  20. 纠错输出码(Error Correcting Output Code, ECOC)

热门文章

  1. [收藏]Mysql日期和时间函数
  2. Q78:规则网格(Regular Grids)——Ray Tracing中的一种加速技术
  3. 问题二十二:C++中怎么添加log开关
  4. C++折半查找的实现
  5. 如何提高使用物联网卡应用的安全性
  6. 静态文件之static+url控制系统(萌新笔记)
  7. Django MTV - 模型层 - (专题)知识要点与实战案例
  8. 安徽信息技术初中会考上机考试模拟_初中信息技术会考模拟试题
  9. Yarn无法查看日志: Aggregation may not be complete, Check back later or try the nodemanager at xxxx:xxxx
  10. mysql轻量在线管理工具_重磅推荐!我在Github找到一个超级轻量、灵活的SQL工具...