tensorflow频域操作及梯度求取

最近尝试使用tensorflow中的傅立叶变换操作,主要涉及的op有tf.complextf.fft, tf.fft2d,tf.angle, 涉及的数据类型为tf.complex64, 这里主要介绍以下几个部分:

  • 一维离散傅立叶变换函数:tf.fft使用示例
  • 复数梯度求取算子:tf.gradients使用示例
  • 阻挡梯度的传播:tf.stop_gradient()使用示例
  • 高阶导数的求取:tf.gradients高阶示例

部分内容节选自tensorflow学习笔记(三十):tf.gradients 与 tf.stop_gradient() 与 高阶导数


一、一维离散傅立叶变换函数:tf.fft使用示例

import tensorflow as tf
# 两种初始化复数的办法
# 方法一:tf.cast():弊病,tf.cast()无法回传梯度
input_real=tf.constant([2,3])
input_complex=tf.cast(input_real ,tf.complex64)#方法二:分别设置实部和虚部,用complex拼接
real=tf.constant([2.25, 3.25])
imag=tf.constant([4.75, 5.75])
complex=tf.complex(real, imag)#方法三:初始化一个跟实部相同shape的虚部,用complex拼接:可以回传梯度
tf.complex(sqrtmag, tf.zeros(sqrtmag.shape))fft=tf.fft(complex)
angle=tf.angle(fft)

二、复数梯度求取算子:tf.gradients使用示例

1. grad_ys的测试:

这里要注意两点:

  1. grad_ys的shape应该与tf.gradients的第一个参数保持一致,[tf.convert_to_tensor([2.,2.,3.]),tf.convert_to_tensor([3.,2.,4.])][z1, z2]的shape是一致的
  2. 这里的数字实际上是给每一个维度的梯度增加一个乘法因子,如果都设置为1,则直接按照公式求取梯度即可,作用应该是在第三部分回传梯度时,可以自由设置不同维度的梯度值,但如果单纯只是需要求取表达式对于某一个变量的梯度,设置为1即可。
import tensorflow as tf
w1 = tf.get_variable('w1', shape=[3])
w2 = tf.get_variable('w2', shape=[3])
w3 = tf.get_variable('w3', shape=[3])
w4 = tf.get_variable('w4', shape=[3])z1 = w1 + w2+ w3
z2 = w3 + w4
grads = tf.gradients([z1, z2], [w1, w2, w3, w4], grad_ys=[tf.convert_to_tensor([2.,2.,3.]),tf.convert_to_tensor([3.,2.,4.])])
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(grads))

2. 复数tf.gradients和grad_ys的使用

使用示例

