中:https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

英:https://blog.csdn.net/baimafujinji/article/details/80743814

小例子:

首先载入VGG16的权重
接下来在初始化好的VGG网络上添加我们预训练好的模型
最后将最后一个卷积块的层数冻结,然后以很低的学习率开始训练(我们只选择最后一个卷积块进行训练,是因为训练样本很少,而VGG模型层数很多,全部训练肯定不能训练好,会过拟合。 其次fine-tune时由于是在一个已经训练好的模型上进行的,故权值更新应该是一个小范围的,以免破坏预训练好的特征)
首先构造VGG16模型:

model = Sequential()
model.add(ZeroPadding2D((1, 1), input_shape=(3, img_width, img_height)))model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_2'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_2'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))

加载VGG16训练好的权重(我们只要全连接层以前的权重):

assert os.path.exists(weights_path), 'Model weights not found (see "weights_path" variable in script).'
f = h5py.File(weights_path)
for k in range(f.attrs['nb_layers']):if k >= len(model.layers):# we don't look at the last (fully-connected) layers in the savefilebreakg = f['layer_{}'.format(k)]weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]model.layers[k].set_weights(weights)
f.close()
print('Model loaded.')

然后在VGG16结构基础上添加一个简单的分类器及预训练好的模型:

top_model = Sequential()
top_model.add(Flatten(input_shape=model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(1, activation='sigmoid'))
top_model.load_weights(top_model_weights_path)
# add the model on top of the convolutional base
model.add(top_model)

把随后一个卷积块前的权重设置为不训练:

for layer in model.layers[:25]:layer.trainable = False
model.compile(loss='binary_crossentropy',optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),metrics=['accuracy'])

Keras:Transfer learning相关推荐

  1. AI入门:Transfer Learning(迁移学习)

    迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中 Pokemon Dataset 通过网络上收集宝可梦的图片,制作图像分类数据集.我收集了5种 ...

  2. 深度学习不得不会的迁移学习Transfer Learning

    http://blog.itpub.net/29829936/viewspace-2641919/ 2019-04-18 10:04:53 目录 一.概述 二.什么是迁移学习? 2.1 模型的训练与预 ...

  3. 迁移学习(transfer learning)与finetune的关系?【finetune只是transfer learning的一种手段】

    目录 1.迁移学习简介 2.为什么要迁移学习? 3.迁移学习的几种方式 1)Transfer Learning: 2)Extract Feature Vector: 3)Fine-tune: 4.三种 ...

  4. 迁移学习(Transfer Learning)概述及代码实现(full version)

    基于PaddlePaddle的李宏毅机器学习--迁移学习 大噶好,我是黄波波.希望能和大家共进步,错误之处恳请指出! 百度AI Studio个人主页, 我在AI Studio上获得白银等级,点亮2个徽 ...

  5. 迁移学习(Transfer Learning)概述及代码实现

    基于PaddlePaddle的李宏毅机器学习--迁移学习 大噶好,我是黄波波,希望能和大家共进步,错误之处恳请指出! 百度AI Studio个人主页, 我在AI Studio上获得白银等级,点亮2个徽 ...

  6. 连续学习入门(一):Continual Learning / Incremental Learning / Life Long Learning 问题背景及研究挑战

    说明:本系列文章若无特别说明,则在技术上将 Continual Learning(连续学习)等同于 Incremental Learning(增量学习).Lifelong Learning(终身学习) ...

  7. 声音克隆_论文翻译:2019_Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis

    论文:2019_Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis 翻译总结:只需 ...

  8. 简述迁移学习(Transfer Learning)

    目录 定义 变形 例子 总结 定义 迁移学习(Transfer learning) 顾名思义就是把已训练好的模型(预训练模型)参数迁移到新的模型来帮助新模型训练.考虑到大部分数据或任务都是存在相关性的 ...

  9. 基于Keras Application和Densenet迁移学习(transfer learning)的乳腺癌图像分类模型(良性、恶性)

    基于Keras Application和Densenet迁移学习(transfer learning)的乳腺癌图像分类模型(良性.恶性) 概论: 美国癌症学会官方期刊发表<2018年全球癌症统计 ...

最新文章

  1. 设计模式之抽象工厂模式(Abstract Factory)摘录
  2. 在vue2.x项目中怎么引入Element UI
  3. python括号生成_Python括号生成器的问题
  4. redis List的用途及常用命令
  5. 使用 Carla 和 Python 的自动驾驶汽车第 4 部分 —— 强化学习Action
  6. [ASP.NET Core 3.1]浏览器嗅探解决部分浏览器丢失Cookie问
  7. python执行oracle命令_如何使用cx\U Oracle运行非查询sql命令?
  8. java supplier_现代化的 Java (二十一)——宏和生成宏
  9. MySQL服务器的启动与停止
  10. 软件设计师考试大纲2018
  11. Access宏学习总结
  12. 计算机科学的主要目标,学习计算机的主要目的是什么?
  13. ftp工具FileZilla下载安装配置
  14. 基于ssm Vue+elementui农家乐管理系统java 项目源码介绍
  15. 为什么业务中很少用到设计模式
  16. 录音音频如何转换为mp3格式
  17. Win7旗舰版开机不需要输入密码登录
  18. 【Unity入门】软件Unity Hub和Unity的安装和简单尝试
  19. 功能超级强大的计算器程序 免费开源 全部源码
  20. Android4.2 Quectel EC20 R2.1模块移植

热门文章

  1. SQL注入:6、SQLMAP的使用
  2. 操作系统之进程管理:12、生产者消费者问题和多级生产者多级消费者问题
  3. 数组经典题之杨辉三角变形
  4. 3-6:常见任务和主要工具之正则表达式
  5. Python 参数传入sys.argv和getopt.getopt()的用法
  6. centos下 Tcpreplay 重放数据(流量采集重放)
  7. PySpider问题记录http599
  8. golang——strconv包常用函数
  9. 数据结构期末复习(に)--链式栈定义及使用
  10. thinkphp3.2 代码生成并点击验证码