↑↑↑关注后"星标"Datawhale

每日干货 & 每月组队学习,不错过

Datawhale干货

作者:欧明锋,浙江大学

导读:在实际的深度学习项目中,难免遇到多个相似数据集,这时一次仅用单个数据集训练模型,难免造成局限。是否存在利用多个数据集训练的可能性?本文带来解读。

01 介绍

迄今为止,在深度学习领域,最流行的范式或者大家最常用的范式是端到端学习范式。

我们可以把该范式简单概括为四个步骤:准备数据,喂入网络数据,神经网络优化,最后评估模型。这个范式确实也在各个领域取得了巨大成功。

然而,当我们在做一些实际的工程应用时,一项任务可能有多个相似数据集,比如在宠物分类的Dogs vs Cats, Oxford-IIIT Pet数据集,交通车辆检测的BDD100k,KITTI-object等数据集。通常的做法是一次仅选择其中的一个进行各种模型训练,这不仅浪费了其他的数据集,也同时给模型带来局限。

因此,我们可能会问这样一个问题:为什么只使用一个数据集来训练神经网络模型?

这是我在Graviti作为算法实习生,与leader以及导师一起完成的一项研究工作,已经被ICML2021接受了,非常感谢Datawhale给我向大家分享论文。今天的分享简单分为 介绍(包括movivation,related work等等),方法,实验验证,最后的结论 四个部分。

回到正题,针对上面的问题,那肯定要利用起多个数据集的。

有些数据集可以轻松融合在一起,因为他们有重叠的标签,就像下面这两个traffic相关的数据集有共同的标签类 person和bike, 但有些不能,我们认为其中一个主要的瓶颈之一是标签差异,标签集存在不同的语义层次或粒度。

就像这里底部宠物数据集的例子,数据集a标签是猫狗等,数据集b标签是一些猫狗的品种如布偶猫,萨摩耶等,因为两个数据集的标签粒度存在差异,导致其无法直接融合。

事实上,确实有些前人的工作涉及该方面, 我将这些工作主要分为了两类:1.是左边的直接融合,直接在标签空间进行,这要求标签的一致性,这通常可以通过伪标签的方式进行;2.是右边的间接融合,它可以抽象为通过共享的隐藏向量空间进行数据集融合,相应的算法框架涉及迁移学习、领域自适应等。

而我们的思路是从数据集的语义信息角度出发, 由于具有相似目的的数据集其标签在领域知识是具有的语义关联,所以我们就通过构造一个统一的知识驱动的标签图来在标签空间中直接进行数据集融合。

这里举了个具体的例子,左边的部分是动物领域的三个相似的数据集及其标签集,由于这些标签集之间的语义层次和粒度不同,它们无法轻松融合。然而,在通过标签集之间的语义关系建立标签图之后,这些数据集成功地连接起来,三个数据集就被组合成一个单一的数据集。

更具体地来说,左边是传统的未融合数据集的示例,几个相似的数据集,但标签集之间存在差异,每个数据集对应一个单标签预测模型的训练过程。右边我们提出的方法,我们将这些数据集连接在一起,驱动模型预测 标签图上以目标节点为终点的整个轨迹,而不是单一的标签预测。

我们模型的基本架构就是特征提取网络接上序列生成网络,即Encoder-Decoder的结构。

介绍部分就到这里,接下来是方法部分。

02 方法

首先是图谱构建的流程,这里其实是展示了一个抽象化的流程。这里假设对两个数据集的标签来构建图谱, 这两个数据集分别假设为:

  • 猫狗二分类数据集

  • 猫狗的细粒度品种分类数据集

构建步骤抽象为以下四个步骤, 1.首先是添加根节点,就是黄色的动物节点;2. 所有数据集的标签节点,就是绿色的节点;3. 以及代表属性特征的扩展节点,即蓝色的节点;4. 最后连接边。

但实际上这个图的构建过程是更为具体和直接的,因为这个图其实不是我们构造的,而是通过 “窃取”来的。因为这个标签图本质上是从相关的领域几十年来积累的领域知识中获得的。

