关于什么是 LSTM 我就不详细阐述了,吴恩达老师视频课里面讲的很好,我大概记录了课上的内容在吴恩达《序列模型》笔记一,网上也有很多写的好的解释,比如:LSTM入门、理解LSTM网络

然而,理解挺简单,上手写的时候还是遇到了很多的问题,网上大部分的博客都没有讲清楚 cell 参数的设置,在我看了N多篇文章后终于搞明白了,写出来让大家少走一些弯路吧!

如上图是一个LSTM的单元,可以应用到多种RNN结构中,常用的应该是 one-to-manymany-to-many


下面介绍 many-to-many 这种结构:

  1. batch_size:批度训练大小,即让 batch_size 个句子同时训练。
  2. time_steps:时间长度,即句子的长度
  3. embedding_size:组成句子的单词的向量长度(embedding size)
  4. hidden_size:隐藏单元数,一个LSTM结构是一个神经网络(如上图就是一个LSTM单元),每个小黄框是一个神经网络,小黄框的隐藏单元数就是hidden_size,那么这个LSTM单元就有 4*hidden_size 个隐藏单元。
  5. 每个LSTM单元的输出 C、h,都是向量,他们的长度都是当前 LSTM 单元的 hidden_size。
  6. n_words:语料库中单词个数。

实现方式一:

import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnndef add_layer(inputs, in_size, out_size, activation_function=None):  # 单层神经网络weights = tf.Variable(tf.random_normal([in_size, out_size]))baises = tf.Variable(tf.zeros([1, out_size]) + 0.1)wx_b = tf.matmul(inputs, weights) + baisesif activation_function is None:outputs = wx_belse:outputs = activation_function(wx_b)return outputsn_words = 15
embedding_size = 8
hidden_size = 8  # 一般hidden_size和embedding_size是相同的
batch_size = 3
time_steps = 5w = tf.Variable(tf.random_normal([n_words, embedding_size], stddev=0.01))  # 模拟参数 W
sentence = tf.Variable(np.arange(15).reshape(batch_size, time_step, 1))    # 模拟训练的句子:3条句子,每个句子5个单词  shape(3,5,1)
input_s = tf.nn.embedding_lookup(w, sentence)  # 将单词映射到向量:每个单词变成了size为8的向量  shape=(3,5,1,8)
input_s = tf.reshape(input_s, [-1, 5, 8])        # shape(3,5,8)with tf.name_scope("LSTM"):  # trustlstm_cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True, name='lstm_layer') h_0 = tf.zeros([batch_size, embedding_size])  # shape=(3,8)c_0 = tf.zeros([batch_size, embedding_size])  # shape=(3,8)state = rnn.LSTMStateTuple(c=c_0, h=h_0)      # 设置初始状态outputs = []for i in range(time_steps):  # 句子长度if i > 0: tf.get_variable_scope().reuse_variables()  # 名字相同cell使用的参数w就一样,为了避免重名引起别的的问题,设置一下变量重用output, state = lstm_cell(input_s[:, i, :], state)     # output:[batch_size,embedding_size]  shape=(3,8)outputs.append(output)     # outputs:[TIME_STEP,batch_size,embedding_size]  shape=(5,3,8)path = tf.concat(outputs, 1)   # path:[batch_size,embedding_size*TIME_STEP]   shape=(3, 40)path_embedding = add_layer(path, time_step * embedding_size, embedding_size)  # path_embedding:[batch_size, embedding_size]with tf.Session() as s:s.run(tf.global_variables_initializer())# 因为使用的参数数量都还比较小,打印一些变量看看就能明白是怎么操作的print(s.run(outputs))print(s.run(path_embedding))

比如一批训练64句话,每句话20个单词,每个词向量长度为200,隐藏层单元个数为128
那么训练一批句子,输入的张量维度是[64,20,200],ht,ct​ 的维度是[128],那么LSTM单元参数矩阵的维度是[128+200,4x128],
在时刻1,把64句话的第一个单词作为输入,即输入一个[64,200]的矩阵,由于会和 ht 进行concat,输入矩阵变成了[64,200+128],输入矩阵会和参数矩阵[200+128,4x128]相乘,输出为[64,4x128],也就是每个黄框的输出为[64,128],黄框之间会进行一些操作,但不改变维度,输出依旧是[64,128],即每个句子经过LSTM单元后,输出的维度是128,所以每个LSTM输出的都是向量,包括Ct,ht,所以它们的长度都是当前LSTM单元的hidden_size 。那么我们就知道cell_output的维度为[64,128]
之后的时刻重复刚才同样的操作,那么outputs的维度是[20,64,128].
softmax相当于全连接层,将outputs映射到vocab_size个单词上,进行交叉熵误差计算。
然后根据误差更新LSTM参数矩阵和全连接层的参数。

