来源:Deep IMBA

本文约4800字,建议阅读9分钟

本文将讨论6种方法,使模型可以在保持旧的性能的同时适应新数据,并避免需要在整个数据集(旧+新)上进行重新训练。

持续学习是指在不忘记从前面的任务中获得的知识的情况下,按顺序学习大量任务的模型。这是一个重要的概念,因为在监督学习的前提下,机器学习模型被训练为针对给定数据集或数据分布的最佳函数。而在现实环境中,数据很少是静态的,可能会发生变化。当面对不可见的数据时,典型的ML模型可能会性能下降。这种现象被称为灾难性遗忘。

解决这类问题的常用方法是在包含新旧数据的新的更大数据集上对整个模型进行再训练。但是这种做法往往代价高昂。所以有一个ML研究领域正在研究这个问题,基于该领域的研究,本文将讨论6种方法,使模型可以在保持旧的性能的同时适应新数据,并避免需要在整个数据集(旧+新)上进行重新训练。

Prompt

Prompt 想法源于对GPT 3的提示(短序列的单词)可以帮助驱动模型更好地推理和回答。所以在本文中将Prompt 翻译为提示。提示调优是指使用小型可学习的提示,并将其与实际输入一起作为模型的输入。这允许我们只在新数据上训练提供提示的小模型,而无需再训练模型权重。

具体来说,我选择了使用提示进行基于文本的密集检索的例子,这个例子改编自Wang的文章《Learning to Prompt for continuous Learning》。

该论文的作者使用下图描述了他们的想法:

实际编码的文本输入用作从提示池中识别最小匹配对的key。在将这些标识的提示输入到模型之前,首先将它们添加到未编码的文本嵌入中。这样做的目的是训练这些提示来表示新的任务,同时保持旧的模型不变,这里提示的很小,大概每个提示只有20个令牌。

class PromptPool(nn.Module):def __init__(self, M = 100, hidden_size = 768, length = 20, N=5):super().__init__()self.pool = nn.Parameter(torch.rand(M, length, hidden_size), requires_grad=True).float()self.keys = nn.Parameter(torch.rand(M, hidden_size), requires_grad=True).float()self.length = lengthself.hidden = hidden_sizeself.n = Nnn.init.xavier_normal_(self.pool)nn.init.xavier_normal_(self.keys)def init_weights(self, embedding):pass# function to select from pool based on indexdef concat(self, indices, input_embeds):subset = self.pool[indices, :] # 2, 2, 20, 768subset = subset.to("cuda:0").reshape(indices.size(0),self.n*self.length,self.hidden) # 2, 40, 768return torch.cat((subset, input_embeds), 1)# x is cls outputdef query_fn(self, x):# encode input x to same dim as key using cosinex = x / x.norm(dim=1)[:, None]k = self.keys / self.keys.norm(dim=1)[:, None]scores = torch.mm(x, k.transpose(0,1).to("cuda:0"))# get argminsubsets = torch.topk(scores, self.n, 1, False).indices # k smallestreturn subsetspool = PromptPool()

然后我们使用的经过训练的旧数据模型,训练新的数据,这里只训练提示部分的权重。

