论文学习心得

  • 前言
    • 应用场景
    • 基础概念
      • 什么是元学习
      • 元学习的分类
      • MAML
        • 基本概念理解
        • MAML中的Task
        • MAML算法详解
    • 摘要
  • 本文贡献
    • 利用来自多个城市的信息来提高迁移的稳定性
    • 元学习时空预测框架-MetaST
  • 问题定义
  • 方法
    • ST-Net
    • Knowledge Transfer
      • 适应初始化
      • 学习时空记忆
      • 转移知识到目标域
      • 算法
  • 实验
    • 交通预测任务
      • 对比实验方法
        • 非转移方法
        • 转移方法
      • 数据集
      • 评价标准
      • 实验结果
    • 水质预测任务
      • 对比实验方法
      • 数据集
      • 评价标准
      • 实验结果
  • 总结
  • 衍生内容
    • Crowd Flow Prediction by Deep Spatio-Temporal Transfer Learning (2019IJCAI)
    • Transfer Knowledge between Cities (2016KDD)
  • 参考文献

前言

应用场景

近年来,智慧城市建设显著改变了城市管理和服务。准确的时空预测是智能城市建设的基础技术之一。例如,交通预测系统可以帮助城市预先分配交通资源和智能控制交通信号。一个准确的环境预测系统可以帮助政府制定环境政策,进而提高公众的健康水平。

基础概念

什么是元学习

meta-learning即元学习,也可以称为“learning to learn”。常见的深度学习模型,目的是学习一个用于预测的数学模型。而元学习面向的不是学习的结果,而是学习的过程。其学习的不是一个直接用于预测的数学模型,而是学习“如何更快更好地学习一个数学模型”。

元学习的分类

  • learning good weight initializations : 学习一个好的初始化权重,从而在新任务上实现fast adaptation,即在小规模的训练样本上迅速收敛并完成fine-tune。其中MAML[4]属于本类中的经典算法。本方法也属于此类。
  • meta-models that generate the parameters of other models :
  • learning transferable optimizers :

MAML

此部分主要参考原论文和【经典论文解析】Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks,此博客对MAML的讲解更为详细。因此部分只是为了辅助理解,所以只阐述了基本概念和预训练算法,MAML与分类及强化学习结合的算法,本部分并未涉及。

基本概念理解

MAML 的中文名就是模型无关的元学习。意思就是不论什么深度学习模型,都可以使用MAML来进行少样本学习。论文中提到该方法可以用在分类、回归,甚至强化学习上。
本文的代码是基于分类的,那么就从分类的角度展开对MAML的解析。
本文介绍的MAML,其实是一种固定模型的meta learning ,可能会有人问
不是说MAML是模型无关的吗?为什么需要固定模型?
模型无关的意思是该方法可以用在CNN,也可以用在RNN,甚至可以用在RL中。但是MAML做的是固定模型的结构,只学习初始化模型参数这件事。
什么意思呢?就是我们希望通过meta-learning学习出一个非常好的模型初始化参数,有了这个初始化参数后,我们只需要少量的样本就可以快速在这个模型中进行收敛。
那么既然是learning to learn,那么输入就不再是单纯的数据了,而是一个个的任务(task)。就像人类在区分物体之前,已经看过了很多中不同物体的区分任务(task),可能是猫狗分类,苹果香蕉分类,男女分类等等,这些都是一个个的任务task。那么MAML的输入是一个个的task,并不是一条条的数据,这与常见的机器学习和深度学习模型是不同的。


MAML算法实际上优化的是一个可以快速适应新任务的表示θ\thetaθ

MAML中的N-way K-shot learning:
这里的N是用于分类的类别数量。K为每个类别的数据量(用于训练)

MAML中的Task

MAML的论文中多次出现名词task,模型的训练过程都是围绕task展开的,而作者并没有给它下一个明确的定义。要正确地理解task,我们需要了解的相关概念包括Dmeta-train, Dmeta-test, support set, query set, meta-train classes, meta-test classes等等。

我们假设这样一个场景:我们需要利用MAML训练一个数学模型模型 Mfine-tune ,目的是对未知标签的图片做分类,类别包括 P1-P5 (每类5个已标注样本用于训练。另外每类有15个已标注样本用于测试)。我们的训练数据除了 P1-P5 中已标注的样本外,还包括另外10个类别的图片C1-C10(每类30个已标注样本),用于帮助训练元学习模型 Mmeta 。我们的实验设置为5-way 5-shot。

