在机器学习中,我们通常致力于针对单个任务,也就是优化单个指标。但是多任务学习(MTL)在机器学习的许多应用中都取得了成功,从自然语言处理和语音识别到计算机视觉和药物发现。

MTL最著名的例子可能是特斯拉的自动驾驶系统。在自动驾驶中需要同时处理大量任务,如物体检测、深度估计、3D重建、视频分析、跟踪等,你可能认为需要10个以上的深度学习模型,但事实并非如此。

HydraNet介绍

一般来说多任务学的模型架构非常简单:一个骨干网络作为特征的提取,然后针对不同的任务创建多个头。利用单一模型解决多个任务。

上图可以看到,特征提取模型提取图像特征。输出最后被分割成多个头,每个头负责一个特定的情况,由于它们彼此独立可以单独进行微调!

特斯拉的讲演中详细的说明这个模型(youtube:v=3SypMvnQT_s)

多任务学习项目

在本文中,我们将介绍如何在Pytorch中实现一个更简单的HydraNet。这里将使用UTK Face数据集,这是一个带有3个标签(性别、种族、年龄)的分类数据集。

我们的HydraNet将有三个独立的头,它们都是不同的,因为年龄的预测是一个回归任务,种族的预测是一个多类分类问题,性别的预测是一个二元分类任务。

每一个Pytorch 的深度学习的项目都应该从定义Dataset和DataLoader开始。

在这个数据集中,通过图像的名称定义了这些标签,例如UTKFace/30_0_3_20170117145159065.jpg.chip.jpg

  • 30岁是年龄
  • 0为性别(0:男性,1:女性)
  • 3是种族(0:白人,1:黑人,2:亚洲人,3:印度人,4:其他)

所以我们的自定义Dataset可以这样写:

 class UTKFace(Dataset):def __init__(self, image_paths):self.transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])self.image_paths = image_pathsself.images = []self.ages = []self.genders = []self.races = []for path in image_paths:filename = path[8:].split("_")if len(filename)==4:self.images.append(path)self.ages.append(int(filename[0]))self.genders.append(int(filename[1]))self.races.append(int(filename[2]))def __len__(self):return len(self.images)def __getitem__(self, index):img = Image.open(self.images[index]).convert('RGB')img = self.transform(img)age = self.ages[index]gender = self.genders[index]eth = self.races[index]sample = {'image':img, 'age': age, 'gender': gender, 'ethnicity':eth}return sample

简单的做个介绍:

__init__

方法初始化我们的自定义数据集,负责初始化各种转换和从图像路径中提取标签。

__get_item__

将:它将加载一张图像,应用必要的转换,获取标签,并返回数据集的一个元素,也就是说这个方法会返回数据集中的单条数据(单个样本)

然后我们定义dataloader

 train_dataloader = DataLoader(UTKFace(train_dataset), shuffle=True, batch_size=BATCH_SIZE)val_dataloader = DataLoader(UTKFace(valid_dataset), shuffle=False, batch_size=BATCH_SIZE)

下面我们定义模型,这里使用一个预训练的模型作为骨干,然后创建3个头。分别代表年龄,性别和种族。

 class HydraNet(nn.Module):def __init__(self):super().__init__()self.net = models.resnet18(pretrained=True)self.n_features = self.net.fc.in_featuresself.net.fc = nn.Identity()self.net.fc1 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))self.net.fc2 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))self.net.fc3 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 5))]))def forward(self, x):age_head = self.net.fc1(self.net(x))gender_head = self.net.fc2(self.net(x))ethnicity_head = self.net.fc3(self.net(x))return age_head, gender_head, ethnicity_head

forward方法返回每个头的结果。

损失作为优化的基础时十分重要的,因为它将会影响到模型的性能,我们能想到的最简单的事就是地把损失相加:

 L = L1 + L2 + L3

但是我们的模型中

L1:与年龄相关的损失,如平均绝对误差,因为它是回归损失。

L2:与种族相关的交叉熵,它是一个多类别的分类损失。

L3:性别有关的损失,例如二元交叉熵。

这里损失的计算最大问题是损失的量级是不一样的,并且损失的权重也是不相同的,这是一个一直在被深入研究的问题,我们这里暂不做讨论,我们只使用简单的相加,所以我们的一些超参数如下:

 model = HydraNet().to(device=device)ethnicity_loss = nn.CrossEntropyLoss()gender_loss = nn.BCELoss()age_loss = nn.L1Loss()sig = nn.Sigmoid()optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.09)

然后我们训练的循环如下:

 for epoch in range(n_epochs):model.train()total_training_loss = 0for i, data in enumerate(tqdm(train_dataloader)):inputs = data["image"].to(device=device)age_label = data["age"].to(device=device)gender_label = data["gender"].to(device=device)eth_label = data["ethnicity"].to(device=device)optimizer.zero_grad()age_output, gender_output, eth_output = model(inputs)loss_1 = ethnicity_loss(eth_output, eth_label)loss_2 = gender_loss(sig(gender_output), gender_label.unsqueeze(1).float())loss_3 = age_loss(age_output, age_label.unsqueeze(1).float())loss = loss_1 + loss_2 + loss_3loss.backward()optimizer.step()total_training_loss += loss

这样我们最简单的多任务学习的流程就完成了

关于损失的优化

多任务学习的损失函数,对每个任务的损失进行权重分配,在这个过程中,必须保证所有任务同等重要,而不能让简单任务主导整个训练过程。手动的设置权重是低效而且不是最优的,因此,自动的学习这些权重是十分必要的,

Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics cvpr_2018

这篇论文提出,将不同的loss拉到统一尺度下,这样就容易统一,具体的办法就是利用同方差的不确定性,将不确定性作为噪声,进行训练

