论文名称:Temporal Fusion Transformers for interpretable multi-horizon time series forecasting
论文下载:https://www.sciencedirect.com/science/article/pii/S0169207021000637
论文年份:2021
论文被引:93(2022/05/01)
论文代码:https://github.com/greatwhiz/tft_tf2

Abstract

Multi-horizon forecasting often contains a complex mix of inputs – including static (i.e. time-invariant) covariates, known future inputs, and other exogenous time series that are only observed in the past – without any prior information on how they interact with the target. Several deep learning methods have been proposed, but they are typically ‘black-box’ models that do not shed light on how they use the full range of inputs present in practical scenarios. In this paper, we introduce the Temporal Fusion Transformer (TFT) – a novel attention-based architecture that combines high-performance multi-horizon forecasting with interpretable insights into temporal dynamics. To learn temporal relationships at different scales, TFT uses recurrent layers for local processing and interpretable self-attention layers for long-term dependencies. TFT utilizes specialized components to select relevant features and a series of gating layers to suppress unnecessary components, enabling high performance in a wide range of scenarios. On a variety of real-world datasets, we demonstrate significant performance improvements over existing benchmarks, and highlight three practical interpretability use cases of TFT.

多范围预测(Multi-horizon forecasting)通常包含复杂的输入组合——包括静态(即时间不变)协变量已知的未来输入以及仅在过去观察到的其他外生时间序列——没有任何关于它们如何与目标交互的先验信息。已经提出了几种深度学习方法,但它们通常是“黑盒”模型,没有阐明它们如何使用实际场景中存在的全部输入。在本文中,我们介绍了时间融合 Transformer(Temporal Fusion Transformer,TFT)——一种新颖的基于注意力的架构,它结合了高性能多范围预测和对时间动态的可解释洞察力。为了学习不同尺度的时间关系,TFT 使用循环层进行局部处理,使用可解释的自注意力层进行长期依赖。 TFT利用专门的组件来选择相关的特征,并利用一系列的门控层来抑制不必要的组件,从而在广泛的场景中实现高性能。在各种现实世界的数据集上,我们展示了比现有基准测试显着的性能改进,并强调了 TFT 的三个实际可解释性用例。

1. Introduction

多范围预测,即在多个未来时间步长的感兴趣变量的预测,是时间序列机器学习中的一个关键问题。与一步一步的预测相比,多范围预测为用户提供了对整个路径的估计的访问权限,使他们能够在未来的多个步骤中优化他们的行动(例如,零售商优化整个即将到来的季节的库存,或临床医生优化患者的治疗计划)。多范围预测在零售(Böse et al., 2017; Courty & Li, 1999),医疗保健(Lim, Alaa, & van der Schaar, 2018; Zhang & Nawata, 2018)和经济学(Capistran, Constandse, & RamosFrancia, 2010)——在此类应用中对现有方法的性能改进非常有价值。

实际的多范围预测应用程序通常可以访问各种数据源,如图 1 所示,包括

  • 有关未来的已知信息(例如即将到来的假期日期)
  • 其他外生时间序列(exogenous time series,例如历史客户客流量)
  • 静态元数据(static metadata,例如商店的位置)

没有任何关于它们如何交互的先验知识数据源的这种异质性以及关于它们交互的信息很少,使得多范围时间序列预测特别具有挑战性

【已有研究存在的问题】

深度神经网络 (DNN) 越来越多地用于多范围预测,与传统时间序列模型相比表现出强大的性能改进(Alaa & van der Schaar, 2019; Makridakis, Spiliotis, & Assimakopoulos, 2020; Rangapuram et al., 2018)。虽然许多架构都专注于循环神经网络 (RNN) 架构的变体(Rangapuram et al., 2018; Salinas, Flunkert, Gasthaus, & Januschowski, 2019; Wen et al., 2017),但最近的改进也使用了基于注意力的增强过去相关时间步长选择的方法(Fan et al., 2019)——包括基于Transformer的模型(Li et al., 2019)。然而,这些通常没有考虑多范围预测中常见的不同类型的输入,并且要么假设所有外生输入都是未来已知的(Li et al., 2019; Rangapuram et al., 2018; Salinas et al., 2018)——自回归模型的一个常见问题——或忽略重要的静态协变量 (Wen et al., 2017)——它们在每一步都简单地与其他时间相关特征连接。最近时间序列模型的许多改进都源于架构与独特数据特征的一致性(Koutník, Greff, Gomez, & Schmidhuber, 2014; Neil et al., 2016)。我们论证并证明,通过设计具有适合多范围预测的归纳偏差的网络,也可以获得类似的性能提升

除了不考虑常见多范围预测输入的异质性(heterogeneity)之外,大多数当前架构都是“黑盒”模型,其中预测由许多参数之间的复杂非线性相互作用控制。这使得很难解释模型是如何得出预测结果的,反过来又使用户难以信任模型的输出和模型构建者对其进行调试。不幸的是,DNN 常用的可解释性方法不太适合应用于时间序列。在传统形式中,事后方法(例如 LIME (Ribeiro et al., 2016) 和 SHAP (Lundberg & Lee, 2017))不考虑输入特征的时间顺序例如,对于 LIME,代理模型是为每个数据点独立构建的,而对于 SHAP,特征是针对相邻时间步独立考虑的。这种事后方法会导致解释质量差,因为时间步长之间的依赖关系在时间序列中通常很重要。另一方面,提出了一些基于注意力的架构,它们对顺序数据(主要是语言或语音)具有固有的可解释性,例如 Transformer 架构。应用它们的基本警告是,多范围预测包括许多不同类型的输入特征,而不是语言或语音。在它们的传统形式中,这些架构可以提供对多范围预测的相关时间步长的洞察,但它们无法区分给定时间步长不同特征的重要性。总体而言,除了需要新的方法来解决多维度预测中数据的异质性以实现高性能之外,还需要新的方法来使这些预测具有可解释性,考虑到用例的需要。

