@作者 | Billy Z

模型加入先验知识的必要性

端到端的深度神经网络是个黑盒子,虽然能够自动学习到一些可区分度好的特征,但是往往会拟合到一些非重要特,导致模型会局部坍塌到一些不好的特征上面。常常一些人们想让模型去学习的特征模型反而没有学习到。

为了解决这个问题,给模型加入人为设计的先验信息会让模型学习到一些关键的特征。下面就从几个方面来谈谈如何给模型加入先验信息。

为了方便展示,我这边用一个简单的分类案例来展示如何把先验知识加入到一个具体的 task 中。我们的 task 是在所有的鸟类中识别出一种萌萌的鹦鹉,这中鹦鹉叫鸮(xiāo)鹦鹉,它长成下面的样子:

▲ 鸮(xiāo)鹦鹉

这种鸟有个特点:

就是它可能出现在任何地方,但就是不可能在天上,因为它是世界上唯一一种不会飞的鹦鹉(不是唯一一种不会飞的鸟)。

好,介绍完 task 的背景,咱们就可以分分钟搭建一个端到端的分类神经网络,可以选择的网络结构可以有很多,如 resnet, mobilenet 等等,loss 往往是一个常用的分类 Loss,如交叉熵,高级一点的用个 focal loss 等等。确定好了最优的数据(扰动方式),网络结构,优化器,学习率等等这些之后,往往模型的精度也就达到了一个上限。

然后你测试模型发现,有些困难样本始终分不开,或者是一些简单的样本也容易分错。这个时候如果你还想提升网络的精度,可以通过给模型加入先验的方式来进一步提升模型的精度。

基于pretrain模型给模型加入先验

给模型加入先验,大家最容易想到的是把网络的 weight 替换成一个在另外一个任务上 pretrain 好的模型 weight。经过的预训练的模型(如 ImageNet 预训练)往往已经具备的识别到一些基本的图片 pattern 的能力,如边缘,纹理,颜色等等,而识别这些信息的能力是识别一副图片的基础。如下图所示:

但这些先验信息都是一些比较 general 的信息,我们是否可以加入一些更加 high level 的先验信息呢。

基于输入给模型加入先验

假如你有这样的一个先验:

你觉得鸮鹦鹉的头是一个区别其他它和鸟类的重要部分,也就是说相比于身体,它的头部更能区分它和其他鸟类。

这时怎么让网络更加关注鸮鹦鹉的头部呢。这时你可以这样做,把整个鸮鹦鹉和它的头部作为一个网络的两路输入,在网咯的后端再把两路输入的信息融合。以达到既关注局域,又关注整体的目的。一个简单的示意图如下所示。

基于模型重现给模型加入先验

接着上面的设定来,假如说你觉得给模型两路输入太麻烦,而且增加的计算量让你感觉很不爽。

这时,你可以尝试让模型自己发现你设定的先验知识

假如说你的模型可以自己输出鸟类头部的位置,虽然这个鸟类头部的位置信息是你不需要的,但是输出这样的信息代表着你的网络能够 locate 鸟类头部的位置,也就给鸟类的头部更加多的 attention,也就相当于给把鸟类头部这个先验信息给加上去了。

当然直接模仿 detection 那样去回归出位置来这个任务太 heavy 了,你可以通过一个生成网络的支路来生成一个鸟类头部位置的 Mask,一个简单的示意图如下:

▲ 测试的时候不增加计算量

基于CAM图激活限制给模型加入先验

针对鸮鹦鹉的分类,我在上面的提到一个非常有意思的先验信息:

那就是鸮鹦鹉是世界上唯一一种不会飞的鹦鹉。

这个信息从侧面来说就是,鸮鹦鹉所有地方都可能出现,就是不可能出现在天空中(当然也不可能出现在水中)。

也就是说不但鸮鹦鹉本身是一个分类的重点,鸮鹦鹉出现的背景也是分类的一个重要参考。假如说背景是天空,那么就一定不是鸮鹦鹉,同样的,假如说背景是海水,那么也一定不是鸮鹦鹉,假如说背景是北极,那么也一定不是鸮鹦鹉,等等。

也就是说,你不能通过背景来判断一只未知的鸟是鸮鹦鹉,但是你能通过背景来判断一只未知的鸟肯定不是鸮鹦鹉(是其他的鸟类)。

所以假如说获取了一张输入图片的激活图(包含背景的),那么这张激活图的鸟类身体部分肯定包含了鸮鹦鹉和其他鸟类的激活,但是鸟类身体外的背景部分只可能包含其他鸟类的激活。

