文章目录

  • 前言
  • 一、任务说明
  • 二、几个概念的理解
  • 三、自定义数据集的处理和加载
    • 1.数据来源
    • 2.数据处理
      • 2.1定义labels
      • 2.2合并
    • 3.load和split
  • 四、训练
    • 1.定义模型和tokenizer
    • 2.load&map数据
    • 3.训练参数
  • 五、评价指标及效果
  • 总结

前言

Huggingface Transformers库是Huggingface公司在github上开源的基于Transformers结构开发的预训练NLP框架。其一直致力于大规模预训练模型应用的平民化,让开发者能够方便使用SOTA的模型,而非受困于训练的算力资源。其打造了一个开放的平台,提供了大量的Models,datasets以及如何快速使用的api,同时也欢迎开发者上传自己的模型和数据集。目前已有超过39000个models,在github上transformer库也超过61k星,可见其火热程度。随着transformer在其他领域出色的表现,目前Huggingface Transformers的任务和模型除了Natural Language Processing,也延申到Multimodal,Audio,Computer Vision,Reinforcement Learning等领域。


一、任务说明

Huggingface Transformers库 官方文档有较完整的说明以及相关视频教程。但一些细节处理问题并没有详细说明和例程,中文的应用也相对较少。所以本文以一个作词人风格分类的任务,把几个关键技术点串联起来,包括数据加载,处理,模型,分词,训练,评价,推理。

任务输入:一段歌词
任务输出:哪个作词人的风格。
如图:

text列是输入的歌词
prediction是模型预测的写词人风格
true_label是真实标签
score是置信度

二、几个概念的理解

dataset:数据集类,提供了很多数据处理接口,如load_dataset,map,split等,可以很方便高效地处理数据集
model:模型,可通过from_pretrained 方法直接加载huggingface上的模型
tokenizer:分词类,同样可通过from_pretrained方法直接加载。每种模型都有对应的vocab词典以及对应分词方法,所以这里加载的模型name或路径需要跟加载model时的参数保持一致。
Pipeline:快速推理通道,不是所有模型都支持
autoclass:主要分为两类:AutoTokenizer+和AutoModel+,通过这两类的from_pretrained来加载分词器和模型
Metrics:评价类,可通过load_metric加载预定义的Metrics。定义好评价函数后在调用训练接口Trainer时赋值给compute_metrics,则训练过程会计算相关的predictions
TrainingArguments:训练参数,搬运官网的说明:TrainingArguments is the subset of the arguments we use in our example scripts which relate to the training loop itself.
Trainer:训练类,参数包括模型,训练参数,训练集,验证集,compute_metrics
以上接口均可在官网doc里找到:https://huggingface.co/docs.

三、自定义数据集的处理和加载

加载自定义数据集:

1.数据来源

音乐数据,可在github搜一下QQMusicSpider

2.数据处理

2.1定义labels

为方便验证效果,筛选了6位优秀写词人的作品,分别是 [‘林夕’, ‘方文山’, ‘黄霑’, ‘罗大佑’, ‘李宗盛’, ‘黄伟文’],并以每个写词人命名json文件分开存放,如下图:

json文件格式如下图:

2.2合并

把所有写词人的数据合并到一个大的json里面。由于我们用于分类的歌词一般是歌词句子而不是整首歌词,所以把每首歌歌词进行了分句,并两两拼接。如下图:

3.load和split

huggingface的数据集是Dataset类型,所以需要将json数据转换为Dataset。这里使用transformer库自带的load_dataset方法进行转换。并且利用train_test_split对数据集按8:1:1进行训练集,验证集,测试集的切分。测试集不参与训练过程,只用于检查模型性能。