【本文的解决方案】

在本文中,我们提出了时间融合Transformer(Temporal Fusion Transformer,TFT)——一种基于注意力的 DNN 架构,用于多范围预测,在实现高性能的同时实现新的可解释性形式。为了在状态基准测试上获得显着的性能改进,我们引入了多种新颖的想法,以使架构与多范围预测常见的全部潜在输入和时间关系保持一致——特别是结合

  • 1)编码上下文的静态协变量编码器用于网络其他部分的向量
  • 2)贯穿始终的门控机制和样本相关变量选择,以最小化不相关输入的贡献
  • 3)用于局部处理已知和观察到的输入的序列到序列层
  • 4)一个时间自注意力解码器,用于学习数据集中存在的任何长期依赖关系

使用这些专门的组件也有助于解释性;特别是,我们展示了 TFT 支持三个有价值的可解释性用例:帮助用户识别 (i) 预测问题的全局重要变量,(ii) 持久的时间模式,以及 (iii) 重要事件。在各种现实世界的数据集上,我们展示了如何实际应用 TFT,以及它提供的见解和好处。

2. Related work

用于多范围预测的 DNN:与传统的多范围预测方法类似(Marcellino, Stock, & Watson, 2006; Taieb, Sorjamaa, & Bontempi, 2010),最近的深度学习方法可以分类为使用自回归模型的迭代方法(Li et al., 2019; Rangapuram et al., 2018; Salinas et al., 2019)或基于序列到序列模型的直接方法(Fan et al., 2019; Wen et al., 2017)。

迭代方法利用单步预测模型,通过将预测递归地输入到未来的输入中来获得多步预测。已经考虑了具有LSTM 网络的方法,例如

  • Deep AR (Salinas et al., 2019),它使用堆叠的 LSTM 层来生成单步高斯预测的参数分布

  • 深度状态空间模型 (DSSM) (Rangapuram et al., 2018) 采用了类似的方法,利用 LSTM 生成预定义线性状态空间模型的参数,该模型具有通过卡尔曼滤波产生的预测分布——扩展了多变量时间序列数据 Wang et al. (2019)。

最近,Li et al. (2019) 探索了基于 Transformer 的架构,它提出了使用卷积层进行局部处理和稀疏注意机制来增加预测过程中感受野的大小。尽管它们很简单,但迭代方法依赖于这样一个假设,即除目标之外的所有变量的值在预测时都是已知的——因此只有目标需要递归地输入到未来的输入中。然而,在许多实际场景中,存在许多有用的时变输入,其中许多是预先未知的。因此,它们的直接使用仅限于迭代方法。另一方面,TFT 明确考虑了输入的多样性——自然地处理静态协变量和(过去观察到的和未来已知的)时变输入

相比之下,直接方法被训练以在每个时间步显式生成多个预定义范围的预测。他们的架构通常依赖于序列到序列模型,例如LSTM 编码器用于总结过去的输入,以及各种生成未来预测的方法

  • Multi-horizon Quantile Recurrent Forecaster (MQRNN) (Wen et al., 2017) 使用 LSTM 或卷积编码器生成上下文向量,这些上下文向量被馈送到每个层的多层感知器 (MLP)。

  • Fan et al. (2019) 多模态注意机制与 LSTM 编码器一起使用,为双向 LSTM 解码器构建上下文向量。尽管性能优于基于 LSTM 的迭代方法,但对于此类标准直接方法而言,可解释性仍然具有挑战性。

相比之下,我们表明,通过解释注意力模式,TFT 可以提供关于时间动态的有见地的解释,并在这样做的同时保持各种数据集上的最先进性能

带注意的时间序列可解释性:注意机制用于翻译(Vaswani 等,2017)、图像分类(Wang、Jiang、Qian、Yang、Li、Zhang、Wang 和 Tang,2017)或表格学习(Arik & Pfister , 2019),使用注意力权重的大小来识别每个实例的输入的显着部分。最近,它们已经被应用到具有可解释性的时间序列(Alaa & van der Schaar,2019;Choi 等人,2016;Li 等人,2019),使用基于 LSTM(Song 等人,2018)和基于Transformer (Li et al., 2019) 的架构。然而,这是在没有考虑静态协变量的重要性的情况下完成的(因为上述方法在每个输入处混合变量)TFT 通过在自注意力之上的每个时间步对静态特征使用单独的编码器-解码器注意力来确定时变输入的贡献,从而缓解了这一问题

DNN 的实例变量重要性实例(即样本)变量重要性可以通过事后解释方法(Lundberg & Lee, 2017; Ribeiro et al., 2016; Yoon, Arik, & Pfister, 2019)和固有的可解释模型(Choi 等人,2016 年;Guo、Lin 和 Antulov-Fantulin,2019 年)获得。

  • 事后解释方法,例如LIME (Ribeiro et al., 2016)、SHAP (Lundberg & Lee, 2017) 和 RL-LIM (Yoon et al., 2019) 应用于预训练的黑盒模型,通常基于提炼成可解释的代理模型,或分解为特征属性。它们的设计并未考虑输入的时间顺序,从而限制了它们对复杂时间序列数据的使用

  • 固有的可解释建模方法将用于特征选择的组件直接构建到架构中。特别是对于时间序列预测,它们基于明确量化时间相关变量的贡献。例如,Interpretable Multi-Variable LSTM (Guo et al., 2019) 对隐藏状态进行分区,使每个变量对其自己的内存段做出唯一贡献,并对内存段进行加权以确定变量的贡献。 Choi 等人 (2016) 也考虑了结合时间重要性和变量选择的方法,它根据每个人的注意力权重计算单个贡献系数。然而,除了仅建模一步超前预测的缺点之外,现有方法还专注于注意力权重的实例特定(即样本特定)解释,而没有提供对全局时间动态的洞察

