使用MNIST数据集,在TensorFlow上实现基础LSTM网络

By 路雪2017年9月29日 13:39

本文介绍了如何在 TensorFlow 上实现基础 LSTM 网络的详细过程。作者选用了 MNIST 数据集,本文详细介绍了实现过程。

长短期记忆(LSTM)是目前循环神经网络最普遍使用的类型,在处理时间序列数据时使用最为频繁。关于 LSTM 的更加深刻的洞察可以看看这篇优秀的博客:http://colah.github.io/posts/2015-08-Understanding-LSTMs/。

我们的目的

这篇博客的主要目的就是使读者熟悉在 TensorFlow 上实现基础 LSTM 网络的详细过程。

我们将选用 MNIST 作为数据集。

  1. from tensorflow.examples.tutorials.mnist import input_data

  2. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

MNIST 数据集

MNIST 数据集包括手写数字的图像和对应的标签。我们可以根据以下内置功能从 TensorFlow 上下载并读取数据。

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

数据被分成 3 个部分:

1. 训练数据(mnist.train):55000 张图像

2. 测试数据(mnist.test):10000 张图像

3. 验证数据(mnist.validation):5000 张图像

数据的形态

讨论一下 MNIST 数据集中的训练数据的形态。数据集的这三个部分的形态都是一样的。

训练数据集包括 55000 张 28x28 像素的图像,这些 784(28x28)像素值被展开成一个维度为 784 的单一向量,所有 55000 个像素向量(每个图像一个)被储存为形态为 (55000,784) 的 numpy 数组,并命名为 mnist.train.images。

所有这 55000 张图像都关联了一个类别标签(表示其所属类别),一共有 10 个类别(0,1,2...9),类别标签使用独热编码的形式表示。因此标签将作为形态为 (55000,10) 的数组保存,并命名为 mnist.train.labels。

为什么要选择 MNIST?

LSTM 通常用来解决复杂的序列处理问题,比如包含了 NLP 概念(词嵌入、编码器等)的语言建模问题。这些问题本身需要大量理解,那么将问题简化并集中于在 TensorFlow 上实现 LSTM 的细节(比如输入格式化、LSTM 单元格以及网络结构设计),会是个不错的选择。

MNIST 就正好提供了这样的机会。其中的输入数据是一个像素值的集合。我们可以轻易地将其格式化,将注意力集中在 LSTM 实现细节上。

实现

在动手写代码之前,先规划一下实现的蓝图,可以使写代码的过程更加直观。

VANILLA RNN

循环神经网络按时间轴展开的时候,如下图所示:

图中:

1.x_t 代表时间步 t 的输入;

2.s_t 代表时间步 t 的隐藏状态,可看作该网络的「记忆」;

3.o_t 作为时间步 t 时刻的输出;

4.U、V、W 是所有时间步共享的参数,共享的重要性在于我们的模型在每一时间步以不同的输入执行相同的任务。

当把 RNN 展开的时候,网络可被看作每一个时间步都受上一时间步输出影响(时间步之间存在连接)的前馈网络。

两个注意事项

为了更顺利的进行实现,需要清楚两个概念的含义:

1.TensorFlow 中 LSTM 单元格的解释;

2. 数据输入 TensorFlow RNN 之前先格式化。

TensorFlow 中 LSTM 单元格的解释

在 TensorFlow 中,基础的 LSTM 单元格声明为:

  1. tf.contrib.rnn.BasicLSTMCell(num_units)

这里,num_units 指一个 LSTM 单元格中的单元数。num_units 可以比作前馈神经网络中的隐藏层,前馈神经网络的隐藏层的节点数量等于每一个时间步中一个 LSTM 单元格内 LSTM 单元的 num_units 数量。下图可以帮助直观理解:

每一个 num_units LSTM 单元都可以看作一个标准的 LSTM 单元:

以上图表来自博客(地址:http://colah.github.io/posts/2015-08-Understanding-LSTMs/),该博客有效介绍了 LSTM 的概念。
数据输入 TensorFlow RNN 之前先格式化

在 TensorFlow 中最简单的 RNN 形式是 static_rnn,在 TensorFlow 中定义如下:

  1. tf.static_rnn(cell,inputs)

虽然还有其它的注意事项,但在这里我们仅关注这两个。

inputs 引数接受形态为 [batch_size,input_size] 的张量列表。列表的长度为将网络展开后的时间步数,即列表中每一个元素都分别对应网络展开的时间步。比如在 MNIST 数据集中,我们有 28x28 像素的图像,每一张都可以看成拥有 28 行 28 个像素的图像。我们将网络按 28 个时间步展开,以使在每一个时间步中,可以输入一行 28 个像素(input_size),从而经过 28 个时间步输入整张图像。给定图像的 batch_size 值,则每一个时间步将分别收到 batch_size 个图像。详见下图说明:

由 static_rnn 生成的输出是一个形态为 [batch_size,n_hidden] 的张量列表。列表的长度为将网络展开后的时间步数,即每一个时间步输出一个张量。在这个实现中我们只需关心最后一个时间步的输出,因为一张图像的所有行都输入到 RNN,预测即将在最后一个时间步生成。

现在,所有的困难部分都已经完成,可以开始写代码了。只要理清了概念,写代码过程是很直观的。

代码

在开始的时候,先导入一些必要的依赖关系、数据集,并声明一些常量。设定 batch_size=128 、 num_units=128。

  1. import tensorflow as tf

  2. from tensorflow.contrib import rnn

  3. #import mnist dataset

  4. from tensorflow.examples.tutorials.mnist import input_data

  5. mnist=input_data.read_data_sets("/tmp/data/",one_hot=True)

  6. #define constants

  7. #unrolled through 28 time steps

  8. time_steps=28

  9. #hidden LSTM units

  10. num_units=128

  11. #rows of 28 pixels

  12. n_input=28

  13. #learning rate for adam

  14. learning_rate=0.001

  15. #mnist is meant to be classified in 10 classes(0-9).

  16. n_classes=10

  17. #size of batch

  18. batch_size=128

现在设置占位、权重以及偏置变量(用于将输出的形态从 [batch_size,num_units] 转换为 [batch_size,n_classes]),从而可以预测正确的类别。

  1. #weights and biases of appropriate shape to accomplish above task

  2. out_weights=tf.Variable(tf.random_normal([num_units,n_classes]))

  3. out_bias=tf.Variable(tf.random_normal([n_classes]))

  4. #defining placeholders

  5. #input image placeholder

  6. x=tf.placeholder("float",[None,time_steps,n_input])

  7. #input label placeholder

  8. y=tf.placeholder("float",[None,n_classes])

现在我们得到了形态为 [batch_size,time_steps,n_input] 的输入,我们需要将其转换成形态为 [batch_size,n_inputs] 、长度为 time_steps 的张量列表,从而可以将其输入 static_rnn。

  1. #processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors

  2. input=tf.unstack(x ,time_steps,1)

现在我们可以定义网络了。我们将利用 BasicLSTMCell 的一个层,将我们的 static_rnn 从中提取出来。

  1. #defining the network

  2. lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)

  3. outputs,_=rnn.static_rnn(lstm_layer,input,dtype="float32")

我们只考虑最后一个时间步的输入,从中生成预测。

  1. #converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication

  2. prediction=tf.matmul(outputs[-1],out_weights)+out_bias

定义损失函数、优化器和准确率。

  1. #loss_function

  2. loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))

  3. #optimization

  4. opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

  5. #model evaluation

  6. correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))

  7. accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

现在我们已经完成定义,可以开始运行了。

  1. #initialize variables

  2. init=tf.global_variables_initializer()

  3. with tf.Session() as sess:

  4.    sess.run(init)

  5.    iter=1

  6.    while iter<800:

  7.        batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)

  8.        batch_x=batch_x.reshape((batch_size,time_steps,n_input))

  9.        sess.run(opt, feed_dict={x: batch_x, y: batch_y})

  10.        if iter %10==0:

  11.            acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})

  12.            los=sess.run(loss,feed_dict={x:batch_x,y:batch_y})

  13.            print("For iter ",iter)

  14.            print("Accuracy ",acc)

  15.            print("Loss ",los)

  16.            print("__________________")

  17.        iter=iter+1

需要注意的是我们的每一张图像在开始时被平坦化为 784 维的单一向量,函数 next_batch(batch_size) 必须返回这些 784 维向量的 batch_size 批次数。因此它们的形态要被改造成 [batch_size,time_steps,n_input],从而可以被我们的占位符接受。

我们还可以计算模型的准确率:

  1. #calculating test accuracy

  2. test_data = mnist.test.images[:128].reshape((-1, time_steps, n_input))

  3. test_label = mnist.test.labels[:128]

  4. print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

在运行的时候,模型的测试准确率为 99.21%。

这篇博客旨在让读者熟悉 TensorFlow 中 RNN 的实现细节。我们将会在 TensorFlow 中建立更加复杂的模型以更有效的利用 RNN。敬请期待!

