tensorflow中定义的tf.Variable时,可以通过trainable属性控制这个变量是否可以被优化器更新。但是,tf.Variable的trainable属性是只读的,我们无法动态更改这个只读属性。在定义tf.Variable时,如果指定trainable=True,那么会把这个Variable添加到“可被训练的变量”集合中。

把trainable指定为布尔变量是不管用的,trainable只在定义变量的那一瞬间有用。

    # trainable只能是bool值,不能是张量trainable = tf.Variable(False, dtype=tf.bool)loss = tf.Variable(3.0, dtype=tf.float32, trainable=trainable)train_op = tf.train.AdamOptimizer(0.01).minimize(loss)with tf.Session()as sess:sess.run(tf.global_variables_initializer())for i in range(100):_, lo = sess.run([train_op, loss], feed_dict={trainable: i % 10 < 5})print('epoch', i, 'loss', lo)    

在定义Variable变量的那一瞬间,如果trainable=true,这个变量就会被添加到可被训练的变量集合中去。当定义optimizer的minimize张量时,minimize张量就会读取可被训练的变量集合并构建张量。此后,即便可被训练的变量集合发生改变,minimize张量也不会再去管哪些变量不能被训练了。

    """如果optimizer的全部变量都是不可训练的,tensorflow会抛出异常所以在这里使用两个变量,两个变量轮流变得可调节:return:"""x = tf.Variable(3.0, dtype=tf.float32)y = tf.Variable(13.0, dtype=tf.float32)train_op = tf.train.AdamOptimizer(0.01).minimize(tf.abs(y - x))with tf.Session()as sess:sess.run(tf.global_variables_initializer())print("trainable_variables is a function")print(tf.trainable_variables, type(tf.trainable_variables()))print(tf.trainable_variables())print("tf.GraphKeys has several string key")print(tf.GraphKeys.TRAINABLE_VARIABLES, type(tf.GraphKeys.TRAINABLE_VARIABLES))print("tf.get_collection can get something by tf.GraphKeys")col = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)print(col, type(col))print("try remove x from trainable variables")del col[col.index(x)]  # 此处虽然可被训练的变量集合变化了,但是train_op已经定义完了print(tf.trainable_variables())print('=======')for i in range(100):_, xx, yy = sess.run([train_op, x, y])print('epoch', i, xx, yy)  # 此处x和y都会变化

tf.GraphKeys

tf.GraphKeys中包含了所有默认集合的名称,可以通过查看__dict__发现具体集合。

tf.GraphKeys.GLOBAL_VARIABLES:global_variables被收集在名为tf.GraphKeys.GLOBAL_VARIABLES的colletion中,包含了模型中的通用参数

tf.GraphKeys.TRAINABLE_VARIABLES:tf.Optimizer默认只优化tf.GraphKeys.TRAINABLE_VARIABLES中的变量。

  • tf.global_variables() GLOBAL_VARIABLES
    存储和读取checkpoints时,使用其中所有变量
    跨设备全局变量集合
  • tf.trainable_variables() TRAINABLE_VARIABLES
    训练时,更新其中所有变量
    存储需要训练的模型参数的变量集合
  • tf.moving_average_variables() MOVING_AVERAGE_VARIABLES
    ExponentialMovingAverage对象会生成此类变量
    实用指数移动平均的变量集合
  • tf.local_variables() LOCAL_VARIABLES
    在global_variables()之外,需要用tf.init_local_variables()初始化
    进程内本地变量集合
  • tf.model_variables() MODEL_VARIABLES
    Key to collect model variables defined by layers.
    进程内存储的模型参数的变量集合
  • QUEUE_RUNNERS 并非存储variables,存储处理输入的QueueRunner
  • SUMMARIES 并非存储variables,存储日志生成相关张量

除了以上的函数外(上表中最后两个集合并非变量集合,为了方便一并放在这里),还可以使用tf.get_collection(集合名)获取集合中的变量,不过这个函数更多与tf.get_collection(集合名)搭配使用,操作自建集合。

Summary被收集在名为tf.GraphKeys.UMMARIES的colletion中,Summary是对网络中Tensor取值进行监测的一种Operation,这些操作在图中是“外围”操作,不影响数据流本身,调用tf.scalar_summary系列函数时,就会向默认的collection中添加一个Operation。

我们也可以自定义变量集合、操作集合,这在正则化参数时非常有用。

x1 = tf.constant(1.0)
l1 = tf.nn.l2_loss(x1)
x2 = tf.constant([2.5, -0.3])
l2 = tf.nn.l2_loss(x2)
tf.add_to_collection("losses", l1)
tf.add_to_collection("losses", l2)
losses = tf.get_collection('losses')
loss_total = tf.add_n(losses)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
losses_val = sess.run(losses)
loss_total_val = sess.run(loss_total)

我说

