Bert/RoBerta 微调笔记

  • 前言
  • 为什么要进行微调?
  • 怎么微调?参数的设置?
  • 问题:
    • (1)Bert/RoBerta所有参数是不是都要训练?
    • (2)微调Bert/RoBerta时,无法载入全部参数报错CUDA out of memory
    • (3)如何冻结模型参数?
    • (4)如何保存fine-tune好的BERT/ROBERTA模型参数,以及如何在特征提取阶段使用这些参数?
    • (5)逐层微调
  • 以MELD数据集为例的微调RoBerta的代码

前言

本文记录我在学习BERT/ROBERTA fine-tuning过程的遇到的问题,包括内存受限,微调概念,微调方法等。文章方法不适用于逐层微调,且只以NLP文本分类举例,微调代码参考github link
更新: (5)逐层微调

为什么要进行微调?

任务数据集与BERT/ROBERTA预训练的数据集差异较大,用微调使预训练模型(PLM)能更好地适应数据集。在实际任务中,用官方下载的PLM参数进行文本特征提取,提取出来的特征效果特别差,例如ROBERTA提取出来的特征经过训练,同一模型的f1score低于原论文10%以上。而进行微调之后能与原论文持平。

怎么微调?参数的设置?

1.第一种将PLM模型后接分类器,直接训练完整的网络。以[1e-6,2e-5,3e-5]相同量级的小学习率训练PLM,epoch小于10完全可以训练出一个好的PLM,再保存PLM的参数。在特征提取时,Bert/RoBerta作为特征提取器载入已保存的参数提取文本特征。本文实现的也是这种方式。
2.第二种PLM后接完整的自己的模型,逐层冻结PLM层参数,一层一层地调Bert/RoBerta的参数,因为在Bert/RoBerta中越底层,参数变化越不明显,所以需要调节的参数都在高层。

问题:

(1)Bert/RoBerta所有参数是不是都要训练?

在fine-tune阶段整个bert模型参数都需要训练。

(2)微调Bert/RoBerta时,无法载入全部参数报错CUDA out of memory

首先确定batch_size要足够小,我在3090/24G显存上fine-tune,设置的bacth_size=1勉强可以跑得动
其次是代码的复杂度,不要有多余的参数放到cuda()里,在每个epoch后都torch.cuda.empty_cache()清理缓存

fro e in range(config['epoch']):torch.cuda.empyt_cache()# train, valid, test dateset...e += 1

然后是数据集的格式,之前使用csv格式逐条读取,只能用24GB的3090训练,后面换成txt文件读取速度和内存占用有很大区别。
或者在test or eval()阶段可以添加with torch.no_grad()使模型在预测阶段不计算参数梯度来减少内存的使用量。

(3)如何冻结模型参数?

有很多种方法,遍历模型参数列表,使模型参数梯度计算False;

for param in model.parameters():param.requires_grad = False

在预测阶段,模型所有参数都不需要梯度计算,都可以冻结,就可以用with torch.no_grad()

# 训练阶段需要计算梯度
model.train()
log_prob = model(**kwargs)
# 验证,测试阶段不需要梯度
model.eval()
with torch.no_grad():log_prob = model(**kwargs)

(4)如何保存fine-tune好的BERT/ROBERTA模型参数,以及如何在特征提取阶段使用这些参数?

保存、载入模型参数方法来自于:solution from github
保存阶段:

save_path = './saved_models/myModel.pth'
# finetune过程跳过
torch.save({'model_state_dict': model.encoder.state_dict()}, save_path)
# 其中self.encoder = AutoModel.from_pretrain('./roberta=large')

调用阶段:

checkpoint = torch.load('../saved_models/mymodel.pth')
model = AutoModel.from_pretrained('../roberta-large')
model.load_state_dict(checkpoint['model_state_dict'])
model.cuda()

(5)逐层微调

方法:在模型参数初始化下面,添加需要冻结的层名字,或者添加不需要冻结的层名字,以roberta-large为例,冻结列表所列出以外的roberta层。

class testModel(nn.Module):def __init__(self):self.encoder = AutoModel.from_pretrained('./roberta-large')self.classifier = nn.Sequential(nn.Linear(args.emb_dim, 300),nn.ReLU(),nn.Linear(300, n_class))unfreeze_layers = ['layer.17', 'layer.18', 'layer.19', 'layer.20','layer.21', 'layer.22', 'layer.23','bert.pooler', 'out.']for name, param in self.encoder.named_parameters():param.requires_grad = Falsefor ele in unfreeze_layers:if ele in name:param.requires_grad = True

