之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别。
而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时间轴上慢慢展开,有点类似我们大脑认识事物时会有相关的短期记忆。

这次我们使用RNN来识别手写数字。

首先导入数据并定义各种RNN的参数:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data# 导入数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)# RNN各种参数定义
lr = 0.001 #学习速率
training_iters = 100000 #循环次数
batch_size = 128
n_inputs = 28 #手写字的大小是28*28,这里是手写字中的每行28列的数值
n_steps = 28 #这里是手写字中28行的数据,因为以一行一行像素值处理的话,正好是28行
n_hidden_units = 128 #假设隐藏单元有128个
n_classes = 10 #因为我们的手写字是0-9,因此最后要分成10个类

接着定义输入、输出以及各权重的形状:

# 定义输入和输出的placeholder
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])# 对weights和biases初始值定义
weights = {# shape(28, 128)'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),# shape(128 , 10)'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}biases = {# shape(128, )'in':tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),# shape(10, )'out':tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}

定义 RNN 的主体结构

最主要的就是定义RNN的主体结构。

def RNN(X, weights, biases):# X在输入时是一批128个,每批中有28行,28列,因此其shape为(128, 28, 28)。为了能够进行 weights 的矩阵乘法,我们需要把输入数据转换成二维的数据(128*28, 28)X = tf.reshape(X, [-1, n_inputs])# 对输入数据根据权重和偏置进行计算, 其shape为(128batch * 28steps, 128 hidden)X_in = tf.matmul(X, weights['in']) + biases['in']# 矩阵计算完成之后,又要转换成3维的数据结构了,(128batch, 28steps, 128 hidden)X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])# cell,使用LSTM,其中state_is_tuple用来指示相关的state是否是一个元组结构的,如果是元组结构的话,会在state中包含主线状态和分线状态lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)# 初始化全0stateinit_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)# 下面进行运算,我们使用dynamic rnn来进行运算。每一步的运算输出都会存储在outputs中,states中存储了主线状态和分线状态,因为我们前面指定了state_is_tuple=True# time_major用来指示关于时间序列的数据是否在输入数据中第一个维度中。在本例中,我们的时间序列数据位于第2维中,第一维的数据只是batch数据,因此要设置为False。outputs, states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)# 计算结果,其中states[1]为分线state,也就是最后一个输出值results = tf.matmul(states[1], weights['out']) + biases['out']return results

训练RNN

定义好了 RNN 主体结构后, 我们就可以来计算 cost 和 train_op:

pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)

训练时, 不断输出 accuracy, 观看结果:

correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)step = 0while step*batch_size < training_iters:batch_xs,batch_ys = mnist.train.next_batch(batch_size)batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])sess.run([train_op], feed_dict={x:batch_xs, y:batch_ys})if step % 20 == 0:print(sess.run(accuracy, feed_dict={x:batch_xs, y:batch_ys}))step += 1

最终 accuracy 的结果如下:

E:\Python\Python36\python.exe E:/learn/numpy/lesson3/main.py
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
2018-02-20 20:30:52.769108: I C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\36\tensorflow\core\platform\cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX
0.09375
0.710938
0.8125
0.789063
0.820313
0.882813
0.828125
0.867188
0.921875
0.90625
0.921875
0.890625
0.898438
0.945313
0.914063
0.945313
0.929688
0.96875
0.96875
0.929688
0.953125
0.945313
0.960938
0.992188
0.953125
0.9375
0.929688
0.96875
0.960938
0.945313