def train():count = 0print("*********** Started Training *************")start = time.time()for epoch in range(40):model.eval()pool.train()optimizer.zero_grad(set_to_none=True)lap = time.time()for batch in iter(train_dataloader):count += 1q, p, train_labels = batchqueries_emb = model(input_ids=q['input_ids'].to("cuda:0"),attention_mask=q['attention_mask'].to("cuda:0"))passage_emb = model(input_ids=p['input_ids'].to("cuda:0"),attention_mask=p['attention_mask'].to("cuda:0"))      # poolq_idx = pool.query_fn(queries_emb)raw_qembedding = model.model.embeddings(input_ids=q['input_ids'].to("cuda:0"))q = pool.concat(indices=q_idx, input_embeds=raw_qembedding)p_idx = pool.query_fn(passage_emb)raw_pembedding = model.model.embeddings(input_ids=p['input_ids'].to("cuda:0"))p = pool.concat(indices=p_idx, input_embeds=raw_pembedding)qattention_mask = torch.ones(batch_size, q.size(1))pattention_mask = torch.ones(batch_size, p.size(1))queries_emb = model.model(inputs_embeds=q,attention_mask=qattention_mask.to("cuda:0")).last_hidden_statepassage_emb = model.model(inputs_embeds=p,attention_mask=pattention_mask.to("cuda:0")).last_hidden_stateq_cls = queries_emb[:, pool.n*pool.length+1, :]p_cls = passage_emb[:, pool.n*pool.length+1, :]loss, ql, pl = calc_loss(q_cls, p_cls)                    loss.backward()optimizer.step()optimizer.zero_grad(set_to_none=True)if count % 10 == 0:print("Model Loss:", round(loss.item(),4), \"| QL:", round(ql.item(),4), "| PL:", round(pl.item(),4), \"| Took:", round(time.time() - lap), "seconds\n")lap = time.time()if count % 40 == 0 and count > 0:print("model saved")torch.save(model.state_dict(), model_PATH)torch.save(pool.state_dict(), pool_PATH)if count == 4600: returnprint("Training Took:", round(time.time() - start), "seconds")print("\n*********** Training Complete *************")

训练完成后,后续的推理过程需要将输入与检索到的提示结合起来。例如这个例子得到了性能—93%的新数据提示池,而完全(旧+新)训练为—94%。这与原论文中提到的表现类似。但是需要说明的一点是结果可能会因任务而不同,你应该尝试实验来知道什么是最好的。

要使此方法成为值得考虑的方法,它必须能够在旧数据上保留老模型> 80%的性能,同时提示也应该帮助模型在新数据上获得良好的性能。

这种方法的缺点是需要使用提示池,这会增加额外的时间。这也不是一个永久的解决方案,但是目前来说是可行的,也或许以后还会有新的方法出现。

Data Distillation

你可能听说过知识蒸馏一词,这是一种使用来自教师模型的权重来指导和训练较小规模模型的技术。数据蒸馏(Data Distillation)的工作原理也类似,它是使用来自真实数据的权重来训练更小的数据子集。因为数据集的关键信号被提炼并浓缩为更小的数据集,我们对新数据的训练只需要提供一些提炼的数据以保持旧的性能。

在此示例中,我将数据蒸馏应用于密集检索(文本)任务。目前看没有其他人在这个领域使用这种方法,所以结果可能不是最好的,但如果你在文本分类上使用这种方法应该会得到不错的结果。

本质上,文本数据蒸馏的想法源于 Li 的一篇题为 Data Distillation for Text Classification 的论文,该论文的灵感来自 Wang 的 Dataset Distillation,他对图像数据进行了蒸馏。Li 用下图描述了文本数据蒸馏的任务:

根据论文,首先将一批蒸馏数据输入到模型以更新其权重。然后使用真实数据评估更新后的模型,并将信号反向传播到蒸馏数据集。该论文在 8 个公共基准数据集上报告了良好的分类结果(> 80% 准确率)。

按照提出的想法,我做了一些小的改动,使用了一批蒸馏数据和多个真实数据。以下是为密集检索训练创建蒸馏数据的代码:

class DistilledData(nn.Module):def __init__(self, num_labels, M, q_len=64, hidden_size=768):super().__init__()self.num_samples = Mself.q_len = q_lenself.num_labels = num_labelsself.data = nn.Parameter(torch.rand(num_labels, M, q_len, hidden_size), requires_grad=True) # i.e. shape: 1000, 4, 64, 768# init using model embedding, xavier, or load from state dictdef init_weights(self, model, path=None):if model:self.data.requires_grad = Falseprint("Init weights using model embedding")raw_embedding = model.model.get_input_embeddings()soft_embeds = raw_embedding.weight[:, :].clone().detach()nums = soft_embeds.size(0)for i1 in range(self.num_labels):for i2 in range(self.num_samples):for i3 in range(self.q_len):random_idx = random.randint(0, nums-1)self.data[i1, i2, i3, :] = soft_embeds[random_idx, :]print(self.data.shape)self.data.requires_grad = Trueif not path:nn.init.xavier_normal_(self.data)else:distilled_data.load_state_dict(torch.load(path), strict=False)# function to sample a passage and positive sample as in the article, i am doing dense retrievaldef get_sample(self, label):q_idx = random.randint(0, self.num_samples-1)sampled_dist_q = self.data[label, q_idx, :, :]p_idx = random.randint(0, self.num_samples-1)while q_idx == p_idx:p_idx = random.randint(0, self.num_samples-1)sampled_dist_p = self.data[label, p_idx, :, :]return sampled_dist_q, sampled_dist_p, q_idx, p_idx

