作者:清华大数据软件团队机器学习组

本文长度为1700字,建议阅读6分钟

本文为你介绍 Trans-Learn 算法库。

Trans-Learn是基于PyTorch实现的一个高效、简洁的迁移学习算法库,目前发布了第一个子库——深度域自适应算法库(DALIB),支持的算法包括:

  • Domain Adversarial Neural Networks (DANN)

  • Deep Adaptation Network (DAN)

  • Joint Adaptation Networks (JAN)

  • Conditional Adversarial Domain Adaptation (CDAN)

  • Maximum Classifier Discrepancy (MCD)

  • Margin Disparity Discrepancy (MDD)

项目地址:

https://github.com/thuml/Transfer-Learning-Library

域自适应背景介绍

目前深度学习模型在部分计算机视觉、自然语言处理任务中已经超过了人类的表现,但是它们的成功依赖于大规模的数据标注。但是实际场景中,标注数据往往是稀缺的。解决标注数据稀缺问题的一个方法是通过计算机模拟生成训练数据,例如用计算机图形学的技术合成训练数据。

图表 1 VisDA2017竞赛任务

但是由于训练数据和测试数据不再服从独立同分布,训练得到的深度网络的准确率大打折扣。为了解决上述数据漂移造成的问题,域自适应(Domain Adaptation) 的概念被提出。域自适应的目标是将模型在源域(Source) 学到的知识迁移到目标域(Target)。例如计算机模拟生成训练数据的例子中,合成数据是源域,真实场景的数据是目标域。

域自适应有效地缓解了深度学习对于人工标注数据的依赖,受到了学术界和工业界广泛的关注。目前已经被引入到图片分类、图像分割(Segmentation)、目标检测(Object Detection)、机器翻译(Machine Translation) 等众多任务上。吴恩达曾说过:“在监督学习之后,迁移学习将引领下一波机器学习技术商业化浪潮。”随着产品级的机器学习应用进入数据稀缺的领域,监督学习得到的尖端模型性能大打折扣,域自适应变得至关重要。

研究现状

深度域自适应方法主要包括以下两大类:

1. 矩匹配。通过最小化分布差异来对齐不同域的特征分布。例如深度适配网络DAN,联合适配网络JAN。

2. 对抗训练域对抗网络DANN是最早的工作,它引入一个领域判别器,鼓励特征提取器学到领域无关的特征。 在DANN的基础上,衍生出了一系列方法,例如条件域对抗网络CDAN,间隔差异散度MDD等。

图表 2 DANN网络架构图

图表 3 MDD网络架构图

上述方法在实验数据上体现了良好的性能。然而目前学术界域自适应方法的开源实现中存在下述问题:

  • 复用性差。域自适应方法和模型架构、数据集耦合在一起,不利于域自适应方法在新的模型、数据集上复用。

  • 稳定性差。部分对抗训练方法随着训练进行,准确率会大幅度下降。

DALIB设计的初衷就是让用户通过少数几行代码,就可以将域自适应算法用在实际项目中,而无需考虑域自适应模块的实现细节。

易用性

DALIB将现有域自适应训练代码中的域自适应损失函数分离出来,按照PyTorch交叉熵损失函数的形式进行封装,方便用户的使用。域自适应损失函数也和模型架构进行了解耦,因此不依赖于具体的分类任务,所以算法库很容易扩展到图片分类以外的分类任务。

如下,使用两行代码即可定义一个与任务无关的域对抗损失函数。

不同域自适应损失函数中有一些公用的模块,例如所有算法中都用到的分类器模块,对抗训练中用到的梯度翻转模块、域判别器模块,核方法中的核函数模块等。这些公用模块和提供的域自适应损失函数是分离的。因此,在DALIB中,用户可以像搭积木一样,重新定制自己需要的域自适应损失函数。

例如,核方法中,用户可以自己定义不同参数的高斯核或者其他核函数,然后传入到多核最大均值差异(MK-MMD)的计算中。

目前,所有的模块和损失函数均已提供详细的API说明文档。

https://dalib.readthedocs.io/en/latest/

稳定性

域自适应算法研究领域往往关注方法的创新程度或者理论层面的价值,而忽视了工程实现中的稳定性和可复现性。在复现现有的算法的过程中,出现了部分算法准确率不稳定的问题。通过对数值方面的改进,这些问题都已经得到解决。(具体实现就不在此处展开了。)

此外,DALIB几乎在所有任务上,准确率都比原论文汇报准确率高,部分数据集上甚至能高14%。下图分别是Office-31和VisDA-2017上的测试结果。

图表 4 Office-31上不同算法的准确率

图表 5 VisDA2017上不同算法的准确率

算法库提供了各个算法在Office-31、Office-Home和VisDA-2017上的测试结果,以及所有的测试脚本。我们认为开源该算法库对于这个领域未来的研究工作是具有巨大价值的。

未来的工作

域自适应算法子库DALIB下一个版本会支持域自适应算法的不同设定,包括部分域自适应任务(Partial Domain Adaptation)、开放集域自适应任务(Open-set Domain Adaptation)、通用域自适应任务(Universal Domain Adaptation)等。