End-to-End Multi-Task Learning with Attention cvpr_2019

这篇论文提出了一种可以自动调节权重的机制( Dynamic Weight Average),使得权重分配更加合理,大概的意思是每个任务首先计算前个epoch对应损失的比值,然后除以一个固定的值T,进行exp映射后,计算各个损失所占比

最后如果你对多任务学习感兴趣,可以先看看这篇论文:

A Survey on Multi-Task Learning arXiv 1707.08114

从算法建模、应用和理论分析的角度对MTL进行了调查,是入门的最好的资料。

https://avoid.overfit.cn/post/57d4e8712c634fe887247ce66e694f8f

作者:Alessandro Lamberti

Pytorch创建多任务学习模型相关推荐

  1. 【项目实战课】从零掌握安卓端Pytorch原生深度学习模型部署

    欢迎大家来到我们的项目实战课,本期内容是<从零掌握安卓端Pytorch原生深度学习模型部署>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战 ...

  2. 多任务学习模型之ESMM介绍与实现

    简介:本文介绍的是阿里巴巴团队发表在 SIGIR'2018 的论文<Entire Space Multi-Task Model: An Effective Approach for Estimat ...

  3. IJCAI 2019 | 为推荐系统生成高质量的文本解释:基于互注意力机制的多任务学习模型...

    编者按:在个性化推荐系统中,如果能在提高推荐准确性的同时生成高质量的文本解释,将更容易获得用户的"芳心".然而,现有方法通常将两者分开优化,或只优化其中一个目标.为了同时兼顾二者, ...

  4. 多任务学习模型ESMM原理与实现(附代码)

    来源:DataFunTalk 本文约2500字,建议阅读5分钟 文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型. [ 导读 ] 本文介绍的是 ...

  5. 排序层-深度模型-2020:PLE【多任务学习模型】【腾讯】

    PLE模型是腾讯发表在RecSys '20上的文章,这篇paper获得了recsys'20的best paper award,也算为腾讯脱离技术贫民的大业添砖加瓦了.这篇文章号称极大的缓解了多任务学习 ...

  6. pytorch argmax_PyTorch深度学习模型的服务化部署

    本文将介绍如何使用Flask搭建一个基于PyTorch的图片分类服务以及并行处理的相关技术.作为一个深度学习工程师,学习这些内容是为了方便对服务化的模型进行debug,因为web开发的同时常常表示他们 ...

  7. MMOE——多任务学习模型

    摘要 对于多任务学习,我们的目标是建立一个单一的模型,同时学习这些多个目标和任务.然而,常用的多任务模型的预测质量往往对任务之间的关系比较敏感.因此,研究任务特定目标和任务间关系之间的建模权衡是很重要 ...

  8. Pytorch 深度强化学习模型训练速度慢

    最近一直在用Pytorch来训练深度强化学习模型,但是速度一直很慢,Gpu利用率也很低. 一.起初开始在训练参数 batch_size = 200, graph_size = 40, epoch_si ...

  9. Sunny.Xia的深度学习(四)MMOE多任务学习模型实战演练

    本专栏文章会在本博客和知乎专栏--Sunny.Xia的深度学习同步更新,对于评论博主若未能够及时回复的,可以知乎私信.未经本人允许,请勿转载,谢谢. 一.什么是MMOE? 三张图分别是多任务模型的不同 ...

  10. 推荐系统遇上深度学习(九十二)-[腾讯]RecSys2020最佳长论文-多任务学习模型PLE

    今天介绍的是腾讯提出的一种新的多任务学习个性化推荐模型,该论文荣获了RecSys2020最佳长论文奖,一起来学习下! 1.背景 多任务学习通过在一个模型中同时学习多个不同的目标,如CTR和CVR,最近 ...

最新文章

  1. “让技术做好事”,最特殊的创投在大湾区成立
  2. 基于CefSharp构建基于Chromium的应用程序
  3. 计算机四级的英文,计算机四级考试中英文术语对照
  4. 精通ASP.NET MVC ——辅助器方法
  5. python中serial模块的使用_python中pyserial模块使用方法
  6. systemverilog编译介绍
  7. 基于机器学习的源代码分类
  8. 算法分析神器—时间复杂度
  9. windows whistler系统安装
  10. 网络爬虫+数据可视化
  11. Learun敏捷框架甘特图——摆脱项目管理的泥沼
  12. BAT机器学习面试1000题系列大集合整理(320)
  13. Chrome google浏览器从缓存下载视频
  14. 如何在计算机添加打印机驱动程序,教你如何安装打印机驱动程序
  15. 新版标准日本语中级_第二十三课
  16. 随机指标计算机程序,MACD/随机指标组合应用分析
  17. JAVA之旅(三十五)——完结篇,终于把JAVA写完了,真感概呐!
  18. C语言进阶——指针笔试题图解
  19. Windows系统封装(三)安装软件和系统优化清理。
  20. R手册(Common)--R语言入门

热门文章

  1. 冬训成果何在?林丹无缘新赛季首冠状态成迷
  2. 移动终端开发详解总结(一)(kotlin版)| CSDN创作打卡
  3. 解决cnzz加载时间长的问题
  4. ips细胞最新进展:利用iPS细胞成功培养出抑制宫颈癌繁殖的免疫杀伤T细胞,有望实现宫颈癌的免疫细胞疗法
  5. emc re 整改 超标_EMC测试及整改办法
  6. H5唤醒支付宝登录授权
  7. leetcode 没有php,Leetcode PHP题解--D99 860. Lemonade Change
  8. Vulhub安装过程记录(包括kali快速安装,一个apache中间件漏洞测试)
  9. 计算机二级是专业技术职务吗,计算机二级算中级技能证吗
  10. C程序~一元二次方程求解