关于具体的训练过程,会在MAML算法详解中介绍。这里只需要有一个大概的了解:

  • MAML首先利用 C1-C10的数据集训练元模型Mmeta,再在P1~P5的数据集上精调(fine-tune)得到最终的模型 Mfine-tune。
  • 此时,C1-C10即meta-train classesC1-C10包含的共计300个样本,即 Dmeta-train是用于训练Mmeta的数据集。与之相对的,P1-P5 即meta-test classesP1~P5 包含的共计100个样本,即 Dmeta-test ,是用于训练和测试 Mfine-tune 的数据集
  • 根据5-way 5-shot的实验设置,我们在训练 Mmeta 阶段,从 C1~C10 中随机取5个类别,每个类别再随机取20个已标注样本,组成一个task T 。其中的5个已标注样本称为 T 的support set,另外15个样本称为 T 的query set。这个task T, 就相当于普通深度学习模型训练过程中的一条训练数据。那我们肯定要组成一个batch,才能做随机梯度下降SGD对不对?所以我们反复在训练数据分布中抽取若干个这样的task T ,组成一个batch。在训练 Mfine-tune 阶段,task、support set、query set的含义与训练 Mmeta 阶段均相同。

MAML算法详解

以下为预训练阶段的算法,目的是得到模型Mmeta:

第一个Require指的是在 Dmeta-train中task的分布。结合我们在上一小节举的例子,这里即反复随机抽取task T ,形成一个由若干个(e.g., 1000个)T 组成的task池,作为MAML的训练集。有的小伙伴可能要纳闷了,训练样本就这么多,要组合形成那么多的task,岂不是不同task之间会存在样本的重复?或者某些task的query set会成为其他task的support set?没错!就是这样!我们要记住,MAML的目的,在于fast adaptation,即通过对大量task的学习,获得足够强的泛化能力,从而面对新的、从未见过的task时,通过fine-tune就可以快速拟合。task之间,只要存在一定的差异即可。再强调一下,MAML的训练是基于task的,而这里的每个task就相当于普通深度学习模型训练过程中的一条训练数据。

第二个Require就很好理解啦。step size其实就是学习率,MAML是基于二重梯度的(gradient by gradient),每次迭代包括两次参数更新的过程,所以有两个学习率可以调整。

步骤1,随机初始化模型的参数。

步骤2,是一个循环,可以理解为一轮迭代过程或一个epoch,当然预训练的过程是可以有多个epoch的。

步骤3,相当于pytorch中的DataLoader,即随机对若干个(e.g., 4个)task进行采样,形成一个batch。

步骤4~步骤7,是第一次梯度更新的过程。注意这里可以理解为copy了一个原模型,计算出新的参数,用在第二轮梯度的计算过程中。我们说过,MAML是gradient by gradient的,有两次梯度更新的过程。步骤4~7中,利用batch中的每一个task,我们分别对模型的参数进行更新(4个task即更新4次)。注意这一个过程在算法中是可以反复执行多次的,伪代码没有体现这一层循环,但是作者再分析的部分明确提到" using multiple gradient updates is a straightforward extension"。

步骤5,即对利用batch中的某一个task中的support set,计算每个参数的梯度。在N-way K-shot的设置下,这里的support set应该有NK个。作者在算法中写with respect to K examples,默认对每一个class下的K个样本做计算。实际上参与计算的总计有NK个样本。这里的loss计算方法,在回归问题中,就是MSE;在分类问题中,就是cross-entropy。

步骤6,即第一次梯度的更新。

步骤4~步骤7结束后,MAML完成了第一次梯度更新。接下来我们要做的,是根据第一次梯度更新得到的参数,通过gradient by gradient,计算第二次梯度更新。第二次梯度更新时计算出的梯度,直接通过SGD作用于原模型上,也就是我们的模型真正用于更新其参数的梯度。换句话说,第一次梯度更新是为了第二次梯度更新,而第二次梯度更新才是为了更新模型参数。

关于以上过程,这里再补充一下解释:假设原模型是θa\theta_aθa​,我们复制了它,得到θb\theta_bθb​。在θb\theta_bθb​上,我们做了反向传播及更新参数,得到第一次梯度更新的结果 θb′\theta'_bθb′​。接着,在θb′\theta'_bθb′​上,我们将计算第二次梯度更新。此时需要先在θb′\theta'_bθb′​上计算梯度(计算方法如接下来的步骤8所述),但是梯度更新的并非是θb′\theta'_bθb′​而是原模型θa\theta_aθa​。这就是二重梯度在代码中的实现。

