1. 什么是迁移学习

迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

找到目标问题的相似性,迁移学习任务就是从相似性出发,将旧领域(domain)学习过的模型应用在新领域上

2. 为什么需要迁移学习?

  • 大数据与少标注的矛盾:虽然有大量的数据,但往往都是没有标注的,无法训练机器学习模型。人工进行数据标定太耗时。
  • 大数据与弱计算的矛盾:普通人无法拥有庞大的数据量与计算资源。因此需要借助于模型的迁移。
  • 普适化模型与个性化需求的矛盾:即使是在同一个任务上,一个模型也往往难以满足每个人的个性化需求,比如特定的隐私设置。这就需要在不同人之间做模型的适配。
  • 特定应用(如冷启动)的需求

3. VGG的例子

3.1 环境

Tensorflow 2.1

3.2 准备工作

下载VGG 的权重可以自动下载也可以离线下载。
下载要训练的图片。这个里图片包含五种类型的花(‘daisy’,‘dandelion’,‘roses’,‘sunflowers’,‘tulips’)

https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
然后解压放在你的项目地下这个目录里 flower_photos

3.3 简要说明

基于VGG的迁移学习, VGG 的权重不训练了,因为已经训练好了。
但是要去掉全连接层,加上我们的全连接层就好。我们只有简单训练一下我全连接层就可以了。

3.4 训练集与验证集的结果

验证集上准确率 80%左右

87/290 [============================>.] - ETA: 0s - loss: 0.4844 - categorical_accuracy: 0.8358
288/290 [============================>.] - ETA: 0s - loss: 0.4836 - categorical_accuracy: 0.8364
289/290 [============================>.] - ETA: 0s - loss: 0.4831 - categorical_accuracy: 0.8366save_weight 36 0.5442695867802415290/290 [==============================] - 35s 119ms/step - loss: 0.4835 - categorical_accuracy: 0.8365 - val_loss: 0.5443 - val_categorical_accuracy: 0.8071

3.5 训练的完整代码