以猫的品种分类为例:

首先,我们将cat设置为根节点,接着我们从Purina这样的领域网站上发现了三种类型的coat特性。因此,我们添加它们作为增强节点来表示猫的一方面外观特征;其次,我们check了coat field中的对应框“Short”,发现了许多短毛品种,并将它们放置在增强节点shorthair下。通过类似的方式,就可以构建出一张很大的或者说完整的标签图。

同时在刚刚的这个过程中,我们很容易发现,构造过程类似于人类在执行分类时的决策方式。当我们人看到一种动物时,我们首先根据它的全局特征来判断它的大致类别,然后仔细观察它的局部特征来确定它细分的品种。

也就是说在我们的方法中,模型在执行推理时,标签图其实提供了一个“决策过程”。

此外,我们认为这种方法是象征主义和连接主义的结合。也就是说,我们将几十年积累起来的领域知识归纳为一个深度神经网络模型。

为了更好地捕捉下方标签图上同一层级节点间的关系,我们定义了竞争节点的概念。

定义u和w是竞争节点,当且仅当u和w有着共同的祖先节点,并且它们在分类法上是互斥的。

针对竞争节点,我们提出了block-softmax;因为对于一般softmax,所有类别都在相互竞争。但是,在我们的体系结构中,竞争关系仅存在于竞争节点之间。因此做了一个block的限制,从而将相对概率的计算限制到了每个竞争节点组内。右图就是一个对比示意图:

说完节点来到路径,我们也定义了确定性和不确定性路径来分别处理 类别具有确定性以及不确定特征 的情况。首先是确定性路径,它的定义如这里所示,比较抽象,我们就直接来看一个具体的例子:

给定标签节点v和经过该节点的路径P(v),如果不存其他路径P′(v)满足条件:∃ u∈P(v),w∈P^′(v), u,w形成竞争节点并且u ≠w 则P(v)是确定性路径。

右图中的一个例子就是动物-猫-〉短毛->英国短毛猫, 之所以说这条路径是确定的是因为,所有的英国短毛猫都是短毛的。

首先是确定性路径的训练,我们采用了Teacher forcing的训练策略, 该流程如右图所示,对于确定性ground truth路径P,我们将其视为一个序列,让循环单元自回归地预测序列上的每个节点, 然后我们就能得到如下的损失函数,(本质上就是最大化整条正确路径的概率),从而反向传播并优化。

然后是关于非确定性路径。给定路径锚定(anchoring)标签节点,,如果存一条其他路径满足条件:,,,形成竞争节点并且 ,则是非确定性路径。

右图中有三条不确定性路径,被标记为红色。因为英国短发猫的毛色模式可以是纯色、重点色、虎斑色中的任意一种。因此,经过这三个节点到英国短毛节点的路径都是不确定的。

由于其路径中的不确定节点导致teacher forcing策略无法正常使用,所以我们采用了Reinforce算法。首先我们定义了一个激励函数,即“模型采样的生成路径”和“ground truth标签节点集”之间交集的归一化大小。进而定义出了损失函数,其实本质上就是最大化采样生成路径的期望奖励,能够通过最后一个式子估计出不确定性路径的梯度,具体的推导请参考reinforce的论文。

然后我们最终的训练策略的话其实就是在一个batch中依次进行确定性和非确定性路径的训练,具体详细的训练流程就不在这里说了,有兴趣的可以看一下我们论文中的伪代码。

03 实验

实验部分我们分别在单标签图像和文本分类任务上进行的。

首先,关于数据集设置,分为三组:

第一组是关于宠物分类,第二组是关于花分类, 第三组是对arxiv文章进行学科分类,arxiv学科的标签其实是有层级的,比如第一级cs,第二级 ml,arxiv augment就只保留了其最高层级的标签。

前两组的标签图都是我们通过现有的领域知识构建的,arxiv那一组标签其实是有层级的,比如第一级cs,第二级 ml,就直接将层级关系展开为标签图。

组1和组3对应于细粒度和粗粒度数据集的融合,并且数据集之间没有标签重叠, 组2对应于在相同粒度级别上标注的两个数据集的融合,其中重叠标签数量为8

