想要将深度学习应用于小型图像数据集,一种常用且非常高效的方法是使用预训练网络。预训练网络(pretrained network)是一个保存好的网络,之前已在大型数据集(通常是大规模图像分类任务)上训练好。如果这个原始数据集足够大且足够通用,那么预训练网络学到的特征的空间层次结构可以有效地作为视觉世界的通用模型,因此这些特征可用于各种不同的计算机视觉问题,即使这些新问题涉及的类别和原始任务完全不同。举个例子,你在 ImageNet 上训练了一个网络(其类别主要是动物和日常用品),然后将这个训练好的网络应用于某个不相干的任务,比如在图像中识别家具。这种学到的特征在不同问题之间的可移植性,是深度学习与许多早期浅层学习方法相比的重要优势,它使得深度学习对小数据问题非常有效。
本例中,假设有一个在 ImageNet 数据集(140 万张标记图像,1000 个不同的类别)上训练好的大型卷积神经网络。ImageNet 中包含许多动物类别,其中包括不同种类的猫和狗,因此可以认为它在猫狗分类问题上也能有良好的表现。
使用预训练网络有两种方法:特征提取(feature extraction)和微调模型(fine-tuning)。两种方法我们都会介绍。首先来看特征提取。
特征提取是使用之前网络学到的表示来从新样本中提取出有趣的特征。然后将这些特征输入一个新的分类器,从头开始训练。
如前所述,用于图像分类的卷积神经网络包含两部分:首先是一系列池化层和卷积层,最后是一个密集连接分类器。第一部分叫作模型的卷积基(convolutional base)。对于卷积神经网络而言,特征提取就是取出之前训练好的网络的卷积基,在上面运行新数据,然后在输出上面训练一个新的分类器。

上图表示的为保持卷积基不变,改变分类器。
为什么仅重复使用卷积基?我们能否也重复使用密集连接分类器?一般来说,应该避免这么做。原因在于卷积基学到的表示可能更加通用,因此更适合重复使用。卷积神经网络的特征图表示通用概念在图像中是否存在,无论面对什么样的计算机视觉问题,这种特征图都可能很有用。但是,分类器学到的表示必然是针对于模型训练的类别,其中仅包含某个类别出现在整张图像中的概率信息。此外,密集连接层的表示不再包含物体在输入图像中的位置信息。密集连接层舍弃了空间的概念,而物体位置信息仍然由卷积特征图所描述。如果物体位置对于问题很重要,那么密集连接层的特征在很大程度上是无用的。
注意,某个卷积层提取的表示的通用性(以及可复用性)取决于该层在模型中的深度。模型中更靠近底部的层提取的是局部的、高度通用的特征图(比如视觉边缘、颜色和纹理),而更靠近顶部的层提取的是更加抽象的概念(比如“猫耳朵”或“狗眼睛”)。因此,如果你的新数据集与原始模型训练的数据集有很大差异,那么最好只使用模型的前几层来做特征提取,而不是使用整个卷积基。
本文中使用的卷积基为vgg16,我们来打印下vgg16的网络结构:

from tensorflow.keras.applications import VGG16'''
weights 指定模型初始化的权重检查点。include_top 指定模型最后是否包含密集连接分类器。默认情况下,这个密集连接分
类器对应于 ImageNet 的 1000 个类别。因为我们打算使用自己的密集连接分类器(只有
两个类别:cat 和 dog),所以不需要包含它。input_shape 是输入到网络中的图像张量的形状。这个参数完全是可选的,如果不传
入这个参数,那么网络能够处理任意形状的输入。
'''
conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))conv_base.summary()

网络结构为:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 150, 150, 3)]     0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 150, 150, 64)      1792
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 150, 150, 64)      36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 75, 75, 64)        0
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 75, 75, 128)       73856
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 75, 75, 128)       147584
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 37, 37, 128)       0
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 37, 37, 256)       295168
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 37, 37, 256)       590080
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 37, 37, 256)       590080
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 18, 18, 256)       0
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 18, 18, 512)       1180160
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 18, 18, 512)       2359808
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 18, 18, 512)       2359808
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 9, 9, 512)         0
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 9, 9, 512)         2359808
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 9, 9, 512)         2359808
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 9, 9, 512)         2359808
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 4, 4, 512)         0
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0

