迁移学习是一种深度学习策略,它通过将解决一个问题所获得的知识应用于另一个不同但相关的问题来重用这些知识。例如,有3种类型的花:玫瑰、向日葵和郁金香。可以使用标准的预训练模型,如VGG16/19、ResNet50或Inception v3模型(在ImageNet上预训练了1000个输出类)对花卉图像进行分类,但是由于模型没有学习这些花卉类别,因此这样的模型无法正确识别它们。换句话说,它们是模型不知道的类。

图11-13所示的是预先训练的VGG16模型错误地对花卉图像进行了分类(代码留给读者作为练习)。其中flamingo(火鹤花)的可信度为0.83,daisy(雏菊)的可信度为0.43,artichoke(菊芋)的可信度为0.33。

图11-13 预先训练的VGG16模型对花卉图像的误分类

用Keras实现迁移学习

许多综合图像分类问题进行了预处理模型的训练。在使用卷积网络对猫与狗图像分类的语境中,以卷积层作为特征提取器,以全连接层作为分类器,如图11-14所示。

图11-14 卷积神经网络体系构架

由于标准模型(如VGG-16/19)相当庞大,并且针对许多图像进行了训练,因此它们能够为不同的类学习许多不同的特性。读者可以简单地重用卷积层作为特征提取器,学习低阶和高阶图像特征,并只训练全连接层权重(参数),这就是迁移学习。

当有一个简洁的训练集时,可以使用迁移学习,所处理的问题与之前训练的模型是一样的。如果有足够的数据,则可以调整卷积层,从头开始学习所有的模型参数,以便训练模型来学习与问题相关的更健壮的特性。

现在,用迁移学习对玫瑰、向日葵和郁金香花的图像。这些图像是从TensorFlow示例图像数据集中获得的3个类各用550张图片,总共1650张,这是小数量的图片,但也是使用迁移学习的好地方。使用每个类中的500个图像进行训练,保留每个类中的其余50个图像进行验证。另外,创建一个名为flower_photos的文件夹,其中包含两个子文件夹train和valid,并将训练图像和验证图像分别保存在这些文件夹中。文件夹结构应该如图11-15所示。

图11-15 文件夹结构图

加载卷积层的权重,只针对预先训练好的VGG16模型(设置include_top=False,不加载最后两个全连接层),将它当作分类器。注意,最后一层的形状尺寸为7×7×512。

使用ImageDataGenerator类来加载图像,并使用flow_from_directory()函数来生成成批的图像和标签。还将使用model.predict()函数来通过网络传递图像,得到一个7×7×512维的张量,然后将张量重新塑造成一个向量,并以同样的方式找到validation_features。

也就是说,用Keras实现迁移学习对VGG16模型进行部分训练,即它只会根据所拥有的训练图像来学习全连接层的权重,然后用它来预测类,如下面的代码所示。

from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras import models, layers, optimizers
from keras.layers.normalization import BatchNormalization
from keras.preprocessing.image import load_img
# train only the top FC layers of VGG16, use weights learnt with ImageNet for the
convolution layers
vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=(224,224, 3))
# the directory flower_photos is assumed to be on the current path
train_dir = './flower_photos/train'
validation_dir = './flower_photos/valid'
n_train = 500*3
n_val = 50*3
datagen = ImageDataGenerator(rescale=1./255)
batch_size = 25
train_features = np.zeros(shape=(n_train, 7, 7, 512))
train_labels = np.zeros(shape=(n_train,3))
train_generator = datagen.flow_from_directory(train_dir, target_size=(224,224),
batch_size=batch_size, class_mode='categorical', shuffle=True)
i = 0
for inputs_batch, labels_batch in train_generator:
features_batch = vgg_model.predict(inputs_batch)
train_features[i * batch_size : (i + 1) * batch_size] = features_batch
train_labels[i * batch_size : (i + 1) * batch_size] = labels_batch
i += 1
if i * batch_size >= n_train: break
train_features = np.reshape(train_features, (n_train, 7 * 7 * 512))
validation_features = np.zeros(shape=(n_val, 7, 7, 512))
validation_labels = np.zeros(shape=(n_val,3))
validation_generator = datagen.flow_from_directory(validation_dir,
target_size=(224, 224),
batch_size=batch_size, class_mode='categorical', shuffle=False)
i = 0
for inputs_batch, labels_batch in validation_generator:
features_batch = vgg_model.predict(inputs_batch)
validation_features[i * batch_size : (i + 1) * batch_size] =features_batch
validation_labels[i * batch_size : (i + 1) * batch_size] = labels_batch
i += 1
if i * batch_size >= n_val: break
validation_features = np.reshape(validation_features, (n_val, 7 * 7 * 512))

接下来,使用带有3个类的softmax输出层的简单前馈网络创建模型。然后必须对模型进行训练,如下面的代码所示。可以看到,在Keras中训练一个网络就像调用model.fit()函数一样简单。为了检验模型的性能,先看看哪些图片被错误分类。

