2019独角兽企业重金招聘Python工程师标准>>>

本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间。

初始化

首先我们可以先初始化一些变量,如学习率、节点单元数、RNN 层数等:

learning_rate = 1e-3
num_units = 256
num_layer = 3
input_size = 28
time_step = 28
total_steps = 2000
category_num = 10
steps_per_validate = 100
steps_per_test = 500
batch_size = tf.placeholder(tf.int32, [])
keep_prob = tf.placeholder(tf.float32, [])

然后还需要声明一下 MNIST 数据生成器:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

接下来常规声明一下输入的数据,输入数据用 x 表示,标注数据用 y_label 表示:

x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])

这里输入的 x 维度是 [None, 784],代表 batch_size 不确定,输入维度 784,y_label 同理。

接下来我们需要对输入的 x 进行 reshape 操作,因为我们需要将一张图分为多个 time_step 来输入,这样才能构建一个 RNN 序列,所以这里直接将 time_step 设成 28,这样一来 input_size 就变为了 28,batch_size 不变,所以reshape 的结果是一个三维的矩阵:

x_shape = tf.reshape(x, [-1, time_step, input_size])

RNN 层

接下来我们需要构建一个 RNN 模型了,这里我们使用的 RNN Cell 是 LSTMCell,而且要搭建一个三层的 RNN,所以这里还需要用到 MultiRNNCell,它的输入参数是 LSTMCell 的列表。

所以我们可以先声明一个方法用于创建 LSTMCell,方法如下:

def cell(num_units):cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)return DropoutWrapper(cell, output_keep_prob=keep_prob)

这里还加入了 Dropout,来减少训练过程中的过拟合。

接下来我们再利用它来构建多层的 RNN:

cells = tf.nn.rnn_cell.MultiRNNCell([cell(num_units) for _ in range(num_layer)])

注意这里使用了 for 循环,每循环一次新生成一个 LSTMCell,而不是直接使用乘法来扩展列表,因为这样会导致 LSTMCell 是同一个对象,导致构建完 MultiRNNCell 之后出现维度不匹配的问题。

接下来我们需要声明一个初始状态:

h0 = cells.zero_state(batch_size, dtype=tf.float32)

然后接下来调用 dynamic_rnn() 方法即可完成模型的构建了:

output, hs = tf.nn.dynamic_rnn(cells, inputs=x_shape, initial_state=h0)

这里 inputs 的输入就是 x 做了 reshape 之后的结果,初始状态通过 initial_state 传入,其返回结果有两个,一个 output 是所有 time_step 的输出结果,赋值为 output,它是三维的,第一维长度等于 batch_size,第二维长度等于 time_step,第三维长度等于 num_units。另一个 hs 是隐含状态,是元组形式,长度即 RNN 的层数 3,每一个元素都包含了 c 和 h,即 LSTM 的两个隐含状态。

这样的话 output 的最终结果可以取最后一个 time_step 的结果,所以可以使用:

output = output[:, -1, :]

或者直接取隐藏状态最后一层的 h 也是相同的:

h = hs[-1].h

在此模型中,二者是等价的。但注意如果用于文本处理,可能由于文本长度不一,而 padding,导致二者不同。

输出层

接下来我们再做一次线性变换和 Softmax 输出结果即可:

# Output Layer
w = tf.Variable(tf.truncated_normal([num_units, category_num], stddev=0.1), dtype=tf.float32)
b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32)
y = tf.matmul(output, w) + b
# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)

这里的 Loss 直接调用了 softmax_cross_entropy_with_logits 先计算了 Softmax,然后计算了交叉熵。

训练和评估

最后再定义训练和评估的流程即可,在训练过程中每隔一定的 step 就输出 Train Accuracy 和 Test Accuracy:

# Train
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))# Train
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for step in range(total_steps + 1):batch_x, batch_y = mnist.train.next_batch(100)sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, batch_size: batch_x.shape[0]})# Train Accuracyif step % steps_per_validate == 0:print('Train', step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5,batch_size: batch_x.shape[0]}))# Test Accuracyif step % steps_per_test == 0:test_x, test_y = mnist.test.images, mnist.test.labelsprint('Test', step,sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]}))

运行

直接运行之后,只训练了几轮就可以达到 98% 的准确率:

Train 0 0.27
Test 0 0.2223
Train 100 0.87
Train 200 0.91
Train 300 0.94
Train 400 0.94
Train 500 0.99
Test 500 0.9595
Train 600 0.95
Train 700 0.97
Train 800 0.98

可以看出来 LSTM 在做 MNIST 字符分类的任务上还是比较有效的。

转载于:https://my.oschina.net/u/3720876/blog/1632604

