• Informer论文:https://arxiv.org/pdf/2012.07436.pdf

  • Informer源码:GitHub - zhouhaoyi/Informer2020: The GitHub repository for the paper "Informer" accepted by AAAI 2021.

  • Transformer笔记:《Attention Is All You Need》_郑烯烃快去学习的博客-CSDN博客

目录

0x01 Transformer存在的问题

0x02 Informer研究背景

0x03 Informer整体架构

(一)ProbSparse Self-attention

(二)Self-attention Distilling

(三)Generative Style Decoder

0x04 计算损失与迭代优化

0x05 源码阅读——站在巨人的肩膀上

(一)环境的搭建

(二)Informer文件框架

(三)数据的输入

(四)模型训练

(1)预处理

(2)encoder

(3)dropout

(4)Convlayer

(5)decoder


0x01 Transformer存在的问题

Informer实质是在Transformer的基础上进行改进,通过修改transformer的结构,提高transformer的速度。那么Transformer有什么样的缺点:

(1)self-attention的平方复杂度。self-attention的时间和空间复杂度是O(L^2),L为序列长度。

(2)对长输入进行堆叠(stack)时的内存瓶颈。多个encoder-decoder堆叠起来就会形成复杂的空间复杂度,这会限制模型接受较长的序列输入。

(3)预测长输出时速度骤降。对于Tansformer的输出,使用的是step-by-step推理得像RNN模型一样慢,并且动态解码还存在错误传递的问题。

0x02 Informer研究背景

论文的研究背景为:长序列预测问题。这些问题会出现在哪些地方呢:

  • 股票预测(数据、规则都在变,模型都是无法预测的)

  • 机器人动作的预测

  • 人体行为识别(视频前后帧的关系)

  • 气温的预测、疫情下的确诊人数

  • 流水线每一时刻的材料消耗,预测下一时刻原材料需要多少....

那么以上需要时间线来进行实现的,无疑会想到使用Transformer来解决这些问题,Transformer的最大特点就是利用了attention进行时序信息传递。每次进行一次信息传递,我们需要执行两次矩阵乘积,也就是QKV的计算。并且我们需要思考一下,我们每次所执行的attention计算所保留下来的值是否是真的有效的吗?我们有没有必要去计算这么多attention?

那么对于现在的时间预测可以大致分为下面三种:

  • 短序列预测

  • 趋势预测

  • 精准长序列预测

很多算法都是基于短序列进行预测的,先得知前一部分的数据,之后去预测短时间的情况。想要预测一个长序列,就不可以使用短预测,预测未来半年or一年,很难预测很准。长序列其实像是滑动窗口,不断地往后滑动,一步一步走,但是越滑越后的时候,他一直在使用预测好的值进行预测,长时间的序列预测是有难度的。

那么有那些时间序列的经典算法:

  • Prophnet:很实用的工具包,很适合预测趋势,但算的不精准。

  • Arima:短序列预测还算精准,但是趋势预测不准。多标签。

以上两种一旦涉及到了长序列,都不可以使用。

  • Informer中将主要致力于长序列问题的解决

可能在这里大家也会想到LSTM:但是这个模型在长序列预测中,如果序列越长,那速度肯定越慢,效果也越差。这个模型使用的为串行结构,效率很低,也会基于前面的特征来预测下一个特征,其损失函数的值也会越来越大。

那么我们Transformer中也有提及到改进LSTM的方法,其优势和问题在于:

(1)万能模型,可直接套用,代码实现简单。

(2)并行的,比LSTM快,全局信息丰富,注意力机制效果好。

(3)长序列中attention需要每一个点跟其他点计算,如果序列太长,其效率很低。

(4)Decoder输出很麻烦,都要基于上一个预测结果来推断当前的预测结果,这对于一个长序列的预测中最好是不要出现这样的情况。

那么Informer就需要解决如下的问题:

Transformer的缺点 Informer的改进
self-attention平方级的计算复杂度 提出ProbSparse Self-attention筛选出最重要的Q,降低计算复杂度
堆叠多层网络,内存占用瓶颈 提出Self-attention Distilling进行下采样操作,减少维度和网络参数的数量
step-by-step解码预测,速度较慢 提出Generative Style Decoder,一步可以得到所有预测的

基于以上,Informer提出了LSTF(Long sequence time-series forecasting)长时间序列预测。

0x03 Informer整体架构

(一)ProbSparse Self-attention