这是将信号提取到蒸馏数据上的代码

def distll_train(chunk_size=32):count, times = 0, 0print("*********** Started Training *************")start = time.time()lap = time.time()for epoch in range(40):        distilled_data.train()for batch in iter(train_dataloader):count += 1# get real query, pos, label, distilled data query, distilled data pos, ... from batchq, p, train_labels, dq, dp, q_indexes, p_indexes = batchfor idx in range(0, dq['input_ids'].size(0), chunk_size):model.train()with torch.enable_grad():  # train on distiled data firstx1 = dq['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)x2 = dp['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)q_emb = model(inputs_embeds=x1.to("cuda:0"),attention_mask=dq['attention_mask'][idx:idx+chunk_size].to("cuda:0")).cpu()p_emb = model(inputs_embeds=x2.to("cuda:0"),attention_mask=dp['attention_mask'][idx:idx+chunk_size].to("cuda:0"))loss = default_loss(q_emb.to("cuda:0"), p_emb)del q_emb, p_embloss.backward(retain_graph=True, create_graph=False)state_dict = model.state_dict()# update model weightswith torch.no_grad():for idx, param in enumerate(model.parameters()):if param.requires_grad and not param.grad is None:param.data -= (param.grad*3e-5)# real datamodel.eval()q_embs = []p_embs = []for k in range(0, len(q['input_ids']), chunk_size):with torch.no_grad():q_emb = model(input_ids=q['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()p_emb = model(input_ids=p['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()q_embs.append(q_emb)p_embs.append(p_emb)q_embs = torch.cat(q_embs, 0)p_embs = torch.cat(p_embs, 0)r_loss = default_loss(q_embs.to("cuda:0"), p_embs.to("cuda:0"))del q_embs, p_embs# distill backwardif count % 2 == 0:d_grad = torch.autograd.grad(inputs=[x1.to("cuda:0")],#, x2.to("cuda:0")],outputs=loss,grad_outputs=r_loss)indexes = q_indexeselse:d_grad = torch.autograd.grad(inputs=[x2.to("cuda:0")],outputs=loss,grad_outputs=r_loss)indexes = p_indexesloss.detach()r_loss.detach()grads = torch.zeros(distilled_data.data.shape) # lbl, 10, 100, 768for i, k in enumerate(indexes):grads[train_labels[i], k, :, :] = grads[train_labels[i], k, :, :].to("cuda:0") \+ d_grad[0][i, :, :]distilled_data.data.grad = gradsdata_optimizer.step()data_optimizer.zero_grad(set_to_none=True)model.load_state_dict(state_dict)model_optimizer.step()model_optimizer.zero_grad(set_to_none=True)if count % 10 == 0:print("Count:", count ,"| Data:", round(loss.item(), 4), "| Model:", \round(r_loss.item(),4), "| Time:", round(time.time() - lap, 4))# print()lap = time.time()if count % 100 == 0:  torch.save(model.state_dict(), model_PATH)torch.save(distilled_data.state_dict(), distill_PATH)if loss < 0.1 and r_loss < 1:times += 1if times > 100:print("Training Took:", round(time.time() - start), "seconds")print("\n*********** Training Complete *************")returndel loss, r_loss, grads, q, p, train_labels, dq, dp, x1, x2, state_dictprint("Training Took:", round(time.time() - start), "seconds")print("\n*********** Training Complete *************")

这里省略了数据加载等代码,训练完蒸馏的数据后,我们可以通过在其上训练新模型来使用它,例如将其与新数据合并一起训练。