相比之下,第 7 节中的用例表明 TFT 能够分析全局时间关系并允许用户解释模型在整个数据集上的全局行为——特别是在识别任何持久模式(例如季节性或滞后效应)和目前的制度(regime)

注:查阅资料可知,regime 亦即 state (状态)

3. Multi-horizon forecasting

让给定的时间序列数据集中有唯一的实体(entities)——例如零售中的不同商店或医疗保健中的患者。每个实体 iii 与一组静态协变量 si∈Rmss_i ∈ \R^{m_s}si​∈Rms​ 以及输入 χi,t∈Rmχχ_{i,t} ∈ R^{m_χ}χi,t​∈Rmχ​ 和标量目标 yi,t∈Ry_{i,t} ∈ \Ryi,t​∈R 在每个时间步 t∈[0,Ti]t ∈ [0, T_i]t∈[0,Ti​] 相关联。与时间相关的输入特征被细分为两类 χi,t=[zi,tT,xi,tT]Tχ_{i,t} = [z^T_{i,t}, x^T_{i,t}]^Tχi,t​=[zi,tT​,xi,tT​]T ——观察到的输入 zi,t∈R(mz)z_{i,t} ∈ \R^{(m_z)}zi,t​∈R(mz​) 只能在每一步测量并且事先未知,已知输入 χi,t∈Rmχχ_{i,t} ∈ R^{m_χ}χi,t​∈Rmχ​ 可以预先确定(例如,时间 ttt 的星期几)。

在许多情况下,通过提供目标可以采用的可能的最佳和最坏情况值的指示,预测间隔的规定可用于优化决策和风险管理。因此, 我们对多范围预测设置采用分位数回归(例如,在每个时间步输出第 10、第 50 和第 90 个百分位数)。每个分位数预测采用以下形式

与其他直接方法一致,我们同时输出 τmaxτ_{max}τmax​ 时间步长的预测——即 τ∈{1,...,τmax}τ ∈ \{1, . . . , τ_{max}\}τ∈{1,...,τmax​}。我们将所有过去的信息合并到一个有限的回顾窗口 kkk 中,仅使用

  • 直到并包括预测开始时间 t 的目标和已知输入,即 yi,t−k:t={yi,t−k,...,yi,t}y_{i,t−k:t} = \{y_{i,t−k}, ..., y_{i,t} \}yi,t−k:t​={yi,t−k​,...,yi,t​}

  • 整个范围内的已知输入,即 xi,t−k:t+τ={xi,t−k,...,xi,t,...,xi,t+τ}x_{i,t−k:t+τ} = \{x_{i,t−k}, . . ., x_{i,t}, . . . , x_{i,t+τ} \}xi,t−k:t+τ​={xi,t−k​,...,xi,t​,...,xi,t+τ​}

4. Model architecture

我们将 TFT 设计为使用规范组件来有效地为每种输入类型(即静态、已知、观察到的输入)构建特征表示,从而在广泛的问题上实现高预测性能。 TFT 的主要成分是:

  1. 门控机制:跳过架构中任何未使用的组件,提供自适应深度和网络复杂性以适应广泛的数据集和场景。
  2. 变量选择网络:在每个时间步选择相关的输入变量。
  3. 静态协变量编码器:将静态特征集成到网络中,通过上下文向量的编码来调节时间动态
  4. 时间处理:从观察到的和已知的随时间变化的输入中学习长期和短期时间关系。序列到序列层用于局部处理,而长期依赖关系使用新颖的可解释多头注意力块捕获
  5. 通过分位数预测的预测区间,以确定每个预测范围内可能的目标值的范围。

图 2 显示了 Temporal Fusion Transformer (TFT) 的高级架构,后续部分将详细介绍各个组件。

4.1. Gating mechanisms

外生输入和目标之间的精确关系通常是事先未知的,因此很难预测哪些变量是相关的。此外,很难确定所需非线性处理的程度,并且在某些情况下,更简单的模型可能是有益的——例如,当数据集很小或嘈杂时。为了使模型能够灵活地仅在需要时应用非线性处理,我们提出了如图 2 所示的门控残差网络 (GRN) 作为 TFT 的构建块GRN 接受一个主要输入 a 和一个可选的上下文向量 c 并产生

GLU 允许 TFT 控制 GRN 对原始输入 a 的贡献程度——如有必要,可能会完全跳过该层,因为 GLU 输出可能全部接近 0,以抑制非线性贡献。对于没有上下文向量的实例,GRN 只是将上下文输入视为零——即等式 (4) 中的 c = 0。在训练期间,在门控层和层归一化之前应用 dropout ——即到等式 (3) 中的 η1

4.2. Variable selection networks

虽然可能有多个变量可用,但它们的相关性和对输出的具体贡献通常是未知的TFT 旨在通过使用应用于静态协变量和时间相关协变量的变量选择网络来提供实例变量选择。除了提供对预测问题最重要的变量的见解之外,变量选择还允许 TFT 消除可能对性能产生负面影响的任何不必要的噪声输入。大多数现实世界的时间序列数据集包含预测内容较少的特征,因此变量选择可以通过仅利用最显着的学习能力来极大地帮助模型性能

