Keras中长短期记忆网络LSTM的5步生命周期

使用Keras在Python中创建和评估深度学习神经网络非常容易,但您必须遵循严格的模型生命周期。

在这篇文章中,您将发现在Keras中创建,训练和评估长期短期记忆(LSTM)回归神经网络的分步生命周期,以及如何使用训练有素的模型进行预测。

阅读这篇文章后,你会知道:

如何在Keras中定义,编译,拟合和评估LSTM。
如何为回归和分类序列预测问题选择标准默认值。
如何将它们结合在一起,在Keras开发和运行您的第一个LSTM循环神经网络。

Overview
下面概述了我们将要研究的Keras LSTM模型生命周期中的5个步骤。

定义网络
编译网络
拟合网络
评估网络
作出预测
环境
     本教程假设您已安装Python SciPy环境。 您可以在此示例中使用Python 2或3。本教程假设您安装了TensorFlow或Theano后端的Keras v2.0或更高版本。本教程还假设您安装了scikit-learn,Pandas,NumPy和Matplotlib。接下来,让我们看看标准时间序列预测问题,我们可以将其用作此实验的上下文。

步骤1.定义网络
      第一步是定义您的网络。神经网络在Keras中被定义为层序列。 这些图层的容器是Sequential类。第一步是创建Sequential类的实例。 然后,您可以创建图层并按照它们应连接的顺序添加它们。 由存储器单元组成的LSTM循环层称为LSTM()。 通常跟随LSTM层并用于输出预测的完全连接层称为Dense()。

例如,我们可以分两步完成:

model = Sequential()
model.add(LSTM(2))
model.add(Dense(1))

但是我们也可以通过创建一个层数组并将其传递给sequence的构造函数来一步完成。

layers = [LSTM(2), Dense(1)]
model = Sequential(layers)

网络的第一层必须定义预期的输入数量。输入必须是三维的,包括样本、步长和特征。

样本。这些是数据中的行。

步长。这些是对某个特性(如延迟变量)的过去观察。

特性。这些是数据中的列。

假设您的数据作为一个NumPy数组加载,您可以使用NumPy中的重塑()函数将2D数据集转换为3D数据集。如果您希望列成为某个特性的时间步骤,可以使用:

data = data.reshape((data.shape[0], data.shape[1], 1))

如果您希望2D数据中的列使用单步长,可以使用:

data = data.reshape((data.shape[0], 1, data.shape[1]))

您可以指定input_shape参数,该参数期望一个包含时间步长和特征数量的元组。例如,对于单变量时间序列,如果我们有两个时间步骤和一个特征,每一行有两个滞后观测值,它将被指定如下:

model = Sequential()
model.add(LSTM(5, input_shape=(2,1)))
model.add(Dense(1))

可以通过将LSTM层添加到顺序模型中来堆叠它们。重要的是,在叠加LSTM层时,我们必须为每个输入输出一个序列,而不是单个值,以便后续的LSTM层可以拥有所需的3D输入。我们可以通过将return_sequences参数设置为True来实现这一点。例如:

model = Sequential()
model.add(LSTM(5, input_shape=(2,1), return_sequences=True))
model.add(LSTM(5))
model.add(Dense(1))

可以将顺序模型看作是一个管道,在管道的末端输入原始数据,在管道的另一端输出预测。在Keras中,这是一个有用的容器,传统上与层关联的关注点也可以分离出来,并作为单独的层添加,清楚地显示它们在从输入到预测的数据转换中的角色。

例如,可以提取从层中的每个神经元转换求和信号的激活函数,并将其作为类层对象(称为激活)添加到序列中。

model = Sequential()
model.add(LSTM(5, input_shape=(2,1)))
model.add(Dense(1))
model.add(Activation('sigmoid'))

激活函数的选择对于输出层来说是最重要的,因为它将定义预测所采用的格式。

例如,下面是一些常见的预测建模问题类型,以及可以在输出层使用的结构和标准激活函数:

回归:线性激活函数,或“线性”,以及匹配输出数量的神经元数量。

二元分类(2类):逻辑激活函数,或称“乙状结肠”,输出层有一个神经元。

多类分类(>2类):Softmax激活函数,或“Softmax”,每个类值有一个输出神经元,假设有一个热编码输出模式。