所以具体的做法是基于激活图,通过限制激活图的激活区域,加入目标先验

CAM [1] 激活图是基于分类网络的倒数第二层卷积层的输出的 feature_map 的线性加权,权重就是最后一层分类层的权重,由于分类层的权重编码了类别的信息,所以加权后的响应图就有了基于不同类别的区域相应。

具体的介绍可以看:

https://zhuanlan.zhihu.com/p/51631163

具体的激活图生成方式可以如下表示:

说了这么多,下面就展示展示激活图的样子:

大家可以看到,上面一张是一只鸮鹦鹉的激活图,下面是一只在天空飞翔的大雁的激活图。

因为鸮鹦鹉的 Label 是 0,其他鸟类的 Label 是 1,所以在激活图上,只要是负值的激活区域都是鸮鹦鹉的激活,也就是 Label 为 0 的激活,只要是正值的激活都是其他鸟类的激活,也就是 Label 为 1 的激活。

为了方便展示,我把负值的激活用冷色调来显示,把正值的激活用暖色调来显示,所以就是变成了上面两幅激活图的样子。而右边的数字是具体的激活矩阵(把激活矩阵进行 GAP 就可以变成最终输出的 Logits)。

到这里不知道大家有没有发现一个问题,就是无论对于鸮鹦鹉还是大雁的图片,它们的激活图除了分布在鸟类本身,也会有一部分分布在背景上。对于大雁我们好理解,因为大雁是飞在天空中的,而鸮鹦鹉是不可能在天空中的,所以天空的正激活是非常合理的。但是对于鸮鹦鹉来说,其在鸟类身体以外的负激活就不是太合理,因为,大雁或者是其他的鸟类,也可能在鸮鹦鹉的地面栖息环境中(但是鸮鹦鹉却不可能在天空中)。

所以环境不能提供任何证据来证明这一次鸟类是一只鸮鹦鹉,鸮鹦鹉的负激活只是在鸟类的身体上是合理的。而其他鸟类的正激活却可以同时在鸟类身体上又可能在鸟类的背景上(如天空或者海洋)。

所以我们需要这样建模这个问题,就是在除鸟类身体的背景上,不能出现鸮鹦鹉的激活,也就是说不能出现负激活(Label 为 0 的激活)。所以下面的激活才是合理的:

从上面来看,在除鸟类身体外的背景部分是不存在负激活的,虽然上面的背景部分有一些正的激活(其他鸟类的激活),但是从右边的激活矩阵来看,负激活的 scale 是占据绝对优势的,所以完全不会干扰对于鸮鹦鹉的判断。

所以问题来了,怎么从网络设计方面来达到这个目的呢?

其实可以从 Loss 设计方面来达到这个效果。我们假设每一个鸟都有个对应的 mask,mask 内是鸟类的身体部分,mask 外是鸟类的背景部分。那么我们需要做的就是抑制 mask 外的背景部分激活矩阵的负值,把那一部分负值给抑制到 0 即可。

鸟类的激活矩阵和 mask 的关系如下图(红色的曲线代表鸟的边界 mask):

我们的 Loss 设计可以用下面的公式表示:

具体的网络的 framework 可以如下所示:

其中虚线部分只是训练时候需要用到,inference 的时候是不需要的,所以这种方法也是不会占用任何在 inference 前向时候的计算量。


基于辅助学习给模型加入先验知识

到现在为止,咱们还只是把我们的鸟类分类的 task 当成一个二分类来处理,即鸮鹦鹉是一类,其他的鸟类是一类。

但是我们知道,世界的鸟类可不仅仅是两类,除了鸮鹦鹉之外还有很多种类的鸟类。而不同鸟类的特征或许有很大的差别,比如鸵鸟的特征就是脖子很长,大雁的特征就是翅膀很大。

假如只是把鸮鹦鹉当做一类,把其他的鸟类当做一类来学习的话,那么模型很可能不能学到可以利用的区分非鸮鹦鹉的特征,或者是会坍塌到一些区分度不强的特征上面,从而没有学到能够很好的区分不同其他鸟类的特征,而那些特征对去区别鸮鹦鹉和其他鸟类或许是重要的。

所以我们有必要加入其他鸟类存在不同类别的先验知识。而这里,我主要介绍基于辅助学习的方式去学习类似的先验知识。首先我要解释一下什么是辅助学习,以及辅助学习和多任务学习的区别:

