练习利用LSTM实现手写数字分类任务

MNIST数据集中图片大小为28*28.

按照行进行展开成28维的特征向量。

考虑到这28个的向量之间存在着顺序依赖关系,我们可以将他们看成是一个长为28的输入序列,将其输入到LSTM中,LSTM可以从中提取到序列特征,再将此序列特征用一层全联接作为分类器,分类器输出10种分类类别。

综合代码

import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import fully_connectedimport input_data
mnist = input_data.read_data_sets('MNIST_data/',one_hot = True)
#one_hot = True 独热编码,类似[0,0,0,1,0,0,0,0,0,0]这种形式,等价于class=3n_inputs  = 28  #表示输入神经元的个数
n_steps   = 28  #表示序列长度
n_neurons = 150 #表示LSTM中隐藏层和输出层神经元呢个数
n_outputs = 10  #是最终分类器输出的类别数,mnist数据集是10分类任务learning_rate = 0.01 #优化方法的学习率X = tf.placeholder(tf.float32,[None,n_steps,n_inputs])
Y_labels = tf.placeholder(tf.int32,[None,n_outputs])basic_cell = tf.contrib.rnn.BasicLSTMCell(n_neurons,forget_bias = 1.0, state_is_tuple = True)
#获取一层LSTM网络,参数1是每个cell的输出神经元个数,参数2是遗忘的偏置,参数3表示双状态outneurons, states = tf.nn.dynamic_rnn(basic_cell,X,dtype = tf.float32)
#outneurons得到了输出序列logits = fully_connected(tf.transpose(outneurons,perm = [1,0,2])[-1], n_outputs,activation_fn = None)
#在这里由于outneurons的维度为[batch_size,n_steps,n_inputs]的形式,而我们只需要最后一个cell对于所有batch的输出,因此把前两个维度调换一下,再取用[-1]取到最后一个cell对于所有batch的输出。shape为[batch_size,n_inputs]
#将其接到一层全连接网络作为分类器得到logitscross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels = Y_labels,logits = logits)
loss = tf.reduce_mean(cross_entropy)
#对logits用softmax做归一化,计算其对于样本labels的交叉熵的和,取均值作为损失函数lossoptimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
trainop = optimizer.minimize(loss)
#申请一个优化器,用来最后小化损失函数losscorrect = tf.equal(tf.argmax(logits,1),tf.argmax(Y_labels,1))
#分析正确率accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))batch_size = 64
init = tf.global_variables_initializer()
with tf.Session() as sess:init.run()for i in range(10000):x_batch, y_batch = mnist.train.next_batch(batch_size)x_batch = x_batch.reshape([-1,n_steps,n_inputs])sess.run(trainop,feed_dict = {X : x_batch,Y_labels : y_batch})if i % 200 == 0:print('train accuracy =',sess.run(accuracy,feed_dict = {X : x_batch,Y_labels : y_batch}))X_test = mnist.test.images.reshape((-1,n_steps,n_inputs))Y_test = mnist.test.labelsprint('test accuracy =',sess.run(accuracy,feed_dict = {X : X_test,Y_labels : Y_test}))

评估

实验表明求得得准确率可达到99%。

疑问

我将BasicLSTMCell换成BasicRNNCell就无法训练,这是为什么呢?难道跟LSTM有遗忘们相关吗?

练习利用LSTM实现手写数字分类任务相关推荐

  1. pytorch利用rnn通过sin预测cos 利用lstm预测手写数字

    一.利用rnn通过sin预测cos 1.首先可视化一下数据 import numpy as np from matplotlib import pyplot as plt def show(sin_n ...

  2. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  3. 机器学习算法(九): 基于线性判别LDA模型的分类(基于LDA手写数字分类实践)

    机器学习算法(九): 基于线性判别模型的分类 1.前言:LDA算法简介和应用 1.1.算法简介 线性判别模型(LDA)在模式识别领域(比如人脸识别等图形图像识别领域)中有非常广泛的应用.LDA是一种监 ...

  4. 利用CNN进行手写数字识别

    资源下载地址:https://download.csdn.net/download/sheziqiong/85884967 资源下载地址:https://download.csdn.net/downl ...

  5. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  6. 如何为MNIST手写数字分类开发CNN

    导言 MNIST手写数字分类问题是计算机视觉和深度学习中使用的标准数据集. 虽然数据集得到了有效的解决,但它可以作为学习和实践如何开发,评估和使用卷积深度学习神经网络从头开始进行图像分类的基础.这包括 ...

  7. 独家 | 如何从头开始为MNIST手写数字分类建立卷积神经网络(附代码)

    翻译:张睿毅 校对:吴金笛 本文约9300字,建议阅读20分钟. 本文章逐步介绍了卷积神经网络的建模过程,最终实现了MNIST手写数字分类. MNIST手写数字分类问题是计算机视觉和深度学习中使用的标 ...

  8. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

  9. Keras入门实战(1):MNIST手写数字分类

    目录 1)首先我们加载Keras中的数据集 2)网络架构 3)选择编译(compile参数) 4)准备图像数据 5) 训练模型 6)测试数据 前面的博客中已经介绍了如何在Ubuntu下安装Keras深 ...

最新文章

  1. wxWidgets:调试 WxWindow 应用程序
  2. dotnet core 数据库
  3. MSSQL 标量函数
  4. 计算机上机模拟试题答案,2016计算机二级上机模拟试题及答案
  5. 《深入PHP:面向对象、模式与实践》(二)
  6. Tomcat打包时多项目共享jar和精确指定jar版本
  7. C/C++ OpenCV之Canny边缘检测
  8. python3的文件读写模式
  9. Java成神之路——UML类关系图
  10. 获取当前日期是本年的第几周java与mysql获取值不一致
  11. 计算机编程的 20 年变迁!
  12. sublime Text 2使用小技巧
  13. 瑞友天翼 v5.1.0.6 远程打印跳行、跳页、错位问题解决方法
  14. 《老路用得上的商学课》36-40学习笔记
  15. 机器学习中的多分类任务详解
  16. iOS-CYLTabBarController【好用的TabbarController】
  17. 华为OD社招Java岗面经,已OFFER
  18. 如何使用sz命令下载较大文件到本地
  19. 摘:一张废手机卡的作用
  20. STM32 I2C通信操作24C02写数据、读数据

热门文章

  1. bootstrap grid php,bootstrap grid用法
  2. 关于前端性能优化问题,认识网页加载过程和防抖节流
  3. And Then There Was One POJ - 3517(变形约瑟夫环+规律)
  4. linux jdk1.7 tomcat mysql_Linux环境搭建 jdk+tomcat+mysql
  5. 常用决策树集成模型Random Forest、Adaboost、GBDT详解
  6. CoreJava 笔记总结-第三章 Java的基本程序设计结构
  7. 牛客小白月赛12:月月给华华出题(欧拉函数)
  8. Codeforces Global Round 12 C1 C2. Errich-Tac-Toe 思维构造 好题
  9. M - Kill the tree 计蒜客 - 42552(2019icpc徐州/树的重心/树形dp)
  10. 字符串hash(一)