根据我的实验,一个在蒸馏数据上训练的模型(每个标签只包含4个样本)获得了66%的最佳性能,而一个完全在原始数据上训练的模型也是得到了66%的最佳性能。而未经训练的普通模型得到45%的性能。就像上面提到的这些数字对于密集检索任务可能不太好,分类数据上会好很多。

要使此方法成为在调整模型以适应新数据时值是一个有用的方法,需要能够提取出比原始数据小得多的数据集(即~ 1%)。经过提炼的数据也能够给你一个略低于或等于主动学习方法的表现。

这个方法的优点是可以创建用于永久使用的蒸馏数据。缺点是提取的数据没有可解释性,并且需要额外的训练时间。

Curriculum/Active training

Curriculum training是一种方法,训练时向模型提供训练样本的难度逐渐变大。在对新数据进行训练时,此方法需要人工的对任务进行标注,将任务分为简单、中等或困难,然后对数据进行采样。为了理解模型的简单、中等或困难意味着什么,我以这张图片为例:

这是在分类任务中的混淆矩阵,困难样本是假阳性(False Positive),是指模型预测为True的可能性很高,但实际上不是True的样本。中等样本是那些具有中到高的正确性可能性但低于预测阈值的True Negative。而简单样本则是那些可能性较低的True Positive/Negative。

Maximally Interfered Retrieval

这是 Rahaf 在题为“Online Continual Learning with Maximally Interfered Retrieval”的论文(1908.04742)中介绍的一种方法。主要思想是,对于正在训练的每个新数据批次,如果针对较新数据更新模型权重,将需要识别在损失值方面受影响最大的旧样本。保留由旧数据组成的有限大小的内存,并检索最大干扰的样本以及每个新数据批次以一起训练。

这篇论文在持续学习领域是一篇成熟的论文,并且有很多引用,因此可能适用于您的案例。

Retrieval Augmentation

检索增强(Retrieval Augmentation)是指通过从集合中检索项目来扩充输入、样本等的技术。这是一个普遍的概念而不是一个特定的技术。我们到目前为止所讨论的方法,大多数都在一定程度都是检索相关的操作。Izacard 的题为 Few-shot Learning with Retrieval Augmented Language Models 的论文使用更小的模型获得了出色的少样本 学习的性能。检索增强也用于许多其他情况,例如单词生成或回答事实问题。

扩展模型

在训练时使用附加层是最常见也最简单的方法,但是不一定有效,所以在这里不进行详细的讨论,这里的一个例子是 Lewis 的 Efficient Few-Shot Learning without Prompts。使用附加层通常是在新旧数据上获得良好性能的最简单但经过尝试和测试的方法。主要思想是保持模型权重固定,并通过分类损失在新数据上训练一层或几层。有兴趣可以参考他们的 Github