实现方式二:

测试数据链接:https://pan.baidu.com/s/1j9sgPmWUHM5boM5ekj3Q2w 提取码:go3f

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tfdata = pd.read_excel("seq_data.xlsx")  # 读取序列数据
data = data.values[1:800]   # 取前800个
normalize_data = (data - np.mean(data)) / np.std(data)  # 标准化数据
s = np.std(data)
m = np.mean(data)
time_step = 96   # 序列段长度
rnn_unit = 8     # 隐藏层节点数目
lstm_layers = 2  # cell层数
batch_size = 7   # 序列段批处理数目
input_size = 1   # 输入维度
output_size = 1  # 输出维度
lr = 0.006       # 学习率train_x, train_y = [], []
for i in range(len(data) - time_step - 1):x = normalize_data[i:i + time_step]y = normalize_data[i + 1:i + time_step + 1]train_x.append(x.tolist())train_y.append(y.tolist())
X = tf.placeholder(tf.float32, [None, time_step, input_size])  # shape(?,time_step, input_size)
Y = tf.placeholder(tf.float32, [None, time_step, output_size])  # shape(?,time_step, out_size)
weights = {'in': tf.Variable(tf.random_normal([input_size, rnn_unit])),'out': tf.Variable(tf.random_normal([rnn_unit, 1]))}
biases = {'in': tf.Variable(tf.constant(0.1, shape=[rnn_unit, ])),'out': tf.Variable(tf.constant(0.1, shape=[1, ]))}
def lstm(batch):w_in = weights['in']b_in = biases['in']input = tf.reshape(X, [-1, input_size])input_rnn = tf.matmul(input, w_in) + b_ininput_rnn = tf.reshape(input_rnn, [-1, time_step, rnn_unit])cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(rnn_unit) for i in range(lstm_layers)])init_state = cell.zero_state(batch, dtype=tf.float32)output_rnn, final_states = tf.nn.dynamic_rnn(cell, input_rnn, initial_state=init_state, dtype=tf.float32)output = tf.reshape(output_rnn, [-1, rnn_unit])w_out = weights['out']b_out = biases['out']pred = tf.matmul(output, w_out) + b_outreturn pred, final_statesdef train_lstm():global batch_sizewith tf.variable_scope("sec_lstm"):pred, _ = lstm(batch_size)loss = tf.reduce_mean(tf.square(tf.reshape(pred, [-1]) - tf.reshape(Y, [-1])))train_op = tf.train.AdamOptimizer(lr).minimize(loss)saver = tf.train.Saver(tf.global_variables())loss_list = []with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(100):  # We can increase the number of iterations to gain better result.start = 0end = start + batch_sizewhile (end < len(train_x)):_, loss_ = sess.run([train_op, loss], feed_dict={X: train_x[start:end], Y: train_y[start:end]})start += batch_sizeend = end + batch_sizeloss_list.append(loss_)if i % 10 == 0:print("Number of iterations:", i, " loss:", loss_list[-1])if i > 0 and loss_list[-2] > loss_list[-1]:saver.save(sess, 'model_save1\\modle.ckpt')# I run the code in windows 10,so use  'model_save1\\modle.ckpt'# if you run it in Linux,please use  'model_save1/modle.ckpt'print("The train has finished")train_lstm()def prediction():with tf.variable_scope("sec_lstm", reuse=tf.AUTO_REUSE):pred, _ = lstm(1)saver = tf.train.Saver(tf.global_variables())with tf.Session() as sess:saver.restore(sess, 'model_save1\\modle.ckpt')# I run the code in windows 10,so use  'model_save1\\modle.ckpt'# if you run it in Linux,please use  'model_save1/modle.ckpt'predict = []for i in range(0, np.shape(train_x)[0]):next_seq = sess.run(pred, feed_dict={X: [train_x[i]]})predict.append(next_seq[-1])plt.figure()plt.plot(list(range(len(data))), data, color='b')plt.plot(list(range(time_step + 1, np.shape(train_x)[0] + 1 + time_step)), [value * s + m for value in predict],color='r')plt.show()prediction()

参考文章:

基于TensorFlow构建LSTM
TensorFlow实战:LSTM的结构与cell中的参数