from tensorflow.keras.applications import VGG16
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import tensorflow.keras.preprocessing.image as image
import os as osvgg16=VGG16(input_shape = (224,224,3),  include_top=False)best_model =vgg16l_layer=len(best_model.layers)new_model=keras.Sequential(best_model)
for i in range(l_layer-1):best_model.layers[i].trainable = Falsenew_output=keras.layers.Dense(5,activation=tf.nn.softmax,kernel_initializer=tf.initializers.Constant(0.001))
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
new_model.add(global_average_layer)
new_model.add(new_output)new_model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.categorical_crossentropy,# metrics=['accuracy'])metrics=[keras.metrics.categorical_accuracy])new_model.summary()#雏菊,蒲公英, 郁金香
label_names={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
label_key=['daisy','dandelion','roses','sunflowers','tulips']train_datagen = image.ImageDataGenerator(rescale=1 / 255,rotation_range=40,  # 角度值,0-180.表示图像随机旋转的角度范围width_shift_range=0.2,  # 平移比例,下同height_shift_range=0.2,shear_range=0.2,  # 随机错切变换角度zoom_range=0.2,  # 随即缩放比例horizontal_flip=True,  # 随机将一半图像水平翻转validation_split=0.2,fill_mode='nearest'  # 填充新创建像素的方法
)IMG_SIZE = 224
BATCH_SIZE = 32
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
pic_folder = './flower_photos'train_generator = train_datagen.flow_from_directory(directory=pic_folder,target_size=IMG_SHAPE[:-1],color_mode='rgb',classes=None,class_mode='categorical',batch_size=10,subset='training',shuffle=True)validation_generator = train_datagen.flow_from_directory(directory=pic_folder,target_size=IMG_SHAPE[:-1],color_mode='rgb',classes=None,class_mode='categorical',batch_size=10,subset='validation',shuffle=True)current_max_loss = 9999
weight_file='./weightsf/model.h5'if os.path.isfile(weight_file):print('load weight')new_model.load_weights(weight_file)def save_weight(epoch, logs):global current_max_lossif (logs['val_loss'] is not None and logs['val_loss'] < current_max_loss):current_max_loss = logs['val_loss']print('save_weight', epoch, current_max_loss)new_model.save_weights(weight_file)batch_print_callback = keras.callbacks.LambdaCallback(on_epoch_end=save_weight
)callbacks = [tf.keras.callbacks.EarlyStopping(patience=4, monitor='val_loss'),batch_print_callback,# keras.callbacks.ModelCheckpoint('./weights/model.h5', save_best_only=True),tf.keras.callbacks.TensorBoard(log_dir='logsf')
]history = new_model.fit_generator(train_generator, steps_per_epoch=290, epochs=40, callbacks=callbacks,validation_data=validation_generator, validation_steps=70)
print(history)def show_result(history):plt.plot(history.history['loss'])plt.plot(history.history['val_loss'])plt.plot(history.history['categorical_accuracy'])plt.plot(history.history['val_categorical_accuracy'])plt.legend(['loss', 'val_loss', 'categorical_accuracy', 'val_categorical_accuracy'],loc='upper left')plt.show()print(history)show_result(history)

Tensorflow 2.1 迁移学习 基于VGG相关推荐

  1. Tensorflow官网——迁移学习和微调部分解读

    import matplotlib.pyplot as plt import numpy as np import os import tensorflow as tf#数据预处理 #数据下载 fro ...

  2. 绒毛动物探测器:通过TensorFlow.js中的迁移学习识别浏览器中的自定义对象

    目录 起点 MobileNet v1体系结构上的迁移学习 修改模型 训练新模式 运行物体识别 终点线 下一步是什么?我们可以检测到脸部吗? 下载TensorFlowJS-Examples-master ...

  3. 基于特征的对抗迁移学习论文_学界 | 综述论文:四大类深度迁移学习

    选自arXiv 作者:Chuanqi Tan.Fuchun Sun.Tao Kong. Wenchang Zhang.Chao Yang.Chunfang Liu 机器之心编译 参与:乾树.刘晓坤 本 ...

  4. 【深度学习】一文看懂 (Transfer Learning)迁移学习(pytorch实现)

    前言 你会发现聪明人都喜欢"偷懒", 因为这样的偷懒能帮我们节省大量的时间, 提高效率. 还有一种偷懒是 "站在巨人的肩膀上". 不仅能看得更远, 还能看到更多 ...

  5. 迁移学习 Transfer Learning—通俗易懂地介绍(常见网络模型pytorch实现)

    前言 你会发现聪明人都喜欢"偷懒", 因为这样的偷懒能帮我们节省大量的时间, 提高效率. 还有一种偷懒是 "站在巨人的肩膀上". 不仅能看得更远, 还能看到更多 ...

  6. 迁移学习笔记3: TCA, Finetune, 与Triplet Network(元学习)

    主要想讲的内容有: TCA, Finetune, Triplet Network 迁移学习与元学习有哪几类方法 想讲的目标(但不一定完全能写完, 下一次笔记补充): 分别属于什么方法, 处于什么位置, ...

  7. TensorFlow 笔记6--迁移学习

    TensorFlow 笔记6–迁移学习 参考文档:https://github.com/ageron/handson-ml/blob/master/11_deep_learning.ipynb 一.冻 ...

  8. 整理学习之深度迁移学习

    迁移学习(Transfer Learning)通俗来讲就是学会举一反三的能力,通过运用已有的知识来学习新的知识,其核心是找到已有知识和新知识之间的相似性,通过这种相似性的迁移达到迁移学习的目的.世间万 ...

  9. 迁移学习---迁移学习基础概念、分类

    迁移学习提出背景 在机器学习.深度学习和数据挖掘的大多数任务中,我们都会假设training和inference时,采用的数据服从相同的分布(distribution).来源于相同的特征空间(feat ...

最新文章

  1. html广告20s倒计时,一段广告倒计时退出代码
  2. oracle数据库 名词,Oracle数据库名词解释
  3. mac 上开发需要的软件
  4. SQL经典面试题及答案
  5. 使用ST05 研究product extension field deletion
  6. redis淘汰策略面试题_redis有哪些数据淘汰策略
  7. 时间同步绝对是一个大问题
  8. 菜鸟学习笔记:Java提升篇6(IO流2——数据类型处理流、打印流、随机流)
  9. Java线程之间的协作
  10. 深入浅出UML类图(二)
  11. 亚马逊CloudFront
  12. python类封装成dl_第7.9节 案例详解:Python类封装
  13. Apache Shiro(一)——Shiro简介
  14. PostGIS 报错libcrypto
  15. 小米运动蓝牙耳机使用说明书-如果第二次切换到配对状态
  16. Jetson TK1 配置
  17. node api框架_使用Web API,Node和Nexmo从浏览器发送SMS
  18. 揭秘刘德华感恩立志的少年时光
  19. 物联网信息安全复习笔记
  20. [KDL库学习]KDL库安装与使用

热门文章

  1. cocos2d-x 3.0rc开发指南:Windows下Android环境搭建
  2. Oracle 11g安装(window)的7个服务
  3. [导入]ASP常用函数:doAlert()
  4. 极度丝滑!CentOS/Unbuntu系统下快速设置虚拟内存,一行命令快速搞定!!!
  5. 【Python】Python库之Web信息提取
  6. 面向对象设计原则之2-开放闭合原则
  7. mitmproxy https抓包的原理是什么?
  8. java @Column 引发的一点思考
  9. linux shell 读取文件脚本
  10. JAMstack简介:现代Web的体系结构