步骤2:编译网络

一旦定义了网络,就必须编译它。编译是一个有效的步骤。它将我们定义的简单层序列转换为一组高效的矩阵转换,其格式旨在在GPU或CPU上执行,具体取决于如何配置Keras。可以将编译看作是网络的预计算步骤。定义模型之后总是需要它。编译需要指定许多参数,这些参数是专门为训练您的网络而定制的。具体来说,用优化算法来训练网络,用损失函数来评价被优化算法最小化的网络。例如,下面是一个编译已定义模型并指定随机梯度下降(sgd)优化算法和均值平方误差(mean_squared_error)损失函数的例子,这是一个回归类型问题。

model.compile(optimizer='sgd', loss='mean_squared_error')

另外,优化器可以在作为编译步骤的参数提供之前创建和配置。

algorithm = SGD(lr=0.1, momentum=0.3)
model.compile(optimizer=algorithm, loss='mean_squared_error')

预测建模问题的类型对可以使用的损失函数的类型施加了约束。

例如,以下是针对不同预测模型类型的一些标准损失函数:

 回归:平均平方误差或“平均平方误差”。

二元分类(2类):对数损失,也称为交叉熵或“binary_crossentropy”。

多类分类(>2类):多类对数损失或“categorical_crossentropy”。

最常见的优化算法是随机梯度下降,但Keras还支持一组其他最先进的优化算法,这些算法在配置很少或没有配置的情况下工作良好。也许最常用的优化算法,因为他们通常更好的性能是:

 随机梯度下降法(sgd):需要调整学习速率和动量。

ADAM:需要调整学习速度。

RMSprop:需要调整学习速率。

最后,除了损失函数之外,您还可以指定在拟合模型时要收集的指标。通常,需要收集的最有用的附加度量是分类问题的准确性。要收集的指标由数组中的名称指定。

例如:

model.compile(optimizer='sgd', loss='mean_squared_error', metrics=['accuracy'])

步骤3:拟合网络

一旦网络被编译,它就可以被拟合,这意味着调整训练数据集上的权重。拟合网络需要指定训练数据,输入模式矩阵X和匹配输出模式数组y。利用反向传播算法对网络进行训练,并根据优化算法和模型编译时指定的损失函数进行优化。反向传播算法要求对网络进行特定次数的训练。每个epoch可以划分为一组称为批的输入-输出模式对。这定义了网络在一个epoch内更新权重之前所暴露的模式的数量。它也是一种效率优化,确保一次不会有太多的输入模式加载到内存中。

拟合网络的最小例子如下:

history = model.fit(X, y, batch_size=10, epochs=100)

一旦拟合,将返回一个history对象,该对象提供训练期间模型性能的摘要。这既包括损失,也包括在编译模型时指定的任何额外指标,记录每个epoch。训练可能需要很长时间,根据网络的大小和训练数据的大小,从几秒钟到几小时到几天不等。默认情况下,命令行上会显示每个纪元的进度条。这可能会给您带来太多的噪音,或者可能会给您的环境带来问题,例如,如果您在交互式笔记本或IDE中。通过将详细参数设置为2,可以将显示的信息量减少到每个epoch的损失。您可以通过将verbose设置为1关闭所有输出。

例如:

history = model.fit(X, y, batch_size=10, epochs=100, verbose=1)

步骤4:评估网络

一旦网络被训练,它就可以被评估。可以根据训练数据对网络进行评估,但这不能作为预测模型提供网络性能的有用指示,因为以前已经看到了所有这些数据。我们可以在一个单独的数据集上评估网络的性能,在测试期间是看不到的。这将提供网络在预测未来不可见数据方面的性能评估。模型评估跨所有测试模式的损失,以及在编译模型时指定的任何其他度量,如分类精度。返回一个评估指标列表。

例如,对于使用精度度量标准编译的模型,我们可以在新的数据集上对其进行如下评估:

损失,精度=模型。评估(X, y)

loss, accuracy = model.evaluate(X, y)

通过对网络的拟合,给出了详细的输出,给出了评价模型的进展情况。我们可以通过将详细参数设置为0来关闭它。