import tensorflow as tf
real=tf.constant([2.25, 3.25])
imag=tf.constant([4.75, 5.75])
complex=tf.complex(real, imag)fft=tf.fft(complex)
angle=tf.angle(fft)
# complex对real的实部求梯度
grads_complex_real=tf.gradients(complex,real,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# fft对real求梯度
grads_fft_real=tf.gradients(fft,real,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# fft对complex求梯
度
grads_fft_complex=tf.gradients(fft,complex,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# angle对fft求梯度
grads_angle_fft=tf.gradients(angle,fft,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.complex64))
# angle对complex求梯度
grads_angle_complex=tf.gradients(angle,complex,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.float32)
# angle对real求梯度
grads_angle_real=tf.gradients(angle,real,grad_ys=tf.convert_to_tensor([1,1],dtype=tf.float32))sess=tf.InteractiveSession()
print(sess.run([angle,grads_fft_complex])

3. 频域loss函数注意事项:

  1. 使用频域fft或者angle作为loss函数的输入时,由于离散傅立叶变换公式中(参见网页),某一个频率点的值由complex中每一个值和其对应的指数项,求和之后得到,其中指数项根据欧拉公式就是一个复数。公式总结如下:
    fft=∑n−1i=0(complex中每个离散点的值(复数)×欧拉复数因子(由fft和complex的坐标决定,m,n和u,v))fft=∑i=0n−1(complex中每个离散点的值(复数)×欧拉复数因子(由fft和complex的坐标决定,m,n和u,v))fft =\sum _{i=0}^{n-1}\left (complex中每个离散点的值(复数) \times 欧拉复数因子(由fft和complex的坐标决定,m,n和u,v) \right )
    因此,fft值的数量级相当于是255×width×height255×width×height255\times width\times height,因此将fft作为loss函数的一部分时,fft的值会非常的大,在梯度回传是,根据离散傅立叶反变换,complex中每一个点的梯度,同样由(fft中每个离散点的值×欧拉复数因子)(fft中每个离散点的值×欧拉复数因子)\left ( fft中每个离散点的值\times 欧拉复数因子\right )决定,所回传的实际梯度值的数量级应该是△(fft−label)×width×height△(fft−label)×width×height\triangle\left ( fft-label \right ) \times width\times height, 而△(fft−label)△(fft−label)\triangle\left ( fft-label \right )是fft和标签label之间的差值,其数量级和fft的数量级相同为 255×width×height255×width×height255\times width\times height, 最终回传到complex中每一个点的梯度的数量级应该是255×width2×height2255×width2×height2255\times width^{2}\times height^{2}
  2. 使用tf.angle进行梯度回传到real或者imag的适合,由于tf.angle = tf.atan(tf.imag/tf.real)), 当中存在除法运算,在训练过程中可能会出现NAN的error。具体原因应该就是这个除法,有待进一步确定,解决办法暂时还没想到,也有待进一步研究。

三、阻挡梯度的传播:tf.stop_gradient()使用示例

阻挡节点BP的梯度

import tensorflow as tf
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(gradients))
#输出 [1.0, 1.0]

虽然c节点被stop了,但是ab还有从d传回的梯度,所以还是可以输出梯度值的。

import tensorflow as tf
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)opt = tf.train.GradientDescentOptimizer(0.1)
gradients = tf.gradients(b, xs=tf.trainable_variables())tf.summary.histogram(gradients[0].name, gradients[0])# 这里会报错,因为gradients[0]是None
#其它地方都会运行正常,无论是梯度的计算还是变量的更新。总觉着tensorflow这么设计有点不好,
#不如改成流过去的梯度为0
train_op = opt.apply_gradients(zip(gradients, tf.trainable_variables()))print(gradients)
with tf.Session() as sess:tf.global_variables_initializer().run()print(sess.run(train_op))print(sess.run([w1, w2])

四、高阶导数的求取:tf.gradients高阶示例

tensorflow 求 高阶导数可以使用tf.gradients来实现。
Note: 有些optf没有实现其高阶导的计算,例如 tf.add …, 如果计算了一个没有实现 高阶导的 op的高阶导, gradients 会返回 None

import tensorflow as tfwith tf.device('/cpu:0'):a = tf.constant(1.)b = tf.pow(a, 2)grad = tf.gradients(ys=b, xs=a) # 一阶导print(grad[0])grad_2 = tf.gradients(ys=grad[0], xs=a) # 二阶导grad_3 = tf.gradients(ys=grad_2[0], xs=a) # 三阶导print(grad_3)with tf.Session() as sess:print(sess.run(grad_3))

tensorflow频域操作及梯度求取相关推荐

  1. Canny算子中的梯度求取及非最大值抑制(NMS)实现

    @Canny算子中的非最大值抑制(NMS)实现 canny算子中的非极大值抑制是在对图像进行梯度求取之后,在梯度方向进行的运算,也就是说此处的非极大值抑制是在对图像进行梯度求取后,在生成的梯度矩阵上求 ...

  2. TensorFlow 2.0 - 张量/自动求导/梯度下降

    文章目录 1. 张量 2. 自动求导.梯度下降 学习于:简单粗暴 TensorFlow 2 1. 张量 import tensorflow as tf print(tf.__version__) # ...

  3. 深度学习(11)TensorFlow基础操作七: 向前传播(张量)实战

    深度学习(11)TensorFlow基础操作七: 向前传播(张量)实战 1. 导包 2. 加载数据集 3. 转换数据类型 4. 查看x.shape, y.shape, x.dtype, y.dtype ...

  4. 用机器学习算法来求取战斗力公式

    一般游戏的战力公式,是一个线性回归方程: a*x+b*y+c*z+- =p 其中,p是战斗力,[a,b,c-]是属性,[x,y,z-]是属性价值. 属性一般包括:最大生命值,攻击力,防御力,闪避,暴击 ...

  5. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  6. 深度学习(6)TensorFlow基础操作二: 创建Tensor

    深度学习(6)TensorFlow基础操作二: 创建Tensor 一. 创建方式 1. From Numpy,List 2. zeros,ones (1) tf.zeros() (2) tf.zero ...

  7. 深度学习(5)TensorFlow基础操作一: TensorFlow数据类型

    深度学习(5)TensorFlow基础操作一: TensorFlow数据类型 Data Container(数据载体) What's Tensor TF is a computing lib(科学计算 ...

  8. Tensorflow中的各种梯度处理gradient

    最近其实一直想自己手动创建op,这样的话好像得懂tensorflow自定义api/op的规则,设计前向与反向,注册命名,注意端口以及文件组织,最后可能还要需要重新编译才能使用.这一部分其实记得tens ...

  9. 登月图片消噪及圆周率的求取(Python数据分析)

    登月图片消噪 scipy.fftpack模块用来计算快速傅里叶变换 速度比传统傅里叶变换更快,是对之前算法的改进 图片是二维数据,注意使用fftpack的二维转变方法 # 所有的函数都可以使用正弦波表 ...

最新文章

  1. SAP RETAIL WRMO 补货监控
  2. 官方发布:深度学习高层API保姆级中文教程免费开放
  3. linux添加源地址ping,实战经验:Linux Source NAT在Ping场景下的应用
  4. HTML+CSS+JS实现 ❤️照相机快门图片动画特效❤️
  5. Qt工作笔记-Qt元对象系统解析【2合1】
  6. .NET轻量级任务管理类
  7. 【题解】牛客小白月赛16(部分题,待补充……)
  8. WebLogic12.1.1中跨域问题的探讨以及几种常见中间件中跨域问题的解决方法
  9. asp.net mysql数据库连接字符串_如何让您的ASP.NET数据库连接字符串是安全的
  10. MySQL正则表达式的问题
  11. 百度地图生成器,图标消失,中文乱码和自定义名字undefind
  12. MFQPPDCS测试理论(海盗派测试分析)
  13. 职业规划(一)怎么写简历
  14. 荣耀10青春版支持鸿蒙吗,荣耀10青春版详细评测:又一款年轻群体收割机
  15. [破解]天草初级笔记
  16. win10+node@16 安装特定版本 node-sass
  17. 防火墙用户管理和入侵防御简介
  18. 34.驱动--块设备驱动
  19. 外链怎么做?看看外链代发的这些黑幕!
  20. Unknown database ‘ ‘

热门文章

  1. ftp、go-fastdfs、HelpManual、redis、git、ngnix
  2. 电子签名合同的有效期是多久
  3. 两天价网站背后重重迷雾:做个网站究竟要多少钱
  4. css路径自动加上了路径_CSS和关键路径
  5. 大写金额换算器iOS版源代码
  6. 忘记手机密码怎么用计算机解开,手机忘记密码怎么办?教你三种方法帮你搞定!...
  7. JavaScript经典pdf书籍推荐
  8. 你不得不掌握的前端提交规范(git cz)
  9. ubuntu18断电后recovering journal一直卡在开机界面
  10. 分析型CRM软件能帮到你什么?