继上一篇用简单的卷积神经网络做mnist分类之后,本篇文章采用RNN替换CNN写了一个mnist分类实例。实例中包含两个文件:

train.py:数据加载和训练代码。

# coding=utf-8
import tensorflow as tf
import os
import model
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('dataset/', one_hot=True)tf.app.flags.DEFINE_integer('sequence_step', 28, 'step of input sequence')
tf.app.flags.DEFINE_integer('vector_size', 28, 'length of input vector')
tf.app.flags.DEFINE_integer('num_classes', 10, 'num of class')
tf.app.flags.DEFINE_float('lr', 0.001, 'learning rate')
tf.app.flags.DEFINE_integer('batch_size', 32, 'batch size')
tf.app.flags.DEFINE_integer('epochs', 50, 'num of epoch')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
tf.app.flags.DEFINE_boolean('continue_training', False, 'continue')
FLAGS = tf.app.flags.FLAGSdef main(_):input = tf.placeholder(dtype=tf.float32, shape=[None, FLAGS.sequence_step * FLAGS.vector_size])output = tf.placeholder(dtype=tf.int32, shape=[None, 10])# control GPU resource utilizationconfig = tf.ConfigProto(allow_soft_placement=True)config.gpu_options.allow_growth = Truesess = tf.Session(config=config)# networklogits = model.build_rnn(input, FLAGS.sequence_step, FLAGS.vector_size, FLAGS.batch_size)# losscross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))# optimitertrain_op = tf.train.AdamOptimizer().minimize(cross_entropy)# evaluationcorrect_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))with sess.as_default():# initialsaver = tf.train.Saver(max_to_keep=1000)sess.run(tf.global_variables_initializer())# Restore weights fileif FLAGS.continue_training:saver.restore(sess, FLAGS.checkpoints)# begin trainfor epoch in range(FLAGS.epochs):for k in range(int(mnist.train.num_examples / FLAGS.batch_size)):train_image, train_label = mnist.train.next_batch(FLAGS.batch_size)train_image = train_image / 255.0_, network, loss, acc = sess.run([train_op, logits, cross_entropy, accuracy], feed_dict={input: train_image, output: train_label})print('loss : %f accuracy : %f' % (loss, acc))test_image = mnist.test.images / 255.0#test_image = test_image.reshape([-1, FLAGS.sequence_step, FLAGS.vector_size])test_label = mnist.test.labelsindices = np.arange(len(test_image))np.random.shuffle(indices)test_index = indices[0:FLAGS.batch_size]print('精确率:', accuracy.eval({input: test_image[test_index], output: test_label[test_index]}))# Create directories if neededif not os.path.isdir("checkpoints"):os.makedirs("checkpoints")saver.save(sess, "%s/model.ckpt" % ("checkpoints"))if __name__ == '__main__':tf.app.run()

训练部分基本和上一篇CNN分类相同。

model.py:搭建了一个简单的循环神经网络,RNN的输入和CNN略有不同,针对28×28的图片,CNN采用[batch, height, width, channel]的tensor形式; RNN网络的输入将28行看成28个时间序列,每一个时间序列的输入是[1,28]。本例子的lstm单元设置128个节点。