出于评估目的,我们的测试都是在难度更大的细粒度数据集上进行的:

然后,是关于模型的设置的。

首先是baseline, 在图像分类中,有三种。1.传统的单标签预测模型 2.基于伪标签的融合数据集,即为粗数据集中的样本生成细粒度伪标签,并将这些样本合并到细粒度数据集中。3.它是一个多标签分类设置,采用了之前工作中的一个关键实验。而在文本分类任务中,基线是传统的单标签预测模型。

然后是我们的模型。其中对于Encoder,图像分类任务中使用EfficientNet-b4而文本分类任务使用Bert或LSTM作为特征提取器,对于Decoder使用GRU, 并且在图像分类任务中融合了注意力模块来帮助GRU单元在不同的step关注到图像中不同位置的信息。

然后是实验的主要结果。从表中可以看出两点:

1.如红色虚线框中对比数据所示,即使没有额外数据集的帮助,简单地将标签扩展为标签关系图,再加上我们的训练策略,表现仍然会有所提升。因为将标签扩展为标签关系图,其实本质上就是一种数据增强的方式,只是与传统的数据增强方法集中于数据本身上不同,本文增强了标签之间的关系,或者另一种角度来看本文为每个标签的样本又引入了额外的标签,即额外的监督信息。

2.如绿色虚线框中的对比数据所示,使用本文所提出的方法要优于直接融合,以及基于伪标签融合的方法,同时也要优于传统的单标签预测模型,说明了我们方法在标签空间进行数据集融合的可行性。

更重要的是,我们的方法具有增强的可解释性。为了说明这一点,我们以波斯猫为例,波斯猫用红色虚线椭圆标记,波斯猫的毛色模式是重点色或纯色,这是不确定的。该模型通过确定性的重点色和纯色的猫类样本来学习这两种颜色模式的特征,应用在不确定性路径样本的推理上,从而区分波斯猫中不同毛色模式的样本。这就像之前说的,我们的标签图其实就是为我们的模型在推理时提供了决策过程的过程,从而使其更具有可解释性。实验部分到此结束。

04 结论

在这项工作中,我们研究了数据集连接的问题,更具体地说是在标签系统不一致时的标签集连接问题。我们提出了一个新的框架来解决这个问题,包括标签空间扩充、递归神经网络、序列训练和策略梯度。经过训练的模型在性能和可解释性方面都显示出良好的结果。

当然这项工作只是一个多数据集连接初步的探索, 其中还有很多问题可以研究解决,包括以下:

  • 图谱质量的如何衡量,

  • 如何构建更加鲁棒的方法来适应的有噪声标签关系图,

  • 融合后数据集产生的分布偏移问题该如何解决,

同时直接还有很多可扩展的方向,包括:

  • 伪标签方法相结合

  • 在其他任务如目标检测、分割上进行探索

以上的话就是对我们这项工作的整体介绍,关于该项工作的更多细节可以去arxiv上看看我们的paper。

相关链接:数据驱动的算法工程落地

干货学习,三连