(https://github.com/huggingface/setfit)

总结

在本文中,我介绍了在新数据上训练模型时可以使用的 6 种方法。与往常一样应该进行实验并决定哪种方法最适合,但是需要注意的是,除了我上面的方法外还有很多方法,例如数据蒸馏是计算机视觉中的一个活跃领域,你可以找到很多关于它的论文。最后说明的一点是:要使这些方法有价值,它们应该在旧数据和新数据上同时获得良好的性能。

编辑:于腾凯

校对:林亦霖

持续学习常用6种方法总结:使ML模型适应新数据的同时保持旧数据的性能相关推荐

  1. js中当等于最小值是让代码不执行_网页中JS函数自动执行常用三种方法

    本文为大家分享了在网页中JS函数自动执行常用方法,供大家参考,具体内容如下 一.JS方法 1.最简单的调用方式,直接写到html的body标签里面: 2.在JS语句调用: function myfun ...

  2. elixir开发的项目_我对Elixir的介绍:学习另一种编程语言如何使您成为更好的开发人员...

    elixir开发的项目 by Nikolas O'Donnell 由Nikolas O'Donnell 我对Elixir的介绍:学习另一种编程语言如何使您成为更好的开发人员 (My intro to ...

  3. python缺失值与异常值处理_pandas学习(常用数学统计方法总结、读取或保存数据、缺省值和异常值处理)...

    pandas学习(常用数学统计方法总结.读取或保存数据.缺省值和异常值处理) 目录 常用数学统计方法总结 读取或保存数据 缺省值和异常值处理 常用数学统计方法总结 count 计算非NA值的数量 de ...

  4. Adapter-适配预训练持续学习的一种技术

    前言 长期做预训练的小伙伴,可以关注一下这个技术点即adapter,最近关于这方面的工作还挺多的,其是这样一个背景:在不遗忘以前学到知识前提下,怎么向大模型中持续性注入知识. 今天就给大家带来两篇最新 ...

  5. 数据分析常用三种方法

    数据分析常用三种方法:趋势分析.对比分析.细分分析 1. 趋势分析 趋势分析般而言,适用于产品核心指标的长期跟踪,比如,点击率,GMV,活跃用户数等.做出简单的数据趋势图,并不算是趋势分析,趋势分析更 ...

  6. 创建JSONArray的常用四种方法

     创建JSONArray的常用四种方法 1.从头或者从零开始,创建一个JSONArray(Creating a JSONArray from scratch) 实例1: Java代码  JSONA ...

  7. 【Arduino串口数据保存到excel中常用三种方法】

    [Arduino串口数据保存到excel中常用三种方法] 1. 前言 2. 利用excel自带Data Streamer读取 2.1 启用 Data Streamer 加载项 2.2 刷写代码并将微控 ...

  8. html网页自动运行函数,在网页中JS函数自动执行常用三种方法

    在网页中JS函数自动执行常用三种方法 在HTML中的Head区域中,有如下函数: functionn MyAutoRun() { //以下是您的函数的代码,请自行修改先! alert("函数 ...

  9. datawhale 10月学习——树模型与集成学习:两种并行集成的树模型

    前情回顾 决策树 CART树的实现 集成模式 结论速递 本次学习了两种并行集成的树模型,随机森林和孤立森林,并进行了相应的代码实践.其中对孤立森林的学习比较简略,有待后续补充. 这里写自定义目录标题 ...

最新文章

  1. 八种基本类型的包装类你真的懂了?
  2. 用Core Temp查看服务器CPU温度
  3. Python网络数据采集
  4. 【数据结构与算法】之深入解析“回文数”的求解思路和算法示例
  5. SSD浅层网络_目标检测SSD
  6. C++中两个常用的控制语句格式的函数(width和precision函数)
  7. iPhone 11办理联通5G套餐后,上网速度变快?网友:发广告翻车了?
  8. 中断数周之后 微软网站恢复销售华为笔记本电脑
  9. 单片机数码管00 99c语言,STC89C52单片机数码管显示00~99,间隔1S程序
  10. C语言实战项目:学生管理系统
  11. 系统集成项目管理工程师14 总结
  12. 服务器系统小米随身wifi,Mac OS10.13正常使用的小米随身WIFI无线驱动 | 陳松's 博客...
  13. springboot 整合 ueditor 并实现文件上传(自定义上传路径)
  14. java chmod 777_java中 执行shell中的chmod 777命令,出现Caused by: java.io.IOException: Permission denied???...
  15. 查看CPU物理核数和逻辑核数
  16. 即便到愚人节,也千万别做的恶作剧!
  17. grpc-gateway 返回值中默认值为什么不显示?
  18. 开/闭环控制的直流调速系统
  19. 精简jre(JDK6瘦身)
  20. oracle数据库密码如果忘了怎么办?(修改密码和用户解锁)

热门文章

  1. 快学Scala 读书笔记之 Chapter 2、3、4(控制结构函数,数组,映射,元组)
  2. 【它山之玉】Trump:让人们发出噢、啊的惊叹声!—科学网马臻
  3. abp框架学习笔记(三)--Angular和前端
  4. 计算机网络 标性能指标
  5. Spring,SpringBoot,Springcloud都是干嘛的?
  6. FMODE学习之-------第一站
  7. 【历史上的今天】6 月 8 日:万维网之父诞生;PHP 公开发布;iPhone 4 问世
  8. 网络营销实战课-笔记5
  9. day22-网络爬虫2
  10. sa-token使用简单使用