Keras TensorFlow 混编中 trainable=False设置无效

这是最近碰到一个问题,先描述下问题:
首先我有一个训练好的模型(例如vgg16),我要对这个模型进行一些改变,例如添加一层全连接层,用于种种原因,我只能用TensorFlow来进行模型优化,tf的优化器,默认情况下对所有tf.trainable_variables()进行权值更新,问题就出在这,明明将vgg16的模型设置为trainable=False,但是tf的优化器仍然对vgg16做权值更新

以上就是问题描述,经过谷歌百度等等,终于找到了解决办法,下面我们一点一点的来复原整个问题。

trainable=False 无效

首先,我们导入训练好的模型vgg16,对其设置成trainable=False

from keras.applications import VGG16
import tensorflow as tf
from keras import layers
# 导入模型
base_mode = VGG16(include_top=False)
# 查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]
# 设置 trainable=False
# base_mode.trainable = False似乎也是可以的
for layer in base_mode.layers:layer.trainable = False

设置好trainable=False后,再次查看可训练的变量,发现并没有变化,也就是说设置无效

# 再次查看可训练的变量
tf.trainable_variables()
[<tf.Variable 'block1_conv1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32_ref>,<tf.Variable 'block1_conv1_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,<tf.Variable 'block1_conv2_1/bias:0' shape=(64,) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/kernel:0' shape=(3, 3, 64, 128) dtype=float32_ref>,<tf.Variable 'block2_conv1_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/kernel:0' shape=(3, 3, 128, 128) dtype=float32_ref>,<tf.Variable 'block2_conv2_1/bias:0' shape=(128,) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/kernel:0' shape=(3, 3, 128, 256) dtype=float32_ref>,<tf.Variable 'block3_conv1_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv2_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/kernel:0' shape=(3, 3, 256, 256) dtype=float32_ref>,<tf.Variable 'block3_conv3_1/bias:0' shape=(256,) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/kernel:0' shape=(3, 3, 256, 512) dtype=float32_ref>,<tf.Variable 'block4_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block4_conv3_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv1_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv2_1/bias:0' shape=(512,) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/kernel:0' shape=(3, 3, 512, 512) dtype=float32_ref>,<tf.Variable 'block5_conv3_1/bias:0' shape=(512,) dtype=float32_ref>]

解决的办法

解决的办法就是在导入模型的时候建立一个variable_scope,将需要训练的变量放在另一个variable_scope,然后通过tf.get_collection获取需要训练的变量,最后通过tf的优化器中var_list指定需要训练的变量

from keras import models
with tf.variable_scope('base_model'):base_model = VGG16(include_top=False, input_shape=(224,224,3))
with tf.variable_scope('xxx'):model = models.Sequential()model.add(base_model)model.add(layers.Flatten())model.add(layers.Dense(10))
# 获取需要训练的变量
trainable_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'xxx')
trainable_var
[<tf.Variable 'xxx_2/dense_1/kernel:0' shape=(25088, 10) dtype=float32_ref>,<tf.Variable 'xxx_2/dense_1/bias:0' shape=(10,) dtype=float32_ref>]
# 定义tf优化器进行训练,这里假设有一个loss
loss = model.output / 2; # 随便定义的,方便演示
train_step = tf.train.AdamOptimizer().minimize(loss, var_list=trainable_var)

总结

  • 在keras与TensorFlow混编中,keras中设置trainable=False对于TensorFlow而言并不起作用
  • 解决的办法就是通过variable_scope对变量进行区分,在通过tf.get_collection来获取需要训练的变量,最后通过tf优化器中var_list指定训练