def json2tranformers_data(self):m_dataset = load_dataset('json', data_files=self.all_data_path)['train']# train_test_split默认能够shuffle数据,无需人工再打乱split_dataset = m_dataset.train_test_split(test_size=0.1)train_dataset = split_dataset['train']train_split_dataset = train_dataset.train_test_split(test_size=0.1)# 训练集保存到本地train_dataset = train_split_dataset['train']train_dataset.to_csv('data/train_data.csv')# 验证集保存到本地eval_dataset = train_split_dataset['test']eval_dataset.to_csv('data/eval_data.csv')# 测试集保存到本地test_dataset = split_dataset['test']test_dataset.to_csv('data/test_data.csv')

最后产生可在数据集下载

四、训练

1.定义模型和tokenizer

如何选择模型:在models里,筛选task为fill_mask,Languages为zh。如图:
选择一个作为我们finetune的模型,本文以hfl/rbt3为例,其是3层网络的 RoBERTa-wwm-ext 。原理不在这里展开,感兴趣可以看论文:https://arxiv.org/abs/1906.08101。注意每个model都有其对应的tokenizer方法,必须一一匹配。
定义模型和tokenizer的代码如下:

tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")
model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3", num_labels=6)

2.load&map数据

代码如下:

def preprocess_function(self, examples):return self.tokenizer(examples["text"], truncation=True)train_dataset = load_dataset('csv', data_files='data/train_data.csv')['train']
eval_dataset = load_dataset('csv', data_files='data/eval_data.csv')['train']tokenized_train_dataset = train_dataset.map(self.preprocess_function, batched=True)
tokenized_eval_dataset = eval_dataset.map(self.preprocess_function, batched=True)

3.训练参数

代码如下:

training_args = TrainingArguments(output_dir="./results",learning_rate=2e-5,per_device_train_batch_size=96,per_device_eval_batch_size=96,num_train_epochs=5,weight_decay=0.01,save_steps=2000)trainer = Trainer(model=model,args=training_args,train_dataset=tokenized_train_dataset,eval_dataset=tokenized_eval_dataset,tokenizer=tokenizer,data_collator=data_collator,
)trainer.train()

五、评价指标及效果

dataset类提供了许多评价方法,使用list_metrics()可查看

from datasets import list_metrics
metrics_list = list_metrics()
print(metrics_list)

我使用了load_metric(“accuracy”)的方法,代表的准确率:

def evaluate(self):from transformers.pipelines.base import KeyDatasettest_dataset = load_dataset('csv', data_files='data/test_data.csv')['train']references = [label['label'] for label in test_dataset]predictions = []kd = KeyDataset(test_dataset, "text")for out in tqdm(self.pipe(kd)):predictions.append(int(out['label'][-1]))accuracy_metric = load_metric("accuracy")results = accuracy_metric.compute(references=references, predictions=predictions)print(results)return results

最终,模型在测试集上的表现为91%:

总结

以上就是整个文本分类finetune的过程,本文仅仅介绍了transformer的文本分类训练过程和简单的应用。而没有对原理,模型,参数,优化方法进行深入的研究。未来有时间再根据实际任务进行探究。
训练的代码可参考lyric-classification