我们将实体嵌入 (Gal & Ghahramani, 2016) 用于分类变量作为特征表示,并使用线性变换用于连续变量——将每个输入变量转换为 (dmodel) 维向量,该向量与后续层中的维度相匹配以进行跳过连接。所有静态、过去和未来的输入都使用具有不同权重的单独变量选择网络(如图 2 的主架构图中不同颜色所示)。不失一般性,我们在下面展示了过去输入的变量选择网络——注意其他输入的变量选择网络采用相同的形式。

我们注意到每个变量都有自己的 GRNξ(j)GRN_{ξ(j)}GRNξ(j)​,权重在所有时间步 ttt 之间共享。然后通过其可变选择权重对处理后的特征进行加权并组合

4.3. Static covariate encoders

与其他时间序列预测架构相比,TFT 经过精心设计以整合来自静态元数据的信息,使用单独的 GRN 编码器生成四个不同的上下文向量 cs、ce、cc 和 ch。这些联系向量连接到时间融合解码器(第 4.5 节)中的各个位置,其中静态变量在处理中起重要作用。具体来说,这包括用于

  • 1)时间变量选择 (cs)
  • 2)时间特征 (cc, ch) 的局部处理
  • 3)使用静态信息 (ce) 丰富时间特征的上下文

例如,将 ζζζ 作为静态变量选择网络的输出,时间变量选择的上下文将根据 cs=GRNcs(ζ)c_s = GRN_{c_s}(ζ)cs​=GRNcs​​(ζ) 进行编码。

4.4. Interpretable multi-head attention

TFT 采用自注意机制来学习不同时间步长的长期关系,我们从基于 Transformer 的架构中的多头注意进行修改(Li et al., 2019; Vaswani et al., 2017)以增强可解释性。一般来说,注意力机制根据键 K∈RN×dattnK ∈ \R^{N×d_{attn}}K∈RN×dattn​ 和查询 Q∈RN×dattnQ ∈ \R^{N×d_{attn}}Q∈RN×dattn​ 之间的关系来缩放 V∈RN×dVV ∈ \R^{N×d_V}V∈RN×dV​,如下所示:

其中 A() 是归一化函数,N 是输入注意力层的时间步数(即 k + τmax)。一个常见的选择是缩放点积注意力(Vaswani et al., 2017):

为了提高标准注意力机制的学习能力,Vaswani et al., (2017) 提出了多头注意力机制,为不同的表示子空间采用不同的头

鉴于每个头部使用不同的值,单独的注意力权重并不能表示特定特征的重要性。因此,我们修改多头注意力以在每个头中共享值,并采用所有头的加法聚合

比较方程式 (9) 和 (14),可以看到可解释的多头注意力的最终输出与单个注意力层非常相似——关键区别在于生成注意力权重 A(Q,K) 的方法。从方程式 (15),每个头部可以学习不同的时间模式 A (Q W (h) Q , K W (h) K ),同时关注一组共同的输入特征 V——这可以解释为注意力权重的简单集成到组合等式 (14) 中的矩阵 ∼A(Q , K )与等式 (10) 中的 A(Q , K) 相比,~A(Q , K) 以有效的方式产生增加的表示能力,同时仍然允许通过分析一组注意力权重来执行简单的可解释性研究

4.5. Temporal fusion decoder

时间融合解码器使用下面描述的一系列层来学习数据集中存在的时间关系:

4.5.1. Locality enhancement with sequence-to-sequence layer

在时间序列数据中,重要点通常与其周围的值相关,例如异常、变化点或周期性模式。因此,通过在逐点值之上构建利用模式信息的特征,利用局部上下文可以提高基于注意力的架构的性能。例如,Li et al. (2019) 采用单个卷积层进行局部增强——在所有时间都使用相同的滤波器提取局部模式。但是,由于过去和未来输入的数量不同,这可能不适用于存在观察到的输入的情况

因此,我们建议应用序列到序列层来自然处理这些差异——将 ∼ξt-k:t 馈入编码器,将 ∼ξt+1:t+τmax 馈入解码器然后生成一组统一的时间特征,作为时间融合解码器本身的输入,记为 φ(t, n) ∈ {φ(t, -k), . . . , φ(t, τmax)},其中 n 是位置索引。受其在规范序列编码问题上的成功启发,我们考虑使用 LSTM 编码器-解码器,这是其他多范围预测架构中常用的构建块(Fan et al., 2019; Wen et al., 2017),尽管也可以采用其他设计。这也可以替代标准位置编码,为输入的时间排序提供适当的归纳偏置。此外,为了允许静态元数据影响局部处理,我们使用来自静态协变量编码器的 cc, ch 上下文向量分别为层中的第一个 LSTM 初始化单元状态和隐藏状态。我们还在这一层上使用了一个门控跳过连接:

4.5.2. Static enrichment layer

由于静态协变量通常对时间动态有显着影响(例如疾病风险的遗传信息),我们引入了一个静态富集层,通过静态元数据增强时间特征。对于给定的位置索引 n,静态扩充采用以下形式:

其中 GRNφ 的权重在整个层中共享,并且 ce 是来自静态协变量编码器的上下文向量。

4.5.3. Temporal self-attention layer

在静态丰富之后,我们接下来应用自注意力层。所有静态丰富的时间特征首先被分组到一个矩阵中——即 Θ(t) = [θ(t, -k), . . ., θ(t, τ)]T ——并且可解释的多头注意力(参见第 4.4 节)应用于每个预测时间(N = τ_max + k + 1):

解码器掩码 (Li et al., 2019; Vaswani et al., 2017) 应用于多头注意力层,以确保每个时间维度只能关注其之前的特征。除了通过掩码保留因果信息流外,自注意力层还允许 TFT 获取可能对基于 RNN 的架构学习具有挑战性的长期依赖关系。在 self-attention 层之后,还应用了一个额外的门控层来促进训练:

4.5.4. Position-wise feed-forward layer

我们对自注意力层的输出应用额外的非线性处理。与静态丰富层类似,这利用了 GRN:

其中 GRNψ 的权重在整个层中共享。根据图 2,我们还应用了一个门控残差连接,它跳过了整个Transformer块,提供了到序列到序列层的直接路径——如果不需要额外的复杂性,则产生一个更简单的模型,如下所示:

4.6. Quantile outputs

根据之前的工作(Wen et al., 2017),TFT 还在点预测之上生成预测区间。这是通过在每个时间步同时预测各种百分位数(例如第 10、第 50 和第 90)来实现的。分位数预测是使用时间融合解码器的输出的线性变换生成的

5. Loss functions

TFT 通过联合最小化分位数损失进行训练(Wen et al., 2017),对所有分位数输出求和:

其中 Ω 是包含 M 个样本的训练数据域,W 表示 TFT 的权重,Q 是输出分位数的集合(我们在实验中使用 Q = {0.1, 0.5, 0.9},并且 (.)+ = max( 0, .)。对于样本外测试,我们评估整个预测范围内的归一化分位数损失——关注 P50 和 P90 风险以与之前的工作保持一致(Li et al., 2019; Rangapuram et al., 2018; Salinas et al., 2019):

其中〜Ω是测试样本的域。关于超参数优化和训练的完整细节可以在附录 A 中找到。

6. Performance evaluation

6.1. Datasets

我们选择数据集来反映广泛具有挑战性的多范围预测问题中普遍观察到的特征。为了建立与先前学术工作相关的基线和位置,我们

  • 首先评估 Li et al. (2019), Rangapuram et al. (2018), Salinas et al. (2019) 使用的电力和交通数据集的性能——专注于更简单的单变量时间序列,其中仅包含目标旁边的已知输入
  • 然后,零售数据集帮助我们使用在多范围预测应用程序中观察到的所有复杂输入(见第 3 节)对模型进行基准测试——包括丰富的静态元数据和观察到的随时间变化的输入。
  • 最后,为了评估在较小噪声数据集上过拟合的鲁棒性,我们考虑了波动率预测的金融应用——使用比其他数据集小得多的数据集

下面可以找到每个数据集的广泛描述,以及附录 B 中对数据集目标的探索性分析:

  • 电力:UCI 电力负荷图数据集,包含 370 名客户的电力消耗——按 Yu、Rao 和 Dhillon(2016 年)按小时汇总。根据 (Salinas et al., 2019),我们使用过去一周(即 168 小时)来预测未来 24 小时
  • 交通:UCI PEM-SF 交通数据集描述了旧金山湾区 440 条高速公路的占用率(yt ∈ [0, 1])——如 Yu 等人所述。 (2016 年)。它还根据电力数据集按小时汇总,具有相同的回顾窗口和预测范围。
  • 零售:来自 Kaggle 竞赛(Favorita,2018 年)的 Favorita 杂货销售数据集,该数据集结合了不同产品和商店的元数据,以及在日常级别采样的其他外生时变输入。我们使用过去 90 天的信息预测未来 30 天的原木产品销售。
  • 波动率(或波动率):OMI 已实现库(Heber、Lunde、Shephard 和 Sheppard,2009 年)包含根据日内数据计算得出的 31 个股票指数的每日已实现波动率值及其每日收益。对于我们的实验,我们使用过去一年(即 252 个工作日)的信息来考虑对下周(即 5 个工作日)的预测

6.2. Training procedure

对于每个数据集,我们将所有时间序列划分为 3 个部分——用于学习的训练集、用于超参数调整的验证集和用于性能评估的保留测试集。超参数优化是通过随机搜索进行的,对 Volatility 使用 240 次迭代,对其他使用 60 次迭代。所有超参数的完整搜索范围如下,数据集和最佳模型参数列于表 1。

  • State size – 10, 20, 40, 80, 160, 240, 320
  • Dropout rate – 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9
  • Minibatch size – 64, 128, 256
  • Learning rate – 0.0001, 0.001, 0.01
  • Max. gradient norm – 0.01, 1.0, 100.0
  • Num. heads – 1, 4

为了保持可解释性,我们只采用了一个可解释的多头注意力层。对于 ConvTrans (Li et al., 2019),我们使用与其相同的固定堆栈大小(3 层)和头数(8 个头)。我们保持相同的注意力模型,并将卷积处理层的内核大小视为超参数(∈ {1, 3, 6, 9})——因为观察到最佳内核大小取决于数据集(Li et al., 2019)。可以在 GitHub 上找到这些数据集上 TFT 的开源实现,以获得完全的可重复性。

6.3. Computational cost

在所有数据集中,每个 TFT 模型也在单个 GPU 上进行了训练,并且可以在不需要大量计算资源的情况下进行部署。例如,使用 NVIDIA Tesla V100 GPU,我们的最佳 TFT 模型(用于 Electricity 数据集)只需 6 小时多一点的时间来训练(每个 epoch 大约为 52 分钟)。对整个验证数据集(由 50,000 个样本组成)的批量推理需要 8 分钟。通过特定于硬件的优化,可以进一步减少 TFT 训练和推理时间。

6.4. Benchmarks

基于第 2 节中描述的类别,我们将 TFT 与广泛的多范围预测模型进行了广泛的比较。超参数优化是在预定义的搜索空间上使用随机搜索进行的,在所有基准测试中使用相同数量的迭代给定的数据集。其他详细信息包含在附录 A 中。


直接方法:由于 TFT 属于此类多范围模型,我们主要将比较重点放在直接生成未来范围预测的深度学习模型上,包括:

  • 1)具有全局上下文的简单序列到序列模型(Seq2Seq)
  • 2)Multihorizon Quantile Recurrent Forecaster (MQRNN) (Wen et al., 2017)

