tensorflow应用:双向LSTM神经网络手写数字识别

  • 思路
  • Python程序1.建模训练保存
  • Tensorboard检查计算图及训练结果
  • 打开训练好的模型进行预测

思路

将28X28的图片看成28行像素,按行展开成28时间步,每时间步间对识别都有影响,故用双向LSTM神经元,其实每列间对识别也有影响,用卷积神经网络也许更合理,这里只是学习LSTM的用法。应该也可以用两个双向LSTM神经网络进行联合预测,一个按行扫描,一个按列扫描。

Python程序1.建模训练保存

# coding=utf-8
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np###data (50000,784),(1000,784),(1000,784):
import pickle
import gzipdef load_data():f = gzip.open('../data/mnist.pkl.gz', 'rb')training_data, validation_data, test_data = pickle.load(f,encoding='bytes')f.close()return (training_data, validation_data, test_data)def vectorized_result(j):e = np.zeros(10)e[j] = 1.0return etraining_data, validation_data, test_data = load_data()
trainData_in=training_data[0][:50000]
trainData_out=[vectorized_result(j) for j in training_data[1][:50000]]
validData_in=validation_data[0]
validData_out=[vectorized_result(j) for j in validation_data[1]]
testData_in=test_data[0][:100]
testData_out=[vectorized_result(j) for j in test_data[1][:100]]#define constants
#unrolled through 28 time steps 28行对应28个时间步:
TIME_STEPS=28
#hidden LSTM units
NUM_HIDDEN=128
#???rows of 28 pixels 每行28个像素:
NUM_INPUT=28
#learning rate for adam
LEARNING_RATE=0.001
#mnist is meant to be classified in 10 classes(0-9).
NUM_CLASSES=10
#size of batch
BATCH_SIZE=1024TRAINING_EPOCHS=1##weights and biases of appropriate shape to accomplish above task
#out_weights=tf.Variable(tf.random_normal([NUM_HIDDEN,NUM_CLASSES]))
#双向神经网络的权重为单向的2倍尺度:
out_weights=tf.Variable(tf.random_normal([2*NUM_HIDDEN,NUM_CLASSES]))
out_bias=tf.Variable(tf.random_normal([NUM_CLASSES]))
#defining placeholders
#input image placeholder:
x_input=tf.placeholder("float",[None,TIME_STEPS,NUM_INPUT],name='x_input')
#input label placeholder:
y_desired=tf.placeholder("float",[None,NUM_CLASSES])
#processing the input tensor from [BATCH_SIZE,NUM_STEPS,NUM_INPUT] to "TIME_STEPS" number of [BATCH-SIZE,NUM_INPUT] tensors!:
#对输入的一个张量的第二维解包变成TIME_STEPS个张量!:
x_input_step=tf.unstack(x_input ,TIME_STEPS,1)#defining the network:
#def BiRNN(x_input_step,out_weights,out_bias):
#lstm_layer=rnn.BasicLSTMCell(NUM_HIDDEN,forget_bias=1.0)
#正向神经元:
lstm_fw_cell=rnn.BasicLSTMCell(NUM_HIDDEN,forget_bias=1.0)
#反向神经元:
lstm_bw_cell=rnn.BasicLSTMCell(NUM_HIDDEN,forget_bias=1.0)
#outputs,_=rnn.static_rnn(lstm_layer,x_input_step,dtype="float32")
#构建双向LSTM网络:
outputs,_,_=rnn.static_bidirectional_rnn( lstm_fw_cell,lstm_bw_cell,x_input_step,dtype="float32")
#converting last output of dimension [batch_size,num_hidden] to [batch_size,num_classes] by out_weight multiplication
z_prediction= tf.add(tf.matmul(outputs[-1],out_weights),out_bias,name='z_prediction')
#z_prediction=BiRNN(x_input_step, out_weights, out_bias)
#注意!z_prediction经softmax归一化后才是最终的输出,用于和标签比较,下面的损失函数中用了softmax哈交叉熵,跳过了求y_output这一步:
y_output=tf.nn.softmax(z_prediction,name='y_output')#loss_function:
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z_prediction,labels=y_desired),name='loss')
#optimization
opt=tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(loss)
#model evaluation
correct_prediction=tf.equal(tf.argmax(z_prediction,1),tf.argmax(y_desired,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#以下汇总一些参数用于TensorBoard:
for value in [loss]:tf.summary.scalar(value.op.name,value) #汇总的标签及值
summary_op=tf.summary.merge_all() #汇总合并#initialize variables:
init=tf.global_variables_initializer()with tf.Session() as sess:# 生成一个写日志的writer,并将当前的tensorflow计算图写入日志。# tensorflow提供了多种写日志文件的APIsummary_writer=tf.summary.FileWriter(r'C:\temp\log_simple_stats',sess.graph)sess.run(init)num_batches=int(len(trainData_in)/BATCH_SIZE)for epoch in range(TRAINING_EPOCHS):for i in range(num_batches):batch_x=trainData_in[i*BATCH_SIZE:(i+1)*BATCH_SIZE]batch_x=batch_x.reshape((BATCH_SIZE,TIME_STEPS,NUM_INPUT))#batch_y=trainData_out[i*BATCH_SIZE:(i+1)*BATCH_SIZE]#优化及日志结果!!!!!!:::::            _,summary=sess.run([opt,summary_op], feed_dict={x_input: batch_x, y_desired: batch_y})#写日志,将结果添加到汇总:summary_writer.add_summary(summary,global_step=epoch*num_batches+i)if i %10==0:acc=sess.run(accuracy,feed_dict={x_input:batch_x,y_desired:batch_y})los=sess.run(loss,feed_dict={x_input:batch_x,y_desired:batch_y})print('epoch:%4d,'%epoch,'%4d'%i)print("Accuracy ",acc)print("Loss ",los)print("__________________")                print("Finished!")print("Test Accuracy ",sess.run(accuracy,\feed_dict={x_input:testData_in.reshape((-1,TIME_STEPS,NUM_INPUT)),\y_desired:testData_out}))saver=tf.train.Saver()save_path=saver.save(sess,'../data')print('Model saved to %s' % save_path)    summary_writer.close()

Tensorboard检查计算图及训练结果

在终端运行:

Tensorboard --logdir= C:\temp\log_simple_stats

C:\Users\li\AppData\Local\Programs\Python\Python36\Scripts>tensorboard --logdir=C:\temp\log_simple_stats

TensorBoard 1.10.0 at http://li-PC:6006 (Press CTRL+C to quit)

用谷歌浏览器打开http://li-pc:6006/

打开训练好的模型进行预测

# coding=utf-8
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error ###data (50000,784),(1000,784),(1000,784):
import pickle
import gzip
import numpy as np
def load_data():f = gzip.open('../data/mnist.pkl.gz', 'rb')training_data, validation_data, test_data = pickle.load(f,encoding='bytes')f.close()return (training_data, validation_data, test_data)def vectorized_result(j):e = np.zeros(10)e[j] = 1.0return eimport tensorflow as tf
import matplotlib.pyplot as plt#unrolled through 28 time steps 28行对应28个时间步:
TIME_STEPS=28
#???rows of 28 pixels 每行28个像素:
NUM_INPUT=28training_data, validation_data, test_data = load_data()
testData_in=test_data[0]
testData_out=[vectorized_result(j) for j in test_data[1]]sess=tf.InteractiveSession()
new_saver=tf.train.import_meta_graph('../data.meta')
new_saver.restore(sess, '../data')
tf.get_default_graph().as_graph_def()
x_input=sess.graph.get_tensor_by_name('x_input:0')
y_output=sess.graph.get_tensor_by_name('y_output:0')try_input=testData_in[6]
try_desired=testData_out[6]
print(try_desired)
print(y_output.eval(feed_dict={x_input:\np.array([try_input]).reshape((-1,TIME_STEPS,NUM_INPUT))}))
try_input.resize(28,28)
plt.imshow(try_input,cmap='Greys_r')
plt.show()

[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[[1.13495190e-07 7.10399399e-06 5.81623426e-05 1.16373285e-05
7.91627526e-01 5.11910184e-04 1.04986066e-04 1.08990945e-01
1.73597573e-03 9.69514474e-02]]


模型只进行了一轮训练,40次更新,40X1024个样本,就准确识别了手写数字4,判断为数字4的概率是 7.91627526e-01

tensorflow应用:双向LSTM神经网络手写数字识别相关推荐

  1. python cnn代码详解图解_基于TensorFlow的CNN实现Mnist手写数字识别

    本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下 一.CNN模型结构 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5* ...

  2. 前馈神经网络手写数字识别

    前馈神经网络手写数字识别 今天就来说说手写数字识别,我们上课的时候老师要求我们使用前馈神经网络和卷积神经网络两种神经网络实现手写数字识别.做一下这两个实验还真的挺有意思啊. 举个例子,识别图片中的 : ...

  3. 基于TensorFlow和mnist数据集的手写数字识别系统 ,可识别电话号码,识别准确率高,有对比实验,两组模型,可讲解代码

    基于TensorFlow和mnist数据集的手写数字识别系统 ,可识别电话号码,识别准确率高,有对比实验,两组模型,可讲解代码

  4. 教程 | 基于LSTM实现手写数字识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 基于tensorflow,如何实现一个简单的循环神经网络,完成手写 ...

  5. 深度学习--TensorFlow(项目)Keras手写数字识别

    目录 效果展示 基础理论 1.softmax激活函数 2.神经网络 3.隐藏层及神经元最佳数量 一.数据准备 1.载入数据集 2.数据处理 2-1.归一化 2-2.独热编码 二.神经网络拟合 1.搭建 ...

  6. Python纯手动搭建BP神经网络--手写数字识别

    1 实验介绍 实验要求: 实现一个手写数字识别程序, 如下图所示, 要求神经网络包含一个隐层, 隐层的神经元个数为 15. 整体思路:主要参考西瓜书第五章神经网络部分的介绍,使用批量梯度下降对神经网络 ...

  7. 【手写数字识别】RBM神经网络手写数字识别【含GUI Matlab源码 1109期】

    ⛄一.手写数字识别技术简介 1 案例背景 手写体数字识别是图像识别学科下的一个分支,是图像处理和模式识别研究领域的重要应用之一,并且具有很强的通用性.由于手写体数字的随意性很大,如笔画粗细.字体大小. ...

  8. 利用python卷积神经网络手写数字识别_卷积神经网络使用Python的手写数字识别

    为了使机器更智能,开发人员正在研究机器学习和深度学习技术.人类通过反复练习和重复执行任务来学习执行任务,从而记住了如何执行任务.然后,他大脑中的神经元会自动触发,它们可以快速执行所学的任务.深度学习与 ...

  9. 卷积神经网络 手写数字识别(包含Pytorch实现代码)

    Hello!欢迎来到六个核桃Lu! 运用卷积神经网络 实现手写数字识别 1 算法分析及设计 卷积神经网络: 图1-2 如图1-2,卷积神经网络由若干个方块盒子构成,盒子从左到右仿佛越来越小,但却越来越 ...

最新文章

  1. 自动编码器的评级预测
  2. rasa算法_(六)RASA NLU意图分类器
  3. 计算机进制转换图,计算机等级考试进制转换及常用函数
  4. ZooKeeper入门之数据模型和常用命令介绍
  5. 如何在用户登录时SAP时自动执行Tcode或者其他一些东西
  6. Redis集群的搭建(具体步骤)
  7. matebookxpro上鸿蒙系统,华为MateBook X Pro对比MacBook Pro该买谁?
  8. C++实现链式存储二叉树
  9. 如何在客户端终止一个已经发出的HTTP请求
  10. takePic and Videos
  11. 解析poj页面获取题目
  12. php xml 接口调用,php的SimpleXML方法读写XML接口文件实例解析
  13. Idea中maven项目中导入本地jar包
  14. 如何在Node.js中处理POST数据?
  15. 移动Web开发之流式布局笔记
  16. php导出excel出现乱码,完美解决phpexcel导出到xls文件出现乱码的问题
  17. opencv实现银行卡号识别
  18. rpc服务器不可用处于启用状态,电脑提示RPC服务器不可用怎么办?
  19. Android APP启动时出现白屏或者黑屏怎么办?
  20. YOLO系列(V1-V2-V3)

热门文章

  1. 通过NodeJS自动生成的MySQL的REST风格API
  2. 容器编排技术 -- Init 容器
  3. 在centos7上设置swap交换空间
  4. React v16版本 源码解读
  5. 【Java】统计字符个数
  6. C 语言实例 -求分数数列1/2+2/3+3/5+5/8+...的前n项和
  7. SSH客户端常用工具SecureCRT操作
  8. 感谢相信你鼓励你的人
  9. 超级玛丽程序_如何构建一个超级快速的微笑跟踪应用程序
  10. bootstrap快速入门_在5分钟内学习Bootstrap 4-快速入门指南