RNN LSTM 循环神经网络 (分类例子)

作者: Morvan    编辑: Morvan学习资料:机器学习-简介系列 什么是RNN

本代码基于网上这一份代码 code

本节的内容包括:

设置 RNN 的参数

这次我们会使用 RNN 来进行分类的训练 (Classification). 会继续使用到手写数字 MNIST 数据集. 让 RNN 从每张图片的第一行像素读到最后一行, 然后再进行分类判断. 接下来我们导入 MNIST 数据并确定 RNN 的各种参数(hyper-parameters):

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

tf.set_random_seed(1) # set random seed

# 导入数据

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# hyperparameters

lr = 0.001 # learning rate

training_iters = 100000 # train step 上限

batch_size = 128

n_inputs = 28 # MNIST data input (img shape: 28*28)

n_steps = 28 # time steps

n_hidden_units = 128 # neurons in hidden layer

n_classes = 10 # MNIST classes (0-9 digits)

接着定义 x, y 的 placeholder 和 weights, biases 的初始状况.

# x y placeholder

x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])

y = tf.placeholder(tf.float32, [None, n_classes])

# 对 weights biases 初始值的定义

weights = {

# shape (28, 128)

'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),

# shape (128, 10)

'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))

}

biases = {

# shape (128, )

'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),

# shape (10, )

'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))

}

定义 RNN 的主体结构

接着开始定义 RNN 主体结构, 这个 RNN 总共有 3 个组成部分 ( input_layer, cell, output_layer). 首先我们先定义 input_layer:

def RNN(X, weights, biases):

# 原始的 X 是 3 维数据, 我们需要把它变成 2 维数据才能使用 weights 的矩阵乘法

# X ==> (128 batches * 28 steps, 28 inputs)

X = tf.reshape(X, [-1, n_inputs])

# X_in = W*X + b

X_in = tf.matmul(X, weights['in']) + biases['in']

# X_in ==> (128 batches, 28 steps, 128 hidden) 换回3维

X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])

接着是 cell 中的计算, 有两种途径:使用 tf.nn.rnn(cell, inputs) (不推荐原因). 但是如果使用这种方法, 可以参考这个代码;

使用 tf.nn.dynamic_rnn(cell, inputs) (推荐). 这次的练习将使用这种方式.

因 Tensorflow 版本升级原因, state_is_tuple=True 将在之后的版本中变为默认. 对于 lstm 来说, state可被分为(c_state, h_state).

# 使用 basic LSTM Cell.

lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)

init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) # 初始化全零 state

如果使用tf.nn.dynamic_rnn(cell, inputs), 我们要确定 inputs 的格式. tf.nn.dynamic_rnn 中的 time_major 参数会针对不同 inputs 格式有不同的值.如果 inputs 为 (batches, steps, inputs) ==> time_major=False;

如果 inputs 为 (steps, batches, inputs) ==> time_major=True;

outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)

最后是 output_layer 和 return 的值. 因为这个例子的特殊性, 有两种方法可以求得 results.

方式一: 直接调用final_state 中的 h_state (final_state[1]) 来进行运算:

results = tf.matmul(final_state[1], weights['out']) + biases['out']

方式二: 调用最后一个 outputs (在这个例子中,和上面的final_state[1]是一样的):

# 把 outputs 变成 列表 [(batch, outputs)..] * steps

outputs = tf.unstack(tf.transpose(outputs, [1,0,2]))

results = tf.matmul(outputs[-1], weights['out']) + biases['out'] #选取最后一个 output

在 def RNN() 的最后输出 result

return results

定义好了 RNN 主体结构后, 我们就可以来计算 cost 和 train_op:

pred = RNN(x, weights, biases)

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))

train_op = tf.train.AdamOptimizer(lr).minimize(cost)

在1.0 版本:tf.nn.softmax_cross_entropy_with_logits 有所改变#labels为实际值,标签;;;logits是计算出来的预测值。cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=pred))

train_op = tf.train.AdamOptimizer(lr).minimize(cost)

correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

训练 RNN

训练时, 不断输出 accuracy, 观看结果:

correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法

# 替换成下面的写法:

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

step = 0

while step * batch_size < training_iters:

batch_xs, batch_ys = mnist.train.next_batch(batch_size)

batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])

sess.run([train_op], feed_dict={

x: batch_xs,

y: batch_ys,

})

if step % 20 == 0:

print(sess.run(accuracy, feed_dict={

x: batch_xs,

y: batch_ys,

}))

step += 1

最终 accuracy 的结果如下:

0.1875

0.65625

0.726562

0.757812

0.820312

0.796875

0.859375

0.921875

0.921875

0.898438

0.828125

0.890625

0.9375

0.921875

0.9375

0.929688

0.953125

....

如果你觉得这篇文章或视频对你的学习很有帮助, 请你也分享它, 让它能再次帮助到更多的需要学习的人.

