在深度学习中,RNN是处理序列数据的有效方法之一,也是深度的一种很好的体现,本文将简单介绍RNN的工作方式,以及针对IMDB数据集的简单实践

RNN简介

RNN(Recurrent Neural Network),在基本的全连接层上迭代一层或多层带有历史信息(h)的RNN神经单元(RNN cell),使神经网络能够处理具有上下文关联的序列数据,能够有效减少隐层的参数量,提升训练效率和准确率

为了更好的说明RNN的工作原理,我们带入一个具体的目标,就是评价情感分析,如图所示:

我们所要做的就是通过下方由单词组成的评论来确定其情感是积极还是消极。我们把语句定义为x,输出定义为y,输出的结果即:P(y|x)
这里的embedding操作可以简单理解为一个线性和,即

Oi=x@weighti+biasi

但这样简单的线性传递操作之后,只能通过每一个单词的含义来判定情感,无法关联到上下文,为了保存并处理上下文的语义,我们给线性操作附加一个历史信息h。如果这样处理,那我们完全可以省略掉针对每一个单词不同的weight,而使用一个公共的weight用于单词提取,称为weightx,同理偏置称为biasx,此时引入历史信息h,初始化h0为全零,则公式修改为:

Oi=x@weightx+hi@weighth+biasx=hi+1

每一次计算的输出和传递给下一层的历史信息其实是相同的,这里分开来写是为了下一篇LSTM留坑;而所谓的传递给下一层,实际上可以由同一个RNNcell迭代完成,这也是RNN名字的由来
说完了公式,我们回到神经网络的根基,也就是梯度的求解
额外的参数定义:
t表示第t个句子,或者t时刻
激活函数——tan()
则:
ht=tan(x*weightx+ht-1@weighth)
yt=weighto*ht
这里我们忽略偏置
则损失函数的梯度由链式法则可以写为:

第一个导数,由于损失函数和t时刻输出yt是直接关联的,因此第一个导数就是我们定义的损失函数对yt的直接求导,已知
第二个导数,当前时刻输出yt对当前时刻历史信息ht的导数在公式中可直接看出为weighto,已知
第三个导数,
令f=tanh(x),由ht公式可知

推导过程请自行演算
第四个导数,对tan激活函数求导后再对weighth 求导即可,已知

综上可知,RNN梯度的复杂度需要对时间轴进行展开,复杂程度很高,因此需要用到TensorFlow等框架进行计算

IMDB数据集和RNN网络的简单实践

对于数据集的加载可以直接使用TensorFlow2下的Keras中Dataset直接导入,如果下载速度很慢可能是因为……你懂得

total_words = 10000
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)

接下来做数据预处理

max_review_len = 80
# x_train:[b, 80]
# x_test: [b, 80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)

训练集和测试集构建

db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz, drop_remainder=True)

简单起见,我们只设计一层的RNN网络,自定义一个RNN网络用于训练

class MyRNN(keras.Model):def __init__(self, units):super(MyRNN, self).__init__()# [b, 64]self.state = [tf.zeros([batchsz, units])]# self.state1 = [tf.zeros([batchsz, units])]# transform text to embedding representation# [b, 80] => [b, 80, 100]self.embedding = layers.Embedding(total_words, embedding_len,input_length=max_review_len)# [b, 80, 100] , h_dim: 64# RNN: cell1 ,cell2, cell3# SimpleRNNself.rnn_cell = layers.SimpleRNNCell(units, dropout=0.2)# self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)# fc, [b, 80, 100] => [b, 64] => [b, 1]self.fc= layers.Dense(1)def call(self, inputs, training=None):# [b, 80]x = inputs# embedding: [b, 80] => [b, 80, 100]x = self.embedding(x)# rnn cell compute# [b, 80, 100] => [b, 64]state = self.state# state1 = self.state1for word in tf.unstack(x, axis=1): # word: [b, 100]# h1 = x*wxh+h0*whh# out: [b, 64]out, state = self.rnn_cell(word, state, training)# out: [b, 64] => [b, 1]x = self.fc(out)# p(y is pos|x)prob = tf.sigmoid(x)return prob

然后使用TensorFlow2中的compile and fit功能即可实现训练和测试,给出笔者的运行结果

整体来看运行的正确率达到82%,没有达到很高的原因在于层数太少,仅仅简单实现了一层的RNN网络,同时可以发现笔者使用了随机种子,这样的随机RNN如果更换成更加贴合数据的因子就能够有所提升

以上就是全部内容,笔者目前研究生在读,所了解到的知识有限,欢迎大佬们留言一起交流学习