通过以下的数据可以看到,并不是每个QK的点积都是有效值,我们也不需要花很多时间在处理这些数据上:

这个结果也是合理的,因为某个元素可能只和几个元素高度相关,和其他的元素并没有很显著的关联。如果我们要提高计算效率的话,我们需要关注那些有特点的那些值,那我们要怎么去关注那些有特点的值呢:

我们需要进行一次Query稀疏性的衡量:

作者从概率的角度看待自注意力,定义是概率的形式,即在给定第i个query的条件下key的分布。作者认为,如果算出来的这个结果接近于均匀分布 ,那么就说明这个query是在偷懒,没办法选中那些重要的Key,如果反之,就说明这个Q为积极的,活跃的:

其计算公式如下:

之后我们进行比较:

我们算出了其概率以及与均匀分布的差异,如果差异越大,那么这个Q就有机会去被关注、说明其起到了作用。那么其计算方法到底是怎么样进行的,我们要取哪些Q哪些K进行计算:

(1)输入序列长度为96,首先在K中进行采样,随机选取25个K。

(2)计算每个Q与25个K的点积,可以得到M(qi,K),现在一个Q一共有25个得分

(3)在25个得分中,选取最高分的那个Q与均值算差异。

(4)这样我们输入的96个Q都有对应的差异得分,我们将差异从大到小排列,选出差异前25大的Q。

(5)那么传进去参数例如:[32,8,25,96],代表的意思为输入96个序列长度,32个batch,8个特征,25个Q进行处理。

(6)其他位置淘汰掉的Q使用均匀方差代替,不可以因为其不好用则不处理,需要进行更新,保证输入对着有输出。

以上的时间复杂度为O(L ln L):

ProbSparse Attention在为每个Q随机采样K时,每个head的采样结果是相同的,也就是采样的K是相同的。但是由于每一层self-attention都会对QKV做线性转换,这使得序列中同一个位置上不同的head对应的QK都不同,那么每一个head对于Q的差异都不同,这就使得每个head中的得到的前25个Q也是不同的。这样也等价于每个head都采取了不同的优化策略。

(二)Self-attention Distilling

这一层类似于下采样。将我们输入的序列缩小为原来的二分之一。作者在这里提出了自注意力蒸馏的操作,具体是在相邻的的Attention Block之间加入卷积池化操作,来对特征进行降采样。为什么可以这么做,在上面的ProbSparse Attention中只选出了前25个Q做点积运算,形成Q-K对,其他Q-K对则置为0,所以当与value相乘时,会有很多冗余项。这样也可以突出其主要特征,也降低了长序列输入的空间复杂度,也不会损失很多信息,大大提高了效率。

另外,作者为了提高encoder的鲁棒性,还提出了一个strick。途中输入embedding经过了三个Attention Block,最终得到Feature Map。还可以再复制一份具有一半输入的embedding,让它让经过两个Attention Block,最终会得到和上面维度相同的Feature Map,然后把两个Feature Map拼接。作者认为这种方式对短周期的数据可能更有效一些。

(三)Generative Style Decoder

对于Transformer其输出是先输出第一个,再基于第一个输出第二个,以此类推。这样子效率慢并且精度不高。看看总的架构图可以发现,decoder由两部分组成:第一部分为encoder的输出,第二部分为embedding后的decoder输入,即用0掩盖了后半部分的输入。

看看Embedding的操作:

  • Scalar是采用conv1d将1维转换为512维向量。

  • Local Time Stamp采用Transformer中的Positional Embedding

  • Gloabal Time Stamp则是上述处理后的时间戳经过Embedding。可以添加上我们的年月日时。

这种位置编码信息有比较丰富的返回,不仅有绝对位置编码,还包括了跟时间相关的各种编码。

最后,使用三者相加得到相加得到最后的输入(shape:[batch_size,seq_len,d_model])。

Decoder的最后一个部分是过一个linear layer将decoder的输出扩展到与vocabulary size一样的维度上,经过softmax后,选择概率最高的一个word作为预测结果。

那么假设我们有一个已经训练好的Transformer的神经网络,在预测时,传统的步骤是step by step的:

(1)给decoder输入encoder对整个句子embedding的结果和一个特殊的开始符号。decoder将产生预测,产生”I”。

(2)给decoder输入encoder的embedding结果和“I”,产生预测“am”

(3)给decoder输入encoder的embedding结果和“I am”,产生预测“a”