或者冻结指定层,其他层参数默认需要训练

     freeze_layers = ['layer.1', 'layer.2']for name, param in self.encoder.named_parameters():for ele in freeze_layers:if ele in name:param.requires_grad = False

以MELD数据集为例的微调RoBerta的代码

import torch
import random
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence
from sklearn.metrics import precision_recall_fscore_support
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmupconfig = {'bert_path': './roberta-large','dataset_path': './data/MELD','saved_model': './saved_models/mymodel.pth',  # 这里要注意一下'emb_dim': 1024,'n_class': 7,'batch': 1,'epoch': 10,'lr': 1e-6,'max_grad_norm': 10
}roberta_tokenizer = AutoTokenizer.from_pretrained(config['bert_path'])class MELD_loader(Dataset):def __init__(self, txt_file, dataclass):self.dialogs = []f = open(txt_file, 'r')dataset = f.readlines()f.close()temp_speakerList = []context = []context_speaker = []self.speakerNum = []# 'anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'emodict = {'anger': "anger", 'disgust': "disgust", 'fear': "fear", 'joy': "joy", 'neutral': "neutral",'sadness': "sad", 'surprise': 'surprise'}self.sentidict = {'positive': ["joy"], 'negative': ["anger", "disgust", "fear", "sadness"],'neutral': ["neutral", "surprise"]}self.emoSet = set()self.sentiSet = set()for i, data in enumerate(dataset):if i < 2:continueif data == '\n' and len(self.dialogs) > 0:self.speakerNum.append(len(temp_speakerList))temp_speakerList = []context = []context_speaker = []continuespeaker, utt, emo, senti = data.strip().split('\t')context.append(utt)if speaker not in temp_speakerList:temp_speakerList.append(speaker)speakerCLS = temp_speakerList.index(speaker)context_speaker.append(speakerCLS)self.dialogs.append([context_speaker[:], context[:], emodict[emo], senti])self.emoSet.add(emodict[emo])self.sentiSet.add(senti)self.emoList = sorted(self.emoSet)self.sentiList = sorted(self.sentiSet)if dataclass == 'emotion':self.labelList = self.emoListelse:self.labelList = self.sentiListself.speakerNum.append(len(temp_speakerList))def __len__(self):return len(self.dialogs)def __getitem__(self, idx):return self.dialogs[idx], self.labelList, self.sentidictdef encode_right_truncated(text, tokenizer, max_length=511):'''完成分词工作:return:'''tokenized = tokenizer.tokenize(text)truncated = tokenized[-max_length:]ids = tokenizer.convert_tokens_to_ids(truncated)return [tokenizer.cls_token_id] + idsdef padding(ids_list, tokenizer):max_len = 0for ids in ids_list:if len(ids) > max_len:max_len = len(ids)pad_ids = []for ids in ids_list:pad_len = max_len - len(ids)add_ids = [tokenizer.pad_token_id for _ in range(pad_len)]pad_ids.append(ids + add_ids)return torch.tensor(pad_ids)def make_batch_roberta(sessions):'''collate_fn:return:'''batch_input, batch_labels, batch_speaker_tokens = [], [], []for session in sessions:data = session[0]label_list = session[1]context_speaker, context, emotion, sentiment = datanow_speaker = context_speaker[-1]speaker_utt_list = []inputString = ""for turn, (speaker, utt) in enumerate(zip(context_speaker, context)):inputString += '<s' + str(speaker + 1) + '> '  # s1, s2, s3...inputString += utt + " "if turn < len(context_speaker) - 1 and speaker == now_speaker:speaker_utt_list.append(encode_right_truncated(utt, roberta_tokenizer))concat_string = inputString.strip()batch_input.append(encode_right_truncated(concat_string, roberta_tokenizer))if len(label_list) > 3:label_ind = label_list.index(emotion)else:label_ind = label_list.index(sentiment)batch_labels.append(label_ind)batch_speaker_tokens.append(padding(speaker_utt_list, roberta_tokenizer))batch_input_tokens = padding(batch_input, roberta_tokenizer)batch_labels = torch.tensor(batch_labels)return batch_input_tokens, batch_labels, batch_speaker_tokensdef CELoss(pred_outs, labels):"""pred_outs: [batch, clsNum]labels: [batch]"""loss = nn.CrossEntropyLoss()loss_val = loss(pred_outs, labels)return loss_valdef _CalACC(model, dataloader):model.eval()correct = 0label_list = []pred_list = []# label arragnewith torch.no_grad():for i_batch, data in tqdm(enumerate(dataloader), desc='testing is on...'):"""Prediction"""batch_input_tokens, batch_labels, batch_speaker_tokens = databatch_input_tokens, batch_labels = batch_input_tokens.cuda(), batch_labels.cuda()pred_logits = model(batch_input_tokens)  # (1, clsNum)"""Calculation"""pred_label = pred_logits.argmax(1).item()true_label = batch_labels.item()pred_list.append(pred_label)label_list.append(true_label)if pred_label == true_label:correct += 1acc = correct / len(dataloader)return acc, pred_list, label_listclass ft_model(nn.Module):def __init__(self):super(ft_model, self).__init__()self.context_model = AutoModel.from_pretrained(config['bert_path'])self.classifier = nn.Sequential(nn.Linear(config['emb_dim'], 300),nn.ReLU(),nn.Linear(300, config['n_class']))def forward(self, batch_input_tokens):batch_context_output = self.context_model(batch_input_tokens).last_hidden_state[:, 0, :]logits = self.classifier(batch_context_output)return logitsif __name__ == '__main__':torch.cuda.empty_cache()make_batch = make_batch_robertatrain_path = config['dataset_path'] + '/MELD_train.txt'dev_path = config['dataset_path'] + '/MELD_dev.txt'test_path = config['dataset_path'] + '/MELD_test.txt'train_dataset = MELD_loader(train_path, 'emotion')train_dataloader = DataLoader(train_dataset, batch_size=config['batch'], shuffle=True,num_workers=4, collate_fn=make_batch)dev_dataset = MELD_loader(dev_path, 'emotion')dev_dataloader = DataLoader(dev_dataset, batch_size=config['batch'], shuffle=False,num_workers=4, collate_fn=make_batch)test_dataset = MELD_loader(test_path, 'emotion')test_dataloader = DataLoader(test_dataset, batch_size=config['batch'], shuffle=False,num_workers=4, collate_fn=make_batch)model = ft_model().cuda()model.train()# training processnum_warmup_steps = len(train_dataset)num_training_steps = len(train_dataset) * config['epoch']train_sample_num = int(len(train_dataloader))optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps,num_training_steps=num_training_steps)best_dev_fscore, best_test_fscore = 0, 0best_dev_fscore_macro, best_dev_fscore_micro,\best_test_fscore_macro, best_test_fscore_micro = 0, 0, 0, 0best_epoch = 0for epoch in range(config['epoch']):model.train()for i_batch, data in tqdm(enumerate(train_dataloader), desc='training is on ...'):if i_batch > train_sample_num:print(i_batch, train_sample_num)break"""Prediction"""batch_input_tokens, batch_labels, batch_speaker_tokens = databatch_input_tokens, batch_labels = batch_input_tokens.cuda(), batch_labels.cuda()pred_logits = model(batch_input_tokens)"""Loss calculation & training"""loss_val = CELoss(pred_logits, batch_labels)loss_val.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),config['max_grad_norm'])  # Gradient clipping is not in AdamW anymore (so you can use amp without issue)optimizer.step()scheduler.step()optimizer.zero_grad()model.eval()dev_acc, dev_pred_list, dev_label_list = _CalACC(model, dev_dataloader)dev_pre, dev_rec, dev_fbeta, _ = precision_recall_fscore_support(dev_label_list, dev_pred_list,average='weighted')test_acc, test_pred_list, test_label_list = _CalACC(model, test_dataloader)test_pre, test_rec, test_fbeta, _ = precision_recall_fscore_support(test_label_list, test_pred_list,average='weighted')"""Best Score & Model Save"""if test_fbeta > best_test_fscore:best_test_fscore = test_fbetabest_epoch = epochtorch.save({'model_state_dict': model.context_model.state_dict(),}, config['saved_model'])print('epoch: {}, accuracy: {}, precision: {}, recall: {}, fscore: {}'.format(epoch + 1, test_acc, test_pre, test_rec, test_fbeta))print('Final Fscore ## test-fscore: {}, test_epoch: {}'.format(best_test_fscore, best_epoch))