tensorflow臃肿庞杂,设计者的设计水平远远比不上keras。
tensorflow臃肿庞杂,做了许多外围操作。比如为变量起名字,把变量添加到集合中,使用summary来监控训练中产生的数据。这些操作都不是核心操作,分清核心操作和扩展操作非常重要。

  • 基本操作:如加减乘除、矩阵乘法等运算
  • python语言操作:基本上是一些外围操作如collection,summary,dataset等。tf.gfile中定义了一堆文件操作,比python自带的文件操作要高效易用。
  • 函数级封装:把经常使用的基本操作定义成一个函数,如softmax、wx_b、cross_entropy等。
  • 层级封装:定义一些常见层,如全连接层、卷积层等。
  • 模型封装:keras中有Model,Tensorflow不好意思直接拿来用,起了个名叫“Estimator”。

optimizer其实也是一种封装,optimizer其实就是对变量执行assign操作。除了使用反向传播,我们也可以自己定义基于遗传算法的optimizer。

拦截optimizer的梯度更新过程实现动态trainable

optimizer计算梯度的过程是应用梯度的过程是两个步骤。计算梯度张量返回一个grad_and_vars列表,应用梯度需要grad_and_vars列表作为参数。

我们可以建立(loss,exemp)到minize张量的映射。

    # 拦截梯度更新过程class MyOptimizer:def __init__(self, optimizer: tf.train.Optimizer):self.optimizer = optimizerself.operations = dict()def minimize(self, loss, exemp):"""注意:因为minimize操作是在sess运行时运行的,如果总是创建新操作,GPU内存会溢出"""k = ' '.join(sorted([i.name for i in exemp])) + loss.nameif k not in self.operations:a = [i for i in tf.trainable_variables() if i not in exemp]grad_vars = self.optimizer.compute_gradients(loss, a)op = self.optimizer.apply_gradients(grad_vars)self.operations[k] = opreturn self.operations[k]x = tf.Variable(3.0, dtype=tf.float32)y = tf.Variable(31.0, dtype=tf.float32)loss = tf.abs(x - y)"""为了初始化optimizer中的一些信息,所以需要来一个加的operation形成一个张量"""optimizer = MyOptimizer(tf.train.AdamOptimizer(0.01))train_op = optimizer.minimize(loss, [])with tf.Session()as sess:sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))for i in range(100):exemp = [x if i % 10 < 5 else y]_, xx, yy, lo = sess.run([optimizer.minimize(loss, exemp=exemp), x, y, loss])print('epoch', i, 'x', xx, 'y', yy, 'loss', lo)

这种方法的缺点在于使用loss和exemp作为key,如果key太多,定义的张量就会变多,这样会产生很多变量。

尝试优化一下,使用loss作为key。

        def __init__(self, optimizer: tf.train.Optimizer):self.optimizer = optimizerself.operations = dict()def minimize(self, loss, exemp):"""注意:因为minimize操作是在sess运行时运行的,如果总是创建新操作,GPU内存会溢出"""if loss.name not in self.operations:grad_vars = self.optimizer.compute_gradients(loss)self.operations[loss.name] = grad_varsgrad_vars = self.operations[loss.name]exemp = set(exemp)grad_vars = list(filter(lambda x: x[1] not in exemp, grad_vars))op = self.optimizer.apply_gradients(grad_vars)return opx = tf.Variable(3.0, dtype=tf.float32)y = tf.Variable(31.0, dtype=tf.float32)loss = tf.abs(x - y)"""为了初始化optimizer中的一些信息,所以需要来一个加的operation形成一个张量"""optimizer = MyOptimizer(tf.train.AdamOptimizer(0.01))train_op = optimizer.minimize(loss, [])with tf.Session()as sess:sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))for i in range(100):exemp = [x if i % 10 < 5 else y]_, xx, yy, lo = sess.run([optimizer.minimize(loss, exemp=exemp), x, y, loss])print('epoch', i, 'x', xx, 'y', yy, 'loss', lo)

这种方法其实更差劲,因为apply_gradients依旧会创建许多张量(许多tf.assign_sub张量),而第一种方法反倒没有那么多的张量。

梯度更新的过程其实就是一堆assign操作。

    # 拦截梯度更新过程class MyOptimizer:def __init__(self, optimizer: tf.train.Optimizer):self.optimizer = optimizerself.operations = dict()def minimize(self, loss, exemp):"""注意:因为minimize操作是在sess运行时运行的,如果总是创建新操作,GPU内存会溢出"""if loss.name not in self.operations:grad_vars = self.optimizer.compute_gradients(loss)op = [(variable, tf.assign_sub(variable, self.optimizer._lr * grad)) for grad, variable in grad_vars]self.operations[loss.name] = opgrad_vars = self.operations[loss.name]op = [x[1] for x in grad_vars if x[0] not in exemp]return opx = tf.Variable(3.0, dtype=tf.float32)y = tf.Variable(31.0, dtype=tf.float32)loss = tf.abs(x - y)"""为了初始化optimizer中的一些信息,所以需要来一个加的operation形成一个张量"""optimizer = MyOptimizer(tf.train.AdamOptimizer(0.01))train_op = optimizer.minimize(loss, [])with tf.Session()as sess:sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))for i in range(100):exemp = [x if i % 10 < 5 else y]_, xx, yy, lo = sess.run([optimizer.minimize(loss, exemp=exemp), x, y, loss])print('epoch', i, 'x', xx, 'y', yy, 'loss', lo)