步骤8即对应第二次梯度更新的过程。这里的loss计算方法,大致与步骤5相同,但是不同点有两处。一处是我们不再是分别利用每个task的loss更新梯度,而是像常见的模型训练过程一样,计算一个batch的loss总和,对梯度进行随机梯度下降SGD。另一处是这里参与计算的样本,是task中的query set,在我们的例子中,即5-way*15=75个样本,目的是增强模型在task上的泛化能力,避免过拟合support set。步骤8结束后,模型结束在该batch中的训练,开始回到步骤3,继续采样下一个batch。
θ←θ−β∇θ∑Ti−p(T)LTi(fθi′)\theta\gets\theta-\beta\nabla_{\theta}\sum_{T_i-p(T)}\mathcal{L}_{T_i}(f_{\theta'_i})θ←θ−β∇θ​∑Ti​−p(T)​LTi​​(fθi′​​)
minθ∑Ti−p(T)LTi(fθi′)=∑Ti−p(T)LTi(fθ−α∇θLi(fθ))min_{\theta}\sum_{T_i-p(T)}\mathcal{L}_{T_i}(f_{\theta'_i})=\sum_{T_i-p(T)}\mathcal{L}_{T_i}(f_{\theta}-\alpha\nabla_\theta\mathcal{L}_{i}(f_\theta))minθ​∑Ti​−p(T)​LTi​​(fθi′​​)=∑Ti​−p(T)​LTi​​(fθ​−α∇θ​Li​(fθ​))
以上即时MAML预训练得到的全部过程。接下来,面对心得task,在目前的基础上,精调得到方法。

fine-tune的过程与预训练的过程大致相同,不同的地方主要在于以下几点:

步骤1中,fine-tune不用再随机初始化参数,而是利用训练好的初始化参数。

步骤3中,fine-tune只需要抽取一个task进行学习,自然也不用形成batch。fine-tune利用这个task的support set训练模型,利用query set测试模型。实际操作中,我们会随机抽取许多个task(e.g., 500个),分别微调模型,并对最后的测试结果进行平均,从而避免极端情况。

fine-tune没有步骤8,因为task的query set是用来测试模型的,标签对模型是未知的。因此fine-tune过程没有第二次梯度更新,而是直接利用第一次梯度计算的结果更新参数

摘要

时空预测是构建智能城市的一个基本问题,对交通控制、出租车调度、环境决策等任务具有重要意义。由于数据采集机制,数据采集的空间分布不平衡是很常见的。例如,一些城市可能会发布多年的出租车数据,而另一些城市则只发布几天的数据;一些地区可能有由传感器监测的固定水质数据,而一些地区只有少量的水样本收集。在本文中,作者解决了只有短时间数据采集的城市的时空预测问题。作者的目标是通过迁移学习来利用来自其他城市的长期数据。与以往将知识从单一来源城市转移到目标城市的研究不同,本文是第一个利用来自多个城市的信息来提高迁移的稳定性的研究。具体来说,本文提出的模型被设计为一个具有元学习范式的时空网络。元学习范式学习了时空网络的广义初始化,可以有效地适应目标城市。此外,还设计了一种基于模式的时空记忆机制来提取长期的时间信息(即周期性)。本文在交通(出租车和自行车)预测和水质预测两个任务上进行了广泛的实验。实验证明了本文提出的模型在几个竞争基线模型上的有效性。

本文贡献

利用来自多个城市的信息来提高迁移的稳定性

迁移学习已经被研究为解决数据不足问题的有效解决方案,通过利用那些数据丰富的城市的知识(例如,覆盖几年的GPS跟踪)。在FLORAL[6]中,作者提出转移从数据丰富的城市即源城市学习到的语义相关字典,以预测数据不足的城市即目标城市的空气质量类别。在RegionTrans[3]中提出的方法将源城市和目标城市的相似区域对齐,以实现更细粒度的传输。然而,这些方法只从单一来源城市转移知识,会导致不稳定的结果和负转移的风险。如果城市之间的基础数据分布存在显著差异,知识转移不会有任何贡献,甚至影响性能
为了降低风险,本文将研究从多个源城市转移知识,用于目标城市的时空预测。与单个城市相比,从多个城市提取的知识涵盖了更全面的城市时空相关性,如时间依赖性、空间亲近性和区域功能,从而提高了迁移的稳定性。然而,这个问题面临着两个关键的挑战。

  • 如何使这些知识适应目标城市的各种时空关联情景?
  • 如何从源城市中获取和借用长周期的时空格局?

元学习时空预测框架-MetaST

是第一个将元学习范式纳入时空网络(ST-net,实际上就是一个局部CNN和LSTM的组合时空网络)的框架

  • 解决第一个挑战:从多个源城市的大量预测任务中学习ST-net的广义初始化,其中涵盖了全面的时空场景。随后,初始化可以很容易地通过微调来适应目标城市,即使只有少数训练样本是可访问的。
  • 解决第二个挑战:从所有的源城市中学习一个基于全局模式的时空记忆(memory),并将其转移到一个目标城市,以支持长期模式。描述和存储长期时空模式的记忆,与ST-net以端到端的方式联合训练。

问题定义

yrct,kct+1∗=argmaxyrct,kct+1p(yrct,kct+1∣Yct,fθ0)y^*_{r_{c_t},k_{c_t+1}}=argmax_{y_{r_{c_t},k_{c_t+1}}}p(y_{r_{c_t},k_{c_t+1}}|\mathcal{Y}_{c_t},f_{\theta_{0}})yrct​​,kct​+1​∗​=argmaxyrct​​,kct​+1​​​p(yrct​​,kct​+1​​∣Yct​​,fθ0​​)
其中,yrc,kcy_{r_c,k_c}yrc​,kc​​是待预测的时空信息(如交通需求、空气质量、气候价值)。rcr_crc​表示区域c,kck_ckc​表示目前的时间步,Kc\mathcal{K}_{c}Kc​表示不重叠的连续时间段,Yc\mathcal{Y}_{c}Yc​表示时空序列,fff表示ST-net,作为预测时空序列的基础模型。在元学习范式中,θ0\theta_{0}θ0​表示ST-net的初始化的所有参数。

方法

ST-Net


简单的CNN+LSTM的时空网络

Knowledge Transfer

本节提出了一个元学习框架,使ST-net模型能够从多个城市借用知识。该框架包括两个部分:适应初始化和学习时空记忆。图中展示的是整个框架。

适应初始化

(此处一定先阅读MAML的相关内容)
在ST-net中,参数θ\thetaθ正是加密时空相关性的知识。为了有效地使参数适应目标城市,如MAML[4]所建议,初始化多个源城市,即θ0\theta_0θ0​,使θ0\theta_0θ0​初始化的ST-net达到所有源城市泛化损失平均值的最小值,即:
θ0=minθ0∑cs∈CsLcs′(fθ0−α∇θLcs(fθ))\theta_0=min_{\theta_0}\sum_{c_s\in C_s}\mathcal{L}'_{c_s}(f_{\theta_0}-\alpha\nabla_\theta\mathcal{L}_{c_s}(f_\theta))θ0​=minθ0​​∑cs​∈Cs​​Lcs​′​(fθ0​​−α∇θ​Lcs​​(fθ​))
其中,Lcs(fθ)\mathcal{L}_{c_s}(f_\theta)Lcs​​(fθ​)表示从CsC_sCs​中采样的一个城市csc_scs​的训练集上的训练损失,即Dcs\mathcal{D}_{c_s}Dcs​​(参见图中的S-train)。Lcs′(.)\mathcal{L}'_{c_s}(.)Lcs​′​(.)评估城市测试集上的损失,即Dcs′\mathcal{D}'_{c_s}Dcs​′​(参见图中的S-test)。通过使用随机梯度下降(图2中的紫色实心箭头所示)优化此等式,可得到一个初始化,可以很好地推广到不同的源城市。
θcs=θ0−α∇θLcs(fθ)\theta_{c_s}=\theta_0-\alpha\nabla_\theta\mathcal{L}_{c_s}(f_\theta)θcs​​=θ0​−α∇θ​Lcs​​(fθ​)
这是其中一个梯度下降的例子。这说明了参数θcs\theta_{c_s}θcs​​在训练过程中的迭代更新过程。在实践中,可以使用几个梯度下降步骤从初始化θ0\theta_0θ0​更新到θcs\theta_{c_s}θcs​​。对于每个城市csc_scs​,训练过程是重复在从

【论文笔记】Learning from Multiple Cities: A Meta-Learning Approach for Spatial-Temporal Prediction相关推荐

  1. 【论文笔记09】Differentially Private Hypothesis Transfer Learning 差分隐私迁移学习模型, ECMLPKDD 2018

    目录导引 系列传送 Differentially Private Hypothesis Transfer Learning 1 Abstract 2 Bg & Rw 3 Setting &am ...

  2. 小样本论文笔记5:Model Based - [6] One-shot learning with memory-augmented neural networks.

    小样本论文笔记5:Model Based - [6] One-shot learning with memory-augmented neural networks 文章目录 小样本论文笔记5:Mod ...

  3. 论文笔记:Bootstrap Your Own Latent A New Approach to Self-Supervised Learning

    论文笔记:Bootstrap Your Own Latent A New Approach to Self-Supervised Learning abstract: 介绍了BYOL网络(原理):依赖 ...

  4. 论文笔记《Incorporating Copying Mechanism in Sequence-to-Sequence Learning》

    论文笔记<Incorporating Copying Mechanism in Sequence-to-Sequence Learning> 论文来源:2016 ACL 论文主要贡献:提出 ...

  5. 论文笔记:Decoding Brain Representations by Multimodal Learning of Neural Activity and Visual Features

    论文笔记:Decoding Brain Representations by Multimodal Learning of Neural Activity and Visual Features(通过 ...

  6. 论文笔记:Missing Value Imputation for Multi-view UrbanStatistical Data via Spatial Correlation Learning

    TKDE 2021(Apr) 0 摘要 作为城市化的发展趋势,海量的多视角(如人口和经济视角)的城市统计数据被越来越多地收集并受益于不同领域,包括交通服务.区域分析等. 划分为细粒度区域的数据在获取和 ...

  7. 【论文笔记】《Social Influence-Based Group Representation Learning for Group Recommendation》

    ICDE 19 A会 这篇论文一作是阴老师,获得了ICDE19最佳论文奖. Abstract 作为群居动物,参加群组活动是人日常生活中必不可少的一部分,为群组用户推荐满意的活动是推荐系统一项重要的任务 ...

  8. 【论文笔记】Gradient Episodic Memory for Continual Learning

    Gradient Episodic Memory for Continual Learning(用于持续学习的梯度情景记忆) 本篇论文的贡献 创新性 Gradient of Episodic Memo ...

  9. 论文笔记 WWW 2022|Ontology-enhanced Prompt-tuning for Few-shot Learning

    文章目录 1 简介 1.1 动机 1.2 创新 2 方法 2.1 General Framework with Prompt-Tuning 2.2 Ontology Transformation 2. ...

最新文章

  1. 有存款,才能过得更踏实
  2. 首场见习挑战赛倒计时3天!20000元奖学金瓜分就等你了!
  3. LINQ to XML 常用操作(转)
  4. SpringBoot - Spring Boot 应用剖析
  5. 什么叫「人的格局」?是否有必要培养大的格局或怎么培养?
  6. opensuse x64下编译Ice源码(以编译c++为例)
  7. java.util.regex_java.util.regex.PatternSyntaxException:索引附近的...
  8. 从前中后序遍历构造二叉树,三题无脑秒杀
  9. pom.xml中依赖的<optional>true</optional>标签
  10. 经典面试题(19):以下代码将输出的结果是什么?
  11. Spring学习笔记专题二
  12. 1014.修改clion的工具链
  13. MFC小笔记:系统托盘实现
  14. 二分查找的代码实现--go语言
  15. 接口接收数据_基于原语的千兆以太网RGMII接口设计
  16. 命令行_Pytest之命令行执行
  17. 华为谷歌安装器 Android6.0,gms安装器华为
  18. 记一次python cpu100%分析记录
  19. 迅雷如何添加html文件夹,迅雷7上我的收藏怎么找
  20. 数字图像处理笔记(八)彩色图像和彩色图像直方图均衡化

热门文章

  1. 聆听第18期贡献者荣誉榜发布,体验落地才是王道
  2. 学会聆听,职场最重要的事情,没有之一!!!
  3. 《图解密码技术》笔记2:历史上的密码-写一篇别人看不懂的文章
  4. [2015国家集训队互测]口胡
  5. 攻防世界各类题目相关
  6. 综合案例——手写数字图像处理算法比较
  7. Android刷机SD卡分区指南 [
  8. MATLAB[2]:绘图坐标轴的设置
  9. java反算坐标方位角,根据经纬度求方位角,以北为0,顺时针为正方向
  10. Python爬取童程童美TTS网站知识点图片