Keras TensorFlow 混编中 trainable=False设置无效相关推荐

  1. tf.keras与 TensorFlow混用,trainable=False设置无效

    目录 一个简单的例子 另外一个例子:来自KerasLayer trainable=false seems to have no effect 解决的办法 一个简单的例子 import tensorfl ...

  2. 在装有Keras(Tensorflow)的环境中安装Pandas报错的问题与解决

    背景与问题 在神经网络编程中,需要将数据集处理成神经网络能够处理的格式.常见的以csv.xls等结构化表格文件表示的数据集,需要通过pandas进行读取才能在Python中使用. 在一个安装有Kera ...

  3. Vue Admin Template关闭eslint校验,lintOnSave:false设置无效解决办法

    目录 第一步:lintOnSave:false 第二步:修改package.json中的配置 最后一步: 使用Vue Admin Template 二次开发是一件非常愉悦的事情,可是它里面的eslin ...

  4. Swift和Objective-C混编注意事项

    前言 Swift已推出数年,与Objective-C相比Swift的语言机制及使用简易程度上更接地气,大大降低了iOS入门门槛.当然这对新入行的童鞋没来讲,的确算是福音,但对于整个iOS编程从业者来讲 ...

  5. Swift和Objective-C混编

    Swift和Objective-C混编的注意啦 2016-10-19 13:29 编辑: 不灭的小灯灯 分类:iOS开发 来源:仁伯安的简书 0  OCSwift混编 前言 Swift已推出数年,与O ...

  6. 网易漫画Swift混编实践

    \ 本文为『移动前线』群在4月8日的分享总结整理而成,转载请注明来自『移动开发前线』公众号.\ 嘉宾介绍 \ 胡波,来自于网易杭州研究院,之前在网易杭研移动应用部参与网易公开课/网易看游戏/网易云阅读 ...

  7. php中数据类型、数组排序、循环语句、混编、操作本地文件流程、常用API、函数、魔术常量

    php中数据类型: php中有7种数据类型,分别是: //1.String字符串,用引号包裹的字符,如:$str = 'hello word';//2.Integer整型,可以是正数或负数,有十进制. ...

  8. c++ python opencv_ubuntu下C++与Python混编,opencv中mat类转换

    C++ 与 Python 混编 因为赶项目进度,需要使用到深度学习的内容,不过现有的深度学习框架大多使用python代码,对于不会改写C++的朋友来说,需要耗费大量的时间去改写,因此,使用python ...

  9. swift 打包sdk_在封装SDK中Swift和OC混编之相互调用

    oc和swift混编之相互调用.jpg 在非SDK中: 1.swift调用oc 步骤: 创建 工程名-Bridging-Header.h 放入oc的头文件,swift即可调用 在swift项目中或者在 ...

最新文章

  1. 桌面应用程序员简单尝试Rich JavaScript Application
  2. Android 自定义实现switch开关按钮
  3. Android 原生通知Notification 写法
  4. 野史杂谈,西游记令人崩溃的真相
  5. SQL注入 1-3_基于post报错注入
  6. BZOJ-3473 (广义后缀自动机:拓扑 or 启发式合并)
  7. 【转载】浅谈 看图软件 的设计与实现
  8. 24-[模块]-re
  9. base32解码工具_[随波逐流]CTF编码工具 V1.0
  10. Baumer相机BGAPI_ImageHeader Member List
  11. 原始的Ajax请求方式 (XMLHttpRequest)
  12. windows下安装,配置gcc编译器
  13. NLP系列(6)_从NLP反作弊技术看马蜂窝注水事件
  14. 自制电吉他效果器 DIY PCB(三)原理图与封装 上
  15. 计算VGG16的参数量
  16. 第三讲 信息资产的分类与控制
  17. 【DAOS】Intel DAOS 分布式异步对象存储
  18. 用matlab画树叶,matlab画漂亮的树叶
  19. Mongoose disconnected. Mongoose connection error: MongoError: Authentication failed. (node:1532) Unh
  20. 阿里云云服务器 ECS SSHKEY登录

热门文章

  1. python xrange_Python学习中的知识点(range和xrange)
  2. Docker的常用操作
  3. 【论文阅读】Deep Neural Networks for Learning Graph Representations | day14,15
  4. 【c语言】 gets()函数不执行/被跳过
  5. 计算机盘不显示桌面,电脑开机后不显示Windows系统桌面怎么办?
  6. 火狐中怎么把xml转换为html,创建兼容IE、火狐、chrome、oprea浏览器的xmlDom对象方法...
  7. 小小c语言贪吃蛇思路,【图片】C语言小游戏~贪吃蛇【c语言吧】_百度贴吧
  8. python通信模块_基于Python的电路故障诊断系统通信模块的实现
  9. 响应信息有json和html,获取HTML响应而不是Json响应
  10. opera for android,Opera Mobile浏览器