tf.gradients()解析及grad_ys在xs为(?, 1)时的理解

问题简介

使用tensorflow1.15学习时,有一项tf.gradients的代码,其中用到了grad_ys这个参数,经过一些解析,得到了一些自己的理解

原代码

    def fwd_gradients_1(self, U, x):g = tf.gradients(U, x, grad_ys=self.dummy_x1_tf)[0]return tf.gradients(g, self.dummy_x1_tf)[0]

这里面的U是经过神经网络之后的output,shape为[250, 500],xxx是input,shape为[250, 1]。在加上grad_ys之后,最后得到的是[250, 500]的值,代表的是每个UijU_{ij}Uij​对xix_{i}xi​的导数。
tf.gradients()的参数为:

tf.gradients(ys, xs, grad_ys=None, name='gradients',colocate_gradients_with_ops=False,gate_gradients=False,aggregation_method=None,stop_gradients=None)

tf.gradients的输出结果是一个列表,列表中的是array格式的导数。我需要的代码中只需要理解前三个参数,i.e. ys, xs, grad_ys即可。
首先,ys和xs我理解的是一定要有关系,这个关系的意思是必须在代码中经过代码运算可以通过xs得到ys。而grad_ys是导数的权重参数。ys和xs的行数相同,这是因为行代表样本个数。
grad_ys的重点解析,grad_ds是对导数的加权,len(ys)=len(grad_ys),我尝试了很多,最好使用shape(grad_ds)=shape(ys), 其实这是我自己的论文代码的特殊性,只需要对单个xxx进行求导。
这时,tf.gradients(y,x)tf.gradients(y, x)tf.gradients(y,x)为:

也就是说,dy_dx是一个(2,1)的数组,将w1+w2+w3w_{1}+w_{2}+w_{3}w1​+w2​+w3​重复 x.shape[0]x.shape[0]x.shape[0]次。

例子1 代码:

import tensorflow as tf
import numpy as npx = tf.constant([1., 2.],dtype=tf.float32, shape=(2,1))
w = tf.constant([2.,3.5,4.0], dtype = tf.float32, shape=[1, 3])
y = tf.matmul(x,w)
dy_dx = tf.gradients(y, x)
with tf.Session() as sess:print(sess.run(dy_dx))

得到的结果为

可以看到,一共有两个xxx,所以得到的是一个(2,1)的重复数组, 其中9.5=w1+w2+w3=2+3.5+49.5=w_{1}+w_{2}+w_{3}=2+3.5+49.5=w1​+w2​+w3​=2+3.5+4,当加上grad_ys时,得到的是乘以权重的导数,令

dy_dx_gra = tf.gradients(ys, xs, grad_ys)为:

例子2

import tensorflow as tf
import numpy as npx = tf.constant([1., 2.], dtype=tf.float32, shape=(2,1))
w = tf.constant([2.,3.5,4.0], dtype=tf.float32, shape=[1, 3])
y = tf.matmul(x, w)
gr = tf.constant([1,2,3,4,5,6],shape=[2,3],dtype=tf.float32)
dy_dx = tf.gradients(y, x, grad_ys=gr)
with tf.Session() as sess:print(sess.run(dy_dx))

结果为

此时的21=1∗2+2∗3.5+3∗4,49.5=4∗2+5∗3.5+6∗421=1*2+2*3.5+3*4,49.5=4*2+5*3.5+6*421=1∗2+2∗3.5+3∗4,49.5=4∗2+5∗3.5+6∗4。到这里为止,得到的导数都是所有的y对x的导数的和:
例如,例子1中的9.5

神经网络的求导原则

对于神经网络,无论进行多少层的计算,我们要的是最后的output对xxx的导数

例子1,xxx为一列时的结果

import tensorflow as tf
import numpy as npdef weight_variables(shape):weight = tf.Variable(tf.random_normal(shape=shape), dtype=tf.float32)return weightdef bias_variables(shape):bias = tf.Variable(tf.constant(0.0, shape=shape), dtype=tf.float32)return biasx1 = tf.constant([0., 1., 2., 3., 4., 5., 6.], dtype=tf.float32)
x1 = x1[:, None]
t1 = tf.constant([3, 2, 4, 5, 6, 7, 6], dtype=tf.float32)
t1 = t1[:, None]
x = tf.concat((x1, t1), axis=1)w_1 = weight_variables([1, 4])b_1 = bias_variables([4])y_1 = tf.matmul(x1, w_1) + b_1w_2 = weight_variables([4, 4])
w = tf.matmul(w_1, w_2)
b_2 = bias_variables([4])
y_2 = tf.matmul(y_1, w_2) + b_2
gr = tf.ones([7, 4])
y_x1 = tf.gradients(y_2, x1)[0]
y_x12 = tf.gradients(y_2, x1, grad_ys=gr)[0]
y_x12_1 = tf.gradients(y_x12, gr)[0]
y_x11 = y_x1[0]
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(y_2))print('*' * 30)print(sess.run(w))print(sess.run(w).shape)print('+' * 20)print(sess.run(y_x1))print(sess.run(y_x12))print(sess.run(y_x12_1))