最后的特征图形状为 (4, 4, 512)。我们将在这个特征上添加一个密集连接分类器。
接下来,下一步有两种方法可供选择。
1、在你的数据集上运行卷积基,将输出保存成硬盘中的 Numpy 数组,然后用这个数据作为输入,输入到独立的密集连接分类器中(与本书第一部分介绍的分类器类似)。这种方法速度快,计算代价低,因为对于每个输入图像只需运行一次卷积基,而卷积基是目前流程中计算代价最高的。但出于同样的原因,这种方法不允许你使用数据增强。
2、在顶部添加 Dense 层来扩展已有模型(即 conv_base),并在输入数据上端到端地运行整个模型。这样你可以使用数据增强,因为每个输入图像进入模型时都会经过卷积基。但出于同样的原因,这种方法的计算代价比第一种要高很多。

1 不使用数据增强的快速特征提取

训练代码如下:

from cProfile import label
from statistics import mode
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image as kimage
from tensorflow.keras.applications import VGG16
import numpy as npconv_base=VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))#训练样本的目录
train_dir='./dataset/training_set/'
#验证样本的目录
validation_dir='./dataset/validation_set/'
#测试样本目录
test_dir='./dataset/test_set/'datagen=ImageDataGenerator(rescale=1./255)
batch_size=20
#卷积基提取特征
def extract_features(dir,sample_count):features=np.zeros(shape=(sample_count,4,4,512))labels=np.zeros(shape=(sample_count,))generator=datagen.flow_from_directory(dir,target_size=(150,150),batch_size=batch_size,class_mode='binary')i=0for inputs_batch,lables_batch in generator:features_batch=conv_base.predict(inputs_batch)features[i*batch_size:(i+1)*batch_size]=features_batchlabels[i*batch_size:(i+1)*batch_size]=lables_batchi+=1if i*batch_size>=sample_count:breakreturn features,labelsif __name__=='__main__':#提取卷积特征train_features, train_labels = extract_features(train_dir, 3200) validation_features, validation_labels = extract_features(validation_dir, 800) test_features, test_labels = extract_features(test_dir, 1000)#将特征展平 以便传入全连接层train_features=train_features.reshape(3200,-1)validation_features=validation_features.reshape(800,-1)test_features=test_features.reshape(1000,-1)#构建训练网络model=models.Sequential()model.add(layers.Dense(units=256,activation='relu',input_dim=4*4*512))model.add(layers.Dropout(rate=0.25))model.add(layers.Dense(units=1,activation='sigmoid'))model.compile(optimizer=optimizers.RMSprop(lr=2e-5),loss='binary_crossentropy',metrics=['acc'])history = model.fit(train_features, train_labels,epochs=30,batch_size=20,validation_data=(validation_features, validation_labels))test_eval=model.evaluate(x=test_features,y=test_labels)print(test_eval)acc = history.history['acc']val_acc = history.history['val_acc']loss = history.history['loss']val_loss = history.history['val_loss']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'bo', label='Training acc')plt.plot(epochs, val_acc, 'b', label='Validation acc')plt.title('Training and validation accuracy')plt.legend()plt.figure()plt.plot(epochs, loss, 'bo', label='Training loss')plt.plot(epochs, val_loss, 'b', label='Validation loss')plt.title('Training and validation loss')plt.legend()plt.show()

准确率的变化曲线如下:

损失函数的变化曲线如下:

采用这种方法的验证集准确率可以达到90%。

2 采用数据增强的特征提取