参考资料

https://www.cnblogs.com/hellcat/p/9006904.html

转载于:https://www.cnblogs.com/weiyinfu/p/9973022.html

tensorflow动态设置trainable相关推荐

  1. 微信小程序动态设置 tabBar

    微信小程序开发交流qq群   173683895    承接微信小程序开发.扫码加微信. 使用微信提供的API wx.setTabBarItem(Object object) 动态设置 tabBar ...

  2. Silverlight动态设置WCF服务Endpoint

    去年12月收到一位朋友的邮件,咨询Silverlight使用WCF服务,应用部署后一直无法访问的问题,通过几次交流,才发现在他的项目中,全部使用静态URL作为WCF服务的Endpoint地址,后来修改 ...

  3. 【WPF】动态设置Binding的ConverterParameter转换器参数

    原文:[WPF]动态设置Binding的ConverterParameter转换器参数 问题:XAML中,想要在一个Bingding语句中再次Bingding. Source="{Bindi ...

  4. [C# 开发技巧系列]如何动态设置屏幕分辨率

    原文 http://www.cnblogs.com/zhili/archive/2013/05/23/ChangeResolution.html 因为最近在MSDN论坛和stackflow中看到一些朋 ...

  5. html设置根rem,经过js动态设置根元素的rem方案

    rem目前是响应式开发移动端一个很重要也是经常使用的一个元素,可是在网上看的各类文章都会超级懵逼.因此我在下面给出两个方案,也列举出使用方法,让你们一目了然.前提是设计稿以750为准.其中测试的设计稿 ...

  6. android 动态设置padding,Android动态设置控件大小以及设定margin以及padding值

    http://www.aichengxu.com/Java/73893.htm Android动态设置控件大小以及设定margin以及padding值,有需要的朋友可以参考下. 一.概述 在andro ...

  7. Android 如何在xmL 里面动态设置padding

    如题,Android 如何在xmL 里面动态设置padding 有时候,你的布局加载完成之后,你findViewByid 找到控件,设置padding 会导致白条,布局闪动,那怎么办呢? 你是不是就想 ...

  8. Flex 学习笔记 动态设置itemRenderer

    Tree.DataGrid经常要设置自己制定样式或特殊的UI,我们需要呈现器,经常使用外部呈现器(作为项目渲染器使用的自定义组件在MXML或ActionScript编写),我们需要用到itemRend ...

  9. 微信小程序首页index.js获取不到app.js中动态设置的globalData的原因以及解决方法

    微信小程序首页index.js获取不到app.js中动态设置的globalData的原因以及解决方法 参考文章: (1)微信小程序首页index.js获取不到app.js中动态设置的globalDat ...

最新文章

  1. Django模型之数据库操作-查询
  2. 回首这一年,其实我还是一样!
  3. 把自己的思想记录下来
  4. cxgrid 行合并单元格_【Excel VBA】如何批量撤销合并单元格?
  5. 【数据结构与算法】之深入解析“水壶问题”的求解思路与算法示例
  6. NET问答: C# 中是否有最高效的方式对大文件做 checksum ?
  7. c# mvvm模式获取当前窗口_【自学C#】I 书 12 异常处理
  8. c 不安装oracle,安装oracle 10g 的艰难之旅
  9. UVa10082 - WERTYU
  10. Splunk数据处理
  11. java导出到txt_Java生成TXT文本并下载
  12. chrome官网下载win64离线安装包
  13. matlab心电信号处理,基于MATLAB的心电信号的数字滤波处理
  14. 前端和后端哪个工资更高呢?
  15. Linux网卡流量限制
  16. 将excel中的合并单元格拆分并填充数据
  17. 科创板|龙软科技国科环宇等4公司中止审核
  18. 计算机网络与通信之局域网
  19. web课程设计网页规划与设计:个人毕设网站设计 —— 二手书籍(11个页面) HTML+CSS+JavaScript
  20. pubmed文献批量化下载器

热门文章

  1. mybatis plus 入门
  2. 二十四 Redis消息订阅事务持久化
  3. [转]学会使用DB2指令
  4. Java每天学习一点点 09.10.13
  5. The requested lisk key xxx could not be resolved as a collection type.
  6. 谁动过你的电脑?小姐姐们要学会保护好自己电脑里的小秘密呀
  7. 女友晚安之后依然在线:python男友用20行代码写了个小工具
  8. [转]Linux内核基础与常用命令总结
  9. 设计模式20_观察者
  10. Hive的10种常用优化总结,再也不怕MapReduce分配不均了