莫烦没有正式的经济来源, 如果你也想支持 莫烦Python 并看到更好的教学内容, 请拉倒屏幕最下方, 赞助他一点点, 作为鼓励他继续开源的动力.

lstm python_5.8 莫烦 Python RNN LSTM 循环神经网络 (分类例子)相关推荐

  1. 莫烦 python_5.1 莫烦 Python Classification 分类学习

    Classification 分类学习 作者: Mark JingNB 编辑: Morvan 学习资料: 这次我们会介绍如何使用TensorFlow解决Classification(分类)问题. 之前 ...

  2. RNN LSTM 循环神经网络 (分类例子)

    学习资料: 相关代码 为 TF 2017 打造的新版可视化教学代码 机器学习-简介系列 什么是RNN 机器学习-简介系列 什么是LSTM RNN 本代码基于网上这一份代码 code 设置 RNN 的参 ...

  3. tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】

    之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...

  4. 【莫烦Python】机器要说话 NLP 自然语言处理教程 W2V Transformer BERT Seq2Seq GPT 笔记

    [莫烦Python]机器要说话 NLP 自然语言处理教程 W2V Transformer BERT Seq2Seq GPT 笔记 教程与代码地址 P1 NLP行业大佬采访 P2 NLP简介 P3 1. ...

  5. 理论——RNN(循环神经网络)与LSTM(长短期记忆神经网络)

    这里写目录标题 RNN 背景 结构 应用 梯度消失.爆炸 LSTM 长期依赖问题 LSTM网络 结构 RNN 背景 人类的思考具有连续性,我们常联系过去的经验来理解现在.比如阅读时我们常提及的&quo ...

  6. 自然语言菜鸟学习笔记(七):RNN(循环神经网络)及变体(LSTM、GRU)理解与实现(TensorFlow)

    目录 前言 RNN(循环神经网络) 为什么要用循环神经网络(RNN)? 循环神经网络(RNN)可以处理什么类型的任务? 多对一问题 一对多问题 多对多问题 循环神经网络结构 单层网络情况 正向传播 反 ...

  7. tkinter 笔记 checkbutton 勾选项 (莫烦python笔记)

    和前面radiobutton的区别在于,radiobutton 各选项只能勾选一个,checkbutton可以勾选多个,也可以不勾选 1 主体框架 还是一样的 import tkinter as tk ...

  8. tkinter笔记:scale 尺度 (莫烦python笔记)

    1 主题框架 还是一样的 import tkinter as tkwindow = tk.Tk() #创建窗口window.title('my window') #窗口标题window.geometr ...

  9. tkinter 笔记: radiobutton 选择按钮(莫烦python笔记)

    1 主体框架还是那个主体框架 window = tk.Tk() window.title('my window') window.geometry('500x500') 2 设置tkinter的文字变 ...

  10. tkinter 笔记:列表部件 listbox (莫烦python 笔记)

    1  主体框架 主体框架部分还是 import tkinter as tkwindow = tk.Tk() #创建窗口window.title('my window') #窗口标题window.geo ...

最新文章

  1. 0x62.图论 - 最小生成树
  2. 5G 信令流程 — 5GC 的业务请求(Service Request)
  3. raptor算法流程图例题_RAPTOR流程图+算法程序设计教程
  4. [NOI2007] 货币兑换 (dp+李超树维护凸包)
  5. Azure实践之如何批量为资源组虚拟机创建alert
  6. 实例:评审速度与缺陷密度之间的相关性
  7. 问题四十四:怎么用ray tracing画空间任意位置的圆环的任意片段
  8. 防爆技术在工业电子秤中的最新应用(转)
  9. 运动控制卡应用开发教程之Python
  10. Visual C++ 2010 Express 下载及安装教程
  11. Scratch案例——放烟花
  12. wowza拉流和推流接口备忘
  13. vbs格式编程教程基础
  14. 桔梗载药上浮 柴胡升于左 升麻生于右
  15. 【数据分析】重要环节--缺失值怎么处理
  16. 苹果账号申请流程——99刀(个人版或公司版 ),299刀(企业版)
  17. 区块链技术的风险!(转载)
  18. .NET MongoDB Driver GridFS 2.2原理及使用示例
  19. 【英语语法入门】第41讲 原形不定式(2)感官动词
  20. Odoo16正式版于2022年9月12日发布

热门文章

  1. 联想A590刷机方法
  2. UEditor编辑器保存数据到数据库
  3. 怎样做小游戏挖金子(VC,源码3)
  4. 插件学习笔记:搜索引擎ElasticSearch
  5. UNIX网络编程第三版
  6. BLUE引擎检查放入装备的名称全名脚本
  7. 如何设置qq支持临时会话聊天
  8. 如何使用dos启动mysql数据库_如何使用dos命令启动停止mysql数据库?
  9. matlab程序改写python3
  10. 如何使用monitor(DDMS)抓取traceview文件