Pytorch实现Bert/RoBerta微调(以MELD数据集为例)相关推荐

  1. pytorch搭建卷积网络(以minist数据集为例)以及如何查看输出每层的权重和特征图

    1.总的程序 # -*- coding: utf-8 -*- """ Created on Sun Jul 18 15:19:41 2021@author: pony & ...

  2. 在MELD数据集上利用BERT得到句向量表示

    目标数据格式: [[{"text": "Hi, I need an ID.", "speaker": "Ses05F_impro0 ...

  3. Pytorch+Google BERT模型(RoBERTa+LSTM+GRU)实战

    Pytorch+Google BERT模型(RoBERTa+LSTM+GRU)实战 BERT(Bidirectional Encoder Representations from Transforme ...

  4. 【Pytorch】BERT+LSTM+多头自注意力(文本分类)

    [Pytorch]BERT+LSTM+多头自注意力(文本分类) 2018年Google提出了BERT[1](Bidirectional Encoder Representations from Tra ...

  5. 学习笔记:深度学习(8)——基于PyTorch的BERT应用实践

    学习时间:2022.04.26~2022.04.30 文章目录 7. 基于PyTorch的BERT应用实践 7.1 工具选取 7.2 文本预处理 7.3 使用BERT模型 7.3.1 数据输入及应用预 ...

  6. [学习日志]使用pytorch 和 bert 实现一个简单的文本分类任务

    项目简介 最近在学习pytorch和Bert,所以做了一个这样完全新手向的入门项目来练习. 由于之前在网上学习发现现存的教程比较少,所以记录一下自己的学习过程,加深印象,也希望能帮到别的学习者吧,能涨 ...

  7. 【小白学习PyTorch教程】七、基于乳腺癌数据集​​构建Logistic 二分类模型

    「@Author:Runsen」 在逻辑回归中预测的目标变量不是连续的,而是离散的.可以应用逻辑回归的一个示例是电子邮件分类:标识为垃圾邮件或非垃圾邮件.图片分类.文字分类都属于这一类. 在这篇博客中 ...

  8. EasyBert,基于Pytorch的Bert应用

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx EasyBert 基于Pytorch的Bert应用,包括命名实体识别.情感分析.文本分类以及文 ...

  9. 使用pytorch获取bert词向量 将字符转换成词向量

    转载保存: 使用pytorch获取bert词向量_海蓝时见鲸_的博客-CSDN博客_获取bert词向量 pytorch-pretrained-bert简单使用_风吹草地现牛羊的马的博客-CSDN博客_ ...

最新文章

  1. BEC攻击危害惊人 3年造成23亿美元损失
  2. Oracle 原理: PL/SQL基础
  3. java的import关键字的使用
  4. Azure系列2.1.13 —— CloudBlockBlob
  5. 基于Matlab的跨孔电磁波\跨孔雷达的胖射线追踪(一)
  6. 高级程序员证书_过了而立之年的程序员应该何去何从?
  7. 数据分析师熬夜整理:最全「零售业」数据指标和使用技巧
  8. linux是发展历史,linux发展历史.doc.doc
  9. eclipse提交代码到github其他分支
  10. linux shell 命令批量杀死进程
  11. 在阿里云注册域名后怎样进行网站的备案流程
  12. 怎么解决在微信中不能直接下载APP(APK)的方案
  13. 推荐 :数据科学研究的现状与趋势
  14. 不久的明天,也许是很光明的
  15. DBCO-PEG-DPPE DBCO-二棕榈酰基磷脂酰乙醇胺-聚乙二醇
  16. Java实现微信公众号每日推送
  17. 小程序、APP Store 需要的 SSL 证书是个什么东西?
  18. php mail 垃圾邮件,如何避免我的邮件从PHP邮件()被标记为垃圾邮件? - 程序园
  19. 实现阿拉伯数字转中文大写
  20. Github配置ssh key【不用密码访问Github上代码】

热门文章

  1. 使用Aspose组件将WORD、PDF、PPT转为图片
  2. 【MySQL】SHOW WARNINGS和SHOW ERRORS的作用是什么?
  3. 骁龙8gen1Plus和骁龙8gen1区别
  4. WPS 多文档独立显示
  5. matlab频谱分析中振幅的物理意义,频谱图分析的意义,频谱图的物理意义是什么呢,频率的振幅能够反映什么物理意义呢?(例如下图)请哪位大师指点,不胜感激!...
  6. 如何在WPS中打开多个窗口
  7. 接入层交换机、汇聚层交换机和核心层交换机的区别
  8. 办公族如何防治鼠标手?
  9. DTI及MRI数据预处理
  10. 键盘右Crtl键变成了鼠标键效果的解决办法