训练代码如下:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image as kimage
from tensorflow.keras.applications import VGG16
import numpy as np'''
ImageDataGenerator可完成读取图像数据
读取图像文件
将jpeg图像解码为RGB像素网络
将这些像素转换到为浮点型张量并缩放到0~1之间
'''
#训练样本的目录
train_dir='./dataset/training_set/'
#验证样本的目录
validation_dir='./dataset/validation_set/'
#测试样本目录
test_dir='./dataset/test_set/'#训练样本生成器
#注意数据增强只能用于训练数据,不能用于验证数据和测试数据
'''
进行数据增强
'''
#设置数据增强
'''
rotation_range 是角度值(在 0~180 范围内),表示图像随机旋转的角度范围。
width_shift 和 height_shift 是图像在水平或垂直方向上平移的范围(相对于总宽
度或总高度的比例)。
shear_range 是随机错切变换的角度。
zoom_range 是图像随机缩放的范围。
horizontal_flip 是随机将一半图像水平翻转。如果没有水平不对称的假设(比如真
实世界的图像),这种做法是有意义的。
fill_mode是用于填充新创建像素的方法,这些新像素可能来自于旋转或宽度/高度平移。
'''
train_datagen=ImageDataGenerator(rescale=1./255,rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')train_generator=train_datagen.flow_from_directory(directory=train_dir,target_size=(150,150),class_mode='binary',batch_size=20
)#验证样本生成器
validation_datagen=ImageDataGenerator(rescale=1./255)
validation_generator=train_datagen.flow_from_directory(directory=validation_dir,target_size=(150,150),class_mode='binary',batch_size=20
)#测试样本生成器
test_datagen=ImageDataGenerator(rescale=1./255)
test_generator=train_datagen.flow_from_directory(directory=test_dir,target_size=(150,150),class_mode='binary',batch_size=20
)if __name__=='__main__':conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))#冻结卷积基 保证其权重在训练过程中不变conv_base.trainable=False#构建训练网络model=models.Sequential()model.add(conv_base)model.add(layers.Flatten())model.add(layers.Dense(units=256,activation='relu'))model.add(layers.Dense(units=1,activation='sigmoid'))model.compile(optimizer=optimizers.RMSprop(learning_rate=1e-4),loss='binary_crossentropy',metrics=['acc'])model.summary()#拟合模型history=model.fit_generator(train_generator,steps_per_epoch=100,epochs=100,validation_data=validation_generator,validation_steps=50)#测试测试集的准确率test_eval=model.evaluate_generator(test_generator)print(test_eval)acc = history.history['acc']val_acc = history.history['val_acc']loss = history.history['loss']val_loss = history.history['val_loss']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'bo', label='Training acc')plt.plot(epochs, val_acc, 'b', label='Validation acc')plt.title('Training and validation accuracy')plt.legend()plt.figure()plt.plot(epochs, loss, 'bo', label='Training loss')plt.plot(epochs, val_loss, 'b', label='Validation loss')plt.title('Training and validation loss')plt.legend()plt.show()

使用数据增强的准确率变化曲线:

使用数据增强的损失函数变化曲线:

keras深度学习之猫狗分类三(特征提取)相关推荐

  1. PyTorch深度学习实战 | 猫狗分类

    本文内容使用TensorFlow和Keras建立一个猫狗图片分类器. 图1 猫狗图片 01.安装TensorFlow和Keras库 TensorFlow是一个采用数据流图(data flow grap ...

  2. 深度学习实现猫狗分类

    本文使用vgg网络实现对猫狗分类. 可以当做图像分类的一个baseline. 一.前期工作 数据:直接到kaggle上下载相应的数据集即可. 1.导入模块 # 数据import torchfrom t ...

  3. 基于tensorflow深度学习的猫狗分类识别

  4. 【深度学习】猫狗识别TensorFlow2实验报告

    实验二:猫狗识别 一.实验目的 利用深度学习实现猫狗动物识别,采用Kaggle提供的公开数据集,训练深度学习模型,对测试集猫狗中的图片准确分类.通过该实验掌握深度学习中基本的CV处理过程. 二.实验原 ...

  5. Tensorflow 学习之猫狗分类案例

    Tensorflow 学习之猫狗分类案例 本人一直在Cousera上学习Ng Andrew老师的Tensorflow课程,在本次猫狗分类案例当中,我对课程做了相应的记录,呈现在此,一方面加深学习的印象 ...

  6. 基于MATLAB的Alexnet迁移学习进行猫狗分类(数据集:Kaggle)

    基本介绍 软件:Matlab R2018b 数据集:Kaggle猫狗数据集 网络:AlexNet 前期准备 数据集 Kaggle猫狗数据集猫与狗用于训练的图片(train)分别12500张,每张图片的 ...

  7. 华为云深度学习kaggle猫狗识别

    使用华为云深度学习服务完成kaggle猫狗识别竞赛 参考: kaggle猫狗竞赛kernel第一名的代码 Tensorflow官网代码 华为云DLS服务github代码 1. 环境配置与数据集处理 首 ...

  8. 产品经理之深度学习促进产品之分类(三)

    微信:18091589062   高级产品经理 钱波 深度学习收到公众关注度越来越高,产业投资也越来越多,想要进入这个行业,产品经理必然要熟悉很多概念,理解这些概念,并且获得相应的案例知识.本章依然和 ...

  9. Keras框架下的猫狗识别(一)

    Tensorflow学习(使用jupyter notebook) Keras框架下的猫狗识别(二) Keras框架下的猫狗识别(三) 数据预处理 Tensorflow学习(使用jupyter note ...

  10. Keras深度学习使用VGG16预训练神经网络实现猫狗分类

    Keras深度学习使用VGG16预训练神经网络实现猫狗分类 最近刚刚接触深度学习不久,而Keras呢,是在众多的深度学习框架中,最适合上手的,而猫狗的图像分类呢,也算是计算机视觉中的一个经典案例,下面 ...

最新文章

  1. 找到表中某一列值相同的记录,而且只要其中一条记录的sql
  2. Spring Boot 2.4版本前后的分组配置变化及对多环境配置结构的影响
  3. linux问题排查常用命令详解
  4. 【Python 必会技巧】使用 Python 追加写入 json 文件或更改 json 文件中的值
  5. 谷歌浏览器桌面通知 HTML5 Chrome Desktop Notifications
  6. 华为上机--质数因子
  7. 腾讯视频 android 2倍,腾讯视频多倍速播放产品设计小结
  8. 用16进制编辑器编写一个DLL文件
  9. Mac系统如何安装Eclipse并搭建Android开发环境
  10. MATLAB中MRE误差怎么算,『怎样用excel 求RMSE(均方根误差)和MRE(平均相对误差),不知道选计算函数中的哪个,非常谢谢。』excle怎么算均方误差...
  11. 应届生offer指南
  12. 有这5类人最难成为银行的优质客户!
  13. PostgreSQL 分区表一点也不差
  14. 阿里云建站之模板建站的核心优势有哪些?
  15. 安卓开发: Jetpack compose + kotlin 实现 俄罗斯方块游戏
  16. 财务会计基础(一)概念
  17. 最新多屏群控技术---手机控制手机/苹果群控/IOS群控/实时同步操作群控功能讲解以及入门教程
  18. 十月常见算法考题、最长递增子序列,Leetcode第300题最长上升子序列的变种,我没见过乔丹,今天詹姆斯就是我的神!
  19. js_实现网页自动跳转
  20. 在Asset Store上购买unity插件

热门文章

  1. 绘图工具 Gliffy 使用简介
  2. FreeRtos在RH850 D1L芯片上移植
  3. 内连接和外连接的区别--举例
  4. JUnit4单元测试入门教程
  5. 计算机编程php网页源码水果网上销售系统mysql数据库web结构html布局
  6. 如何实现远程给PLC上下载程序?
  7. 导出数据到txt文本
  8. togaf简介(一)
  9. GfK十大洞见揭示物联网时代正全面开启
  10. java web在线购物_JAVAWEB网上商城购物系统