此外,我们还包括两个简单的直接基准来评估深度学习模型的优势:

  • 1)多层感知器 (MLP)
  • 2)具有 L2 正则化 (Ridge) 的线性分位数回归

对于 MLP,我们使用单个两层神经网络,它获取每个时间步长的所有可用信息(即 {yt−k:t, zt−k:t, xt−k:t+τ , s}),并预测预测范围内的所有分位数(即 ^y(q, t, τ ) ∀q ∈ {0.1, 0.5, 0.9} 和 τ ∈ {1, . . . , τmax})。对于 Ridge,我们为每个范围/分位数输出使用一组单独的线性系数,输入与 MLP 相同的输入。鉴于我们数据集的大小,我们还使用随机梯度下降训练 Ridge。


迭代方法:为了定位迭代模型上的大量工作,我们使用与 (Salinas et al., 2019) 中电力和交通数据集相同的设置来评估 TFT。这扩展了 (Li et al., 2019) 的结果:

  • 1)DeepAR (Salinas et al., 2019)

  • 2)DSSM (Rangapuram et al., 2018)

  • 3)具有局部卷积处理的 Transformer 架构——称为 ConvTrans (Li et al., 2019) 。

对于更复杂的数据集,我们专注于 ConvTrans 模型,因为它在之前的工作中优于其他迭代模型,而 DeepAR 由于它在从业者中很受欢迎。由于此类模型需要了解未来的所有输入以生成预测,因此我们通过将未知输入与它们的最后可用值进行插补来适应复杂数据集


对于更简单的单变量数据集,我们注意到 ARIMA、ETS、TRMF、DeepAR、DSSM 和 ConvTrans 的结果已从 (Li et al., 2019) 中的表 2 复制,以保持一致性

6.5. Results and discussion

表 2 显示 TFT 显着优于第 6.1 节中描述的各种数据集的所有基准——展示了明确地将架构与一般多范围预测问题保持一致的好处。这适用于点预测和不确定性估计,与次佳模型相比,TFT 的 P50 损失平均降低 7%,P90 损失平均降低 9%。我们还在附录 C 中测试了 TFT 改进的统计显着性,这表明 TFT 损失显着低于具有 95% 置信度的下一个最佳基准。此外,附录 E 中还提供了对 TFT 可信区间的更定性评估,以供参考。

比较直接模型和迭代模型,我们观察到考虑观察到的输入的重要性——注意到在需要观察输入插补的复杂数据集(即波动性和零售)上,ConvTrans 的结果较差。此外,当高斯分布不能很好地捕获目标时,也可以观察到分位数回归的好处,而直接模型在这些场景中表现出色。例如,从目标分布明显偏斜的交通数据集中可以看出这一点——超过 90% 的入住率介于 0 和 0.1 之间,其余的均匀分布直到 1.0。

6.6. Ablation analysis

为了量化我们提出的每个架构贡献的好处,我们进行了广泛的消融分析——从网络中删除每个组件,如下所示,并量化损失相对于原始架构的百分比增加

  • 门控层:我们通过将每个 GLU 层(等式 (5))替换为简单的线性层,然后是 ELU 来消融。
  • 静态协变量编码器:我们通过将所有上下文向量设置为零(即 cs=ce=cc=ch=0)并将所有转换后的静态输入连接到所有与时间相关的过去和未来输入来消融。
  • 实例变量选择网络:我们通过替换等式 (6) 的softmax 输出来消融。具有可训练的系数,并删除生成可变选择权重的网络。我们保留,然而,可变 GRN(见公式(7))保持了相似数量的非线性处理。
  • 自注意力层:我们通过用可训练参数矩阵 W_A 替换可解释的多头注意力层(等式(14))的注意力矩阵来消融——即 ∼A(Q, K) = W_A,其中 W_A ∈ RN×N。这可以防止 TFT 在不同时间关注不同的输入特征,有助于评估实例注意力权重的重要性。
  • 用于局部处理的序列到序列层:我们通过用 Vaswani 等人(2017)使用的标准位置编码替换第 4.5.1 节的序列到序列层来消融。

使用表 1 的超参数对每个数据集进行消融网络训练。图 3 显示,所有数据集对 P50 和 P90 损失的影响相似,所有组件都有助于整体性能提升

一般来说,负责捕获时间关系、局部处理和自注意力层的组件对性能的影响最大,当消融时,P90 损失平均增加 > 6%,在选定数据集上增加 > 20%。时间序列数据集的多样性也可以从各个时间分量的消融影响的差异中看出。具体来说,虽然局部处理在交通、零售和波动性方面至关重要,但消融后 P50 损失较低表明它可能对电力有害——自我注意层发挥着更重要的作用。一种可能的解释是,持续的每日季节性似乎主导了电力数据集中的其他时间关系。对于这个数据集,附录 D 的表 D.6 还显示,一天中的小时在所有时间输入中具有最大的变量重要性得分,甚至超过了目标(即用电量)本身。与过去目标观察更重要的其他数据集(例如交通)相比,直接关注前几天似乎有助于了解电力的日常季节性模式——相邻时间步之间的局部处理不太必要。我们可以通过将时间融合解码器中的序列到序列架构视为要调整的超参数来解释这一点,包括无需任何局部处理的简单位置编码选项。

静态协变量编码器和实例变量选择的影响次之——将 P90 损失平均增加 2.6% 和 4.1% 以上。在电力数据集上观察到这些最大的好处,其中一些输入特征的重要性非常低。