loss, accuracy = model.evaluate(X, y, verbose=0)

  第5步:作出预测

一旦我们对fit模型的性能感到满意,就可以使用它对新数据进行预测。这与在模型上使用新的输入模式数组调用predict()函数一样简单。
例如:

predictions = model.predict(X)

预测将以网络输出层提供的格式返回。在回归问题中,这些预测可能直接以问题的形式出现,由线性激活函数提供。对于二元分类问题,预测可能是第一类的概率数组,可以通过四舍五入将其转换为1或0。对于多类分类问题,结果可能以概率数组的形式出现(假设一个热编码输出变量),可能需要使用argmax() NumPy函数将其转换为单个类输出预测。另外,对于分类问题,我们可以使用predict_classes()函数,该函数将自动将不清晰的预测转换为清晰的整数类值。

predictions = model.predict_classes(X)

通过对网络进行拟合和评价,给出了详细的输出,以了解模型预测的进展情况。我们可以通过将详细参数设置为0来关闭它。

predictions = model.predict(X, verbose=0)

端到端工作的例子

让我们用一个简单的例子把所有这些联系起来。本例将使用一个简单的问题来学习10个数字序列。我们将向网络显示一个数字,如0.0,并期望它预测0.1。然后显示它0.1,并期望它预测0.2,以此类推到0.9。

定义网络:我们将构建一个LSTM神经网络,在可见层有1个输入时间步和1个输入特征,在LSTM隐层有10个记忆单元,在完全连接的输出层有1个神经元,具有线性(默认)激活函数。

编译网络:由于是一个回归问题,我们将使用具有默认配置和均方误差损失函数的高效ADAM优化算法。

拟合网络:拟合网络1000个时点,使用与训练集中模式数量相等的批处理大小,关闭所有冗余输出。

评估网络。我们将在训练数据集上对网络进行评估。通常我们会在测试或验证集上评估模型。

作出预测。我们将对训练输入数据进行预测。通常我们会对不知道正确答案的数据进行预测。

下面提供了完整的代码清单:

# Example of LSTM to learn a sequence
from pandas import DataFrame
from pandas import concat
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
# create sequence
length = 10
sequence = [i/float(length) for i in range(length)]
print(sequence)
# create X/y pairs
df = DataFrame(sequence)
df = concat([df.shift(1), df], axis=1)
df.dropna(inplace=True)
# convert to LSTM friendly format
values = df.values
X, y = values[:, 0], values[:, 1]
X = X.reshape(len(X), 1, 1)
# 1. define network
model = Sequential()
model.add(LSTM(10, input_shape=(1,1)))
model.add(Dense(1))
# 2. compile network
model.compile(optimizer='adam', loss='mean_squared_error')
# 3. fit network
history = model.fit(X, y, epochs=1000, batch_size=len(X), verbose=0)
# 4. evaluate network
loss = model.evaluate(X, y, verbose=0)
print(loss)
# 5. make predictions
predictions = model.predict(X, verbose=0)
print(predictions[:, 0])

运行此示例将生成以下输出,显示10个数字的原始输入序列、预测整个序列时网络的平均平方误差损失以及每个输入模式的预测。为了便于阅读,输出被隔开了。我们可以看到,这个数列学得很好,尤其是如果我们把预测四舍五入到小数点前一位。

[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]4.54527471447e-05[ 0.11612834 0.20493418 0.29793766 0.39445466 0.49376178 0.59512401
0.69782174 0.80117452 0.90455914]

