目录

6-1:推导RNN反向传播算法BPTT.

6-2P:设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试.


6-1:推导RNN反向传播算法BPTT.

在RNN中,输入序列为     \left \{ x^{(1)},x^{(2)},...,x^{(T)} \right \},输出序列为,标签序列为  \left \{ y^{(1)},y^{(2)},...,y^{(T)} \right \} ,其中的  均为列向量,模型的更新方式为:

损失函数为负对数似然函数,总体损失为每个时间步的损失之和:

需要更新的参数有 ,根据链式求导法则,先对进行求导。假设是 n 维向量, 则 也是 n维列向量, t 时刻的标签对应其中的第 i 维,则对应地定义  这样的向量:第 i 维为1,其余为0,则 :

其结果为n维列向量,考虑到序列之间的关系,根据公式 :

先计算最后的隐状态 的梯度:

这里对向量对向量求导结果,统一使用分母布局形式记录, n 维列向量 y对 m 维列向量 x 的导数矩阵维度为m×n ,即 :

这里经常用到这样一个式子,对于 ,其中z为标量,则

这样写是由矩阵元素摆放位置决定的。对于 τ 之前的隐状态  ,根据迭代系, 和  与之相关,故:

其中:

那么 :

也就是说,  的梯度可以递归的进行求解。接下来,我们写出b,W,U,V,c的梯度即可:

在计算时,需要将中间变量进行拆解,最后进行合成即可。

类似地

6-2P:设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试.