深度学习——RNN原理与TensorFlow2下的IMDB简单实践相关推荐

  1. 深度学习算法原理_用于对象检测的深度学习算法的基本原理

    深度学习算法原理 You just got a new drone and you want it to be super smart! Maybe it should detect whether ...

  2. 没人说得清深度学习的原理 只是把它当作一个黑箱来使

    没人说得清深度学习的原理 只是把它当作一个黑箱来使 人类正在慢慢接近世界的本质--物质只是承载信息模式的载体.人脑之外的器官都只是保障这一使命的给养舰队. 自从去年AlphaGo 完虐李世乭,深度学习 ...

  3. Python大数据综合应用 :零基础入门机器学习、深度学习算法原理与案例

    机器学习.深度学习算法原理与案例实现暨Python大数据综合应用高级研修班 一.课程简介 课程强调动手操作:内容以代码落地为主,以理论讲解为根,以公式推导为辅.共4天8节,讲解机器学习和深度学习的模型 ...

  4. 深度学习模型轻量化(下)

    深度学习模型轻量化(下) 2.4 蒸馏 2.4.1 蒸馏流程 蒸馏本质是student对teacher的拟合,从teacher中汲取养分,学到知识,不仅仅可以用到模型压缩和加速中.蒸馏常见流程如下图所 ...

  5. TensorFlow深度学习算法原理与编程实战 人工智能机器学习技术丛书

    作者:蒋子阳 著 出版社:中国水利水电出版社 品牌:智博尚书 出版时间:2019-01-01 TensorFlow深度学习算法原理与编程实战 人工智能机器学习技术丛书 ISBN:97875170682 ...

  6. python原理书籍_python书籍推荐:《深入浅出深度学习:原理剖析与Python实践》

    在过去的这十年,深度学习已经席卷了整个科技界和工业界,2016年谷歌阿尔法狗打败围棋世界冠军李世石,更是使其成为备受瞩目的技术焦点. 今日,小编就为大家推荐一本能让初学者和"老司机" ...

  7. 深度学习环境配置10——Ubuntu下的torch==1.7.1环境配置

    深度学习环境配置10--Ubuntu下的torch==1.7.1环境配置 注意事项 一.2022/9/18更新 学习前言 各个版本pytorch的配置教程 环境内容 环境配置 一.Anaconda安装 ...

  8. 深度学习环境配置5——windows下的torch-cpu=1.2.0环境配置

    深度学习环境配置5--windows下的torch-cpu=1.2.0环境配置 注意事项 一.2021/10/8更新 学习前言 各个版本pytorch的配置教程 环境内容 环境配置 一.Anacond ...

  9. 深度学习环境配置1——windows下的tensorflow-gpu=1.13.2环境配置

    深度学习环境配置1--windows下的tensorflow-gpu=1.13.2环境配置 注意事项 一.2021/9/11更新 二.2021/7/8更新 三.2020/11/5更新 学习前言 环境内 ...

最新文章

  1. 文件节点的linux指令,Java工程师必学的Linux命令(一)文件与目录管理
  2. windows server2003 升级到windows server2012
  3. NameValueCollection类总结和一个例子源码
  4. python方法调用名字不一样_python中调用父类同名方法
  5. Linux系统下按了Ctrl+s锁定屏幕后怎么办?
  6. 廖雪峰JavaScript学习笔记(基础及数据类型、变量)
  7. html中的保存功能代码怎么写,java保存html代码怎么写
  8. Android MediaRecorder录制视频提示start failed的解决办法
  9. 转: ImageMagick 命令行的图片处理工具(客户端与服务器均可用)
  10. 转:CRC校验之模2除法
  11. linux java keytool_JDK自带的keytool证书工具详解
  12. Ubuntu18.04、Ubuntu20.04之ROS安装教程
  13. Android 自定义和可下载字体
  14. Java 删除session实现退出登录
  15. 计算机网络设置端口转发,怎么设置路由器端口转发功能?
  16. ERROR: Cannot uninstall 'wrapt'. It is a distutils installed project and thus we cannot accurately
  17. Android Room 数据库详解
  18. springboot2------自定义消息转换器
  19. Android 实现一键反混淆功能
  20. 新手追高,熟手突破,老手抄底,高手回撤,庄家筹码,机构算法!

热门文章

  1. sqlserver/mysql按天,按小时,按分钟统计连续时间段数据
  2. Leetcode No.145 **
  3. 9本java程序员必读的书(附下载地址)
  4. CobarClient源码分析
  5. .net获取地址栏中的url
  6. [转]对C#泛型中的new()约束的一点思考
  7. Array.from()
  8. spring4.2更好的应用事件
  9. Android简易实战教程--第四十七话《使用OKhttp回调方式获取网络信息》
  10. Mysql导出函数、存储过程