https://yq.aliyun.com/articles/202939

Mnist: BATCH_SIZE X 784 array

CCN:BATCH_SIZE X28X28 -->BATCH_SIZE X28x28X1 array

LSTM:28(NUM_STEPS)个BATCH_SIZE X28 list

先试试数据变换:

# coding=utf-8
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np
sess=tf.Session()
x=np.array([[[111,112],[121,122]],[[211,212],[221,222]],[[311,312],[321,322]]])
x1=np.transpose(x, [1,0,2])
x2=np.stack(x,1)
x2_u=tf.unstack(x,2,1)
x3=np.reshape(x1,[-1,2])
x4=np.split(x3,2,0)
print(x.shape)
print(x)
print(x1.shape)
print(x1)
print(x2.shape)
print(x2)
print('x2_u is list')
print(sess.run(x2_u))
print(x3.shape)
print(x3)
print('x4 is list')
print(x4)
(3, 2, 2)
[[[111 112][121 122]][[211 212][221 222]][[311 312][321 322]]]
(2, 3, 2)
[[[111 112][211 212][311 312]][[121 122][221 222][321 322]]]
(2, 3, 2)
[[[111 112][211 212][311 312]][[121 122][221 222][321 322]]]
x2_u is list
[array([[111, 112],[211, 212],[311, 312]]), array([[121, 122],[221, 222],[321, 322]])]
(6, 2)
[[111 112][211 212][311 312][121 122][221 222][321 322]]
x4 is list
[array([[111, 112],[211, 212],[311, 312]]), array([[121, 122],[221, 222],[321, 322]])]

再试试ANN-LSTM,对每个时间步,网络结构为28X128X10(NUM_INPUT x NUM_HIDDEN x NUM_CLASSES),输入为每行像素,关键步骤:

定义一个LSTM元胞:

lstm_layer=rnn.BasicLSTMCell(NUM_HIDDEN,forget_bias=1)

构建网络:

outputs,_=rnn.static_rnn(lstm_layer,x_input_step,dtype="float32")

注意,输入x_input_step为list:NUM_STEPS个array:(BATCH_SIZE , NUM_INPUT)

输出outputs为一个输出列表(其中每个元素对应一个输入),长度为NUM_STEPS;

另一个输出为states,元胞的最终状态。