最后,门控层的消融分析也显示 P90 损失增加,平均增加 1.9%。这是波动性最显著的(P90 损失增加 4.1%),这是门控组件对更小、更嘈杂的数据集的好处的基础

7. Interpretability use cases

在确定了我们模型的性能优势之后,接下来我们将展示我们的模型设计如何允许分析其各个组件以解释它所学习的一般关系。我们展示三个可解释性用例:

  • 1)检查每个输入变量在预测中的重要性

  • 2)可视化持久的时间模式

  • 3)识别导致时间动态显着变化的任何状态或事件

与其他基于注意力的可解释性示例(Alaa & van der Schaar, 2019; Li et al., 2019; Song et al., 2018)相比,这些示例放大了有趣但特定于实例的示例,我们的方法侧重于聚合整个数据集的模式——提取关于时间动态的可概括的见解。

7.1. Analyzing variable importance


我们首先通过分析 4.2 节中描述的变量选择权重来量化变量重要性。具体来说,我们汇总了整个测试集中每个变量的选择权重(即等式(8)中的 v(j)χt),记录了每个抽样分布的第 10、50 和 90 个百分位数。由于零售数据集包含全套可用输入类型(即静态元数据、已知输入、观察到的输入和目标),我们在表 3 中展示了其变量重要性分析的结果。我们还注意到其他数据集中的类似发现,为完整起见,它们记录在附录 D.1 中。总体而言,结果表明 TFT 仅提取了在预测中直观地发挥重要作用的关键输入的子集。持续时间模式的分析通常是理解给定数据集中存在的时间相关关系的关键。例如,经常采用滞后模型来研究干预措施生效所需的时间长度(Du、Song、Han 和 Hong,2018 年)——例如政府增加公共支出对经济增长的影响国民生产总值(Baltagi,2008 年)。季节性模型也常用于计量经济学中,以识别感兴趣目标中的周期性模式(Hylleberg,1992)并测量每个周期的长度。从实际的角度来看,模型构建者可以使用这些见解来进一步改进预测模型——例如,如果在回溯窗口开始时观察到注意力峰值,则通过增加感受野来整合更多历史,或者通过工程特征直接整合季节性效果。因此,使用时间融合解码器的自注意力层中存在的注意力权重,我们提出了一种识别相似持久模式的方法——通过测量过去固定滞后的特征对不同视野预测的贡献。结合方程 (14) 和 (19),我们看到自注意力层在每个预测时间 t 包含一个注意力权重矩阵——即 ∼A(φ(t), φ(t))。每个预测范围 τ 的多头注意力输出(即 β(t, τ ))然后可以描述为每个位置 n 的较低级别特征的注意力加权和

由于解码器掩蔽,我们还注意到 α(t, i, j) = 0, ∀i > j。对于每个预测范围 τ ,前一个时间点 n < τ 的重要性因此可以通过分析所有时间步长和实体中的 α(t, n, τ ) 的分布来确定

7.2. Visualizing persistent temporal patterns

注意力权重模式可用于阐明 TFT 模型基于其决策的最重要的过去时间步长与其他依赖基于模型的季节性和滞后分析规范的传统和机器学习时间序列方法相比TFT 可以从原始训练数据中学习此类模式

图 4 显示了我们所有测试数据集的注意力权重模式 ——上图绘制了一步提前预测的注意力权重的平均值以及注意力权重的第 10、50 和 90 个百分位数(即 α(t, 1, τ )) 在测试集上,底部的图表绘制了不同视野的平均注意力权重(即 τ ∈ {5, 10, 15, 20})。我们观察到这三个数据集呈现出季节性模式,在电力和交通方面每天观察到明显的注意力峰值,而零售方面的每周模式稍弱。对于零售业,我们还观察到衰退趋势模式,过去几天占主导地位

波动性没有观察到强烈的持续模式——注意力权重平均分布在所有位置。这类似于特征级别的移动平均滤波器,并且——考虑到与波动过程相关的高度随机性——通过消除高频噪声可用于提取整个时期的趋势。

TFT 从原始训练数据中学习这些持久的时间模式,无需任何人工硬编码。预计这种能力对于通过健全性检查与人类专家建立信任非常有用。模型开发人员也可以将这些用于模型改进,例如通过特定的特征工程或数据收集。

7.3. Identifying regimes & significant events

识别时间模式的突然变化也非常有用,因为重要的制度或事件的存在可能会发生临时变化。例如,制度转换行为已在金融市场中得到广泛记录(Ang & Timmermann,2012 年),观察到回报特征(例如波动性)在不同制度之间突然变化。因此,识别此类制度变化可以深入了解潜在问题,这对于识别重大事件很有用。

首先,对于给定的实体,我们将每个预测范围的平均注意力模式定义为:

其中 ρ(p,q)=∑j√pjqjρ(p, q) = ∑_j √p_jq_jρ(p,q)=∑j​√pj​qj​ 是测量离散分布之间重叠的 Bhattacharya 系数 (Kailath, 1967)——其中 pj, qj 分别是概率向量 p, q 的元素。对于每个实体,然后使用具有平均模式的每个点的注意力向量之间的距离来测量时间动态的显着变化,针对所有范围汇总如下

使用波动率数据集,我们尝试通过将距离度量应用于标准普尔 500 指数在我们训练期间(2001 年至 2015 年)的注意力模式来分析制度。在图 5 的底部图表中针对目标(即对数实际波动率)绘制 dist(t),在高波动期(例如 2008 年金融危机)附近可以观察到注意力模式的显着偏差——对应于在图 5 中观察到的峰值距离(t)。从图中,我们可以看到 TFT 似乎改变了其在不同制度之间的行为——当波动性较低时,对过去的输入给予同等的关注,而在高波动性时期更多地关注急剧的趋势变化——这表明在每一种制度中学习到的时间动态存在差异案例

