摘要: 循环神经网络是如何工作的?如何构建一个Elman循环神经网络?在这里,教你手把手创建一个Elman循环神经网络进行简单的序列预测。

本文以最简单的RNNs模型为例:Elman循环神经网络,讲述循环神经网络的工作原理,即便是你没有太多循环神经网络(RNNs)的基础知识,也可以很容易的理解。为了让你更好的理解RNNs,我们使用Pytorch张量包和autograd库从头开始构建Elman循环神经网络。该文中完整代码在Github上是可实现的。

在这里,假设你对前馈神经网络略有了解。Pytorch和autograd库更为详细的内容请查看我的其他教程。

Elman循环神经网络

Jeff Elman首次提出了Elman循环神经网络,并发表在论文《Finding structure in time》中:它只是一个三层前馈神经网络,输入层由一个输入神经元x1和一组上下文神经元单元{c1 … cn}组成。隐藏层前一时间步的神经元作为上下文神经元的输入,在隐藏层中每个神经元都有一个上下文神经元。由于前一时间步的状态作为输入的一部分,因此我们可以说,Elman循环神经网络拥有一定的内存——上下文神经元代表一个内存。

预测正弦波

现在,我们来训练RNNs学习正弦函数。在训练过程中,一次只为模型提供一个数据,这就是为什么我们只需要一个输入神经元x1,并且我们希望在下一时间步预测该值。输入序列x由20个数据组成,并且目标序列与输入序列相同。

模型实现

首先导入包。

接下来,设置模型的超参数。设置输入层的大小为7(6个上下文神经元和1个输入神经元),seq_length用来定义输入和目标序列的长度。

生成训练数据:x是输入序列,y是目标序列。

创建两个权重矩阵。大小为(input_size,hidden_size)的矩阵w1用于隐藏连接的输入,大小为(hidden_size,output_size)的矩阵w2用于隐藏连接的输出。 用零均值的正态分布对权重矩阵进行初始化。

定义forward方法,其参数为input向量、context_state向量和两个权重矩阵,连接input和context_state创建xh向量。对xh向量和权重矩阵w1执行点积运算,然后用tanh函数作为非线性函数,在RNNs中tanh比sigmoid效果要好。 然后对新的context_state和权重矩阵w2再次执行点积运算。 我们想要预测连续值,因此这个阶段不使用任何非线性。

请注意,context_state向量将在下一时间步填充上下文神经元。 这就是为什么我们要返回context_state向量和out。

训练

训练循环的结构如下:

1.外循环遍历每个epoch。epoch被定义为所有的训练数据全部通过训练网络一次。在每个epoch开始时,将context_state向量初始化为0。

2.内部循环遍历序列中的每个元素。执行forward方法进行正向传递,该方法返回pred和context_state,将用于下一个时间步。然后计算均方误差(MSE)用于预测连续值。执行backward()方法计算梯度,然后更新权重w1和w2。每次迭代中调用zero_()方法清除梯度,否则梯度将会累计起来。最后将context_state向量包装放到新变量中,以将其与历史值分离开来。

训练期间产生的输出显示了每个epoch的损失是如何减少的,这是一个好的衡量方式。损失的逐渐减少则意味着我们的模型正在学习。

预测

一旦模型训练完毕,我们就可以进行预测。在序列的每一步我们只为模型提供一个数据,并要求模型在下一个步预测一个值。

预测结果如下图所示:黄色圆点表示预测值,蓝色圆点表示实际值,二者基本吻合,因此模型的预测效果非常好。

结论

在这里,我们使用了Pytorch从零开始构建一个基本的RNNs模型,并且学习了如何将RNNs应用于简单的序列预测问题。

原文链接

干货好文,请关注扫描以下二维码:

使用PyTorch从零开始构建Elman循环神经网络相关推荐

  1. PyTorch如何构建和实验神经网络

    点击上方"视学算法",马上关注 真爱,请设置"星标"或点个"在看" 作者 | Tirthajyoti Sarkar 来源 | Medium ...

  2. [Pytorch系列-61]:循环神经网络 - 中文新闻文本分类详解-3-CNN网络训练与评估代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  3. [Pytorch系列-60]:循环神经网络 - 中文新闻文本分类详解-2-LSTM网络训练与评估代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  4. [Pytorch系列-58]:循环神经网络 - 词向量的自动构建与模型训练代码示例

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  5. 【Python 初学者】从零开始构建自己的神经网络

    此图为使用神经网络预测猫狗案例. 原创:CSDN/知乎:川川菜鸟 文章目录 什么是神经网络? 训练神经网络 前向传播 损失函数 反向传播 完整应用 提问 结束语 什么是神经网络? 大多数神经网络的介绍 ...

  6. 独家 | 数据科学家的必备读物:从零开始用 Python 构建循环神经网络(附代码)...

    作者:Faizan Shaikh 翻译:李文婧 校对:张一豪 本文约4300字,建议阅读10+分钟. 本文带你快速浏览典型NN模型核心部分,并教你构建RNN解决相关问题. 引言 人类不会每听到一个句子 ...

  7. 【theano-windows】学习笔记十九——循环神经网络

    前言 前面已经介绍了RBM和CNN了,就剩最后一个RNN了,抽了一天时间简单看了一下原理,但是没细推RNN的参数更新算法BPTT,全名是Backpropagation Through Time. [注 ...

  8. 【火炉炼AI】深度学习004-Elman循环神经网络

    [火炉炼AI]深度学习004-Elman循环神经网络 (本文所使用的Python库和版本号: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib ...

  9. 从零开始学Pytorch(十)之循环神经网络基础

    本节介绍循环神经网络,下图展示了如何基于循环神经网络实现语言模型.我们的目的是基于当前的输入与过去的输入序列,预测序列的下一个字符.循环神经网络引入一个隐藏变量HHH,用HtH_{t}Ht​表示HHH ...

最新文章

  1. android manifest 分辨率,android程序界面自动适应屏幕分辨率例子
  2. AI人脸识别 生物识别 活体检测 的发展历程
  3. 如何在VS2013配置CUDA,并编译生成DLL
  4. zyUpload+struct2完成文件上传
  5. arr数组怎么取值_JS 面试之数组的几个不 low 操作
  6. 机器学习之判别/生成模型小结
  7. mybatis项目报错:java.sql.SQLException: ORA-00911: 无效字符 解决方法
  8. 【计算机组成原理笔记】计算机的基本组成
  9. 负载均衡之让nginx跑起来
  10. input框,需要隐式显示的时候,不让它自动填充的办法
  11. 使用MySQL中的对象数组查询JSON列
  12. .Net Frame安装心得
  13. 科研网站大全,你值得拥有!
  14. ArcGIS——地理配准操作
  15. quartz 表结构 mysql_Quartz表结构说明
  16. 计算机的软键盘在哪里,如何调出软键盘_怎么在电脑上调出软键盘_如何调出搜狗软键盘-Guide信息网...
  17. android 来电默认铃声,android – 来电动态覆盖默认铃声
  18. java定义一个周长类三角形_point类 三点的三角形的周长、面积 编程求解矩形和圆面积 java 三角形的定义...
  19. Nginx之一:Nginx的编译安装
  20. Anacoda的用途

热门文章

  1. c++ 读取访问权限冲突_Linux系统利用可执行文件的Capabilities实现权限提升
  2. c4d启动无反应_浙江无填料喷雾式冷却塔
  3. 【LeetCode笔记】剑指 Offer 13-. 机器人的运动范围 (Java、dfs)
  4. gns3中两个路由器分别连接主机然后分析ip数据转发报文arp协议_ARP协议在同网段及跨网段下的工作原理...
  5. 圆平移后的方程变化_平移法解题
  6. 蓝牙连接不上车要hfp_鹅厂又要霸屏,连接四部剧将袭,冲着主创颜值不追不行啦...
  7. android 上下数字滚动_原来PPT数字还有这么高大上的展示方式
  8. JAVA物体运动检测_基于OpenCv的运动物体检测算法
  9. 登上热搜!这可能是中国最穷的211大学
  10. 2020,这些前沿技术成全球关注热点