RNN初探(vanilla RNN)

前言

实习工作需要,不得不入个新坑。

为什么需要Recurrent Neuron Network (RNN)

全连接神经网络(FCN)和卷积神经网络(CNN)所针对的输入对象相互之间可以没有关系,不分先后顺序,比如如果要对猫和狗的图像进行分类,猫和狗的输入顺序是无所谓的。不过,如果要识别视频中狗的动作,那么就需要一个新的网络(当然这里就是RNN啦)来分析这种序列数据。

举例

序列数据其实在生活中无处不在:

  • 机器翻译
  • 异常检测(图像)
  • 股票分析预测
  • 天气与蚊虫繁殖的关系…
    只要有序列数据,我们就可以用RNN来进行分析。

优点

  1. 无论序列长度如何,模型都月相同的输入大小 (same input size)
  2. 在每一步中都有可能使用相同的转换函数fff和相同的参数

RNN模型结构

输入

序列数据:x=[x(1),x(2),...,x(t),...,x(τ)]\bold x = [\bold x^{(1)},\bold x^{(2)},..., \bold x^{(t)},...,\bold x^{(\tau)}]x=[x(1),x(2),...,x(t),...,x(τ)], ttt代表的是时间节点,τ\tauτ代表的是月τ\tauτ个时间节点,每一个x(t)x^{(t)}x(t)都是一个d维度的向量。

公式

h(t)=f(h(t−1),x(t),θ)\bold h^{(t)} = f(\bold h^{(t-1)}, \bold x^{(t)}, \theta)h(t)=f(h(t−1),x(t),θ),h(t)\bold h^{(t)}h(t)代表的是在时间节点ttt的隐层(hidden layer),fff是激活函数。这个公式的意思是:在时间节点t的隐层和在时间节点t的输入,上一个时间节点(t-1)的隐层以及一个θ\thetaθ(所有参数的统计,h和x拼起来之后乘一个W得到下一个h,简单来说W和b就是θ\thetaθ)有关,所以这个公式确实表达了序列数据中输出结果与先后顺序有关的这种思想。

结构