8. Conclusions

我们介绍了 TFT,这是一种新颖的基于注意力的深度学习模型,用于可解释的高性能多范围预测。为了在广泛的多范围预测数据集中有效地处理静态协变量、先验已知输入和观察到的输入,TFT 使用了专门的组件。具体来说,这些包括:(1) 序列到序列和基于注意力的时间处理组件,用于捕获不同时间尺度的时变关系,(2) 静态协变量编码器,允许网络根据静态元数据进行时间预测,(3 ) 允许跳过网络中不必要部分的门控组件,(4) 变量选择以在每个时间步选取相关的输入特征,以及 (5) 分位数预测,以获得所有预测范围内的输出间隔。在广泛的现实世界任务中——在仅包含已知输入的简单数据集和包含所有可能输入的复杂数据集上——我们展示了 TFT 实现了最先进的预测性能。最后,我们通过一系列可解释性用例研究 TFT 学习的一般关系——提出使用 TFT 的新方法(i)分析给定预测问题的重要变量,(ii)可视化学习的持久时间关系(例如季节性), (iii) 识别重大的制度变化。

【时序】TFT:具有可解释性的时间序列多步直接预测 Transformers相关推荐

  1. R语言时间序列(time series)分析实战:时序数据加载、绘制时间序列图

    R语言时间序列(time series)分析实战:时序数据加载.绘制时间序列图 目录

  2. bagging和时间序列预测_时间序列的LSTM模型预测——基于Keras

    一.问题背景     现实生活中,在一系列时间点上观测数据是司空见惯的活动,在农业.商业.气象军事和医疗等研究领域都包含大量的时间序列数据.时间序列的预测指的是基于序列的历史数据,以及可能对结果产生影 ...

  3. 深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测大气压( air pressure)+代码实战

    深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测大气压( air pressure)+代码实战 长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主 ...

  4. 深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测空气质量(PM2.5)+代码实战

    深度学习时间序列预测:LSTM算法构建时间序列单变量模型预测空气质量(PM2.5)+代码实战 # 导入需要的包和函数: from __future__ import print_function im ...

  5. 时间序列的分析和预测

    由于预测股票市场的未来股票价格对投资者至关重要,时间序列及其相关概念具有组织数据以进行准确预测的卓越品质.在本文中,让我们阅读时间序列的重要性.分析和预测. 在这里,涵盖的一些基本子主题是: 什么是时 ...

  6. bagging和时间序列预测_时间序列多步预测的五种策略

    通常,时间序列预测描述了预测下一个时间步长的观测值.这被称为"一步预测",因为仅要预测一个时间步.在一些时间序列问题中,必须预测多个时间步长.与单步预测相比,这些称为多步时间序列预 ...

  7. 【时间序列】-航空数据预测

    学习预测时序数据,如有侵权,请联系删除. 主要参考: A comprehensive beginner's guide to create a Time Series Forecast (with C ...

  8. 什么原数据更容易平稳_【时间序列】-航空数据预测

    ts_log_moving_avg_diff.dropna(inplace=True) test_stationarity(ts_log_moving_avg_diff) 这看起来像一个更好的系列.滚 ...

  9. 时间序列多步预测_使用LSTM深度学习模型进行温度的时间序列单步和多步预测...

    本文的目的是提供代码示例,并解释使用python和TensorFlow建模时间序列数据的思路. 本文展示了如何进行多步预测并在模型中使用多个特征. 本文的简单版本是,使用过去48小时的数据和对未来1小 ...

最新文章

  1. b-spline python_SciPyTutorial-一元B样条插值
  2. 序列表转换成横向菜单
  3. 机器学习(二十二)——推荐算法中的常用排序算法, Tri-training
  4. SAP Cloud for Customer upselling的前台实现
  5. 破五唯后,高校从“唯论文”变成了“唯纵向”?​
  6. python2.7虚拟环境
  7. 学三菱plc编程应该先学什么?
  8. android qq纯净输入法,QQ输入法纯净版更新 同步手机词库
  9. windows抓包工具——Fiddler配置及使用、手机抓包(iPhone、安卓)
  10. IDEA2019安装教程
  11. Markdown如何给图片添加图注
  12. 人工智能、机器学习、深度学习 三者关系
  13. 浮点数为什么不精确?
  14. 七.面向对象编程(中)
  15. Visionpro从小白到大佬,第一章了解工具名称和用途
  16. 福州三中 计算机竞赛,福建福州三中喜获信息学竞赛NOIP2020全省人数第1!总计35人获奖...
  17. R语言筛选两列中元素相同的重复数据
  18. 群晖挂载玩客云网络磁盘
  19. Scale-Equalizing Pyramid Convolution for Object Detection论文阅读
  20. 一个WPF和SL的严重BUG,能导致任何的寄主程序崩溃

热门文章

  1. 猿编程python_猿编程下载-猿编程客户端 v2.12.0.1103 官方版 - 安下载
  2. c语言中如果产量的隐藏类型是,如果随着产量的增加,生产函数首先表现出边际产量增加,然后表现出边际产量递减,那么相应的边际成本曲线将 答案:是U形的...
  3. Express Cookie的使用
  4. 刷入magisk无限重启_手机刷成砖了?别慌,这些方法可以救回来
  5. acm总结——多源BFS
  6. 计算机停车管理系统界面,智慧停车管理系统-智慧停车整体解决方案
  7. 高尔顿数据集和Anscombe四重奏数据集
  8. CEF 、chromium源码下载前相关代理配置
  9. YY游戏云的AngularJS实践
  10. SCARA四轴机器人eye-to-hand手眼标定(九点标定)