Keras TensorFlow 混编中 trainable=False设置无效
Keras TensorFlow 混编中 trainable=False设置无效
这是最近碰到一个问题,先描述下问题:
首先我有一个训练好的模型(例如vgg16
),我要对这个模型进行一些改变,例如添加一层全连接层,用于种种原因,我只能用TensorFlow来进行模型优化,tf的优化器,默认情况下对所有tf.trainable_variables()
进行权值更新,问题就出在这,明明将vgg16
的模型设置为trainable=False
,但是tf的优化器仍然对vgg16
做权值更新
以上就是问题描述,经过谷歌百度等等,终于找到了解决办法,下面我们一点一点的来复原整个问题。
trainable=False 无效
首先,我们导入训练好的模型vgg16
,对其设置成trainable=False
from keras.applications import VGG16
import tensorflow as tf
from keras import layers
# 导入模型
base_mode = VGG16(include_top=False)
# 查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
# 设置 trainable=False
# base_mode.trainable = False似乎也是可以的
for layer in base_mode.layers:layer.trainable = False
设置好trainable=False
后,再次查看可训练的变量,发现并没有变化,也就是说设置无效
# 再次查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
解决的办法
解决的办法就是在导入模型的时候建立一个variable_scope
,将需要训练的变量放在另一个variable_scope
,然后通过tf.get_collection
获取需要训练的变量,最后通过tf的优化器中var_list
指定需要训练的变量
from keras import models
with tf.variable_scope('base_model'):base_model = VGG16(include_top=False, input_shape=(224,224,3))
with tf.variable_scope('xxx'):model = models.Sequential()model.add(base_model)model.add(layers.Flatten())model.add(layers.Dense(10))
# 获取需要训练的变量
trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx')
trainable_var
[<tf.Variable 'xxx_2/dense_1/kernel:0' shape=(25088, 10) dtype=float32_ref>,<tf.Variable 'xxx_2/dense_1/bias:0' shape=(10,) dtype=float32_ref>]
# 定义tf优化器进行训练,这里假设有一个loss
loss = model.output / 2; # 随便定义的,方便演示
train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)
总结
- 在keras与TensorFlow混编中,keras中设置
trainable=False
对于TensorFlow而言并不起作用 - 解决的办法就是通过
variable_scope
对变量进行区分,在通过tf.get_collection
来获取需要训练的变量,最后通过tf优化器中var_list
指定训练
Keras TensorFlow 混编中 trainable=False设置无效相关推荐
- tf.keras与 TensorFlow混用,trainable=False设置无效
目录 一个简单的例子 另外一个例子:来自KerasLayer trainable=false seems to have no effect 解决的办法 一个简单的例子 import tensorfl ...
- 在装有Keras(Tensorflow)的环境中安装Pandas报错的问题与解决
背景与问题 在神经网络编程中,需要将数据集处理成神经网络能够处理的格式.常见的以csv.xls等结构化表格文件表示的数据集,需要通过pandas进行读取才能在Python中使用. 在一个安装有Kera ...
- Vue Admin Template关闭eslint校验,lintOnSave:false设置无效解决办法
目录 第一步:lintOnSave:false 第二步:修改package.json中的配置 最后一步: 使用Vue Admin Template 二次开发是一件非常愉悦的事情,可是它里面的eslin ...
- Swift和Objective-C混编注意事项
前言 Swift已推出数年,与Objective-C相比Swift的语言机制及使用简易程度上更接地气,大大降低了iOS入门门槛.当然这对新入行的童鞋没来讲,的确算是福音,但对于整个iOS编程从业者来讲 ...
- Swift和Objective-C混编
Swift和Objective-C混编的注意啦 2016-10-19 13:29 编辑: 不灭的小灯灯 分类:iOS开发 来源:仁伯安的简书 0 OCSwift混编 前言 Swift已推出数年,与O ...
- 网易漫画Swift混编实践
\ 本文为『移动前线』群在4月8日的分享总结整理而成,转载请注明来自『移动开发前线』公众号.\ 嘉宾介绍 \ 胡波,来自于网易杭州研究院,之前在网易杭研移动应用部参与网易公开课/网易看游戏/网易云阅读 ...
- php中数据类型、数组排序、循环语句、混编、操作本地文件流程、常用API、函数、魔术常量
php中数据类型: php中有7种数据类型,分别是: //1.String字符串,用引号包裹的字符,如:$str = 'hello word';//2.Integer整型,可以是正数或负数,有十进制. ...
- c++ python opencv_ubuntu下C++与Python混编,opencv中mat类转换
C++ 与 Python 混编 因为赶项目进度,需要使用到深度学习的内容,不过现有的深度学习框架大多使用python代码,对于不会改写C++的朋友来说,需要耗费大量的时间去改写,因此,使用python ...
- swift 打包sdk_在封装SDK中Swift和OC混编之相互调用
oc和swift混编之相互调用.jpg 在非SDK中: 1.swift调用oc 步骤: 创建 工程名-Bridging-Header.h 放入oc的头文件,swift即可调用 在swift项目中或者在 ...
最新文章
- 桌面应用程序员简单尝试Rich JavaScript Application
- Android 自定义实现switch开关按钮
- Android 原生通知Notification 写法
- 野史杂谈,西游记令人崩溃的真相
- SQL注入 1-3_基于post报错注入
- BZOJ-3473 (广义后缀自动机:拓扑 or 启发式合并)
- 【转载】浅谈 看图软件 的设计与实现
- 24-[模块]-re
- base32解码工具_[随波逐流]CTF编码工具 V1.0
- Baumer相机BGAPI_ImageHeader Member List
- 原始的Ajax请求方式 (XMLHttpRequest)
- windows下安装,配置gcc编译器
- NLP系列(6)_从NLP反作弊技术看马蜂窝注水事件
- 自制电吉他效果器 DIY PCB(三)原理图与封装 上
- 计算VGG16的参数量
- 第三讲 信息资产的分类与控制
- 【DAOS】Intel DAOS 分布式异步对象存储
- 用matlab画树叶,matlab画漂亮的树叶
- Mongoose disconnected. Mongoose connection error: MongoError: Authentication failed. (node:1532) Unh
- 阿里云云服务器 ECS SSHKEY登录
热门文章
- python xrange_Python学习中的知识点(range和xrange)
- Docker的常用操作
- 【论文阅读】Deep Neural Networks for Learning Graph Representations | day14,15
- 【c语言】 gets()函数不执行/被跳过
- 计算机盘不显示桌面,电脑开机后不显示Windows系统桌面怎么办?
- 火狐中怎么把xml转换为html,创建兼容IE、火狐、chrome、oprea浏览器的xmlDom对象方法...
- 小小c语言贪吃蛇思路,【图片】C语言小游戏~贪吃蛇【c语言吧】_百度贴吧
- python通信模块_基于Python的电路故障诊断系统通信模块的实现
- 响应信息有json和html,获取HTML响应而不是Json响应
- opera for android,Opera Mobile浏览器