tensorflow频域操作及梯度求取
tensorflow频域操作及梯度求取
最近尝试使用tensorflow中的傅立叶变换操作,主要涉及的op有tf.complex
,tf.fft
, tf.fft2d
,tf.angle
, 涉及的数据类型为tf.complex64
, 这里主要介绍以下几个部分:
- 一维离散傅立叶变换函数:tf.fft使用示例
- 复数梯度求取算子:tf.gradients使用示例
- 阻挡梯度的传播:tf.stop_gradient()使用示例
- 高阶导数的求取:tf.gradients高阶示例
部分内容节选自tensorflow学习笔记(三十):tf.gradients 与 tf.stop_gradient() 与 高阶导数
一、一维离散傅立叶变换函数:tf.fft使用示例
import tensorflow as tf
# 两种初始化复数的办法
# 方法一:tf.cast():弊病,tf.cast()无法回传梯度
input_real=tf.constant([2,3])
input_complex=tf.cast(input_real ,tf.complex64)#方法二:分别设置实部和虚部,用complex拼接
real=tf.constant([2.25, 3.25])
imag=tf.constant([4.75, 5.75])
complex=tf.complex(real, imag)#方法三:初始化一个跟实部相同shape的虚部,用complex拼接:可以回传梯度
tf.complex(sqrtmag, tf.zeros(sqrtmag.shape))fft=tf.fft(complex)
angle=tf.angle(fft)
二、复数梯度求取算子:tf.gradients使用示例
1. grad_ys
的测试:
这里要注意两点:
- grad_ys的shape应该与tf.gradients的第一个参数保持一致,
[tf.convert_to_tensor([2.,2.,3.]),tf.convert_to_tensor([3.,2.,4.])]
和[z1, z2]
的shape是一致的 - 这里的数字实际上是给每一个维度的梯度增加一个乘法因子,如果都设置为1,则直接按照公式求取梯度即可,作用应该是在第三部分回传梯度时,可以自由设置不同维度的梯度值,但如果单纯只是需要求取表达式对于某一个变量的梯度,设置为1即可。
import tensorflow as tf
w1 = tf.get_variable('w1', shape=[3])
w2 = tf.get_variable('w2', shape=[3])
w3 = tf.get_variable('w3', shape=[3])
w4 = tf.get_variable('w4', shape=[3])z1 = w1 + w2+ w3
z2 = w3 + w4
grads = tf.gradients([z1, z2], [w1, w2, w3, w4], grad_ys=[tf.convert_to_tensor([2.,2.,3.]),tf.convert_to_tensor([3.,2.,4.])])
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(grads))
2. 复数tf.gradients和grad_ys的使用
使用示例
import tensorflow as tf
real=tf.constant([2.25, 3.25])
imag=tf.constant([4.75, 5.75])
complex=tf.complex(real, imag)fft=tf.fft(complex)
angle=tf.angle(fft)
# complex对real的实部求梯度
grads_complex_real=tf.gradients(complex,real,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# fft对real求梯度
grads_fft_real=tf.gradients(fft,real,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# fft对complex求梯
度
grads_fft_complex=tf.gradients(fft,complex,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# angle对fft求梯度
grads_angle_fft=tf.gradients(angle,fft,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# angle对complex求梯度
grads_angle_complex=tf.gradients(angle,complex,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.float32)
# angle对real求梯度
grads_angle_real=tf.gradients(angle,real,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.float32))sess=tf.InteractiveSession()
print(sess.run([angle,grads_fft_complex])
3. 频域loss函数注意事项:
- 使用频域fft或者angle作为loss函数的输入时,由于离散傅立叶变换公式中(参见网页),某一个频率点的值由complex中每一个值和其对应的指数项,求和之后得到,其中指数项根据欧拉公式就是一个复数。公式总结如下:
fft=∑n−1i=0(complex中每个离散点的值(复数)×欧拉复数因子(由fft和complex的坐标决定,m,n和u,v))fft=∑i=0n−1(complex中每个离散点的值(复数)×欧拉复数因子(由fft和complex的坐标决定,m,n和u,v))fft =\sum _{i=0}^{n-1}\left (complex中每个离散点的值(复数) \times 欧拉复数因子(由fft和complex的坐标决定,m,n和u,v) \right )
因此,fft
值的数量级相当于是255×width×height255×width×height255\times width\times height,因此将fft
作为loss
函数的一部分时,fft
的值会非常的大,在梯度回传是,根据离散傅立叶反变换,complex
中每一个点的梯度,同样由(fft中每个离散点的值×欧拉复数因子)(fft中每个离散点的值×欧拉复数因子)\left ( fft中每个离散点的值\times 欧拉复数因子\right )决定,所回传的实际梯度值的数量级应该是△(fft−label)×width×height△(fft−label)×width×height\triangle\left ( fft-label \right ) \times width\times height, 而△(fft−label)△(fft−label)\triangle\left ( fft-label \right )是fft
和标签label
之间的差值,其数量级和fft的数量级相同为 255×width×height255×width×height255\times width\times height, 最终回传到complex
中每一个点的梯度的数量级应该是255×width2×height2255×width2×height2255\times width^{2}\times height^{2} - 使用tf.angle进行梯度回传到real或者imag的适合,由于
tf.angle = tf.atan(tf.imag/tf.real))
, 当中存在除法运算,在训练过程中可能会出现NAN的error。具体原因应该就是这个除法,有待进一步确定,解决办法暂时还没想到,也有待进一步研究。
三、阻挡梯度的传播:tf.stop_gradient()使用示例
阻挡节点BP
的梯度
import tensorflow as tf
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(gradients))
#输出 [1.0, 1.0]
虽然c
节点被stop
了,但是a
,b
还有从d
传回的梯度,所以还是可以输出梯度值的。
import tensorflow as tf
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)opt = tf.train.GradientDescentOptimizer(0.1)
gradients = tf.gradients(b, xs=tf.trainable_variables())tf.summary.histogram(gradients[0].name, gradients[0])# 这里会报错,因为gradients[0]是None
#其它地方都会运行正常,无论是梯度的计算还是变量的更新。总觉着tensorflow这么设计有点不好,
#不如改成流过去的梯度为0
train_op = opt.apply_gradients(zip(gradients, tf.trainable_variables()))print(gradients)
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(train_op))print(sess.run([w1, w2])
四、高阶导数的求取:tf.gradients高阶示例
tensorflow
求 高阶导数可以使用tf.gradients
来实现。
Note
: 有些op
,tf
没有实现其高阶导的计算,例如 tf.add …
, 如果计算了一个没有实现 高阶导的 op
的高阶导, gradients
会返回 None
。
import tensorflow as tfwith tf.device('/cpu:0'):a = tf.constant(1.)b = tf.pow(a, 2)grad = tf.gradients(ys=b, xs=a) # 一阶导print(grad[0])grad_2 = tf.gradients(ys=grad[0], xs=a) # 二阶导grad_3 = tf.gradients(ys=grad_2[0], xs=a) # 三阶导print(grad_3)with tf.Session() as sess:print(sess.run(grad_3))
tensorflow频域操作及梯度求取相关推荐
- Canny算子中的梯度求取及非最大值抑制(NMS)实现
@Canny算子中的非最大值抑制(NMS)实现 canny算子中的非极大值抑制是在对图像进行梯度求取之后,在梯度方向进行的运算,也就是说此处的非极大值抑制是在对图像进行梯度求取后,在生成的梯度矩阵上求 ...
- TensorFlow 2.0 - 张量/自动求导/梯度下降
文章目录 1. 张量 2. 自动求导.梯度下降 学习于:简单粗暴 TensorFlow 2 1. 张量 import tensorflow as tf print(tf.__version__) # ...
- 深度学习(11)TensorFlow基础操作七: 向前传播(张量)实战
深度学习(11)TensorFlow基础操作七: 向前传播(张量)实战 1. 导包 2. 加载数据集 3. 转换数据类型 4. 查看x.shape, y.shape, x.dtype, y.dtype ...
- 用机器学习算法来求取战斗力公式
一般游戏的战力公式,是一个线性回归方程: a*x+b*y+c*z+- =p 其中,p是战斗力,[a,b,c-]是属性,[x,y,z-]是属性价值. 属性一般包括:最大生命值,攻击力,防御力,闪避,暴击 ...
- 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...
- 深度学习(6)TensorFlow基础操作二: 创建Tensor
深度学习(6)TensorFlow基础操作二: 创建Tensor 一. 创建方式 1. From Numpy,List 2. zeros,ones (1) tf.zeros() (2) tf.zero ...
- 深度学习(5)TensorFlow基础操作一: TensorFlow数据类型
深度学习(5)TensorFlow基础操作一: TensorFlow数据类型 Data Container(数据载体) What's Tensor TF is a computing lib(科学计算 ...
- Tensorflow中的各种梯度处理gradient
最近其实一直想自己手动创建op,这样的话好像得懂tensorflow自定义api/op的规则,设计前向与反向,注册命名,注意端口以及文件组织,最后可能还要需要重新编译才能使用.这一部分其实记得tens ...
- 登月图片消噪及圆周率的求取(Python数据分析)
登月图片消噪 scipy.fftpack模块用来计算快速傅里叶变换 速度比传统傅里叶变换更快,是对之前算法的改进 图片是二维数据,注意使用fftpack的二维转变方法 # 所有的函数都可以使用正弦波表 ...
最新文章
- SAP RETAIL WRMO 补货监控
- 官方发布:深度学习高层API保姆级中文教程免费开放
- linux添加源地址ping,实战经验:Linux Source NAT在Ping场景下的应用
- HTML+CSS+JS实现 ❤️照相机快门图片动画特效❤️
- Qt工作笔记-Qt元对象系统解析【2合1】
- .NET轻量级任务管理类
- 【题解】牛客小白月赛16(部分题,待补充……)
- WebLogic12.1.1中跨域问题的探讨以及几种常见中间件中跨域问题的解决方法
- asp.net mysql数据库连接字符串_如何让您的ASP.NET数据库连接字符串是安全的
- MySQL正则表达式的问题
- 百度地图生成器,图标消失,中文乱码和自定义名字undefind
- MFQPPDCS测试理论(海盗派测试分析)
- 职业规划(一)怎么写简历
- 荣耀10青春版支持鸿蒙吗,荣耀10青春版详细评测:又一款年轻群体收割机
- [破解]天草初级笔记
- win10+node@16 安装特定版本 node-sass
- 防火墙用户管理和入侵防御简介
- 34.驱动--块设备驱动
- 外链怎么做?看看外链代发的这些黑幕!
- Unknown database ‘ ‘