【 基于transformer的歌词分类】相关推荐

  1. 【Pytorch基础教程32】基于transformer的情感分类

    note: 常用的BERT模型其实就是transformer模型的编码器部分,用户为下游任务生成一段话的文本表示.BERT是一个无监督学习的过程,可通过MLM和NSP两种预训练任务实现无监督训练的过程 ...

  2. 基于Transformer实现更精准的脑出血多标签分类

    本文已在飞桨公众号发布,查看请戳链接: 基于Transformer实现更精准的脑出血多标签分类 灵医智惠是百度旗下深耕医疗领域的AI医疗品牌,多年来一直致力于将AI能力深度赋能医疗行业,加速智慧医疗产 ...

  3. 基于 Transformer 模型的电影评论情感分类

    # -*- coding: utf-8 -*- """论文代码 基于Transformer模型的电影评论感情分析 - 环境 tensorflow==2.7.0 GPUnu ...

  4. 刷爆 AI 圈!基于 Transformer 的 DALL-E 代码刚刚开源了

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自 | AI科技评论 OpenAI在1月5日公布DALL-E模型以 ...

  5. 无需卷积,完全基于Transformer的首个视频理解架构TimeSformer出炉

    选自Facebook AI 机器之心编译 编辑:小舟.陈萍 Facebook AI 提出新型视频理解架构:完全基于Transformer,无需卷积,训练速度快.计算成本低. TimeSformer 是 ...

  6. 重磅开源!首个基于Transformer的视频理解网络来啦!

    部分转载自:机器之心  |  编辑:小舟.陈萍 Facebook AI 提出新型视频理解架构:完全基于Transformer,无需卷积,训练速度快.计算成本低.最近由Facebook提出的首个完全基于 ...

  7. NLP——基于transformer 的翻译系统

    文章目录 基于transformer 的翻译系统 1. 数据处理 1.1 英文分词 1.2 中文分词 1.3 生成字典 1.4 数据生成器 2. 构建模型 2.1 构造建模组件 layer norm层 ...

  8. 最新综述:基于Transformer的NLP预训练模型已经发展到何种程度?

    ©作者 | 机器之心编辑部 来源 | 机器之心 Transformer 为自然语言处理领域带来的变革已无需多言.近日,印度国立理工学院.生物医学人工智能创业公司 Nference.ai 的研究者全面调 ...

  9. 【综述】基于Transformer的视频语言预训练

    关注公众号,发现CV技术之美 ▊ 1. 论文和代码地址 Survey: Transformer based Video-Language Pre-training 论文地址:https://arxiv ...

最新文章

  1. python和c学习-学习 Python与C相互调用
  2. 跨工厂物料状态/特定工厂的物料状态
  3. 笔记-信息化与系统集成技术-智慧城市建设参考模型
  4. 服务器上如何安装两个php网站,服务器安装两个php版本吗
  5. (转)Fiddler教程(Web调试工具)
  6. linux web 服务器性能,Linux系统Web服务器性能测试(2)
  7. python内函数名加括号和不加括号的区别,python中 函数名加括号与不加括号
  8. Java生产环境下性能监控与调优详解 第5章 Tomcat性能监控与调优
  9. 十大最受欢迎蓝牙耳机品牌推荐,学生党打工人平价蓝牙耳机
  10. 计算机wifi无法启动不了,电脑wifi启用不了怎么办
  11. Matlab实现数字图像处理——滤波
  12. Android简历附件2
  13. 英尺英寸和厘米的换算_C语言中关于英尺、英寸、厘米的换算
  14. win10开机无响应 无服务器,win10开机假死-状态栏和开始菜单无响应
  15. 实战:高级,高级 让 Kubectl的输出像彩虹一样绚丽多彩-2021.11.13
  16. 解决Docker镜像缺少字体的问题
  17. 苹果太狠了:升级iOS 8小心变砖
  18. 跟我一起玩转微信支付
  19. 计算机等级考试17周岁能考吗,他,8岁计算机过二级,16岁被保送清华,信息类竞赛大奖拿到手软...
  20. 一头扎进Shiro-HelloWorld

热门文章

  1. python判断两个对象是否为相等使用的运算符是_Python入门_浅谈逻辑判断与运算符...
  2. python计算机二级操作题详解(一)
  3. 团体程序设计天梯赛-练习集 L1-040 最佳情侣身高差
  4. Popeyes:姗姗迟来的洋快餐,凭什么敢称“炸鸡大师” | 知消观察
  5. 服务器系统陆金苹果下拉十,《龙吟大陆》10日凌晨4点新版本合服更新公告
  6. Python queue (队列)
  7. Cocos Store 五一大促来袭!首发新品大放价,还有超多好礼免费拿
  8. 计算机鼓励语,鼓励人坚持的经典语句 励志的话简短霸气
  9. Java初级工程师必读的书籍
  10. android开机自动启动程序