# now learn the FC layer parameters by training with the images we have
model = models.Sequential()
model.add(layers.Dense(512, activation='relu', input_dim=7 * 7 * 512))
model.add(BatchNormalization())
model.add(layers.Dropout(0.5))
model.add(layers.Dense(3, activation='softmax'))
model.compile(optimizer=optimizers.Adam(lr=1e-5),
loss='categorical_crossentropy', metrics=['acc'])
history = model.fit(train_features, train_labels, epochs=20,batch_size=batch_size,
validation_data=(validation_features,validation_labels))
filenames = validation_generator.filenames
ground_truth = validation_generator.classes
label2index = validation_generator.class_indices
# Getting the mapping from class index to class label
idx2label = dict((v,k) for k,v in label2index.items())
predictions = model.predict_classes(validation_features)
prob = model.predict(validation_features)
errors = np.where(predictions != ground_truth)[0]
print("No of errors = {}/{}".format(len(errors),n_val))
# No of errors = 13/150
pylab.figure(figsize=(20,12))
for i in range(len(errors)):
pred_class = np.argmax(prob[errors[i]])
pred_label = idx2label[pred_class]
original =load_img('{}/{}'.format(validation_dir,filenames[errors[i]]))
pylab.subplot(3,5,i+1), pylab.imshow(original), pylab.axis('off')
pylab.title('Original
label:{}\nPrediction:{}\nconfidence:{:.3f}'.format(
filenames[errors[i]].split('\\')[0], pred_label,
prob[errors[i]][pred_class]), size=15)
pylab.show()

运行上述代码,输出结果如图11-16所示。可以看到,在迁移学习模型的150幅图像中,验证数据集中有13幅图像被错误分类。

图11-16 迁移学习模型对验证数据集部分图像分类出错

最初使用的花卉图像(它们是验证数据集的一部分,没有用于训练迁移学习模型)现在被正确分类,如图11-17所示(代码实现作为练习留给读者)。

图11-17 利用迁移学习模型对之前错误分类图像的正确分类

本文摘自《Python图像处理实战》

[印度] 桑迪潘·戴伊(Sandipan Dey) 著,陈盈,邓军 译

  • 图像处理,计算机视觉人脸识别图像修复
  • 编程入门教程书籍零基础,深度学习爬虫
  • 用流行的Python图像处理库、机器学习库和深度学习库解决图像处理问题。

本书介绍如何用流行的Python 图像处理库、机器学习库和深度学习库解决图像处理问题。先介绍经典的图像处理技术,然后探索图像处理算法的演变历程,始终紧扣图像处理以及计算机视觉与深度学习方面的最新进展。全书共12 章,涵盖图像处理入门基础知识、应用导数方法实现图像增强、形态学图像处理、图像特征提取与描述符、图像分割,以及图像处理中的经典机器学习方法等内容。

本书适合Python 工程师和相关研究人员阅读,也适合对计算机视觉、图像处理、机器学习和深度学习感兴趣的软件工程师参考。

如果想进一步学习迁移学习,推荐《Python迁移学习》迪潘简·撒卡尔(Dipanjan Sarkar) 著,张浩然 译

  • 使用TensorFlow和Keras实现高级深度学习和神经网络模型

迁移学习是机器学习技术的一种,它可以从一系列机器学习问题的训练中获得知识,并将这些知识用于训练其他相似类型的问题。

本书分为3个部分:第1部分是深度学习基础,介绍了机器学习的基础知识、深度学习的基础知识和深度学习的架构;第2部分是迁移学习精要,介绍了迁移学习的基础知识和迁移学习的威力;第3部分是迁移学习案例研究,介绍了图像识别和分类、文本文档分类、音频事件识别和分类、DeepDream算法、风格迁移、自动图像扫描生成器、图像着色等内容。

本书适合数据科学家、机器学习工程师和数据分析师阅读,也适合对机器学习和迁移学习感兴趣的读者阅读。在阅读本书之前,希望读者对机器学习和Python编程有基本的掌握。