结果中y_2是输出值,w是两层神经网络的权重w_1,w_2矩阵相乘后的结果,y_x1是没有grad_ys的导数结果,y_x12是grad_ys全为1的结果,全为1的时候导数结果不变,y_x12_1是将导数拆开之后的结果,即,每个y对每个x的导数
结果如下:

y_2 = [[  0.           0.           0.           0.        ][  2.4336433    0.90281177  -2.450363    -1.3428192 ][  4.8672867    1.8056235   -4.900726    -2.6856384 ][  7.30093      2.7084355   -7.3510895   -4.028458  ][  9.734573     3.611247    -9.801452    -5.371277  ][ 12.168217     4.514058   -12.251815    -6.7140956 ][ 14.60186      5.416871   -14.702179    -8.056915  ]]
w = [[ 2.4336433   0.90281177 -2.4503632  -1.3428191 ]]
y_x1 = [[-0.45672727][-0.45672727][-0.45672727][-0.45672727][-0.45672727][-0.45672727][-0.45672727]]
y_x12 = [[-0.45672727][-0.45672727][-0.45672727][-0.45672727][-0.45672727][-0.45672727][-0.45672727]]
y_x12_1 = [[ 2.4336433   0.90281177 -2.450363   -1.3428192 ][ 2.4336433   0.90281177 -2.450363   -1.3428192 ][ 2.4336433   0.90281177 -2.450363   -1.3428192 ][ 2.4336433   0.90281177 -2.450363   -1.3428192 ][ 2.4336433   0.90281177 -2.450363   -1.3428192 ][ 2.4336433   0.90281177 -2.450363   -1.3428192 ][ 2.4336433   0.90281177 -2.4503632  -1.3428192 ]]-0.45672727 = 2.4336433-0.90281177-2.450363-1.3428192

例子2,xxx为2列时的结果

x1 = tf.constant([0., 1., 2., 3., 4., 5.], shape=[3,2], dtype=tf.float32)w_1 = weight_variables([2, 4])b_1 = bias_variables([4])y_1 = tf.matmul(x1, w_1) + b_1w_2 = weight_variables([4, 4])
w = tf.matmul(w_1, w_2)
b_2 = bias_variables([4])
y_2 = tf.matmul(y_1, w_2) + b_2
gr = tf.ones([3, 4])
y_x1 = tf.gradients(y_2, x1)[0]
y_x12 = tf.gradients(y_2, x1, grad_ys=gr)[0]
y_x12_1 = tf.gradients(y_x12, gr)[0]
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(y_2))print('*' * 30)print(sess.run(w_1))print('+' * 20)print(sess.run(w))print(sess.run(w).shape)print('+' * 20)print(sess.run(y_x1))print(sess.run(y_x12))print(sess.run(y_x12_1))结果为:
y_2 =[[ -1.7961562  -1.5814966  -1.5294867   1.7090911][ -7.378949   -2.5710013  -5.1707554   4.268697 ][-12.961741   -3.5605054  -8.812023    6.828303 ]]w = [[-0.99524033  1.0867442  -0.29114777 -0.42928803][-1.7961562  -1.5814966  -1.5294867   1.7090911 ]]
y_x1 = [[-0.62893206 -3.1980484 ][-0.62893206 -3.1980484 ][-0.62893206 -3.1980484 ]]-0.62893206 = -0.99524033+ 1.0867442  -0.29114777 -0.42928803
-3.1980484 = -1.7961562-1.5814966-1.5294867+1.7090911
y_x12 =  [[-0.62893206 -3.1980484 ][-0.62893206 -3.1980484 ][-0.62893206 -3.1980484 ]]
y_x12_1 = [[-2.7913966  -0.49475235 -1.8206344   1.2798029 ][-2.7913966  -0.49475235 -1.8206344   1.2798029 ][-2.7913966  -0.49475235 -1.8206344   1.2798029 ]]-2.7913966=-0.99524033-1.7961562
-0.49475235=1.0867442-1.5814966
-1.8206344=-0.29114777-1.5294867
1.2798029=-0.42928803+1.7090911