import tensorflow as tf
import tensorflow.contrib.rnndef weight_variable(shape, stddev=0.1):initial = tf.truncated_normal(shape=shape, stddev=stddev)return tf.Variable(initial)def bias_variable(shape, alpha=0.1):initial = tf.constant(shape=shape, value=alpha)return tf.Variable(initial)def build_rnn(inputs, sequence_size, vector_size, batch_size):weights = {# shape (28, 128)'in': tf.Variable(tf.random_normal([vector_size, 128])),# shape (128, 10)'out': tf.Variable(tf.random_normal([128, 10]))}biases = {# shape (128, )'in': tf.Variable(tf.constant(0.1, shape=[128, ])),# shape (10, )'out': tf.Variable(tf.constant(0.1, shape=[10, ]))}inputs = tf.reshape(inputs, [-1, vector_size])inputs = tf.matmul(inputs, weights['in']) + biases['in']inputs = tf.reshape(inputs, [-1, sequence_size, 128])# lstm cell.lstm_cell = tf.contrib.rnn.BasicLSTMCell(128, forget_bias=1.0, state_is_tuple=True)# initinit_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)  # 初始化全零 state# Implement rnnoutputs, final_state = tf.nn.dynamic_rnn(lstm_cell, inputs, initial_state=init_state, time_major=False)results = tf.matmul(final_state[1], weights['out']) + biases['out']return results

运行结果:

loss : 0.000472 accuracy : 1.000000
loss : 0.000217 accuracy : 1.000000
loss : 0.000051 accuracy : 1.000000
loss : 0.000253 accuracy : 1.000000
loss : 0.000483 accuracy : 1.000000
精确率: 0.96875

如有相关问题,欢迎留言讨论。

tensorflow随笔——简单的循环神经网络分类实例相关推荐

  1. tensorflow随笔——简单的卷积神经网络分类实例

    下例是用tensorflow写的2层cnn+2层fc的一个卷积神经网络做mnist的分类例子,旨在简单明了,过一遍TF代码的分类流程. 实例只有两个文件: train.py:数据读取,模型训练. # ...

  2. 自然语言处理--Keras 实现LSTM循环神经网络分类 IMDB 电影评论数据集

    LSTM 对于循环网络的每一层都引入了状态(state)的概念,状态作为网络的记忆(memory).但什么是记忆呢?记忆将由一个向量来表示,这个向量与元胞中神经元的元素数量相同.记忆单元将是一个由 n ...

  3. 深度学习 -- TensorFlow(9)循环神经网络RNN

    目录 一.循环神经网络RNN介绍 二.Elman network && Jordan network 三.RNN的多种架构 1.一对一 2.多对一 3.多对多 4. 一对多 5.Seq ...

  4. 简单入门循环神经网络RNN:时间序列数据的首选神经网络

    更多深度文章,请关注:https://yq.aliyun.com/cloud 随着科学技术的发展以及硬件计算能力的大幅提升,人工智能已经从几十年的幕后工作一下子跃入人们眼帘.人工智能的背后源自于大数据 ...

  5. TensorFlow HOWTO 5.1 循环神经网络(时间序列)

    5.1 循环神经网络(时间序列) 循环神经网络(RNN)用于建模带有时间关系的数据.它的架构是这样的. 在最基本的 RNN 中,单元(方框)中的操作和全连接层没什么区别,都是线性变换和激活.它完全可以 ...

  6. 自然语言处理--Keras 实现循环神经网络分类 IMDB 电影评论数据集

    那么为什么要使用 RNN 呢?不一定要选择循环神经网络,与前馈网络或卷积神经网络相比,它训练和传递新样本的成本相对较高(循环层计算成本较高). 但是循环网络在记忆能力方面的特殊优势即记住句子中出现过的 ...

  7. TensorFlow深度学习实战---循环神经网络

    循环神经网络(recurrent neural network,RNN)-------------------------重要结构(长短时记忆网络( long short-term memory,LS ...

  8. BPTT-应用于简单的循环神经网络

    上面是一组序列变量,即四个变量 z1,z2,z3,z4z_1, z_2, z_3, z_4 中的任一 ziz_i 的值均依赖于它前面的变量 z1,z2,..,zi−1z_1, z_2, .., z_{ ...

  9. tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】

    之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...

最新文章

  1. AttributeError: ‘FPDF‘ object has no attribute ‘unifontsubset‘
  2. Stimulsoft Reports.Net基础教程(十):创建图表报表②
  3. 网久环境服务启动命令
  4. 超有趣的几个Linux小命令
  5. 一步步实现:JPA的基本增删改查CRUD(jpa基于hibernate)
  6. MFC:ListControl数据修改
  7. 社交营销产品设计思考
  8. Oracle设置权限和还原数据库
  9. 置springboot自带tomcat的最大连接数和最大并发数
  10. oracle select 变量_详解oracle数据库优化参数--cursor_sharing
  11. centos+gitlab+mysql_centos7安装配置gitlab(使用外部nginx)
  12. sizeof(class)分析
  13. python公共变量声明_Python变量声明
  14. 使用阿里云邮件推送服务架设自己邮件验证与推送体系
  15. Oracle数据库运维方案及优化
  16. 线性规划的大M法和非线性规划的拉格朗日乘子法
  17. 25个常用Matplotlib图的Python代码,干货收藏!
  18. html采集插件如何用,火车采集器插件功能详解
  19. 第13章 Python建模库介绍--Python for Data Analysis 2nd
  20. 【开发随机】JAVA+POI+自定义注解+反射构建自定义工具类实现快捷简便的Excel模板化导出(附demo代码)

热门文章

  1. Tarjan点的双联通(寻找割点)
  2. opencv 绘制图像轮廓
  3. python机器学习库keras——CNN卷积神经网络人脸识别
  4. 实现labelme批量json_to_dataset方法
  5. epoll浅析以及nio中的Selector
  6. linux 如何连通网络
  7. candence 知识积累3
  8. SQL Server 2000企业管理器中MMC无法创建管理单元的解决方法
  9. Unity3D吐槽2--AnimationEvent
  10. 不支持对系统目录进行即席更新