(4)给decoder输入encoder的embedding的结果和“I am a”,产生预测”student“。

(5)给decoder输入encoder的embedding的结果和“I am a student”,decoder应该生成句子结尾的标记,decoder应该输出“ ”。

(6)最后decoder生成了,翻译完成。

那么我们再看看Informer一步到位的预测:

提供一个start标志位:

  • 要让Decoder输出预测结果,你得先告诉它从哪开始输出。

  • 先给一个引导,比如要输出20-30号的预测结果,Decoder中需先给出。

  • 前面一个序列的结果,例如10-20号的标签值。

其实我们可以理解为一段有效的标签值带着一群预测值进行学习,效率更高。可以说是生成式推理,作者在这里没有选择一个特定的标记来做开始序列,而是选择了一段长的序列,比如目标序列之前一段已知序列。举例来说如果我们要预测7天的,我们可以把之前5天的信息作为开始序列,那么我们上述的式子这种方法可以一步到位生成目标序列,不需要再使用动态解码。

对于Decoder输入:

源码中的decoder输入长度为72,其中前48是真实值,后24是预测值。第一步是做自身的ProbAttention,注意要加上Mask(避免未卜先知)。先计算完自身的Attention。再算与encoder的Attention即可。

0x04 计算损失与迭代优化

损失函数为预测值和真值的MSE(均方误差),并且损失从解码器的输出反向传播到整个模型。优化器使用的Adam。

0x05 源码阅读——站在巨人的肩膀上

(一)环境的搭建

首先我们可以看到Informer所需要的环境:

首先我们需要下载并配置好Anaconda,并下载好PyTorch。安装Anaconda网上很多教程,下面是我安装Pytorch的过程:

  • 先配置好清华源

conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
conda config --set show_channel_urls yes
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/

如果出现了这种错误:

An HTTP error occurred when trying to retrieve this URL. HTTP errors are often intermittent, and a simple retry will get you on your way.

把镜像源中出现的https改为http就可以了。

清空源的命令:

conda config --remove-key channels
  • 创建环境PyTorch

conda create-n pytorch python=3.6
  • 查看环境是否安装成功

conda info --envs
  • 根据自己的情况安装PyTorch的版本

Pytorch官网:PyTorch

查看自己的CUDA:

nvidia-smi

我的CUDA为11.1,那么安装Pytorch的命令为:

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge

打开Anaconda Prompt命令窗口,进入刚刚所创建的环境中进行

conda activate PyTorch

之后我们就可以进入环境中。

最后再执行官方提供的指令,即可下载,最后可得:

之后进入我们Infomer的源码中进行activate我们的PyTorch环境:

打开PyCharm更改其Python解释器到我们Pytorch文件夹中对应的解释器:

之后在Pycharm的终端中输入:

pip install -r requirements.txt

如何使用呢:这个工程给了我们三个测试样例,我们只需要输入其中一个即可:

# ETTh1(指定使用模型为Informer 数据集为哪些 attention机制又为什么 如何处理时间)
python -u main_informer.py --model informer --data ETTh1 --attn prob --freq h# ETTh2
python -u main_informer.py --model informer --data ETTh2 --attn prob --freq h# ETTm1
python -u main_informer.py --model informer --data ETTm1 --attn prob --freq t

之后我们点击运行就可以看到输出了:

那么我们就通过断点调试来进行学习吧。

(二)Informer文件框架

使用Pycharm打开文件夹我们可以看到如下:

  • data

项目的数据文件夹,其中的data_loader文件是加载数据、预处理数据的作用。

  • exp

项目训练功能文件夹,这里的py文件是用来训练模型的作用。

  • model

看过上面的大概描述应该大概知道这一个个文件要干的事情是什么了,这些东西都是这个模型的详细实现了,可以好好看看。

  • scripts

包含了模型的启动脚本,使用脚本。

  • utils

这里包含了模型的评估指标、时间轴的时间特征处理、指数缩减学习率、提前停止训练策略、数据标准化策略等功能,当然也有Mask机制。

(三)数据的输入

我们打开程序后,我们可以发现有一些数据输入的.csv文件:

那么我们可以打开源文件看看到底是个什么东西:

我们可以看到其时间其实是非常明确的,一个样本都有一个固定的时间,我们的数据集是以小时为单位的,可以理解为每个采样点间隔一个小时,后面的内容我们可以理解为一个时间点其具有多个特征,最后的那一列则为输出结果。那么看到这应该可以知道大概怎么替换自己的数据了吧。.csv文件格式适合使用pandas进行处理。

