[开发技巧]·keras如何冻结网络层

在使用keras进行进行finetune有时需要冻结一些网络层加速训练

keras中提供冻结单个层的方法:layer.trainable = False

这个应该如何使用?下面给大家一些例子

1.冻结model所有网络层

base_model = DenseNet121(include_top=False, weights="imagenet",input_shape=(224, 224, 3))

for layer in base_model.layers:

layer.trainable = False

2.冻结model某些网络层

在keras中除了从model.layers取得layer,我们还可以通过model.get_layer(layer_name)获取。

base_model = VGG19(weights='imagenet')

base_model.get_layer('block4_pool').trainable = False

你可能会疑问,我不知道layer_name该怎么办呢?答案是通过model.summary()输出一下,

如下所示,最左面一列就是layer_name(注意是括号外面的>

__________________________________________________________________________________________________

Layer (type) Output Shape Param # Connected to

==================================================================================================

input_1 (InputLayer) (None, 224, 224, 3) 0

__________________________________________________________________________________________________

NASNet (Model) (None, 7, 7, 1056) 4269716 input_1[0][0]

__________________________________________________________________________________________________

resnet50 (Model) (None, 7, 7, 2048) 23587712 input_1[0][0]

__________________________________________________________________________________________________

densenet121 (Model) (None, 7, 7, 1024) 7037504 input_1[0][0]

__________________________________________________________________________________________________

global_average_pooling2d_1 (Glo (None, 1056) 0 NASNet[1][0]

__________________________________________________________________________________________________

global_average_pooling2d_2 (Glo (None, 2048) 0 resnet50[1][0]

__________________________________________________________________________________________________

global_average_pooling2d_3 (Glo (None, 1024) 0 densenet121[1][0]

__________________________________________________________________________________________________

concatenate_5 (Concatenate) (None, 4128) 0 global_average_pooling2d_1[0][0]

global_average_pooling2d_2[0][0]

global_average_pooling2d_3[0][0]

__________________________________________________________________________________________________

dropout_1 (Dropout) (None, 4128) 0 concatenate_5[0][0]

__________________________________________________________________________________________________

classifier (Dense) (None, 200) 825800 dropout_1[0][0]

==================================================================================================

Total params: 35,720,732

Trainable params: 825,800

Non-trainable params: 34,894,932

__________________________________________________________________________________________________

None

hope this helps

本文同步分享在 博客“小宋是呢”(CSDN)。

如有侵权,请联系 support@oschina.cn 删除。

本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

keras冻结_[开发技巧]·keras如何冻结网络层相关推荐

  1. keras 升级_如何入门Keras?

    Keras 是一款用 Python 编写的高级神经网络 API,由François Chollet发明,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行.Keras 的开 ...

  2. python 滤波_[开发技巧]·Python极简实现滑动平均滤波(基于Numpy.convolve)

    [开发技巧]·Python极简实现滑动平均滤波(基于Numpy.convolve) ​ 1.滑动平均概念 滑动平均滤波法(又称递推平均滤波法),时把连续取N个采样值看成一个队列 ,队列的长度固定为N ...

  3. keras冻结_【连载】深度学习第22讲:搭建一个基于keras的迁移学习花朵识别系统(附数据)...

    在上一讲中,和大家探讨了迁移学习的基本原理,并利用 keras 基于 VGG16 预训练模型简单了在 mnist 数据集上做了演示.鉴于大家对于迁移学习的兴趣,本节将继续基于迁移学习利用一些花朵数据搭 ...

  4. 深度学习技巧应用6-神经网络中模型冻结-迁移学习技巧

    大家好,我是微学AI,今天给大家介绍一下深度学习技巧应用6-神经网络中模型冻结:迁移学习的技巧,迁移学习中的部分模型冻结是一种利用预训练模型来解决新问题的技巧,是计算机视觉,自然语言处理等任务里面最重 ...

  5. cnn keras 实现_在iOS应用中实现Keras CNN

    cnn keras 实现 I first thought about image classification in an app through watching the TV show Silic ...

  6. keras中lstm参数_如何使用Keras为自定义NER构建深度神经网络

    在这篇文章中,我们将学习如何使用Keras创建一个简单的神经网络来从非结构化文本数据中提取信息(NER). 模型架构 在这里,我们将使用BILSTM + CRF层.LSTM层用于过滤不需要的信息,将仅 ...

  7. keras安装_代码详解:构建一个简单的Keras+深度学习REST API

    在本教程中,我们将介绍一个简单的方法来获取Keras模型并将其部署为REST API.本文所介绍的示例将作为你构建自己的深度学习API的模板/起点--你可以扩展代码,根据API端点的可伸缩性和稳定性对 ...

  8. 使用keras进行深度学习_如何在Keras中通过深度学习对蝴蝶进行分类

    使用keras进行深度学习 A while ago I read an interesting blog post on the website of the Dutch organization V ...

  9. python 技巧视频教程_扣丁学堂Python视频教程之Python开发技巧

    扣丁学堂Python视频教程之Python开发技巧 2018-07-25 14:09:44 808浏览 关于Python开发的技巧小编在上篇文章已经给大家分享过一些,本篇文章扣丁学堂 神秘eval: ...

最新文章

  1. 自动驾驶关键技术分解和流程
  2. 皮一皮:男女的不同...
  3. 计算机网络技术与应用教程期末考试,2011大学计算机网络技术与应用教程客观题期末复习(含判断题,属于公共课程,使用)...
  4. 瑞士科学家3D打印出5纳米厚的传感器
  5. bootloader搞定,1.67秒!
  6. 2020 操作系统第三天复习(习题总结)
  7. Java char所占用的字节_关于unicode:为什么Java char原语占用2个字节的内存?
  8. javaScript原型及继承
  9. java编码给出二维数组List<List<Integer>>matrix,输出每列最小的值
  10. 美国虚拟主机大打安全牌争抢国内高端外贸主机市场
  11. android alarmmanager定时任务,AlarmManager 实现定时任务
  12. 拓扑图是用什么软件画的?
  13. spring cloud 项目打包时,有一个数据库配置的是现场的库,所以一直不成功,怎么办?
  14. 2023年天津中德应用技术大学专升本飞行器制造工程专业考试大纲
  15. Android 集成极光推送和厂商通道
  16. Unet++语义分割网络(网络结构分析+代码分析)
  17. linux运维cadn,Aprende an elaborar un amasamiento tГЎntrico citaciГіn
  18. 今天分享一个用Python来爬取小说的小脚本!(附源码)
  19. Yolov5可以看到虽然有结果图片,但是并没有框出识别结果
  20. 沙箱环境--虚拟环境

热门文章

  1. 优化PhoneGAP的Splashscreen 类
  2. JS正则表达式详解(转)
  3. 企业的核心竞争力是什么
  4. RichTextBox中表格不能折行的问题
  5. python开发系列
  6. Sql日期时间格式转换
  7. 聊一聊双十一背后的技术 - 不一样的秒杀技术, 裸秒
  8. PHP上传方式base64图片的接收方式
  9. JS的parseInt
  10. uni-app中使用lodash_uniapp适配到微信小程序注意事项