Keras中长短期记忆网络LSTM的5步生命周期相关推荐

  1. 1014长短期记忆网络(LSTM)

    长短期记忆网络(LSTM) 长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这个问题最早的方法之一就是 LSTM 发明于90年代 使用的效果和 GRU 相差不大,但是使用的东西更加复杂 ...

  2. 基于长短期记忆网络(LSTM)对股票价格的涨跌幅度进行预测

    完整代码:https://download.csdn.net/download/qq_38735017/87536579 为对股票价格的涨跌幅度进行预测,本文使用了基于长短期记忆网络(LSTM)的方法 ...

  3. 简单介绍长短期记忆网络 - LSTM

    文章目录 一.引言 1.1 什么是LSTM 二.循环神经网络RNN 2.1 为什么需要RNN 三.长短时记忆神经网络LSTM 3.1 为什么需要LSTM 3.2 LSTM结构分析 3.3 LSTM背后 ...

  4. MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测

    MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测 摘要 近些年,随着计算机技术的不断发展,神 ...

  5. 『NLP学习笔记』长短期记忆网络LSTM介绍

    长短期记忆网络LSTM介绍 文章目录 一. 循环神经网络 二. 长期依赖问题 三. LSTM 网络 四. LSTM 背后的核心理念 4.1 忘记门 4.2 输入门 4.3 输出门 五. LSTM总结( ...

  6. [react] react中发起网络请求应该在哪个生命周期中进行?为什么?

    [react] react中发起网络请求应该在哪个生命周期中进行?为什么? 异步情况可以在componentDidMount()函数中进行. 同步的情况可以在componentWillMount()中 ...

  7. keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测

    一.概述 传统循环网络RNN可以通过记忆体实现短期记忆进行连续数据的预测,但是,当连续数据的序列边长时,会使展开时间步过长,在反向传播更新参数的过程中,梯度要按时间步连续相乘,会导致梯度消失或者梯度爆 ...

  8. 长短期记忆网络LSTM

    1. LSTM是循环神经网络的一个变体可以有效的解决简单循环神经网络的梯度消失和梯度爆炸的问题. 2. 改进方面: 新的内部状态 Ct专门进行线性的循环信息传递,同时(非线性的)输出信息给隐藏层的外部 ...

  9. 长短期记忆网络 LSTM

    这里写目录标题 1. LSTM介绍 1.1 什么是LSTM 1.2 LSTM相较于RNN的优势 1.3 LSTM的结构图 1.3.1 LSTM的核心思想 1.3.2 LSTM的遗忘门 1.3.3 LS ...

  10. 【个人整理】长短是记忆网络LSTM的原理以及缺点

    前言:普通的循环神经网络RNN是很难训练的,这导致了它在实际应用中,很难处理长距离的依赖.在本文中,我们将介绍一种改进之后的循环神经网络:长短时记忆网络(Long Short Term Memory ...

最新文章

  1. android 闹钟服务,如果闹钟时间已经过去,android可以防止即时触发闹钟服务
  2. 把 14 亿中国人民都拉到一个微信群里在技术上能实现吗?
  3. Erlang 位串和二进制数据
  4. 正则只能小于0负数_2019–2020学年七年级数学期末考试考点之正数与负数考点详解...
  5. python画条形图-python使用Matplotlib画条形图
  6. c语言中用分数表示结果,C语言实例 计算分数的精确值
  7. Mujoco-一阶单摆建模与控制
  8. C++ 学习记录(18) NVI
  9. SpringBoot resultful风格返回格式
  10. 从程序员到项目经理(8):程序员加油站 -- 不要死于直率
  11. 原生js实现购物车添加删除商品、计算价格功能
  12. Jenkins服务器磁盘空间管理策略
  13. 『XXG JS』JavaScript 数组 - 查找
  14. LWIP+ENC28J60长时间运行后无法访问外网服务器
  15. CSS 文本字体颜色设置方法。
  16. MRM:基于ISMRM研究与欧洲痴呆研究动脉自旋灌注成像临床应用的补充建议
  17. 多多评价怎么显示第一个_拼多多商品质量分哪里看?怎么看评分?
  18. 华为云开发者日震撼来袭!11月20日,上海见
  19. 川轻化c语言实验答案,计算机二级c语言第4套笔试模拟试卷.doc
  20. Linux如何手动编译fcitx文件,linux下安装和配置fcitx中文输入法

热门文章

  1. NoSQL、memcached介绍、安装memcached、查看memcached状态
  2. 阿里启动“Buy+”计划,正式成立 VR 实验室
  3. linux-磁盘结构
  4. ELK 日志管理系统,初次尝试记录
  5. edx : Permission denied
  6. 使用命令行搜索你的java 库
  7. 由乱序播放说开了去-数组的打乱算法Fisher–Yates Shuffle
  8. 计算机科学中常见计量单位解析
  9. Silverlight 解谜游戏 之十二 游戏暗示(1)
  10. react 动态获取数据