回到main函数,我们在这打上断点1:

这一句很明显是传入参数,可能大部分人看到上面那些一大堆的参数不知道是什么,要怎么改,我在这里注释了大部分:

parser = argparse.ArgumentParser(description='[Informer] Long Sequences Forecasting')# 使用的网络结构(方便对比实验),使用defalut更改网络结构
parser.add_argument('--model', type=str, default='informer',help='model of experiment, options: [informer, informerstack, informerlight(TBD)]')# 读的数据是什么(类型 路径)
parser.add_argument('--data', type=str, default='WTH', help='data')
parser.add_argument('--root_path', type=str, default='./data/', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='WTH.csv', help='data file')# 预测的种类及方法
parser.add_argument('--features', type=str, default='M',help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')# 哪一列要当作是标签
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
# 数据中存在时间 时间是以什么为单位(属于数据挖掘中的重采样)
parser.add_argument('--freq', type=str, default='h',help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
# 模型最后保存位置
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
# 当前输入序列长度(可自定义)
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length of Informer encoder')
# 标签(带着预测值的那个东西)长度(可自定义)
parser.add_argument('--label_len', type=int, default=48, help='start token length of Informer decoder')
# 预测未来序列长度 (可自定义)
parser.add_argument('--pred_len', type=int, default=24, help='prediction sequence length')
# Informer decoder input: concat[start token series(label_len), zero padding series(pred_len)]
# 编码器、解码器输入输出维度
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
# 输出预测未来多少个值
parser.add_argument('--c_out', type=int, default=7, help='output size')
# 隐层特征
parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
# 多头注意力机制
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
# 要做几次多头注意力机制
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
# 堆叠几层encoder
parser.add_argument('--s_layers', type=str, default='3,2,1', help='num of stack encoder layers')
parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
# 对Q进行采样,对Q采样的因子数
parser.add_argument('--factor', type=int, default=5, help='probsparse attn factor')
parser.add_argument('--padding', type=int, default=0, help='padding type')
# 是否下采样操作pooling
parser.add_argument('--distil', action='store_false',help='whether to use distilling in encoder, using this argument means not using distilling',default=True)
parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
# 注意力机制
parser.add_argument('--attn', type=str, default='prob', help='attention used in encoder, options:[prob, full]')parser.add_argument('--embed', type=str, default='timeF',help='time features encoding, options:[timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')
parser.add_argument('--mix', action='store_false', help='use mix attention in generative decoder', default=True)
# 读数据
parser.add_argument('--cols', type=str, nargs='+', help='certain cols from the data files as the input features')
# windows用户只能给0
parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
# 训练轮数以及epoch
parser.add_argument('--itr', type=int, default=2, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=6, help='train epochs')
parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
# 停止策略
parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
# 学习率
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
# 损失函数
parser.add_argument('--loss', type=str, default='mse', help='loss function')
parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
# 是否为分布式
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
# 如果为分布式指定有几个显卡
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')args = parser.parse_args()

之后单步调试后,我们可以发现上面的数据都被读取进来了:

之后在往下走,我们可以看到我们输入的数据已经被定义为几个字典:

其中我们可以发现其指定了标签、指定了列、指定了输出大小等等,那么同样的,如果我们想要使用这个模型训练自己的数据,我们可以这么编写:

代表的是我在某个csv中读取数据,指定了col列为我的标签,我要预测我未来52个数据。

那么在这个函数中,我们就已经把数据读入到我们的设置的变量中啦:

那么下面这个东西也是同理的,我们指定了要循环多少个encoder,时间的频率又是多少,直接从上面读下来:

那么接下来是这样的:

定义了一个类。我们将我们的数据储存进去开始训练啦:

我们已经开始迭代我们的训练以及预测的过程了:

(四)模型训练

(1)预处理

我们跳进其训练的函数,我们可以看到这个:

上面的三个断点我们可以明显看到这是数据集的划分,我们先点进去训练集:

这些其实就是刚刚那个setting中的参数啦。接下来是下面定义的一个字典:

这个是我们的数据集中,处理数据的一种方式,我们可以打开其对应的函数进行看看:

之后配置我们数据的要怎么样进行训练:

配置完毕后,他又继续读入数据,按照配置进行数据处理:

那么我们进去数据处理的时候可以发现他将我们要训练的参数进行了划分,分别就是我们上面的96,48,24,划定了输入的序列长度以及输入的带预测的值以及最后的预测值:

之后是数据标准化处理:

初始化了一个均值以及标准差,处理后返回一个预处理结果,这个时候还没有输入数据。他下面就开始读入数据了:

他在这个时候将我们csv的数据都读取了进去,进行标准化操作:

最后框出来的这一列是我们要预测的标签:

接下来就是处理数据了,使用的是pandas的方法:

在这一步的时候去除了我们的标签以及日期,因为这些东西并不是我们所要的特征值,所以使用pandas的.columns进行去除。接下来就开始分训练集、测试集以及验证集:

之后我们指定好数据集的起始位置以及终止位置,取出我们的96个序列,以及取出我们的验证集,分好我们训练集、测试集、验证集的位置:

总的来说上面的几个数字是这个意思:

我们就可以根据前面的set_type变量来设定我们的边界。

在这一步,我们去掉了我们前面的时间,留下我们的所有特征:

之后我们就进行标准化操作了,我们先将我们训练集的数据拿进去算均值以及方差,最后再进行transformer操作:

最后每个数都可以得到一个数据,这是每个数减去其均值之后再除以方差可得的一个标准差,很基本的标准化操作:

之后我们再分离出我们的时间,时间序列的训练肯定少不了很多时间的处理方式,最后转换为pandas的格式,方便pandas的处理:

关于时间的处理:

看看这个东西:

return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0)

先看看函数time_features_from_frequency_str:

这个东西其实就是我们前面讲的以几个小时为单元的特征提取,返回适合于给定频率字符串的时间特征列表

那么里面offsets又是什么东西,其实就是pandas已经处理好的包了,它的目的是想将我们csv中的时间,根据日月年时间进行读取特征,并且用其他数字来代替这些时间,我们可以看看具体的操作。

class SecondOfMinute(TimeFeature):"""Minute of hour encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return index.second / 59.0 - 0.5class MinuteOfHour(TimeFeature):"""Minute of hour encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return index.minute / 59.0 - 0.5class HourOfDay(TimeFeature):"""Hour of day encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return index.hour / 23.0 - 0.5class DayOfWeek(TimeFeature):"""Hour of day encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return index.dayofweek / 6.0 - 0.5class DayOfMonth(TimeFeature):"""Day of month encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return (index.day - 1) / 30.0 - 0.5class DayOfYear(TimeFeature):"""Day of year encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return (index.dayofyear - 1) / 365.0 - 0.5class MonthOfYear(TimeFeature):"""Month of year encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return (index.month - 1) / 11.0 - 0.5class WeekOfYear(TimeFeature):"""Week of year encoded as value between [-0.5, 0.5]"""def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:return (index.isocalendar().week - 1) / 52.0 - 0.5

最后处理后的时间是这样的:

那么以上就是我们所有的数据处理。而且训练集测试集验证集都是一样的。那么到现在,我们的训练集、测试集、验证集都已经做完了:

之后就开始训练了,规定了模型的位置,以及计算时间、规定训练次数以及损失函数的定义:

每一次迭代都会进入这个函数进行数据的处理:

在这里他对96个输入作为x,将72个输出作为y:

那么以上循环32次。

(2)encoder

接下来看看模型的搭建:

在decoder的处理中,我们先将decoder输入预测的24个值全部初始化为0,之后再进行拼接:

那么最后,我们将我们上面处理好的值,终于要传入网络结构了:

  • batch_x与batch_x_mark:输入的encoder的数据,96个长度以及96个数据的时间

  • dec_inp与batch_y_mark:输出的decoder的数据,72个数据以及72个数据的时间

模型的搭建位于model.py中的Informer类中:

我们首先看看forward函数:

对于输入的函数,我们首先进行了embeddding,是什么操作?

对输入的序列进行一维卷积,并且使用padding等操作后,最后变为512维的向量:

之后加入了位置编码以及时间特征,这些特征最都映射为512的特征。

最后的x是这样的:

class Encoder(nn.Module):def __init__(self, attn_layers, conv_layers=None, norm_layer=None):super(Encoder, self).__init__()self.attn_layers = nn.ModuleList(attn_layers)self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else Noneself.norm = norm_layerdef forward(self, x, attn_mask=None):# x [B, L, D]attns = []if self.conv_layers is not None:for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):x, attn = attn_layer(x, attn_mask=attn_mask)x = conv_layer(x) # pooling后再减半,还是为了速度考虑attns.append(attn)x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)attns.append(attn)else:for attn_layer in self.attn_layers:x, attn = attn_layer(x, attn_mask=attn_mask)attns.append(attn)if self.norm is not None:x = self.norm(x)return x, attns

之后我们跳进去EncoderLayer中:

我们发现其输入了3个x,为什么要三个x?原因是我们想要计算QKV矩阵呀。之后我们就步入了AttentionLayer:

我们将我们输入的参数x,分别乘以三个权重参数矩阵,可以得到QKV三个向量。在这一步也同样做了多头注意力机制的操作,在上面定义了8个头,在view的过程中将其转换为512=8*64,形成了多头注意力机制。

之后进入了Informer独有的ProbAttention

也就是上面所说的,筛选出明显特征的Q,找出Q的代表!!这篇论文最核心的东西了。

这里计算出了我们需要25个Q:

之后进入了_prob_QK

它先将K进行复制,扩充了一个维度,之后进行随机采样25个K。再使用25个K与96个Q相乘计算其内积。最后得到的是96*25的的矩阵,最后再赋予到那个复制的维度中。之后我们再继续算出了Q与K之间的关系:

首先找出每一个Q(一共96个Q)中最大的那个值,之后要与均匀分布作比较:

比较完后,我们进行排序,最后得到索引,去找对应的25个Q。

到了这里,我们已经采样完25个特征最明显的Q了。K还是一样保持了96个。这样我们的ProAttention就这么算完了。

之后再排除掉维度造成的影响,也就是公式中除的那个根号d:

之后我们要对V矩阵进行处理,我们进入到函数_get_initial_context

在这里我们将没有视为特征的Q,使用均值来代替。让他一直在平庸。在这一步我们还没有得知25个Q是哪些,我们把所有的V都初始化为我们的平庸值。

之后我们将有特征的值进行保留,进入函数_update_context

在其中,我们也执行了softmax操作,目的是为了这条公式:

之后我们对25个Q进行更新V,并且计算好了其Attention。之后返回了我们的上下文,并且经过了全连接。

(3)dropout

这个操作就是Transformer中的残差连接。叫做至少不比原来差

(4)Convlayer

蒸馏操作,我们需要做多次attention,但是我们不会以原大小进行操作,我们会将其缩小为原来的二分之一,提速的作用:1维卷积+ELU激活函数+最大池化

之后再重复我们上面的操作。

(5)decoder

以上我们就已经执行完了encoder的操作了,短短的两行程序操作了很多东西:

那么我们就进入decoder的世界了:

首先我们进行了embedding,这个操作跟上面的是一模一样的。

之后进入了类Decoder:

我们可以发现上面框住的Attention与下面的那个框其实是不一样的,但是上面的那个执行过程,与encoder其实是一模一样的。

不一样在哪呢,我们的decoder加入了mask机制呀,最大的区别就是不可以未卜先知。也就是图示中的Masked Multi-head。

那么我们进入FullAttention去看看操作到底是怎么样的:

那么self-attention与cross-attention有什么区别呢:

  • self-attention是自己与自己之间

  • cross-attention是encoder与decoder之间

其他操作都是一模一样的。

那么从代码中我们可以发现,我们最后的输出是联系了encoder的self-attention以及decoder的cross-attention:

上面的self-attention中,输入的三个x其实是为了计算我们的QKV,这个x其实就是我们那96个从csv文件中读取出来的数据,而下面的cross-attention,首先输入的第一个x是我们decoder前embedding好的要带着我们预测数据的那48个值,它只需要一个Q向量,我们使用encoder中算出来的Ek以及Ev进行预测我们要的预测值。

之后就再做一次残差连接,然后就快胜利了!!

之后我们对我们的72个输出,将48个进行抛弃,读取我们需要的24个。并且这个模型做了一个多变量的预测。需要注意的是,Mask机制只是将48个后面的24个进行mask,而不是一开始就直接mask。


学习这个网络花了我好多好多时间,希望自己下次效率可以再快一些!!

源码阅读及理论详解《 Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting 》相关推荐

  1. Linux内核源码阅读以及工具详解

    接上篇Linux内核源码下载方法 这篇总结了如何利用source insight对Linux内核代码进行阅读和学习(资料来源于网络) 随着linux的逐步普及,现在有不少人对于Linux的安装及设置已 ...

  2. Xposed源码剖析——app_process作用详解

    Xposed源码剖析--app_process作用详解 首先吐槽一下CSDN的改版吧,发表这篇文章之前其实我已经将此篇文章写过了两三次了.就是发表不成功.而且CSDN将我的文章草稿也一带>删除掉 ...

  3. mysql data文件夹恢复_【专注】Zabbix源码安装教程—步骤详解(2)安装并配置mysql...

    四.安装并配置mysql(1) 解压mysql-5.7.26.tar.gz与boost_1_59_0.tar.gz #tar -xvf mysql-5.7.26.tar.gz #tar -xvf bo ...

  4. php+mysql案例含源码_【专注】Zabbix源码安装教程—步骤详解(1)安装前准备

    一.实验环境准备 Rhel 7.6 x86_64(server) 192.168.163.72 Rhel 6.5 x86_64(agent) 192.168.163.61 均已配置操作安装光盘为YUM ...

  5. React 源码系列 | React Context 详解

    目前来看 Context 是一个非常强大但是很多时候不会直接使用的 api.大多数项目不会直接使用 createContext 然后向下面传递数据,而是采用第三方库(react-redux). 想想项 ...

  6. dockerfile源码安装mysql_docker容器详解五: dockerfile实现tomcat环境以及源码安装mysql...

    tomcat 上一节讲到了dockerfile的基础,这一次咱们来作一个小的练习 首先要了解tomcat安装的整个过程 首先搭建 jdk环境: 下载jdk包,解压以后添加环境变量 而后搭建tomcat ...

  7. 未能找到元数据文件_Flink 源码:Checkpoint 元数据详解

    本文是 Flink 源码解析系列,通过阅读本文你能 get 到以下点: Flink 任务从 Checkpoint 处恢复流程概述 Checkpoint 元数据详解 从源码层分析:JM 该如何合理地给每 ...

  8. python随机数程序源码_Python 实现随机数详解及实例代码

    Python3实现随机数 random是用于生成随机数的,我们可以利用它随机生成数字或者选择字符串. random.seed(x)改变随机数生成器的种子seed. 一般不必特别去设定seed,Pyth ...

  9. 源码免杀教程 源码免杀思路详解

    绝对不一样的源码免杀教程!绝对不一样的免杀实战体验!清晰的思路!细致全面的思路详解!让你感到免杀原来可以这么简单!教你如何在源代码中找出被杀代码,修改代码从而达到免杀效果! 免杀之-网络攻防入门书籍推 ...

最新文章

  1. mysqli存储过程
  2. 四张图揭秘中国AI人才现状
  3. 将query存进数组 php,thinkphp下通过QueryList获取网站指定数据并封装成数组,存入数据库...
  4. 来自极客标签10款最新设计素材-系列七
  5. xfce的开始菜单增加搜索框
  6. gbase 8s oracle,GBase8s 查看数据库表空间信息
  7. 禅道11.0windows本机安装
  8. 优秀网站设计:打造有吸引力的网站(原书第3版)
  9. LeetCode 513. Find Bottom Left Tree Value
  10. 【java线程系列】java线程系列之java线程池详解
  11. HashTable源码简单介绍
  12. html内容change事件,HTML onvolumechange事件用法及代码示例
  13. 《Redis开发与运维》学习第十章
  14. 阿里云Maven镜像
  15. nuxt 引入iconfont多色图标
  16. 我读《格鲁夫给经理人的第一课》
  17. 手机屏分几种?什么叫水滴屏、刘海屏、瀑布屏、全面屏?
  18. 两起并购!深兰科技完成自动驾驶新能源车产业生态链布局
  19. 数学分析教程史济怀练习16.3
  20. php 天干地支,php实现天干地支计算器示例

热门文章

  1. 彻底理解 模拟频率、数字频率、模拟角频率
  2. 短信验证码发送失败的常见原因有哪些?
  3. PHP 微信小程序获取用户信息
  4. 小企业如何挑选在线客服系统
  5. 51单片机IIC 12864 OLED屏幕滚动显示仿真
  6. Android资源使用详解(一)
  7. volatile能保持线程安全吗_Java线程安全(volatile synchronized)
  8. 小甲鱼python视频xxoo爬虫代码改进--煎蛋网
  9. 入手评测 华为Watch3和Watch3 Pro的区别
  10. 三星watch3测量血压和心电图