完整代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data# 导入数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)# RNN各种参数定义
lr = 0.001 #学习速率
training_iters = 100000 #循环次数
batch_size = 128
n_inputs = 28 #手写字的大小是28*28,这里是手写字中的每行28列的数值
n_steps = 28 #这里是手写字中28行的数据,因为以一行一行像素值处理的话,正好是28行
n_hidden_units = 128 #假设隐藏单元有128个
n_classes = 10 #因为我们的手写字是0-9,因此最后要分成10个类# 定义输入和输出的placeholder
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])# 对weights和biases初始值定义
weights = {# shape(28, 128)'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),# shape(128 , 10)'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}biases = {# shape(128, )'in':tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),# shape(10, )'out':tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}def RNN(X, weights, biases):# X在输入时是一批128个,每批中有28行,28列,因此其shape为(128, 28, 28)。为了能够进行 weights 的矩阵乘法,我们需要把输入数据转换成二维的数据(128*28, 28)X = tf.reshape(X, [-1, n_inputs])# 对输入数据根据权重和偏置进行计算, 其shape为(128batch * 28steps, 128 hidden)X_in = tf.matmul(X, weights['in']) + biases['in']# 矩阵计算完成之后,又要转换成3维的数据结构了,(128batch, 28steps, 128 hidden)X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])# cell,使用LSTM,其中state_is_tuple用来指示相关的state是否是一个元组结构的,如果是元组结构的话,会在state中包含主线状态和分线状态lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)# 初始化全0stateinit_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)# 下面进行运算,我们使用dynamic rnn来进行运算。每一步的运算输出都会存储在outputs中,states中存储了主线状态和分线状态,因为我们前面指定了state_is_tuple=True# time_major用来指示关于时间序列的数据是否在输入数据中第一个维度中。在本例中,我们的时间序列数据位于第2维中,第一维的数据只是batch数据,因此要设置为False。outputs, states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)# 计算结果,其中states[1]为分线state,也就是最后一个输出值results = tf.matmul(states[1], weights['out']) + biases['out']return resultspred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)step = 0while step*batch_size < training_iters:batch_xs,batch_ys = mnist.train.next_batch(batch_size)batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])sess.run([train_op], feed_dict={x:batch_xs, y:batch_ys})if step % 20 == 0:print(sess.run(accuracy, feed_dict={x:batch_xs, y:batch_ys}))step += 1

转载于:https://www.cnblogs.com/dreampursuer/p/8231770.html

tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】相关推荐

  1. RNN循环神经网络的直观理解:基于TensorFlow的简单RNN例子

    RNN 直观理解 一个非常棒的RNN入门Anyone Can learn To Code LSTM-RNN in Python(Part 1: RNN) 基于此文章,本文给出我自己的一些愚见 基于此文 ...

  2. RNN LSTM 循环神经网络 (分类例子)

    学习资料: 相关代码 为 TF 2017 打造的新版可视化教学代码 机器学习-简介系列 什么是RNN 机器学习-简介系列 什么是LSTM RNN 本代码基于网上这一份代码 code 设置 RNN 的参 ...

  3. RNN循环神经网络的自我理解:基于Tensorflow的简单句子使用(通俗理解RNN)

    解读tensorflow之rnn: 该开始接触RNN我们都会看到这样的张图:  如上图可以看到每t-1时的forward的结果和t时的输入共同作为这一次forward的输入 所以RNN存在一定的弊端, ...

  4. [译] RNN 循环神经网络系列 2:文本分类

    原文地址:RECURRENT NEURAL NETWORKS (RNN) – PART 2: TEXT CLASSIFICATION 原文作者:GokuMohandas 译文出自:掘金翻译计划 本文永 ...

  5. 神经网络学习小记录2——利用tensorflow构建循环神经网络(RNN)

    神经网络学习小记录2--利用tensorflow构建循环神经网络(RNN) 学习前言 RNN简介 tensorflow中RNN的相关函数 tf.nn.rnn_cell.BasicLSTMCell tf ...

  6. PyTorch-09 循环神经网络RNNLSTM (时间序列表示、RNN循环神经网络、RNN Layer使用、时间序列预测案例、RNN训练难题、解决梯度离散LSTM、LSTM使用、情感分类问题实战)

    PyTorch-09 循环神经网络RNN&LSTM (时间序列表示.RNN循环神经网络.RNN Layer使用.时间序列预测案例(一层的预测点的案例).RNN训练难题(梯度爆炸和梯度离散)和解 ...

  7. Recurrent Neural Networks(RNN) 循环神经网络初探

    1. 针对机器学习/深度神经网络"记忆能力"的讨论 0x1:数据规律的本质是能代表此类数据的通用模式 - 数据挖掘的本质是在进行模式提取 数据的本质是存储信息的介质,而模式(pat ...

  8. 【机器学习】RNN循环神经网络

    循环神经网络归属: 领域:机器学习 方向:自然语言处理 贡献:自动文本生成 循环神经网络实际应用: 生活中因为原始数据都是序列化的,比如自然语言,语音处理,时间序列问题(股票价格)等问题, 这个时候需 ...

  9. RNN 循环神经网络系列 5: 自定义单元

    原文地址:RECURRENT NEURAL NETWORK (RNN) – PART 5: CUSTOM CELLS 原文作者:GokuMohandas 译文出自:掘金翻译计划 本文永久链接:gith ...

最新文章

  1. rdp连接工具_如何在Windows10中清除RDP连接历史记录?
  2. STM32硬件错误HardFault_Handler的处理方法
  3. asp.net mvc自动完成
  4. 计算机一级查找同类型文件,如何快捷找出电脑内的重复文件
  5. Linux - 简单设置 vim (tab, 行号, 换行)
  6. 危险!!!也许你的web网站或服务正在悄无声息地被SQL注入
  7. iptables,haproxy转发ftp(21端口)
  8. Response.End()在Webform和ASP.NET MVC下的表现差异
  9. python彩票预测与分析_130期祥子大乐透预测奖号:后区大小分析
  10. c++删除文件delete_关于macos删除快捷键,你知道这些区别用法吗?
  11. 下载网页中内嵌的PDF
  12. html打开显示不全,打开浏览器网页显示不全 网页显示不正常解决方法
  13. [离散数学]真值表法求主析、合取范式(Java实现)
  14. YOLOV5网络结构
  15. 学习《华为基本法》(13):市场营销
  16. css实现波浪球效果图,用css实现圆形波浪效果图
  17. 越南支付 移动快速增长 Fpay
  18. python 递增递减数列
  19. Direct3D 11 总结 —— 6 绘制图片
  20. 批量爬取百度贴吧里的标题及链接

热门文章

  1. PHP访问MySQL数据库
  2. Linux多线程编程(一)---多线程基本编程
  3. ARM微处理器的体系结构
  4. matlab中对伺服电机,基于Matlab的伺服电机Modbus通讯研究
  5. 使用钉钉接收gitlab仓库的推送消息
  6. Redis的设计与实现之整数集合和压缩列表
  7. 共享智能指针编程实验
  8. c++关联容器的容器操作(和顺序容器都支持的操作)详细解释,基础于c++primer 5th 表 9.2 (持续更新)
  9. yum -y install与yum install有什么不同
  10. 1976年,提出公钥密码体制概念的学者