###data (50000,784),(1000,784),(1000,784):
import pickle
import gzipdef load_data():f = gzip.open('../data/mnist.pkl.gz', 'rb')training_data, validation_data, test_data = pickle.load(f,encoding='bytes')f.close()return (training_data, validation_data, test_data)def vectorized_result(j):e = np.zeros(10)e[j] = 1.0return etraining_data, validation_data, test_data = load_data()
trainData_in=training_data[0][:50000]
trainData_out=[vectorized_result(j) for j in training_data[1][:50000]]
validData_in=validation_data[0]
validData_out=[vectorized_result(j) for j in validation_data[1]]
testData_in=test_data[0][:100]
testData_out=[vectorized_result(j) for j in test_data[1][:100]]#define constants
#unrolled through 28 time steps 28行对应28个时间步:
TIME_STEPS=28
#hidden LSTM units
NUM_HIDDEN=128
#???rows of 28 pixels 每行28个像素:
NUM_INPUT=28
#learning rate for adam
LEARNING_RATE=0.001
#mnist is meant to be classified in 10 classes(0-9).
NUM_CLASSES=10
#size of batch
BATCH_SIZE=128TRAINING_EPOCHS=30##weights and biases of appropriate shape to accomplish above task
out_weights=tf.Variable(tf.random_normal([NUM_HIDDEN,NUM_CLASSES]))
out_bias=tf.Variable(tf.random_normal([NUM_CLASSES]))
#defining placeholders
#input image placeholder:
x_input=tf.placeholder("float",[None,TIME_STEPS,NUM_INPUT])
#input label placeholder:
y_desired=tf.placeholder("float",[None,NUM_CLASSES])
#processing the input tensor from [BATCH_SIZE,NUM_STEPS,NUM_INPUT] to "TIME_STEPS" number of [BATCH-SIZE,NUM_INPUT] tensors!:
#对输入的一个张量的第二维解包变成TIME_STEPS个张量!:
x_input_step=tf.unstack(x_input ,TIME_STEPS,1)#defining the network:
lstm_layer=rnn.BasicLSTMCell(NUM_HIDDEN,forget_bias=1)
outputs,_=rnn.static_rnn(lstm_layer,x_input_step,dtype="float32")
#converting last output of dimension [batch_size,num_hidden] to [batch_size,num_classes] by out_weight multiplication
z_prediction=tf.matmul(outputs[-1],out_weights)+out_bias#loss_function:
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=z_prediction,labels=y_desired))
#optimization
opt=tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(loss)
#model evaluation
correct_prediction=tf.equal(tf.argmax(z_prediction,1),tf.argmax(y_desired,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#initialize variables:
init=tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)num_batches=int(len(trainData_in)/BATCH_SIZE)for epoch in range(TRAINING_EPOCHS):for i in range(num_batches):batch_x=trainData_in[i*BATCH_SIZE:(i+1)*BATCH_SIZE]batch_x=batch_x.reshape((BATCH_SIZE,TIME_STEPS,NUM_INPUT))#batch_y=trainData_out[i*BATCH_SIZE:(i+1)*BATCH_SIZE]            sess.run(opt, feed_dict={x_input: batch_x, y_desired: batch_y})if i %10==0:acc=sess.run(accuracy,feed_dict={x_input:batch_x,y_desired:batch_y})los=sess.run(loss,feed_dict={x_input:batch_x,y_desired:batch_y})print('epoch:%4d,'%epoch,'%4d'%i)print("Accuracy ",acc)print("Loss ",los)print("__________________")

TensorFlow十三 LSTM练习相关推荐

  1. 使用tensorflow建模LSTM的详细步骤通俗易懂解读

    使用tensorflow建模LSTM的详细步骤人性化解读 一步步条理清晰的写tensorflow代码 Understanding LSTM in Tensorflow(MNIST dataset) L ...

  2. Tensorflow实现LSTM详解

    关于什么是 LSTM 我就不详细阐述了,吴恩达老师视频课里面讲的很好,我大概记录了课上的内容在吴恩达<序列模型>笔记一,网上也有很多写的好的解释,比如:LSTM入门.理解LSTM网络 然而 ...

  3. Tensorflow使用LSTM实现中文文本分类(1)

    前言 使用Tensorflow,利用LSTM进行中文文本的分类. 数据集格式如下: ''' 体育 马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的 ...

  4. 使用Tensorflow训练LSTM+Attention中文标题党分类

    这里用Tensorflow中LSTM+Attention模型训练一个中文标题党的分类模型,并最后用Java调用训练好的模型. 数据预处理 首先根据语料和实验数据训练词向量word2vec模型,这个有很 ...

  5. TensorFlow搭建LSTM实现多变量时间序列预测(负荷预测)

    目录 I. 前言 II. 数据处理 III. LSTM模型 IV. 训练/测试 V. 源码及数据 I. 前言 在前面的一篇文章TensorFlow搭建LSTM实现时间序列预测(负荷预测)中,我们利用L ...

  6. TensorFlow搭建LSTM实现时间序列预测(负荷预测)

    目录 I. 前言 II. 数据处理 III. 模型 IV. 训练/测试 V. 源码及数据 I. 前言 前面已经写过不少时间序列预测的文章: 深入理解PyTorch中LSTM的输入和输出(从input输 ...

  7. 情感分析之电影评论分析-基于Tensorflow的LSTM

    1. 深度学习在自然语言处理中的应用 自然语言处理是教会机器如何去处理或者读懂人类语言的系统,目前比较热门的方向,包括如下几类: 对话系统 - 比较著名的案例有:Siri,Alexa 和 Cortan ...

  8. 如何基于TensorFlow使用LSTM和CNN实现时序分类任务

    https://www.jiqizhixin.com/articles/2017-09-12-5 By 蒋思源2017年9月12日 09:54 时序数据经常出现在很多领域中,如金融.信号处理.语音识别 ...

  9. Kesci:Tensorflow 实现 LSTM——时间序列预测(超详细)

    云脑项目3 -真实业界数据的时间序列预测挑战 https://www.kesci.com/home/project/5a391c670e1fc52691fde623 这篇文章将讲解如何使用lstm进行 ...

最新文章

  1. model存数据_Jepsen 测试框架在图数据库 Nebula Graph 中的实践
  2. openJDK与JDK的区别
  3. 2.7 usb摄像头之usb摄像头描述符打印
  4. python gzipped source tarball,下载及安装Python详细步骤
  5. python项目-马哥教育官网-专业Linux培训班,Python培训机构
  6. python打开后的界面-Python - tkinter:打开和关闭对话框窗口
  7. linux 下的 initrd ramdisk
  8. Java基础学习笔记三 Java基础语法
  9. 鸿蒙系统被烧毁,华为鸿蒙操作系统再次被质疑 国产是原罪
  10. 浏览器兼容之JavaScript篇——已在IE、FF、Chrome测试
  11. 雅迪发布高端智能电动车G5 这个售价真的会有人买吗?
  12. loadrunner12 + ie11 无internet, 代码中文乱码
  13. Linux Futex的设计与实现(转)
  14. 主题模型(Topic Model)与LDA算法
  15. 将APPDATA 迁出C盘
  16. 【工作感想】 关于前后端分离的问题
  17. 怎么注册Github?用手机2分钟完成注册,互联网就是互相连接
  18. 进程系列(三)-进程的基本用法(打开文件示列)
  19. 转载 2015A国赛优秀论文
  20. 沉降观测曲线图 沉降观测汇总_沉降观测曲线图都有哪些

热门文章

  1. Vue.js 运行机制全局概览
  2. 吊炸天!一行命令快速部署大规模K8S集群!!!
  3. 容器编排技术 -- Kubernetes Volume
  4. Docker快速搭建docker-nfs-server服务器
  5. 使用adduser命令在Debian Linux中创建用户
  6. 2021 npm安装Electron失败解决方法
  7. cookie的设置与取值
  8. Java 时间处理整理
  9. 【详细说明】nginx反向代理wss websocket
  10. mybatis 配置详解