作者解读ICML接收论文:如何使用不止一个数据集训练神经网络模型?相关推荐

  1. 【学术相关】作者解读ICML接收论文:如何使用不止一个数据集训练神经网络模型?...

    作者:欧明锋,浙江大学 导读:在实际的深度学习项目中,难免遇到多个相似数据集,这时一次仅用单个数据集训练模型,难免造成局限.是否存在利用多个数据集训练的可能性?本文带来解读. 01 介绍 迄今为止,在 ...

  2. 定点 浮点 神经网络 量化_神经网络模型量化论文小结

    神经网络模型量化论文小结 发布时间:2018-07-22 13:25, 浏览次数:278 现在"边缘计算"越来越重要,真正能落地的算法才是有竞争力的算法.随着卷积神经网络模型堆叠的 ...

  3. 神经网络模型量化论文小结

    现在"边缘计算"越来越重要,真正能落地的算法才是有竞争力的算法.随着卷积神经网络模型堆叠的层数越来越多,网络模型的权重参数数量也随之增长,专用硬件平台可以很好的解决计算与存储的双重 ...

  4. 卷积神经网络模型解读及数学原理 ——翻拍图片识别

    目录 一.需求背景 二.知识储备 1.深度学习 2.卷积神经网络 3.PyTorch框架 4.张量 5.梯度下降法 三.模型解读 1.输入层 2.隐藏层 1)卷积层 2)激活函数 3)池化层 4)流向 ...

  5. ICML 2021论文数据分析:谷歌第一,国内北大论文最多

    转自:机器之心 ICML 2021 官方公布了接收论文结果,共有 5513 篇论文投稿,共有 1184 篇被接收(包括 1018 篇短论文和 166 篇长论文),接受率 21.48%. 这应该是 IC ...

  6. 清北超越剑桥,谷歌全球霸榜,百度领衔中国公司,ICML 2020论文数排名公布

    萧箫 发自 凹非寺 量子位 报道 | 公众号 QbitAI AI领域,谁能跻身前列? 就在刚刚,ICML 2020论文数排名统计出炉. 今年ICML论文接收率为21.8%,相较于去年基本持平. 而在今 ...

  7. 接收率25.9%,ICCV 2021接收论文列表放出,你中了吗?

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 数小时前,ICCV 2021 官方放出了接收论文 ID 列表,在 6236 篇有效提交论文中,有 16 ...

  8. CVPR2021最新接收论文合集!22个方向100+篇论文汇总

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 导读 CVPR2021结果已出,本文为CVPR最新接收论文的资源汇总贴,附有相关文章与代码链接. 官网 ...

  9. ICML 2020论文贡献榜排名出炉:Google单挑斯坦福、MIT、伯克利;清华进TOP 20

    来源:新智元 本文约2800字,建议阅读6分钟. ICML 2020论文贡献榜排名出炉,斯坦福则获高校第一.国内高校.企业上榜. [ 导读 ] ICML 2020论文贡献榜排名出炉,Google在众多 ...

最新文章

  1. 干掉 ZooKeeper?阿里为什么不用 ZK 做服务发现?
  2. mysql数据库导入到excel表格数据_[转载]将EXCEL表格中的数据导入mysql数据库表中(两种方法)...
  3. c语言i o编程,C 语言输入输出 (I/O)
  4. 计算机系统组织结构,第4章 操作系统计算机组织结构.ppt
  5. 显示日期的指令: date
  6. LeetCode 1490. 克隆 N 叉树(DFS/BFS)
  7. Ubuntu 10.04下安装jekyll
  8. winlogon病毒清除
  9. JavaScript 基础知识 表达式和运算符
  10. 自学python要多久-自学Python要学会需要多久?老男孩Python培训班
  11. iOS开发日记29-UIAlertController
  12. 如何在iPhone和Android上使用Instagram效果
  13. Iptables-外网地址及端口映射到内网地址及端口
  14. Unity开发手机游戏从第一行代码到Taptap上线,我经历了什么
  15. 家庭数据中心-私有云服务器定义和选择
  16. ip地址、子网掩码及ip地址的相关计算
  17. GO语言数据结构之队列
  18. 报表相关的同比和环比
  19. 【采用】大数据风控---风险量化和风险定价
  20. 输入一个8bit数,输出其中1的个数。如果只能使用1bit全加器,最少需要几个,请使用verilog进行描述?(附verilog代码)

热门文章

  1. Android Java使用JavaMail API发送和接收邮件的代码示例
  2. LR常见的报错处理方法
  3. Word英文字符间距太大 中英文输入切换都不行
  4. 1282. Game Tree
  5. 创建ASP.NET WEB自定义控件——例程2
  6. 【直播】黎佳佳:音频数据分析以及特征提取
  7. 【怎样写代码】确保对象的唯一性 -- 单例模式(二):解决方案
  8. 【C++】stack的部分使用(之后会不定时进行更新)
  9. Python 自动化办公之 Excel 对比工具
  10. 留不住客户?该从你的系统上找找原因了