上图的左侧是多任务学习的例子,右侧是辅助学习的例子。左侧是个典型的 face attribute 的 task,意思是输入一张人脸,通过多个 branch 来输出这一张人脸的年龄,性别,发型等等信息,各个 branch 的任务是独立的,同时又共享同一个 backbone。

右边是一个典型的辅助学习的 task,意思是出入一张人脸,判断这一张人脸的性别,同时另外开一个(或几个)branch,通过这个 branch 来让网络学一些辅助信息,比如发型,皮肤等等,来帮助网络主任务(分男女)的判别。

好,回到我们的鸮鹦鹉分类的 task,我们可能首先会想到下面的 Pipeline:

这样虽然可以把不同类别的鸟类的特征都学到,但是却削弱了网络对于鸮鹦鹉和其他鸟类特征的分别。

经过实验发现,这种网络架构不能很好的增加主任务的分类精度。为了充分的学到鸮鹦鹉和其他鸟类特征的分别,同时又能带入不同种类鸟类类别的先验,我们引入辅助任务:

在上面的 Pipeline 中,辅助任务相比如主任务,把其他鸟类做更加细致的分类。这样网络就学到了区分不同其他鸟类的能力。

但是从实验效果来看这个 Pipeline 的精度并不高。经过分析原因,发现在主任务和辅助任务里面都有鸮鹦鹉这一类,这样当回传梯度的时候,相当于把区分鸮鹦鹉和其他鸟类的特征回传了两次梯度,而回传两次梯度明显是没用的,而且会干扰辅助任务学习不同其他鸟类的特征。

所以我们可以把辅助任务的鸮鹦鹉类去除,于是便形成了下面的 pipeline:

经过实验发现,这种 pipeline 是有利于主任务精度提升的,网络对于特征明显的其他鸟类的分类能力得到了一定程度的提升,同时对于困难类别的分类能力也有一定程度的提升。

当然,辅助任务的 branch 可以不只是一类,你可以通过多个类别来定义你的辅助任务的 branch:

这时候你会想,上面的 pipeline 好是好,但是我没有那么多的 label 啊。是的,上面的 pipeline 除了主任务的 label 标注,它还同时需要很多的辅助任务的 label 标注,而标注 label 是深度学习任务里面最让人头疼的问题(之一)。

别怕,我下面介绍一个 work,它基于 meta-learning 的方法,让你不再为给辅助任务标注 label 而烦恼,它的 framework 如下:

这个 framework 采用基于 maxl [2] 的方案(https://github.com/lorenmt/maxl),辅助任务的数据和 label 不是由人为手工划分,而是由一个 label generator 来产生,label generator 的优化目标是让主网络在主任务的 task 上的 loss 降低,主网络的目标是在主任务和辅助任务上的 loss 同时降低。

但是这个 framework 有个缺点,就是训练时间会上升一个数量级,同时 label generator 会比较难优化。感兴趣的同学可以自己尝试。但是不得不说,这篇文章有两个结论倒是很有意思:

1. 假设 primary 和 auxiliary task 是在同一个 domain,那么 primary task 的 performance 会提高当且仅当 auxiliary task 的 complexity 高于 primary task。

2. 假设 primary 和 auxiliary task 是在同一个 domain,那么 primary task 的最终 performance 只依赖于 complexity 最高的 auxiliary task。

结语

先总结一下所有可以有效的加入先验信息的框架:

你可以通过上述框架的选择来加入自己的先验信息。

给神经网络的黑盒子里面加入一些人为设定的先验知识,这样往往能给你的task带来一定程度的提升,不过具体的task需要加入什么样的先验知识,需要如何加入先验知识还需要自己探索。

参考文献

[1] CAM https://arxiv.org/abs/1512.04150

[2] maxl https://arxiv.org/abs/1901.08933

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

如何给模型加入先验知识?相关推荐

  1. 综述:如何给模型加入先验知识

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨Billy Z@知乎(已授权) 来源丨https://zhua ...

  2. 【论文阅读】如何给模型加入先验知识

    如何给模型加入先验知识 1. 基于pretain模型给模型加入先验 把预训练模型的参数导入模型中,这些预训练模型在另一个任务中已经p retrain好了模型的weight,往往具备了一些基本图片的能力 ...

  3. 关于NLP相关技术全部在这里:预训练模型、图神经网络、模型压缩、知识图谱、信息抽取、序列模型、深度学习、语法分析、文本处理...

    NLP近几年非常火,且发展特别快.像BERT.GPT-3.图神经网络.知识图谱等技术应运而生. 我们正处在信息爆炸的时代.面对每天铺天盖地的网络资源和论文.很多时候我们面临的问题并不是缺资源,而是找准 ...

  4. 详解预训练模型、图神经网络、模型压缩、知识图谱、信息抽取、序列模型、深度学习、语法分析、文本处理...

    NLP近几年非常火,且发展特别快.像BERT.GPT-3.图神经网络.知识图谱等技术应运而生.我们正处在信息爆炸的时代.面对每天铺天盖地的网络资源和论文.很多时候我们面临的问题并不是缺资源,而是找准资 ...

  5. 详解NLP技术中的:预训练模型、图神经网络、模型压缩、知识图谱

    NLP近几年非常火,且发展特别快.像BERT.GPT-3.图神经网络.知识图谱等技术应运而生. 我们正处在信息爆炸的时代.面对每天铺天盖地的网络资源和论文.很多时候我们面临的问题并不是缺资源,而是找准 ...

  6. 基于图卷积网络的测量与先验知识相结合的故障诊断方法

    目录 Graph Convolutional Network-Based Method for Fault Diagnosis Using a Hybrid of Measurement and Pr ...

  7. 谷歌发布TensorFlow 1.4与TensorFlow Lattice:利用先验知识提升模型准确度 搜狐科技 10-12 15:29 选自:Google Research Blog 参与:李泽南、

    谷歌发布TensorFlow 1.4与TensorFlow Lattice:利用先验知识提升模型准确度 昨天,谷歌发布了 TensorFlow 1.4.0 先行版,将 tf.data 等功能加入了 A ...

  8. AU R-CNN:利用专家先验知识进行表情运动单元检测的R-CNN模型

    ©PaperWeekly 原创 · 作者|Chen Ma 学校|清华大学 研究方向|人脸识别和物体检测 这篇论文率先利用先验知识和物体检测技术做 Action Unit 人脸表情识别,在 BP4D 和 ...

  9. 如何将先验知识注入推荐模型

    看到知乎上的一个问题"如何向深度学习模型中加入先验知识?",觉得这是一个很好的问题,恰好自己在这方面有一些心得,今天拿出来和大家聊一聊. 说这个问题有趣,是因为提问者一定是对DNN ...

最新文章

  1. 产权分割商铺,太坑人!
  2. 程序调试的时候利用Call Stack窗口查看函数调用信息
  3. MLP is Best?
  4. arcserver连接oracle,ArcSDE的二种连接方式(应用服务器连接,直接连接)
  5. Java集合之LinkedHashMap源码分析
  6. 为什么现在那么多人都想做电商?
  7. typeof()用法及JS基本类型
  8. 【优化调度】基于matlab粒子群算法求解梯级水电站调度优化问题【含Matlab源码 065期】
  9. Rayson API 框架分析系列之1: 简介
  10. 红米note3android版本,小米-红米note3-LOS-安卓9.0.0-稳定版Stable3.0-来去电归属-农历等-本地化增强适配...
  11. Androidstudio 连接夜神模拟器
  12. 修真院java_【修真院JAVA小課堂】JMeter的簡單介紹
  13. 苹果对抗FBI 自由与限制的百年难题
  14. 如何使用百度天气预报API接口
  15. Oracle session active 和 inactive 状态 说明
  16. 在Word2010文档中设置和显示隐藏文字
  17. python异常模块raise的概念以及基本用法
  18. 计算机技术与软件证书用处,【考计算机技术与软件专业资格水平考试有什么用,各级别证书有什么用?】- 环球网校...
  19. win8计算机丢失xinput1+3.dll,win8提示xinput1 3.dll丢失的解决方法
  20. python生成随机密码生成器加特殊字符

热门文章

  1. 零售业小程序行业解决方案
  2. yeezy350灰橙_yeezy 350灰橙4.0什么时候发售 椰子350灰橙1.0、2.0和3.0对比赏析
  3. LeetCode - OrderMap - 715.Range模块
  4. MXNet作者李沐:我在CMU读博的这五年
  5. 硕士阶段总结《科苑行》之工作习惯
  6. 【MQ】MQ消息中间件RabbitMQ
  7. 帝国时代之罗马复兴玩法技巧
  8. CocosCreator中游戏摇杆的实现
  9. 从零开始学习SVM(二)---松弛变量
  10. php 爬取新闻,scrapy抓取学院新闻报告