一、背景

有时候在训练模型的时候,会有这样的需求:某个loss,只想影响一部分网络参数的更新,而另外一部分网络参数不想受这个loss的影响,特别是像多目标的多塔结构的模型。

二、实现

a = weight1 + weight2
a_stopped = tf.stop_gradient(a)
y3 = a_stopped + weight3gradients1 = tf.gradients(y3, [weight1, weight2, weight3], grad_ys=[tf.convert_to_tensor([1., 2.])])
gradients2 = tf.gradients(y3, [weight3], grad_ys=[tf.convert_to_tensor([1., 2.])])
print(gradients1)  # [None, None, < tf.Tensor 'gradients_1/grad_ys_0:0' shape = (2,) dtype = float32 >]
with tf.Session() as sess:sess.run(tf.global_variables_initializer())'''下面代码会报错因为weight1、weight2 的梯度被停止,程序试图去求一个None的梯度,所以报错注释掉gradients1求 gradients2 就又正确了'''# print(sess.run(gradients1))print(sess.run(gradients2))

对于y3来说,就相当于把a_stopped从变量(和weight1和weight2有关)变成一个常量,所以对他求导不需要在执行链式求导法则,对常量求偏导就是等于0

tensorflow 冻结梯度相关推荐

  1. tensorflow计算网络占用内存_详细图解神经网络梯度下降法(tensorflow计算梯度)...

    1.什么是梯度 各个方向的偏微分组成的向量 ​ 举例说明,z对x的偏微分和对y的偏微分如下,则梯度是(-2x,2y)的这样一个向量 ​ 在光滑连续函数的每个点上,都可以计算一个梯度,也就是一个向量,用 ...

  2. 【Tensorflow】Tensorflow 自定义梯度

    目录 前言 自定义梯度 说明 gradient_override_map的使用 多输入与多输出op 利用stop_gradient 参考 [fishing-pan:https://blog.csdn. ...

  3. Python使用tensorflow中梯度下降算法求解变量最优值

    TensorFlow是一个用于人工智能的开源神器,是一个采用数据流图(data flow graphs)用于数值计算的开源软件库.数据流图使用节点(nodes)和边线(edges)的有向图来描述数学计 ...

  4. tensorflow随机梯度下降算法使用滑动平均模型

    在采用随机梯度下降算法训练神经网络时,使用滑动平均模型可以提高最终模型在测试集数据上的表现.在Tensflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模 ...

  5. tensorflow 计算梯度: tf.gradient() 与 tf.GradientTape()

    tensorflow就是版本繁多,同一个功能有n种实现方式,之前一直用tf.gradient()计算梯度,今天发现还有tf.GradientTape. 1. tf.gradient() 参考:Tens ...

  6. TensorFlow教程——梯度爆炸与梯度裁剪

    在较深的网络,如多层CNN或者非常长的RNN,由于求导的链式法则,有可能会出现梯度消失(Gradient Vanishing)或梯度爆炸(Gradient Exploding )的问题.(这部分知识后 ...

  7. tensorflow实现梯度累计,再回传

    由于主机显卡只有12g的显存,且只装了一块30系列的卡,因此在跑代码时难免会遇到batch_size不能太大的尴尬,因此可以通过,梯度累计的方式进行优化,来变相扩大batch_size.这样的操作在p ...

  8. TensorFlow实现梯度下降法求解一元和多元线性回归问题

    使用TensorFlow求解一元线性回归问题 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt# 设 ...

  9. tensorflow 教程 梯度下降法实现线性回归问题

    视频 https://www.bilibili.com/video/BV1jK4y187yB?p=52

最新文章

  1. 测试用例的设计方法(三)
  2. django-pure-pagination 组件使用
  3. COS系统的前端演变和发展
  4. IDEA引入外部jar包的方法
  5. caffeine 线程私有的ReadBuffer实现
  6. 卡尔曼滤波 -- 从推导到应用(二)
  7. Oracle数据库基础知识(一)
  8. [易语言]易语言实现简单的答题软件
  9. iozone磁盘读写测试工具的使用以及命令详解、下载(网站最详细讲解步骤)
  10. 使用Aforge 开发的摄像头,有拍照,录像,设置帧率,分辨率等,以及对视频以及相机等的控制
  11. 昂达v891w可以用u盘linux,安卓、Win8随便用 昂达V891w双系统平板测试(转载)
  12. Win7中自带截图工具
  13. 文本分类概述(nlp)
  14. 中宠股份第三季度营收7.55亿元:增速环比持续下滑,净利润转降
  15. 99道python测试题
  16. NSGA3算法及其MATLAB版本实现(转载)
  17. codeforces csp复赛训练利器---初识
  18. C++怎样获取当前系统时间?
  19. Odoo tree视图使用js添加按钮(以及跳转页面)
  20. SparkSQL比MapReduce快的原因

热门文章

  1. JCL,JES运作流程
  2. 如何利用(微软学术)/(google学术)/google网页,聚焦最新科技文献,并获得PDF版
  3. uniapp,video视频播放不了,页面显示不完整
  4. MRTG (Multi Router Traffic Grapher)
  5. Centos 7 安装 OpenResty api 网关 Orange
  6. 计算机工程师花了三年建模女朋友,网易工作十年游戏建模师,还没有女朋友,是怎样的一种体验?...
  7. python抓取360图片之马自达
  8. DNS安全(一)DNS缓存投毒与防护
  9. NYOJ 82 迷宫寻宝
  10. windows 10 python 3.7.9 install rosbag