Tensorflow之RNN,LSTM

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
tensorflow之RNN
循环神经网络做手写数据集分类
"""import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data#设置随机数来比较两种计算结果
tf.set_random_seed(1)#导入手写数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)#设置参数
lr = 0.001
training_iters = 100000
batch_size = 128n_inputs = 28   # MNIST 输入为图片(img shape: 28*28)对应到图片像素的一行
n_steps = 28    # time steps 对应到图片有多少列
n_hidden_units = 128   # 隐藏层神经元个数
n_classes = 10      # MNIST分类结果为10#定义权重
weights = {#(28,128)'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units]))#(128,10)'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))}
#定义bias
biases = {# (128, )'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),# (10, )'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}def RNN(X, weights, biases):#作为cell输入的隐藏层#######################################################输入层#将输入shape从X三维输入变为二维(128 batch * 28 steps, 128 hidden)X = tf.reshape(X, [-1,n_inputs])#隐藏层# X_in = (128 batch * 28 steps, 128 hidden)X_in = tf.matmul(X, weights['in']) + biases['in']# 传给cell时需要将二维转为三维X_in ==> (128 batch, 28 steps, 128 hidden)X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])#cell########################################################LSTM cell forget_bias=1.0表示最开始学习我们不希望忘掉任何state,    #state_is_tuple=True这个为true表示记录每个时间点的cell状态和输出值,以后会默认为truecell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)#将lstm cell 分成两部分(c_state, h_state),对应到lstm一个是主线c_state(没有cell的遗忘),   #支线是h_state(有cell的遗忘),zero_state将每个t时间的cell初始化为0,init_state = cell.zero_state(batch_size, dtype=tf.float32)#outputs为lstm所有输出结果包括每个时刻cell的state,和输出值,final_state为最后的结果,   #time_major参数表示时间序列的位置是否为输入数据的第一个维度,由于我们是在第二个维度,所以为falseoutputs, final_state = tf.nn.dynamic_rnn(cell, X_in, initial_state=init_state, time_major=False)#1.将隐藏层的输出作为最后结果,只有一个结果#results = tf.matmul(final_state[1], weights['out']) + biases['out']#2.将每一步的结果输出到lists,在对outputs unstack后[1,0, 2]是将outputs list中每个tuple中元素对应展开tf.unstack(tf.transpose(outputs, [1, 0, 2]))results = tf.matmul(outputs[-1], weights['out']) + biases['out'] # shape = (128, 10)return resultspred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) with tf.Session() as sess:init = tf.global_variables_initializer()sess.run(init)step = 0while step * batch_size < training_iters:batch_xs, batch_ys = mnist.train.next_batch(batch_size)batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])sess.run([train_op], feed_dict={x: batch_xs,y: batch_ys,})if step % 20 == 0:print(sess.run(accuracy, feed_dict={x: batch_xs,y: batch_ys,}))step += 1

转载于:https://www.cnblogs.com/xmeo/p/7230723.html

Tensorflow之RNN,LSTM相关推荐

  1. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  2. tensorflow lstm 实现 RNN / LSTM 的关键几个步骤 多层通俗易懂

    https://blog.csdn.net/Jerr__y/article/details/61195257?depth_1-utm_source=distribute.pc_relevant.non ...

  3. Tensorflow使用CNN卷积神经网络以及RNN(Lstm、Gru)循环神经网络进行中文文本分类

    Tensorflow使用CNN卷积神经网络以及RNN(Lstm.Gru)循环神经网络进行中文文本分类 本案例采用清华大学NLP组提供的THUCNews新闻文本分类数据集的一个子集进行训练和测试http ...

  4. TensorFlow中RNN实现的正确打开方式

    上周写的文章<完全图解RNN.RNN变体.Seq2Seq.Attention机制>介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为: ...

  5. DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测

    DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测 目录 输出结果 核心代码 输出结果 数据集 tensorboard可视化 iter: 0 loss: 0.010 ...

  6. 使用PaddleFluid和TensorFlow训练RNN语言模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  7. TensorFlow中RNN实现的正确打开方式(转)

    上周写的文章<完全图解RNN.RNN变体.Seq2Seq.Attention机制>介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为: ...

  8. 利用Tensorflow构建RNN并对序列数据进行建模

    利用Tensorflow构建RNN并对序列数据进行建模 对文本处理处理任务的方法中,一般将TF-IDF向量作为特征输入.显然的缺陷是:这种方法丢失了输入的文本序列中每个单词的顺序. 对一般的前馈神经网 ...

  9. 【Language model】使用RNN LSTM训练语言模型 写出45°角仰望星空的文章

    开篇 这篇文章主要是实战内容,不涉及一些原理介绍,原理介绍为大家提供一些比较好的链接: 1. Understanding LSTM Networks : RNN与LSTM最为著名的文章,贴图和内容都恰 ...

最新文章

  1. LeetCode-笔记-48.旋转图像
  2. iOS 修改textholder的颜色
  3. JavaScript之充实文档的内容
  4. 轻松搞定项目中的空指针异常Caused by: java.lang.NullPointerException: null
  5. python毕业设计论文-25 行 Python 代码毕业设计实现人脸识别
  6. BZOJ1082: [SCOI2005]栅栏
  7. 线性回归与梯度下降法
  8. 处理大并发的30条数据库规范
  9. 如何上传本地图片到PictureBox控件
  10. C#多线程学习之:Monitor类
  11. react(89)--设置只能正整数
  12. 如何证明接口中的域是static final的?
  13. K8s与Docker
  14. vscode使用相关配置
  15. UE4物理模块(二)---建立物体碰撞
  16. 提高抗打击能力_“玻璃娃娃”不可取,抗打击能力培养很重要,父母别忽视
  17. PHP判断浏览器类型及版本
  18. 点进来,你看到的就是全网最全c++11新特性(持续整理更新中)
  19. 英语foteball足球
  20. 电脑断电word文档不见了恢复

热门文章

  1. java学习笔记④MySql数据库--03/04 DQL查询
  2. BZOJ1226 SDOI2009学校食堂(状压dp)
  3. 敏捷武士:看敏捷高手交付卓越软件pdf
  4. web项目从Myeclipse迁移到idea的二三事
  5. 阿里云CentOS服务器挂载数据盘
  6. 云计算之路-黎明前的黑暗:20130424网站故障经过
  7. 04级函授计算机等级考试练习.rar
  8. myabtis 数字+逗号 传参问题 $和#
  9. 快速入门系列之 Scala 语言 GitChat连接
  10. input=file 浏览时只显示指定excel文件,筛选特定文件类型