这里的y_x12_1是经过grad_ys解包出来的,但是并不对,因为y对x的解包,应该是对xxx的第一列有shape(grad_ys)个导数,对xxx的第二列有shape(grad_ys)个导数。
但是怎么解出来,还有待考证。

tensorflow中tf.gradients()解析相关推荐

  1. Tensorflow中tf.ConfigProto()详解

    参考Tensorflow Machine Leanrning Cookbook tf.ConfigProto()主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算 具体代码如 ...

  2. tensorflow中tf.nn.xw_plus_b

    tf.nn.xw_plus_b((x, weights) + biases) 相当于tf.matmul(x, weights) + biases #-*-coding:utf8-*- import t ...

  3. TensorFlow中 tf.space_to_depth()函数的用法

    目录 一.函数定义 二.解释范例 三.代码验证 一.函数定义 通俗易懂些,就是把输入为[batch, height, width, channels]形式的Tensor,其在height和width维 ...

  4. TensorFlow 中 tf.app.flags.FLAGS 的用法介绍

    转载自:https://blog.csdn.net/lyc_yongcai/article/details/73456960 下面介绍 tf.app.flags.FLAGS 的使用,主要是在用命令行执 ...

  5. tensorflow中tf.random_normal和tf.truncated_normal的区别

    1.tf.truncated_normal使用方法 tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=No ...

  6. tensorflow中tf.get_variable()函数详解

    如果变量存在,函数tf.get_variable()会返回现有的变量:如果变量不存在,会根据给定形状和初始值创建一个新的变量. def get_variable(name, shape=None, d ...

  7. tensorflow 里metrics_深入理解TensorFlow中的tf.metrics算子

    [IT168 技术]01 概述 本文将深入介绍Tensorflow内置的评估指标算子,以避免出现令人头疼的问题. tf.metrics.accuracy() tf.metrics.precision( ...

  8. 【转】tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究

    笔者近来在tensorflow中使用batch_norm时,由于事先不熟悉其内部的原理,因此将其错误使用,从而出现了结果与预想不一致的结果.事后对其进行了一定的调查与研究,在此进行一些总结. 一.错误 ...

  9. 中tile函数_HelpGirlFriend 系列 --- tensorflow 中的张量运算思想

    GirlFriend 在复现论文的时候,我发现她不太会将通用数学公式转化为张量运算公式,导致 tensorflow 无法通过并行的方式优化其论文复现代码的运行速率. 这里对给 GirlFriend 讲 ...

最新文章

  1. 未来智能制造就是跨界大数据
  2. Callback到Promise再到Async进化初探
  3. php实现多商家开发,Thinkphp5.0实战-仿百度糯米开发多商家电商平台学习注意事项...
  4. 【BZOJ1623】 [Usaco2008 Open]Cow Cars 奶牛飞车 贪心
  5. ubuntu server设置时区和更新时间
  6. 安装thinkphp5后访问public index.php 报错require(/www/wwwroot/test3.com/public/../vendor/autoload.php)
  7. This time, ZTE has released the world‘s first
  8. pytorch 实现半圆数据分类
  9. 查看Linux系统版本的命令
  10. 第六章 prototype和constructor
  11. EditPlus3.1工具以及Js插件(打包下载)
  12. cnpack 菜单顺序
  13. windows保护无法启动修复服务器,解决使用sfc命令提示“windows 资源保护无法启动修复服务”的方法...
  14. UDP进程terminated
  15. 关于Mysql中的生日提醒
  16. 计算机培训简报膜报,第二期计算机培训简报(第十二期)
  17. ChatGPT 大行其道,带你走近 AIGC
  18. 一头 一头百兆全双工 自动协商 测试
  19. Zigbee系列 学习笔记五(信道选择)
  20. 2019年团体程序设计天梯赛总结

热门文章

  1. 计算机硬件创新,最新发现与创新:让计算机硬件性能发挥到极致
  2. instagram akp_如何备份您的社交媒体帐户-Facebook,Twitter,Google +和Instagram
  3. 2192. 有向无环图中一个节点的所有祖先(邻接表 加 拓扑排序)
  4. 人工雨量计_遥测雨量计与人工雨量观测对比分析
  5. 飞桨常规赛:PALM眼底彩照中黄斑中央凹定位-9月第1名方案
  6. Extjs操作Dom
  7. MATLAB1770太阳黑子,太阳黑子周期matlab仿真
  8. keep2share 购买的激活码但激活不了
  9. 气体流量开关的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告
  10. 课堂笔记5(大学生作业)