芝麻HTTP:TensorFlow LSTM MNIST分类相关推荐

  1. TensorFlow使用--MNIST分类学习(BP神经网络)

    目录 测试结果:根据测试集和验证集的验证,训练好的模型识别率可以达到96% 代码块 代码分块解析 保存训练好的神经网络并调用 小测试:将MNIST数据集中的图片显示出来 传送门 激活函数相关 soft ...

  2. TensorFlow使用--MNIST分类学习入门(感知机)

    目录 MNIST简单介绍 代码块 解决问题: 代码剖析: 学习传送门: MNIST简单介绍 MNIST是一组经过预处理的手写数字图片数据集,其中每个样本都是一张长28.宽28的灰度图片,其中包含一个0 ...

  3. TensorFlow实现mnist书写数字分类,出现please use urllib or similar directly错误。

    TensorFlow实现mnist书写数字分类, 在使用 from tensorflow.examples.tutorials.mnist import input_data mnist = inpu ...

  4. TensorFlow的MNIST手写数字分类问题

    一.简介MNIST TensorFlow编程学习的入门一般都是基于MNIST手写数字数据集和Cifar(包括cifar-10和cifar-100)数据集,因为它们都比较小,一般的设备即可进行训练和测试 ...

  5. 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门

    2017年2月16日,Google正式对外发布Google TensorFlow 1.0版本,并保证本次的发布版本API接口完全满足生产环境稳定性要求.这是TensorFlow的一个重要里程碑,标志着 ...

  6. 基于tensorflow的MNIST手写字识别

    一.卷积神经网络模型知识要点卷积卷积 1.卷积 2.池化 3.全连接 4.梯度下降法 5.softmax 本次就是用最简单的方法给大家讲解这些概念,因为具体的各种论文网上都有,连推导都有,所以本文主要 ...

  7. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  8. 深度学习之循环神经网络(11-a)LSTM情感分类问题代码

    深度学习之循环神经网络(11-a)LSTM情感分类问题代码 1. Cell方式 代码 运行结果 2. 层方式 代码 运行结果 1. Cell方式 代码 import os import tensorf ...

  9. tensorflow saver_机器学习入门(6):Tensorflow项目Mnist手写数字识别-分析详解

    本文主要内容:Ubuntu下基于Tensorflow的Mnist手写数字识别的实现 训练数据和测试数据资料:http://yann.lecun.com/exdb/mnist/ 前面环境都搭建好了,直接 ...

  10. lstm 文本分类_带有lstm和单词嵌入的灾难推文上的文本分类

    lstm 文本分类 This was my first Kaggle notebook and I thought why not write it on Medium too? Ť他是我第一次Kag ...

最新文章

  1. c#总结最近的几项重要代码
  2. 被强制007加班!他愤而把加班证据发给国外供应商和大客户!客户纷纷取消订单!他去度假,老板气疯!...
  3. 中科大提出统一输入过滤框架InFi:首次理论分析可过滤性,支持全数据模态
  4. 2018年强化学习领域十篇重要论文(附源码)
  5. 【转】IOS的各种后台情况的实现
  6. 缓存框架(Java缓存)与框架缓存(介绍mybatis缓存)
  7. python最简单的架构_Python实现简单状态框架的方法
  8. T-SQL语句学习(三)
  9. python——Tkinter图形化界面及threading多线程
  10. Spark的新方案UnifiedMemoryManager内存管理模型分析
  11. 吊炸天的Kubernetes微服务图形化管理工具:Kuboard,必须推荐给你!
  12. 数字化转型没有银弹,不破不立,如何破?如何立?
  13. 找回foxmail下的邮箱密码
  14. 云计算技术概述与入门
  15. java md5加密64位_MD5加密的Java实现
  16. 影响Google Adsense广告单价高低的因素分析获取更高的收入
  17. JavaEE企业级实战项目 智牛股第五天 Netty的使用和项目数据库搭建
  18. ${}和`${}`的用法
  19. 基于STM32F103 HAL库 MB85RS128 驱动程序
  20. Scala编程——下界介绍与实例分析

热门文章

  1. 2021-09-1017. 电话号码的字母组合
  2. 14Penrose广义逆(II)
  3. 代价函数的作用(2)--机器学习
  4. openCV,C++接口,cv::Mat矩阵数据元素读取
  5. 最新Activity与Fragment完全理解
  6. ppt生成eps文件_eps是什么格式怎么打开?全面解析图片的eps是什么格式
  7. java mysql大小写_java – 使用select where where Mysql在Mysql中区分大小写
  8. 关于Redis启动时报权限不够(-bash: /usr/local/bin/redis-server: Permission denied)
  9. Mybatis和Spring整合逆向工程
  10. android 模仿今日头条ViewPager+TabLayout