声明:本文由机器之心编译出品,原文来自GitHub,转载请查看要求,机器之心对于违规侵权者保有法律追诉权。

使用MNIST数据集,在TensorFlow上实现基础LSTM网络相关推荐

  1. mnist数据集在FATE上应用

    mnist数据集在FATE上应用 ** 一.下载mnist数据集 ** 我用阿里云盘分享了「MNIST」,复制这段内容打开「阿里云盘」App 即可获取 链接:https://www.aliyundri ...

  2. 深度学习4:使用MNIST数据集(tensorflow)

    本文将介绍MNIST数据集的数据格式和使用方法,使用到的是tensorflow中封装的类,包含代码. MNIST数据集来源于这里, 如果希望下载原始格式的数据集,可以从这里下载.而本文中讲解的是已经使 ...

  3. 【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

    「@Author:Runsen」 GAN 是使用两个神经网络模型训练的生成模型.一种模型称为生成网络模型,它学习生成新的似是而非的样本.另一个模型被称为判别网络,它学习区分生成的例子和真实的例子. 生 ...

  4. 英伟达RTX 3080值不值得抢?在TensorFlow上训练了卷积网络

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习> ...

  5. 从Tensorflow代码中理解LSTM网络

    目录 RNN LSTM 参考文档与引子 缩略词  RNN (Recurrent neural network) 循环神经网络  LSTM (Long short-term memory) 长短期记忆人 ...

  6. 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%

    基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...

  7. 读取mnist数据集方法大全(train-images-idx3-ubyte.gz,train-labels.idx1-ubyte等)(python读取gzip文件)

    文章目录 gzip包 keras读取mnist数据集 本地读取mnist数据集 下载数据集 解压读取 方法一 方法二 gzip包读取 读取bytes数据 注:import导入的包如果未安装使用pip安 ...

  8. 关于mnist数据集下载的相关问题

    文章目录 问题描述:在Tensorflow 2.0.1版本中下载mnist数据集 原因分析: 解决方案: 问题描述:在Tensorflow 2.0.1版本中下载mnist数据集 from tensor ...

  9. 科学家在类脑芯片上实现类似LSTM的功能,能效高1000倍

    来源:机器学习研究组订阅 格拉茨技术大学的计算机科学家在 Nature 子刊上发表的一篇论文表明,他们找到了一种在神经形态芯片上模拟 LSTM 的方案,可以让类脑神经形态芯片上的 AI 算法能效提高约 ...

最新文章

  1. zoj 1670 Jewels from Heaven
  2. ML之xgboost:利用xgboost算法(sklearn+7CrVa)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)
  3. 132. 小组队列【队列 模拟】
  4. PyQt4编程之简短地做出多个选择框
  5. iptables与tomcat
  6. bzoj 4515: [Sdoi2016]游戏
  7. 信息学奥赛一本通 1077:统计满足条件的4位数 | OpenJudge NOI 1.5 26
  8. JDBC解析9_UpdateWithResultSet
  9. mysql的engine不同,导致事物回滚失败的问题
  10. weblogic运行但局域网内无法访问控制台
  11. [转载]计算机端口详解
  12. android默认打开方式修改,教你修改安卓智能手机默认打开方式
  13. mysql 判断质数_质数(素数)判断算法总结
  14. 浅谈网银USB key使用原理与安全策略
  15. 使用Python Snap7读取西门子触摸板 Dint LReal(int double)数据
  16. table 表格边框线去重
  17. spring data jpa 出现Not a managed type
  18. 2010-2019中国企业所有跨国并购数据
  19. eclipse快捷键以及使用技巧大全
  20. 圣天诺LDK - HL

热门文章

  1. 用RAM存储器构造能够依次读取各存储单元内容的电路
  2. java jtextfield设置不可见_java – JPanel设置为不可见,除默认值之外的组合框选择将其设置为可见,但组件丢失...
  3. java 8 string_String.join() --Java8中String类新增方法
  4. java手动输入函数_Java程序如何添加一个函数,如果玩家输入错误,它将返回代码。...
  5. tf.arg_max
  6. linux 可执行文件_linux中ELF二进制程序解析
  7. 数据中台实战(八):如何打造可以支撑N条产品线的标签平台
  8. 深度学习核心技术精讲100篇(五十八)- 如何量化医学图像分割中的置信度?
  9. 经典!MATLAB线性等分linspace()函数,精确等分点数
  10. 强化学习(十九) AlphaGo Zero强化学习原理