左侧是将隐层折叠起来的样子,其实很好理解。xxx就是输入的序列数据,hhh代表了所有隐层,ooo就是输出(预测结果)了,LLL是loss,yyy就是真实数据(标签)。(注意:yyy->LLL<-ooo表示yyy和ooo进行比较得到Loss
接下来是三个权重矩阵:UUU代表从输入到隐层(input-to-hidden)的权重矩阵,WWW代表从一个隐层到下一个隐层(hidden-to-hidden)的权重矩阵,VVV代表了从最后一个隐层到输出(hidden-to-output)的权重矩阵。
右侧是将隐层展开,我们可以发现UUU,WWW,VVV三者全是相同的,这也就是在前面优点中所提到的,或者其实就是参数共享。其次,在每一个时间节点ttt下,有x(t),h(t),o(t),L(t),y(t)x^{(t)}, h^{(t)},o^{(t)},L^{(t)},y^{(t)}x(t),h(t),o(t),L(t),y(t)一一对应。
了解了最基础的RNN结构,之后对双向RNN,多层的RNN的理解也会变得更容易。

前向传播(forward propagation)

公式

从时间节点t=1t=1t=1到t=τt=\taut=τ,前向传播是以下面的方式进行的:
a(t)=b+Wh(t−1)+Ux(t)\bold{a^{(t)} = b + Wh^{(t-1)}+Ux^{(t)}}a(t)=b+Wh(t−1)+Ux(t)
h(t)=tanh⁡(a(t))\bold{h^{(t)}} = \tanh\bold{(a^{(t)})}h(t)=tanh(a(t)),tanh⁡\tanhtanh作为激活函数
o(t)=c+Vh(t)\bold{o^{(t)} = c + Vh^{(t)}}o(t)=c+Vh(t)
y^(t)=softmax(o(t))\bold{\hat{y}^{(t)}} = softmax(\bold{o^{(t)}})y^​(t)=softmax(o(t))
L(t)=J(y^(t),y(t))L^{(t)} = J(\bold{\hat{y}^{(t)},y^{(t)}})L(t)=J(y^​(t),y(t)),JJJ作为计算loss的函数(L2L_2L2​或者交叉熵等等)
RNN的前向传播其实和CNN的大同小异

反向传播(back-propagation through time)

公式

大家应该注意到了,RNN的反向传播的名字是BPTT,比CNN多了个TT,其实就是多了一个从时间尽头向时间开始反向传播的通道,下面我们来介绍一下BPTT吧。

从图上看其实非常显而易见,就是要找那几个红色的梯度加上在前向传播中引入的bbb和ccc的梯度。也就是:∇VL,∇WL,∇UL,∇bL,∇cL\nabla_\bold{V}L,\nabla_\bold{W}L,\nabla_\bold{U}L,\nabla_\bold{b}L,\nabla_\bold{c}L∇V​L,∇W​L,∇U​L,∇b​L,∇c​L
公式如下:
∇cL=∑t∇o(t)L\nabla_\bold{c}L=\sum_{t}\nabla_\bold{o^{(t)}}L∇c​L=∑t​∇o(t)​L
∇bL=∑tdiag(1−(h(t))2)∇h(t)L\nabla_\bold{b}L=\sum_{t}diag(1-(\bold{h}^{(t)})^2)\nabla_\bold{h^{(t)}}L∇b​L=∑t​diag(1−(h(t))2)∇h(t)​L (diag()diag()diag()代表对角矩阵,也就是除了对角线,其他位置全是0)
∇VL=∑t(∇o(t)L)h(t)T\nabla_\bold{V}L=\sum_{t}(\nabla_\bold{o^{(t)}}L)\bold{h}^{(t)^T}∇V​L=∑t​(∇o(t)​L)h(t)T
∇WL=∑tdiag(1−(h(t))2)(∇h(t)L)h(t−1)T\nabla_\bold{W}L=\sum_{t}diag(1-(\bold{h}^{(t)})^2)(\nabla_\bold{h^{(t)}}L)\bold{h}^{(t-1)^T}∇W​L=∑t​diag(1−(h(t))2)(∇h(t)​L)h(t−1)T
∇UL=∑tdiag(1−(h(t))2)(∇h(t)L)x(t−1)T\nabla_\bold{U}L=\sum_{t}diag(1-(\bold{h}^{(t)})^2)(\nabla_\bold{h^{(t)}}L)\bold{x}^{(t-1)^T}∇U​L=∑t​diag(1−(h(t))2)(∇h(t)​L)x(t−1)T

RNN初探(vanilla RNN)相关推荐

  1. 【深度学习】深入浅出CRF as RNN(以RNN形式做CRF后处理)

    [深度学习]深入浅出CRF as RNN(以RNN形式做CRF后处理) 文章目录 1 概述 2 目标 3 思路 4 简述 5 论文原文5.1 Introduction5.2 相关工作5.3 关键步骤 ...

  2. DL之RNN:基于RNN实现模仿贴吧留言

    DL之RNN:基于RNN实现模仿贴吧留言 目录 输出结果 代码设计 输出结果 更新-- 代码设计 注:CPU上跑的较慢,建议GPU运行代码

  3. 【PyTorch】4 姓氏分类RNN实战(Simple RNN)——18 种起源语言的数千种姓氏分类

    使用char-RNN对姓氏进行分类 1. 准备数据 2. 将名称转换为张量 3. 建立网络 4. 准备训练 5. 训练网络 6. 评估结果 7. 全部代码 小结 这是官方NLP From Scratc ...

  4. 机器学习笔记 RNN初探 LSTM

    1 引入 一个input的属性 会受到其前后文的影响-->神经网络需要记忆 这里"Taipei"的属性(destination还是source)受到前面的动词"ar ...

  5. Pytorch RNN(详解RNN+torch.nn.RNN()实现)

    目录 一.RNN简介 二.RNN简介2 三.pytorch RNN 3.1    定义RNN()

  6. 【RNN】基于RNN的动态系统参数辨识matlab仿真

    1.软件版本 matlab2017b 2.本算法理论知识 3.部分源码 clc; clear; close all; warning off; addpath 'func\'data = xlsrea ...

  7. CNN与RNN对比 CNN+RNN组合方式

    CNN和RNN几乎占据着深度学习的半壁江山,所以本文将着重讲解CNN+RNN的对比,以及各种组合方式. 一.CNN与RNN对比 1. CNN卷积神经网络与RNN递归神经网络直观图 2. 相同点: 传统 ...

  8. 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn

    参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 本章介绍的nn模块是构建与autogr ...

  9. 深度学习框架之PyTorch

    文章目录 1 PyTorch简介 2 PyTorch入门 2.1 Tensor 2.2 自动微分Autograd 2.3 神经网络 2.4 损失函数 2.5 优化器 2.6 数据加载与预处理 2.7 ...

最新文章

  1. 解决sql2014的distribution系统库distribution.mdf过大问题
  2. 使用PyTorch从零开始实现YOLO-V3目标检测算法 (三)
  3. DIB位图(Bitmap)的读取和保存
  4. tensorflow的tf.transpose()简单使用
  5. MySQL5.6主从复制(读写分离)方案
  6. MFC中树形控件的应用——电话簿
  7. 注解java_Java注解教程及自定义注解
  8. nginx+tocmat ip_hash做负载均衡时,一台tomcat宕机时没有转发问题
  9. TCP 和 UDP 的区别 TCP 和 UDP 详解
  10. 中国各主要大城市经纬度数据
  11. 灰度思维,黑白决策(下)
  12. 串口通讯---实现 PC 端之间串口连接传输文件
  13. 2021-2027全球及中国数控钻机行业研究及十四五规划分析报告
  14. Gitlab+猪齿鱼 实现自动化部署
  15. 2020电信宽带费用_电信宽带套餐价格表2020
  16. Centos8 磁力链BT地址
  17. 【Android应用开发详解】第01期:第三方授权认证(一)实现第三方授权登录、分享以及获取用户资料
  18. 计算机考研和就业pk,考研PK就业:提高自身竞争力比文凭更重要
  19. 与爱同行,育润走进贫困家庭,助推公益事业
  20. W5100S SPI+DMA 中的片选信号处理

热门文章

  1. C - 喵帕斯之天才算数少女
  2. 如何判断域名的潜力和价值?
  3. ubuntu使用清华源pip安装pytorch
  4. 用python使图形动起来?
  5. Archive一个Microsoft Teams里创建的Team
  6. 10亿级存储挑战!看一看、微信广告、微信支付、小程序都在用的存储系统究竟是怎么扛住的?!
  7. VR数字沙盘高度还原未来房屋实
  8. Acala 全球征文精选
  9. 【信仰充值中心】Firefox 96 正式版用户特性介绍
  10. springboot集成security