import torch
import numpy as npclass RNNCell:def __init__(self, weight_ih, weight_hh,bias_ih, bias_hh):self.weight_ih = weight_ihself.weight_hh = weight_hhself.bias_ih = bias_ihself.bias_hh = bias_hhself.x_stack = []self.dx_list = []self.dw_ih_stack = []self.dw_hh_stack = []self.db_ih_stack = []self.db_hh_stack = []self.prev_hidden_stack = []self.next_hidden_stack = []# temporary cacheself.prev_dh = Nonedef __call__(self, x, prev_hidden):self.x_stack.append(x)next_h = np.tanh(np.dot(x, self.weight_ih.T)+ np.dot(prev_hidden, self.weight_hh.T)+ self.bias_ih + self.bias_hh)self.prev_hidden_stack.append(prev_hidden)self.next_hidden_stack.append(next_h)# clean cacheself.prev_dh = np.zeros(next_h.shape)return next_hdef backward(self, dh):x = self.x_stack.pop()prev_hidden = self.prev_hidden_stack.pop()next_hidden = self.next_hidden_stack.pop()d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)self.prev_dh = np.dot(d_tanh, self.weight_hh)dx = np.dot(d_tanh, self.weight_ih)self.dx_list.insert(0, dx)dw_ih = np.dot(d_tanh.T, x)self.dw_ih_stack.append(dw_ih)dw_hh = np.dot(d_tanh.T, prev_hidden)self.dw_hh_stack.append(dw_hh)self.db_ih_stack.append(d_tanh)self.db_hh_stack.append(d_tanh)return self.dx_listif __name__ == '__main__':np.random.seed(123)torch.random.manual_seed(123)np.set_printoptions(precision=6, suppress=True)rnn_PyTorch = torch.nn.RNN(4, 5).double()rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),rnn_PyTorch.all_weights[0][1].data.numpy(),rnn_PyTorch.all_weights[0][2].data.numpy(),rnn_PyTorch.all_weights[0][3].data.numpy())nums = 3x3_numpy = np.random.random((nums, 3, 4))x3_tensor = torch.tensor(x3_numpy, requires_grad=True)h3_numpy = np.random.random((1, 3, 5))h3_tensor = torch.tensor(h3_numpy, requires_grad=True)dh_numpy = np.random.random((nums, 3, 5))dh_tensor = torch.tensor(dh_numpy, requires_grad=True)h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)h_numpy_list = []h_numpy = h3_numpy[0]for i in range(nums):h_numpy = rnn_numpy(x3_numpy[i], h_numpy)h_numpy_list.append(h_numpy)h3_tensor[0].backward(dh_tensor)for i in reversed(range(nums)):rnn_numpy.backward(dh_numpy[i])print("numpy_hidden :\n", np.array(h_numpy_list))print("torch_hidden :\n", h3_tensor[0].data.numpy())print("-----------------------------------------------")print("dx_numpy :\n", np.array(rnn_numpy.dx_list))print("dx_torch :\n", x3_tensor.grad.data.numpy())print("------------------------------------------------")print("dw_ih_numpy :\n",np.sum(rnn_numpy.dw_ih_stack, axis=0))print("dw_ih_torch :\n",rnn_PyTorch.all_weights[0][0].grad.data.numpy())print("------------------------------------------------")print("dw_hh_numpy :\n",np.sum(rnn_numpy.dw_hh_stack, axis=0))print("dw_hh_torch :\n",rnn_PyTorch.all_weights[0][1].grad.data.numpy())print("------------------------------------------------")print("db_ih_numpy :\n",np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))print("db_ih_torch :\n",rnn_PyTorch.all_weights[0][2].grad.data.numpy())print("-----------------------------------------------")print("db_hh_numpy :\n",np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))print("db_hh_torch :\n",rnn_PyTorch.all_weights[0][3].grad.data.numpy())
numpy_hidden :[[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391][ 0.365172 -0.361254  0.426838 -0.448951  0.331553][ 0.589187 -0.188248  0.684941 -0.45859   0.190099]][[ 0.146213 -0.306517  0.297109  0.370957 -0.040084][-0.009201 -0.365735  0.333659  0.486789  0.061897][ 0.030064 -0.282985  0.42643   0.025871  0.026388]][[ 0.225432 -0.015057  0.116555  0.080901  0.260097][ 0.368327  0.258664  0.357446  0.177961  0.55928 ][ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
torch_hidden :[[[ 0.4686   -0.298203  0.741399 -0.446474  0.019391][ 0.365172 -0.361254  0.426838 -0.448951  0.331553][ 0.589187 -0.188248  0.684941 -0.45859   0.190099]][[ 0.146213 -0.306517  0.297109  0.370957 -0.040084][-0.009201 -0.365735  0.333659  0.486789  0.061897][ 0.030064 -0.282985  0.42643   0.025871  0.026388]][[ 0.225432 -0.015057  0.116555  0.080901  0.260097][ 0.368327  0.258664  0.357446  0.177961  0.55928 ][ 0.103317 -0.029123  0.182535  0.216085  0.264766]]]
-----------------------------------------------
dx_numpy :[[[-0.643965  0.215931 -0.476378  0.072387][-1.221727  0.221325 -0.757251  0.092991][-0.59872  -0.065826 -0.390795  0.037424]][[-0.537631 -0.303022 -0.364839  0.214627][-0.815198  0.392338 -0.564135  0.217464][-0.931365 -0.254144 -0.561227  0.164795]][[-1.055966  0.249554 -0.623127  0.009784][-0.45858   0.108994 -0.240168  0.117779][-0.957469  0.315386 -0.616814  0.205634]]]
dx_torch :[[[-0.643965  0.215931 -0.476378  0.072387][-1.221727  0.221325 -0.757251  0.092991][-0.59872  -0.065826 -0.390795  0.037424]][[-0.537631 -0.303022 -0.364839  0.214627][-0.815198  0.392338 -0.564135  0.217464][-0.931365 -0.254144 -0.561227  0.164795]][[-1.055966  0.249554 -0.623127  0.009784][-0.45858   0.108994 -0.240168  0.117779][-0.957469  0.315386 -0.616814  0.205634]]]
------------------------------------------------
dw_ih_numpy :[[3.918335 2.958509 3.725173 4.157478][1.261197 0.812825 1.10621  0.97753 ][2.216469 1.718251 2.366936 2.324907][3.85458  3.052212 3.643157 3.845696][1.806807 1.50062  1.615917 1.521762]]
dw_ih_torch :[[3.918335 2.958509 3.725173 4.157478][1.261197 0.812825 1.10621  0.97753 ][2.216469 1.718251 2.366936 2.324907][3.85458  3.052212 3.643157 3.845696][1.806807 1.50062  1.615917 1.521762]]
------------------------------------------------
dw_hh_numpy :[[ 2.450078  0.243735  4.269672  0.577224  1.46911 ][ 0.421015  0.372353  0.994656  0.962406  0.518992][ 1.079054  0.042843  2.12169   0.863083  0.757618][ 2.225794  0.188735  3.682347  0.934932  0.955984][ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
dw_hh_torch :[[ 2.450078  0.243735  4.269672  0.577224  1.46911 ][ 0.421015  0.372353  0.994656  0.962406  0.518992][ 1.079054  0.042843  2.12169   0.863083  0.757618][ 2.225794  0.188735  3.682347  0.934932  0.955984][ 0.660546 -0.321076  1.554888  0.833449  0.605201]]
------------------------------------------------
db_ih_numpy :[7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_ih_torch :[7.568411 2.175445 4.335336 6.820628 3.51003 ]
-----------------------------------------------
db_hh_numpy :[7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_hh_torch :[7.568411 2.175445 4.335336 6.820628 3.51003 ]

ref:RNN的反向传播-BPTT - 知乎 (zhihu.com)

L5W1作业1 手把手实现循环神经网络_追寻远方的人的博客-CSDN博客

NNDL 作业9:分别使用numpy和pytorch实现BPTT相关推荐

  1. NNDL 作业3:分别使用numpy和pytorch实现FNN例题

    目录 一.过程推导--了解BP原理 二.数值计算 三.代码实现- numpy手推 + pytorch自动 1.对比[numpy]和[pytorch]程序,总结并陈述. (1)使用numpy实现 (2) ...

  2. 神经网络与深度学习 作业3:分别使用numpy和pytorch实现FNN例题

    目录 一.过程推导 - 了解BP原理 二.数值计算 - 手动计算,掌握细节 三.代码实现 - numpy手推 + pytorch自动 (1)使用numpy实现 (2)使用pytorch实现 (3)思考 ...

  3. NNDL 作业6:基于CNN的XO识别

    实现卷积-池化-激活 Numpy版本:手工实现 卷积-池化-激活 自定义卷积算子.池化算子实现,源码如下: import numpy as npx = np.array([[-1, -1, -1, - ...

  4. NNDL 作业8:RNN-简单循环网络

    简单循环网络(Simple Recurrent Network,SRN)是只有一个隐藏层的神经网络. 目录 1.使用Numpy实现SRN 2.在1的基础上,增加激活函数tanh 3.分别使用nn.RN ...

  5. NNDL 作业5:卷积

    目录 作业1 1. 图1使用卷积核​编辑,输出特征图 2. 图1使用卷积核​编辑,输出特征图 3. 图2使用卷积核​编辑,输出特征图 4. 图2使用卷积核​编辑,输出特征图 5. 图3使用卷积核​编辑 ...

  6. NNDL 作业7:第五章课后题

    目录 习题5-2 证明宽卷积具有交换性,即公式(5.13) 习题5-3 分析卷积神经网络中用1×1的卷积核的作用 习题5-4 对于一个输入为100×100×256的特征映射组,使用3×3的卷积核,输出 ...

  7. 深度盘点Python11个主流框架:Pandas、Django、Matplotlib、Numpy、PyTorch......

    六月份TIOBE编程语言排行榜,位居第二名的Python与第一名C语言之间的差距正在逐渐缩小.Python如此受欢迎一方面得益于它崇尚简洁的编程哲学,另一方面是因为强大的第三方库生态. 要说杀手级的库 ...

  8. 深度盘点 Python11 个主流框架:Pandas、Django、Matplotlib、Numpy、PyTorch......

    六月份TIOBE编程语言排行榜,位居第二名的Python与第一名C语言之间的差距正在逐渐缩小.Python如此受欢迎一方面得益于它崇尚简洁的编程哲学,另一方面是因为强大的第三方库生态. 要说杀手级的库 ...

  9. python numpy.arry, pytorch.Tensor及原生list相互转换

    文章目录 python numpy.arry, pytorch.Tensor及原生list相互转换 1 原生list转numpy list 2 numpy.array 转原生list 3 numpy. ...

最新文章

  1. python input 数字_Python:raw_input读取数字的问题
  2. php实现最大公约数,php求最大公约数
  3. Spring.net 模块组成
  4. Springmvc文件上传(servlet3.0)/下载(ssm)以及坑点
  5. 联想android刷机教程视频,联想s939刷机教程(刷官方系统)
  6. 从零点五开始用Unity做半个2D战棋小游戏(四)
  7. Ubuntu Linux配置IP地址
  8. selenium - Select类 - 下拉框
  9. 物联网落地三大困境破解
  10. 锻炼编程能力的10个游戏:通关既巅峰!
  11. Cognos开发自定义排序规则的报表和自定义排名报表
  12. 详解OpenCV中的cvCreateMat()函数
  13. webstorm破解之jar包破解(2018)
  14. 仿iReader 阅读器(swift)
  15. Video Extractor监控视频侦查取证分析系统
  16. 用2008系统安装k3服务器,金蝶K3SQL-Server-2008-R2安装方法介绍
  17. 数据结构实现排队系统
  18. P1463 [POI2001][HAOI2007]反素数 题解
  19. [笑语天下]风景、照片与评论古今
  20. 22-04-23 西安 javaSE(14)文件流、缓冲流、转换流、对象流、标准流、关闭IO资源的封装类IOUtils(纳命来)

热门文章

  1. 冷热水恒压供水系统,变频器控制,模拟量输入和输出处理,温度控制,流量计算控制
  2. pom.xml文件带有删除线的解决方案
  3. 硬核图解,30张图带你搞懂、路由器,集线器,交换机,网桥,光猫有啥区别?
  4. 分享88个ASP江湖论坛源码,总有一款适合您
  5. 群辉-利用自带套件DNSserver实现免改登录地址访问内外网NAS系统
  6. 抢东西用的时间软件_番茄工作法app怎么用?这款时间管理便签软件教你
  7. 大数据CDH安装详细教程
  8. 两年车间技术员转型大数据开发,说说转型这点事儿
  9. JS限制H5页面只能在手机微信中打开总结
  10. golang重写区块链——0.5 区块链中钱包、地址和签名的实现