(2016-09-03 08:35:36)

转载▼
   
MNIST 字符数据库每个字符(0-9) 对应一张28x28的一通道图片,可以将图片的每一行(或者每一列)当作特征,共28行。则可以通过输入大小为28,时间长度为28的RNN(lstm)对字符建模。对于同一个字符,比如0,其行与行之间的动态变化可以很好地被RNN表示,所有这些连续行的变化表征了某个字符的特定模式。因此可以使用RNN来进行字符识别。
Tensorflow提供了不错的RNN接口,基本思路是 
1. 建立RNN网络中的基本单元 cell; tf提供了很多中类型的cell, BasicRNNCell,BasicLSTMCell,LSTMCell 等等
2. 通过调用tf.nn.rnn函数或者tf.nn.dynamic_rnn 函数将cell连成RNN 网络。
这个历程参考了 https://github.com/aymericdamien/TensorFlow-Examples里面的一些代码,进行了简化和扩展。下面的参数设置能达到98%以上的识别率。
import input_data #这个直接使用上面链接中的input_data.py
mnist = input_data.read_data_sets('./MNIST_data/',one_hot=True)
import tensorflow as tf
import numpy as np
import time
#parameters
learning_rate = 0.001
training_iters = 100000
batch_size = 32
display_step = 10
#network parameters
n_input = 28 #特征维度,字符图片的每一行看成输入特征
n_steps = 28 #每张字符图片共有28行
n_hidden = 128 # 隐藏单元个数
n_classes =10 #类别总数,10个字符,因此类别为10
n_layers = 3 #RNN中有多少个cell
#tf Graph input
x = tf.placeholder("float32",[None,n_steps,n_input])
#rnn中的中间状态变量,包含cell 的状态(c_t)和每个cell 的输出状态(h_t) 对应 lstm的输出公式[ ht=o*tanh(ct)]
#本例中的多层RNN中,每一个CELL中的状态数目相等,因此输入状态变量是 2*n_hidden*n_layers],实际上是可以不相等的
#另外,可以提供初始状态,也可以不提供,让tf自动初始化
istate = tf.placeholder("float32",[None,2*n_hidden*n_layers])
y = tf.placeholder("float32",[None,n_classes])
#define weights, 设置weights 和biases为tf全局变量,weigths['hidden']  biases['hidden']参数代表对输入数据先进行一次线性变换(可选), weigths['out']  biases['out']代表了从RNN状态到字符类别的线性连接层的参数,在训练的过程中, weigths, biases会持续变化 
weights = {#'hidden':tf.Variable(tf.random_normal([n_input,n_hidden])),
'out':tf.Variable(tf.random_normal([n_hidden,n_classes]))}
biases = {#'hidden':tf.Variable(tf.random_normal([n_hidden])),
'out':tf.Variable(tf.random_normal([n_classes]))}
#define the LSTM network very simple,one cell
#基于一个基本LSTM cell的RNN网络
def RNN(_X,_istate,_weights,_biases):
    # 将输入数据由[ batch_size,nsteps,n_input] 变为  [  nsteps, batch_size,n_input]  
_X = tf.transpose(_X,[1,0,2]) 
_X = tf.reshape(_X,[-1,n_input])
#linear activation, not neccessary for the lstm model,can be ommited
#_X = tf.matmul(_X,_weights['hidden'])+_biases['hidden']  #输入rnn之前先加一层线性变换,可选 
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0)
_X = tf.split(0,n_steps,_X) # input a length T list of tensors
outputs,states = tf.nn.rnn(lstm_cell,_X,initial_state = _istate) #由于_X是list,输出ouputs也是lists, 长度为T,元素为[batch_size,hidden_units]的tensor,因此后面可以使用-1索引
return tf.matmul(outputs[-1],_weights['out'])+_biases['out']#rnn's ouput is a list of tensors
#lstm based on dyrnn
def DRNN(_X,_istate,_weights,_biases):
_X = tf.transpose(_X,[1,0,2]) # because the input format is batch_size*nsteps_n_input    
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0)
multi_cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*n_layers) # n_layers lstm cells,该函数接受[cell1,cell2,cell3] cell列表为参数,构建一个多层的RNN模型
#print(multi_cell.state_size)  #可以用rnn_cell的state_size获取rnn_cell的大小
#different from tf.nn.rnn(),input must be a tensor or a tuple of tensors
outputs,states = tf.nn.dynamic_rnn(multi_cell,_X,initial_state = _istate,time_major=True)
# if not set initial state, dtype must be set
#outputs,states = tf.nn.dynamic_rnn(multi_cell,_X,dtype=tf.float32,time_major=True)
#print(outputs.get_shape())
#outputs,states = tf.nn.dynamic_rnn(multi_cell,_X,dtype=tf.float32,time_major=True) #if not set initial state, must set dtype
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1) #dynamic_rnn's output is a tensor
return tf.matmul(last,_weights['out'])+_biases['out']   
#RNN
if n_layers <= 1:
pred = RNN(x,istate,weights,biases)
else:
#DRNN
pred = DRNN(x,istate,weights,biases)
#softmax交叉熵值损失
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred,y))
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cost)
#计算准确率
correct_pred = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
step = 1
total_time = 0.0
start_time = time.time()
while step*batch_size

batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape((batch_size,n_steps,n_input)) #numpy nd array, the feed_dict cannot be tensors
#note that, the feed_dict should not contain any tensor, but should be nd-array      
sess.run(optimizer,feed_dict={x:batch_xs,y:batch_ys,istate:np.zeros((batch_size,2*n_hidden*n_layers))})
if step % display_step == 0:  
loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys,
istate: np.zeros((batch_size, 2*n_hidden*n_layers))})        
acc = sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys,istate:np.zeros((batch_size,2*n_hidden*n_layers))})
print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + \
", Training Accuracy= " + "{:.5f}".format(acc)
step = step+1
total_time = total_time+(time.time()-start_time)
print ("Optimization %d iterations, Finished in %.4f seconds!"%(training_iters,total_time))
#Doing some tests: Calculate accuracy for 256 mnist test images
test_len = 256
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print "Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label,
istate: np.zeros((test_len, 2*n_hidden*n_layers))})

TensorFlow RNN MNIST字符识别Demo快速了解TF RNN核心框架相关推荐

  1. tensorflow教程 开始——数据集:快速了解 tf.data

    参考文章:数据集:快速了解 数据集:快速了解 tf.data 从 numpy 数组读取内存数据. 逐行读取 csv 文件. 基本输入 学习如何获取数组的片段,是开始学习 tf.data 最简单的方式. ...

  2. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  3. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

  4. 重磅 | TensorFlow 2.0即将发布,所有tf.contrib将被弃用

    作者 | 阿司匹林 出品 | AI科技大本营(公众号ID:rgznai100) 上周,谷歌刚刚发布了 TensorFlow 1.10.0 版本(详见<TensorFlow 版本 1.10.0 发 ...

  5. 基于tensorflow的MNIST手写字识别

    一.卷积神经网络模型知识要点卷积卷积 1.卷积 2.池化 3.全连接 4.梯度下降法 5.softmax 本次就是用最简单的方法给大家讲解这些概念,因为具体的各种论文网上都有,连推导都有,所以本文主要 ...

  6. 使用Tensorflow操作MNIST数据

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...

  7. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  8. tensorflow saver_机器学习入门(6):Tensorflow项目Mnist手写数字识别-分析详解

    本文主要内容:Ubuntu下基于Tensorflow的Mnist手写数字识别的实现 训练数据和测试数据资料:http://yann.lecun.com/exdb/mnist/ 前面环境都搭建好了,直接 ...

  9. mnist手写数字识别python_Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】...

    本文实例讲述了Python tensorflow实现mnist手写数字识别.分享给大家供大家参考,具体如下: 非卷积实现 import tensorflow as tf from tensorflow ...

最新文章

  1. vue.js安装过程(npm安装)
  2. vue路由对象($route)参数简介
  3. SSH Web工程环境搭建总结
  4. Email 之父去世:他让邮件有了 @
  5. Apache Kylin v2.6.1 发布,开源的分布式分析引擎
  6. Linux - XShell - alt 快捷键的设置
  7. linux下root密码修改方法
  8. java中,如何实现输入一个正整数,并将这个数字反转输出,比如输入123,输出321
  9. linux启动有两个选择,RHEL5 用CentOS源升级,GRUB出现CentOS,RHEL两个启动项,选择哪一个?...
  10. 百度,你拿什么和谷歌争?| 畅言
  11. java session重复登录_Java开发网 - Hibernate:session中对象重复问题的解决方法(原创)...
  12. python得语言编程模式_一图看懂编程语言迁移模式:终点站是Python、Go、JS!
  13. 顶岗实习周记java方向_java软件开发——顶岗实习周记25篇
  14. CRCNN PCNN
  15. 浅谈关于Java中map这个类衍生的类
  16. PowerApps教程10-图表的设计与嵌入
  17. 个人网站Timonj(Personal website)
  18. 深度学习基础 - 余弦定理
  19. 微信小程序操作es简单搜索
  20. SU几种偏移测试 转自[蠢树]

热门文章

  1. 把你的面子撕下来扔到地上,狠狠踹几脚!
  2. python实现给视频添加字幕,并根据字幕添加语音
  3. 使用PHP从Access数据库中提取对象,第2部分
  4. 软件测试需要学习什么 3分钟带你了解软测的学习内容
  5. PC微信 HOOK 接口 (版本:3.6.0.18)
  6. 地图配色及网络地图比较
  7. 全面认识当前市面99%的大数据技术框架(附:各大厂大数据技术应用文章)
  8. 好看的css下拉框样式,实用的漂亮的下拉框-CUSTOM DROP-DOWN LIST STYLING
  9. SEO之网站标题间隔符
  10. 美团点评前端面试小结