什么是迁移学习?什么时候使用迁移学习?相关推荐

  1. 读“基于深度学习的图像风格迁移研究综述”有感

    前言 关于传统非参数的图像风格迁移方法和现如今基于深度学习的图像风格迁移方法. 基于深度学习的图像风格迁移方法:基于图像迭代和模型迭代的两种方法的优缺点. 基于深度学习的图像风格迁移方法的存在问题及其 ...

  2. ICML 2020 | 小样本学习首次引入领域迁移技术,屡获新SOTA结果

    2020-06-22 02:19:23 本文介绍的是ICML2020论文<Few-Shot Learning as Domain Adaptation: Algorithm and Analys ...

  3. 深度学习不得不会的迁移学习Transfer Learning

    http://blog.itpub.net/29829936/viewspace-2641919/ 2019-04-18 10:04:53 目录 一.概述 二.什么是迁移学习? 2.1 模型的训练与预 ...

  4. 迁移学习中的负迁移:综述

    点击上面"脑机接口社区"关注我们 更多技术干货第一时间送达 导读 迁移移学习(TL)试图利用来自一个或多个源域的数据或知识来促进目标域的学习.由于标记成本.隐私问题等原因,当目标域 ...

  5. 深度学习实战-图像风格迁移

    图像风格迁移 文章目录 图像风格迁移 简介 画风迁移 图像风格捕捉 图像风格迁移 图像风格内插 补充说明 简介 利用卷积神经网络实现图像风格的迁移. 画风迁移 简单来说就是将另一张图像的绘画风格在不改 ...

  6. 量化交易有因子动物园 深度学习里有模型动物园(ModelZoo)又叫模型市场基于深度学习的增量学习,迁移学习等技术发展而来【调研】

    前言 随着迁移模型的概念流行起来,就像快乐会传染样,自然语言处理,计算机视觉,生成模型,强化学习,非监监督学习,语音识别 这几个领域内部产生了大量的可复用可迁移学习的基础模型,领域之间的方法也在互相学 ...

  7. OUC暑期培训(深度学习)——第五周学习记录:ShuffleNet EfficientNet 迁移学习

    第五周学习:ShuffleNet & EfficientNet & 迁移学习 Part 1 视频学习 1.ShuffleNet V1 ShuffleNet和MobileNet一样想,应 ...

  8. 3.2 实战项目二(手工分析错误、错误标签及其修正、快速地构建一个简单的系统(快速原型模型)、训练集与验证集-来源不一致的情况(异源问题)、迁移学习、多任务学习、端到端学习)

    手工分析错误 手工分析错误的大多数是什么 猫猫识别,准确率90%,想提升,就继续猛加材料,猛调优?     --应该先做错误分析,再调优! 把识别出错的100张拿出来, 如果发现50%是"把 ...

  9. 数据异质性会影响深度学习变化检测模型的迁移能力,请列出提升模型迁移性的解决思路...

    数据异质性会导致深度学习变化检测模型的迁移能力降低.可以采用以下解决思路来提升模型的迁移性: 数据预处理: 对于不同类型的数据进行标准化处理,使得模型能够更好的适应不同的数据类型. 模型正则化: 通过 ...

  10. TensorFlow练手项目三:使用VGG19迁移学习实现图像风格迁移

    使用VGG19迁移学习实现图像风格迁移 2020.3.15 更新: 使用Python 3.7 + TensorFlow 2.0的实现: 有趣的深度学习--使用TensorFlow 2.0实现图片神经风 ...

最新文章

  1. 美媒人工智能(AI)代表了计算的优点,没有人类推理的缺点
  2. golang中的strings.Join
  3. 关于ORACLE 10g中“ORA-12541:TNS:no listener”的问题解决方案
  4. 分布式服务常见问题—访问量统计如何做?
  5. android m权限工具类,android M权限适配,简单工具类
  6. Spring - shortcuts
  7. log添加 oracle redo_Oracle更改redo log大小 or 增加redo log组
  8. 从零开始pytorch手写字母识别
  9. 常用的Wi-Fi产品调试测试工具
  10. 【SSM】SSM框架介绍
  11. GhostXP_SP3 PCOS技术快速装机版 5.7(优化细节 力争完美)
  12. 神经网络控制系统的特点,神经网络控制的优点
  13. 米家app扫描不到石头机器人_石头扫地机器人T7评测:能驾驭豪宅的高端旗舰?...
  14. 计算机算样本标准偏差,计算器中的总体标准差和样本标准差有什么区别
  15. 大厂软件测试流程完整版
  16. 极路由 刷linux,极路由1s刷openwrt不完全教程
  17. 解决cumcm17问题的代码记录(待改正)
  18. 【哪吒社区Java技能树 打卡day2】Java学习路线总结(思维导图篇)
  19. 亵渎小说介绍_从PHP过渡到:亵渎神灵,虚张声势还是常识?
  20. 使用MATLABsimulinkstm32mat_targetstm32cubemx开发stm32

热门文章

  1. Vulkan是什么?和我一起完成一个简单的Vulkan应用程序
  2. Android Studio校园二手交易市场app
  3. 【Redis笔记】一起学习Redis | 如何利用Redis实现一个分布式锁?
  4. 开机后黑屏看不到桌面_电脑开机黑屏只有鼠标怎么办?电脑开机后不显示桌面的多种解决方法...
  5. 苹果审核团队_如何才能跟 App Store 审核团队有效沟通?
  6. CocosEditor For JS (Cocos2d-JS) 教程聚合和代码下载
  7. HTML5—网页三兄弟
  8. 红黑树 插入算法(一)
  9. 数据结构算法学习 之 红黑树
  10. 【Linux】一张图让你读懂Linux内核运行原理