RNN入门学习

原文地址:http://blog.csdn.net/hjimce/article/details/49095371

作者:hjimce

一、相关理论

RNN(Recurrent Neural Networks)中文名又称之为:循环神经网络(原来还有一个递归神经网络,也叫RNN,搞得我有点混了,菜鸟刚入门,对不上号)。在计算机视觉里面用的比较少,我目前看过很多篇计算机视觉领域的相关深度学习的文章,就除了2015 ICCV的一篇图像语意分割文献《Conditional Random Fields as Recurrent Neural Networks》有提到RNN这个词外,目前还未见到其它的把RNN用到图像上面。RNN主要用于序列问题,如自然语言、语音音频等领域,相比于CNN来说,简单很多,CNN包含:卷积层、池化层、全连接层、特征图等概念,RNN基本上就仅仅只是三个公式就可以搞定了,因此对于RNN我们只需要知道三个公式就可以理解RNN了。说实话,一开是听到循环神经网络这个名子,感觉好难的样子,因为曾经刚开始学CNN的时候,也有很多不懂的地方。还是不啰嗦了,……开始前,我们先回顾一下,简单的MLP三层神经网络模型:

简单MLP模型

上面那个图是最简单的浅层网络模型了,x为输入,s为隐藏层神经元,o为输出层神经元。然后U、V就是我们要学习的参数了。上面的图很简单,每层神经元的个数就只有一个,我们可以得到如下公式:

(1)隐藏层神经元的激活值为:

s=f(u*x+b1)

(2)然后输出层的激活值为:

o=f(v*s+b2)

这就是最简单的三层神经网络模型的计算公式了,如果对上面的公式,还不熟悉,建议还是看看神经网络的书,打好基础先。而其实RNN网络结构图,仅仅是在上面的模型上,加了一条连接线而已,RNN结构图:

RNN结构图

看到结构图,是不是觉得RNN网络好像很简单的样子,至少没有像CNN过程那么长。从上面的结构图看,RNN网络基础结构,就只有一个输入层、隐藏层、输出层,看起来好像跟传统浅层神经网络模型差不多(只包含输出层、隐藏层、输出层),唯一的区别是:上面隐藏层多了一天连接线,像圆圈一样的东东,而那条线就是所谓的循环递归,同时那个圈圈连接线也多了个一个参数W。还是先看一下RNN的展开图,比较容易理解:

我们直接看,上面展开图中,Ot的计算流程,看到隐藏层神经元st的输入包含了两个:来时xt的输入、来自st-1的输入。于是RNN,t时刻的计算公式如下:

(1)t时刻,隐藏层神经元的激活值为:

st=f(u*xt+w*st-1+b1)

(2)t时刻,输出层的激活值为:

ot=f(v*st+b2)

是不是感觉上面的公式,跟一开始给出的MLP,公式上就差那么一点点。仅仅只是上面的st计算的时候,在函数f变量计算的时候,多个一个w*s t-1。

二、源码实现

下面结合代码,了解代码层面的RNN实现:

[python] view plaincopy
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Oct 08 17:36:23 2015
  4. @author: Administrator
  5. """
  6. import numpy as np
  7. import codecs
  8. data = open('text.txt', 'r').read() #读取txt一整个文件的内容为字符串str类型
  9. chars = list(set(data))#去除重复的字符
  10. print chars
  11. #打印源文件中包含的字符个数、去重后字符个数
  12. data_size, vocab_size = len(data), len(chars)
  13. print 'data has %d characters, %d unique.' % (data_size, vocab_size)
  14. #创建字符的索引表
  15. char_to_ix = { ch:i for i,ch in enumerate(chars) }
  16. ix_to_char = { i:ch for i,ch in enumerate(chars) }
  17. print char_to_ix
  18. hidden_size = 100 # 隐藏层神经元个数
  19. seq_length = 20 #
  20. learning_rate = 1e-1#学习率
  21. #网络模型
  22. Wxh = np.random.randn(hidden_size, vocab_size)*0.01 # 输入层到隐藏层
  23. Whh = np.random.randn(hidden_size, hidden_size)*0.01 # 隐藏层与隐藏层
  24. Why = np.random.randn(vocab_size, hidden_size)*0.01 # 隐藏层到输出层,输出层预测的是每个字符的概率
  25. bh = np.zeros((hidden_size, 1)) #隐藏层偏置项
  26. by = np.zeros((vocab_size, 1)) #输出层偏置项
  27. #inputs  t时刻序列,也就是相当于输入
  28. #targets t+1时刻序列,也就是相当于输出
  29. #hprev t-1时刻的隐藏层神经元激活值
  30. def lossFun(inputs, targets, hprev):
  31. xs, hs, ys, ps = {}, {}, {}, {}
  32. hs[-1] = np.copy(hprev)
  33. loss = 0
  34. #前向传导
  35. for t in xrange(len(inputs)):
  36. xs[t] = np.zeros((vocab_size,1)) #把输入编码成0、1格式,在input中,为0代表此字符未激活
  37. xs[t][inputs[t]] = 1
  38. hs[t] = np.tanh(np.dot(Wxh, xs[t]) + np.dot(Whh, hs[t-1]) + bh) # RNN的隐藏层神经元激活值计算
  39. ys[t] = np.dot(Why, hs[t]) + by # RNN的输出
  40. ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t])) # 概率归一化
  41. loss += -np.log(ps[t][targets[t],0]) # softmax 损失函数
  42. #反向传播
  43. dWxh, dWhh, dWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)
  44. dbh, dby = np.zeros_like(bh), np.zeros_like(by)
  45. dhnext = np.zeros_like(hs[0])
  46. for t in reversed(xrange(len(inputs))):
  47. dy = np.copy(ps[t])
  48. dy[targets[t]] -= 1 # backprop into y
  49. dWhy += np.dot(dy, hs[t].T)
  50. dby += dy
  51. dh = np.dot(Why.T, dy) + dhnext # backprop into h
  52. dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity
  53. dbh += dhraw
  54. dWxh += np.dot(dhraw, xs[t].T)
  55. dWhh += np.dot(dhraw, hs[t-1].T)
  56. dhnext = np.dot(Whh.T, dhraw)
  57. for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
  58. np.clip(dparam, -5, 5, out=dparam) # clip to mitigate exploding gradients
  59. return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs)-1]
  60. #预测函数,用于验证,给定seed_ix为t=0时刻的字符索引,生成预测后面的n个字符
  61. def sample(h, seed_ix, n):
  62. x = np.zeros((vocab_size, 1))
  63. x[seed_ix] = 1
  64. ixes = []
  65. for t in xrange(n):
  66. h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h) + bh)#h是递归更新的
  67. y = np.dot(Why, h) + by
  68. p = np.exp(y) / np.sum(np.exp(y))
  69. ix = np.random.choice(range(vocab_size), p=p.ravel())#根据概率大小挑选
  70. x = np.zeros((vocab_size, 1))#更新输入向量
  71. x[ix] = 1
  72. ixes.append(ix)#保存序列索引
  73. return ixes
  74. n, p = 0, 0
  75. mWxh, mWhh, mWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)
  76. mbh, mby = np.zeros_like(bh), np.zeros_like(by) # memory variables for Adagrad
  77. smooth_loss = -np.log(1.0/vocab_size)*seq_length # loss at iteration 0
  78. while n<20000:
  79. #n表示迭代网络迭代训练次数。当输入是t=0时刻时,它前一时刻的隐藏层神经元的激活值我们设置为0
  80. if p+seq_length+1 >= len(data) or n == 0:
  81. hprev = np.zeros((hidden_size,1)) #
  82. p = 0 # go from start of data
  83. #输入与输出
  84. inputs = [char_to_ix[ch] for ch in data[p:p+seq_length]]
  85. targets = [char_to_ix[ch] for ch in data[p+1:p+seq_length+1]]
  86. #当迭代了1000次,
  87. if n % 1000 == 0:
  88. sample_ix = sample(hprev, inputs[0], 200)
  89. txt = ''.join(ix_to_char[ix] for ix in sample_ix)
  90. print '----\n %s \n----' % (txt, )
  91. # RNN前向传导与反向传播,获取梯度值
  92. loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)
  93. smooth_loss = smooth_loss * 0.999 + loss * 0.001
  94. if n % 100 == 0: print 'iter %d, loss: %f' % (n, smooth_loss) # print progress
  95. # 采用Adagrad自适应梯度下降法,可参看博文:http://blog.csdn.net/danieljianfeng/article/details/42931721
  96. for param, dparam, mem in zip([Wxh, Whh, Why, bh, by],
  97. [dWxh, dWhh, dWhy, dbh, dby],
  98. [mWxh, mWhh, mWhy, mbh, mby]):
  99. mem += dparam * dparam
  100. param += -learning_rate * dparam / np.sqrt(mem + 1e-8) #自适应梯度下降公式
  101. p += seq_length #批量训练
  102. n += 1 #记录迭代次数

参考文献:

1、http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/

2、http://blog.csdn.net/danieljianfeng/article/details/42931721

3、声明:上面的源码例子是从github下载的,具体忘了是从哪个作者,非商业用途,仅供学习参考,如有侵权请联系博主删除

**********************作者:hjimce   时间:2015.10.23  联系QQ:1393852684   地址:http://blog.csdn.net/hjimce   原创文章,版权所有,转载请保留本行信息(不允许删除)

深度学习(十一)RNN入门学习相关推荐

  1. 强化学习(Reinforcement Learning)入门学习--01

    强化学习(Reinforcement Learning)入门学习–01 定义 Reinforcement learning (RL) is an area of machine learning in ...

  2. 流媒体学习-WebRTC全面入门学习-1

    一.初始WebRTC 1.WebRTC 就是音视频处理+即时通讯的开源库 音视频处理中ffmpeg和WebRTC是两个很重要的一部分,ffmpeg注重与数据音视频的编解码,文件的后处理.WebRTC整 ...

  3. matlab 编程学习,matlab编程入门学习(4)

    之前把matlab的一些基本知识点讲解了下,下面继续讲函数的部分 第五章. 自定义函数 5.1.简单介绍 好的编程习惯把大的程序分解成函数,有很多的好处,例如,程序部分的独立检测,代码的可复用性,避免 ...

  4. 用Python爬取了拉勾网的招聘信息+详细教程+趣味学习+快速爬虫入门+学习交流+大神+爬虫入门...

    关于 一直埋头学习,不知当前趋势,这是学习一门技术过程中最大的忌讳.刚好利用python爬虫,抓取一下拉勾网关于python职位的一些基本要求,不仅能知道岗位的基本技能要求,还能锻炼一下代码能力,学以 ...

  5. 学习dajango+sqlite3入门学习

    1 新建项目 django-admin.py startproject website1 2 启动服务器,查看是否正常 manage.py runserver 3 进入工程website1一级文件夹下 ...

  6. 汇编入门学习笔记 (十二)—— int指令、port

    疯狂的暑假学习之  汇编入门学习笔记 (十二)--  int指令.port 參考: <汇编语言> 王爽 第13.14章 一.int指令 1. int指令引发的中断 int n指令,相当于引 ...

  7. 计算机指令int,汇编入门学习笔记 (十二)—— int指令、端口

    疯狂的暑假学习之  汇编入门学习笔记 (十二)--  int指令.端口 参考: <汇编语言> 王爽 第13.14章 一.int指令 1. int指令引发的中断 int n指令,相当于引发一 ...

  8. JBox2d入门学习二 -----我的小鸟

    入门学习一当中我学会了如何定义并且创建一个世界,在世界当中定义并且创建一个刚体,并尝试给刚体一个力.最近比较忙..现在抽空实现了一个类似于愤怒小鸟的例子,先看看图吧.   贴代码,注解写的比较详细了, ...

  9. 算法竞赛入门学习(篇一)

    算法竞赛入门学习 算法竞赛入门学习,本文习题来自牛客网教程. 一.枚举与贪心 优化枚举的基本思路,减少枚举次数 选择合适的枚举对象 选择合适的枚举方向--排除非法或不是最优的情况 选择合适的数据维护方 ...

最新文章

  1. ionic4 select 去掉确定取消按钮_word文档中的水印如何去掉,有三种方法,你最喜欢哪种?...
  2. C++ string中find ,rfind 等函数 用法总结及示例
  3. linux 检测日志文件内容变化
  4. 3DSlicer8:FAQ-2
  5. 见识过世界的强大,才能拥有掌握世界的力量
  6. Java Spring Security示例教程中的2种设置LDAP Active Directory身份验证的方法
  7. 前端学习(3257):js高级教程(1)准备
  8. 主角的创建与选择 Learn Unreal Engine (with C++)
  9. python脚本怎么打印日志_python 接口测试1 --如何创建和打印日志文件
  10. php自定义函数参数,php自定义函数的参数
  11. 给文章添加目录的方法
  12. 火狐浏览器设置url编码_关于不同浏览器对URL编码的分析(转)
  13. lsoci mysql_flask项目从sqlite3升级的mysql数据库
  14. 集群故障处理之处理思路以及健康状态检查(三十二)
  15. 启动Jmeter录制代理进行录制,报 jmeter.protocol.http.proxy.ProxyControl
  16. C++ class 和 struct 构造函数
  17. 摄像镜头型号参数分类
  18. 外贸常用术语_13个常用的国际贸易术语详解
  19. ChatGPT 会开源吗?
  20. 购买学生服务器、备案域名、搭建博客菜鸟级教程

热门文章

  1. 随机梯度下降(Stochastic gradient descent)和 批量梯度下降(Batch gradient descent )的公式对比、实现对比
  2. Apache Kafka-Spring Kafka生产消费@KafkaListener源码解析
  3. Shell-alias在Shell脚本中的使用
  4. Android日期分组,按查询分组在列表视图android中显示一些意...
  5. java web 播放音频_使用Java ME以流形式播放Web服务器上的音乐文件
  6. /bin/bash: jar: command not found
  7. Linux——进程系列知识详述(操作系统、PCB进程控制块、查看进程状态等)
  8. 台式电脑键盘f1是计算机怎么取消,开机F1怎么取消,教您开机F1怎么取消
  9. 深入理解ROS技术 【2】ROS下的模块详解(66-128)
  10. wdcp导出mysql_phpmyadmin导入导出mysql(只适用WDCP系统)