Tensorflow实现LSTM详解相关推荐

  1. Tensorflow安装教程详解(图文详解,深度好文)

    Tensorflow安装教程详解(图文详解,深度好文) 前言 安装前的准备工作 关于python 关于Anaconda 开始使用Tensorflow 系统内配置Anaconda使用路径 Anacond ...

  2. 【nn.LSTM详解】

    参数详解 nn.LSTM是pytorch中的模块函数,调用如下: torch.nn.lstm(input_size,hidden_size,num_layers,bias,batch_first,dr ...

  3. 长短时记忆网络(Long Short Term Memory,LSTM)详解

    长短时记忆网络是循环神经网络(RNNs)的一种,用于时序数据的预测或文本翻译等方面.LSTM的出现主要是用来解决传统RNN长期依赖问题.对于传统的RNN,随着序列间隔的拉长,由于梯度爆炸或梯度消失等问 ...

  4. Tensorflow系列 | Tensorboard详解(下篇)

    编辑 | 安可 [导读]:本文接续Tensorboard详解(上篇)介绍Tensorboard和总结Tensorboard的所有功能并有代码演练.欢迎大家点击上方蓝字关注我们的公众号:深度学习与计算机 ...

  5. windows环境下tensorflow安装过程详解(亲测安装成功后测试那块)

    写在最前: 在安装过程中遇到很多坑,一开始自己从官网下载了Python3.6.3或者Python3.6.5或者Python3.7.1等多个版本,然后直接pip install tensorflow或者 ...

  6. Tensorflow载入模型详解,方法一(基础版):针对测试模型性能 和 使用模型。

    我们知道了如何保存我们的模型接下来,我们就要想办法加载模型,调用模型,这也是我们用来做验证也好.做应用也好必须要做的.当然这里我们只考虑应用和验证,且只涉及模型部分,数据预处理,大家要自己加油啦.下一 ...

  7. Tensorflow保存模型详解(进阶版二):如何保存最近的.ckpt文件 及 如何分开保存.ckpt数据文件和.meta图文件

    在学会了如何有选择的保存变量后,我们来学习如何如何分开保存.ckpt数据文件和.meta图文件 和 如何 保存最近几轮的.ckpt数据文件. 直接上代码: import tensorflow as t ...

  8. RNN到LSTM详解

    RNN(Recurrent Neural Network)是一类用于处理序列数据的神经网络.首先我们要明确什么是序列数据,摘取百度百科词条:时间序列数据是指在不同时间点上收集到的数据,这类数据反映了某 ...

  9. 深度学习中的循环神经网络LSTM详解

    (一).什么是循环神经网络LSTM? LSTM指的是长短期记忆网络(Long Short Term Memory),它是循环神经网络中最知名和成功的扩展.由于循环神经网络有梯度消失和梯度爆炸的问题,学 ...

最新文章

  1. 《JavaScript权威指南》笔记(一)
  2. Java第一个程序(CMD环境)
  3. C++ auto 关键字的使用
  4. 【PC工具】复制翻译神器!有了这个开源免费的翻译软件,阅读英文文档变得再也不困难了...
  5. WildFly上具有AngularJS的Java EE 7和Java WebSocket API(JSR 356)
  6. C#打开php链接传参然后接收返回值
  7. pe我的手机服务器存档文件,我的世界手机版怎么导出存档 pe版怎么把存档给别人用...
  8. innodb_force_recovery
  9. linux dstat rpm,dstat监控工具介绍
  10. Java实现部标JTT1078实时音视频传输指令——视频流负载包(RTP)传输
  11. CTF MISC(杂项)知识点总结——图片类(一)
  12. 怎样能把在线视频(不提供下载)储存下来到电脑
  13. Hybrid Dilated Convolution学习笔记
  14. C语言变量常量,基本数据类型及数据类型转换详讲
  15. synergy使用方法和安装包
  16. 如何防护 DDoS 攻击?
  17. Git中rebase的使用
  18. 未来已经降临,只是先后有别
  19. 在线记录源码调试之@Qualifier源码分析
  20. android向联系人中添加头像以及获得电话记录

热门文章

  1. 花三千块钱求推荐一个靠谱的C++工程师
  2. Linux 内核系统架构
  3. 机器学习——超参数调优
  4. 二进制包如何知道go 版本_gops 是怎么和 Go 的运行时进行交互的?
  5. 1+X web中级 Laravel学习笔记——blade模版
  6. Packet Tracer 5.0 建构 CCNA 实验攻略——路由器实现 Vlan 间通信
  7. 剑指Offer - 面试题33. 二叉搜索树的后序遍历序列(递归)
  8. LeetCode 870. 优势洗牌(贪心 二分查找)
  9. LeetCode 59. 螺旋矩阵 II LeetCode 54. 螺旋矩阵
  10. maven deploy plugin_Maven快速上手