迁移学习算法库Trans-Learn目前还处于初期开发阶段,难免有不完善的地方,欢迎其他研究者提意见。同时迁移学习这个方向也还在不断发展,今后会不断跟进新工作中比较好的算法。

当前版本由龙明盛老师课题组的江俊广和付博同学开发,如果有任何意见和建议,欢迎联系JiangJunguang1123@outlook.com

fb1121@vip.qq.com

编辑:于腾凯

校对:林亦霖

原创 | 清华开源迁移学习算法库相关推荐

  1. 分布对齐 目标函数 迁移学习_原创 | 清华开源迁移学习算法库

    本文长度为1700字,建议阅读6分钟 本文为你介绍 Trans-Learn 算法库. Trans-Learn是基于PyTorch实现的一个高效.简洁的迁移学习算法库,目前发布了第一个子库--深度域自适 ...

  2. 2019 outlook 数据迁移_清华开源迁移学习算法库

    清华大学龙明盛老师课题组长期致力于迁移学习研究.今天,我们很高兴地宣布,我们开源了基于PyTorch实现的一个高效.简洁的迁移学习算法库--Trans-Learn. 目前我们发布了第一个子库--深度域 ...

  3. 清华大学开源迁移学习算法库:基于PyTorch实现,支持轻松调用已有算法

    机器之心报道 编辑:魔王 作者:清华大学大数据研究中心 近日,清华大学大数据研究中心机器学习研究部开源了一个高效.简洁的迁移学习算法库 Transfer-Learn,并发布了第一个子库--深度领域自适 ...

  4. 清华大学开源迁移学习算法库:基于PyTorch实现已有算法

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习> ...

  5. 简单易用高性能!一文了解开源迁移学习框架EasyTransfer

    简介:近日,阿里云正式开源了深度迁移学习框架EasyTransfer,这是业界首个面向NLP场景的深度迁移学习框架.该框架由阿里云机器学习PAI团队研发,让自然语言处理场景的模型预训练和迁移学习开发与 ...

  6. 迁移学习算法之TrAdaBoost ——本质上是在用不同分布的训练数据,训练出一个分类器...

    迁移学习算法之TrAdaBoost from: https://blog.csdn.net/Augster/article/details/53039489 TradaBoost算法由来已久,具体算法 ...

  7. 再无需从头训练迁移学习模型!亚马逊开源迁移学习数据库 Xfer

    雷锋网 AI 科技评论按:所谓的「迁移学习」,是指重新利用已训练的机器学习模型来应对新任务的技术.它给深度学习领域带来了许多好处,最明显的是,一旦无需从头开始训练模型,我们可以省下大量的计算.数据以及 ...

  8. 再无需从头训练迁移学习模型!亚马逊开源迁移学习数据库 Xfer...

    雷锋网 AI 科技评论按:所谓的「迁移学习」,是指重新利用已训练的机器学习模型来应对新任务的技术.它给深度学习领域带来了许多好处,最明显的是,一旦无需从头开始训练模型,我们可以省下大量的计算.数据以及 ...

  9. 清华开源深度学习框架计图,开源超级玩家再进阶

    2020-03-22 09:24 导语:清华开源计图,背后是三代人的共同努力. 雷锋网AI源创评论报道,据官方消息,清华大学计算机系图形实验室宣布开源一个全新的深度学习框架:Jittor,中文名计图. ...

最新文章

  1. Nature:人工甜味剂改变小鼠肠道菌群组成及功能
  2. python中字符串的布尔值_Python基础之字符串,布尔值,整数,列表,元组,字典,集合...
  3. T400的5100无线网卡在Centos下跑起来了
  4. DataTable的AcceptChange方法为什么不能在Update之前?
  5. Android 下拉刷新
  6. Python网络爬虫--Scrapy使用IP代理池
  7. javaScript——原型
  8. NOIP2011题目简析
  9. 【Codeforces AIM Tech Round 4 (Div. 2) C】
  10. 使用Git进行Vivado版本控制
  11. Cimage类的介绍及使用
  12. Tomcat9的安装和配置
  13. Python 代码库之Tuple如何append添加元素
  14. matlab中的对数函数,[matlab对数函数]对数函数运算法则是什么呢?
  15. Excel VBA简介
  16. vue-lazyload图片懒加载的简单使用
  17. 移动端h5文字长按复制_H5实现移动端复制文字功能
  18. jdk9 jdk10 jdk11启动rocketMQ的问题
  19. linux dnf命令安装
  20. 微信模版消息发送失败

热门文章

  1. 初识 Knative: 跨平台的 Serverless 编排框架
  2. 洛谷P2252 取石子游戏(威佐夫博弈)
  3. Discuz学习总结——部分bug解决方案
  4. 文件查找和压缩——Linux基本命令(12)
  5. Html5 Json应用
  6. 深入解析Android关机
  7. Target host is not specified错误
  8. google appengine的yaml文件,配置说明
  9. 134安装教程_PS教程连载第135课:PS第三方插件安装方法
  10. 华中科技大学计算机科学卓越班,2016年华中科技大学光电信息科学与工程(卓越计划实验班)专业在江苏录取分数线...