@Author:Runsen

BERT模型在NLP各项任务中大杀四方,那么我们如何使用这一利器来为我们日常的NLP任务来服务呢?首先介绍使用BERT做文本多标签分类任务。

文本多标签分类是常见的NLP任务,文本介绍了如何使用Bert模型完成文本多标签分类,并给出了各自的步骤。

参考官方教程:https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html

复旦大学邱锡鹏老师课题组的研究论文《How to Fine-Tune BERT for Text Classification?》。

论文: https://arxiv.org/pdf/1905.05583.pdf

这篇论文的主要目的在于在文本分类任务上探索不同的BERT微调方法并提供一种通用的BERT微调解决方法。这篇论文从三种路线进行了探索:

  • (1) BERT自身的微调策略,包括长文本处理、学习率、不同层的选择等方法;
  • (2) 目标任务内、领域内及跨领域的进一步预训练BERT;
  • (3) 多任务学习。微调后的BERT在七个英文数据集及搜狗中文数据集上取得了当前最优的结果。

作者的实现代码: https://github.com/xuyige/BERT4doc-Classification

数据集来源:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset?select=train.csv

该数据集包含 6 个不同的标签(计算机科学、物理、数学、统计学、生物学、金融),以根据摘要和标题对研究论文进行分类。
标签列中的值 1 表示标签属于该标签。每个论文有多个标签为 1。

Bert模型加载

Transformer 为我们提供了一个基于 Transformer 的可以微调的预训练网络。

由于数据集是英文, 因此这里选择加载bert-base-uncased。

具体下载链接:https://huggingface.co/bert-base-uncased/tree/main

from transformers import BertTokenizerFast as BertTokenizer
# 直接下载很很慢,建议下载到文件夹中
# BERT_MODEL_NAME = "bert-base-uncased"
BERT_MODEL_NAME = "model/bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)

微调BERT模型

bert微调就是在预训练模型bert的基础上只需更新后面几层的参数,这相对于从头开始训练可以节省大量时间,甚至可以提高性能,通常情况下在模型的训练过程中,我们也会更新bert的参数,这样模型的性能会更好。

微调BERT模型主要在D_out进行相关的改变,去除segment层,直接采用了字符输入,不再需要segment层。

下面是微调BERT的主要代码

class BertClassifier(nn.Module):def __init__(self, num_labels: int, BERT_MODEL_NAME, freeze_bert=False):super().__init__()self.num_labels = num_labelsself.bert = BertModel.from_pretrained(BERT_MODEL_NAME)#  hidden size of BERT, hidden size of our classifier, and number of labels to classifyD_in, H, D_out = self.bert.config.hidden_size, 50, num_labels# Instantiate an one-layer feed-forward classifierself.classifier = nn.Sequential(nn.Dropout(p=0.3),nn.Linear(D_in, H),nn.ReLU(),nn.Dropout(p=0.3),nn.Linear(H, D_out),)# lossself.loss_func = nn.BCEWithLogitsLoss()if freeze_bert:print("freezing bert parameters")for param in self.bert.parameters():param.requires_grad = Falsedef forward(self, input_ids, attention_mask, labels=None):outputs = self.bert(input_ids, attention_mask=attention_mask)last_hidden_state_cls = outputs[0][:, 0, :]logits = self.classifier(last_hidden_state_cls)if labels is not None:predictions = torch.sigmoid(logits)loss = self.loss_func(predictions.view(-1, self.num_labels), labels.view(-1, self.num_labels))return losselse:return logits

其他

关于数据预处理,DataLoader等代码有点多,这里不一一列举,需要代码的在公众号回复:”bert“ 。

最后的训练结果如下所示:

【小白学习PyTorch教程】十六、在多标签分类任务上 微调BERT模型相关推荐

  1. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型...

    「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...

  2. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型

    @Author:Runsen 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先了解 ...

  3. 【小白学习PyTorch教程】十五、BERT:通过PyTorch来创建一个文本分类的Bert模型

    @Author:Runsen 2018 年,谷歌发表了一篇题为<Pre-training of deep bidirectional Transformers for Language Unde ...

  4. 【小白学习PyTorch教程】十九、 基于torch实现UNet 图像分割模型

    @Author:Runsen 在图像领域,除了分类,CNN 今天还用于更高级的问题,如图像分割.对象检测等.图像分割是计算机视觉中的一个过程,其中图像被分割成代表图像中每个不同类别的不同段. 上面图片 ...

  5. 【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类

    「@Author:Runsen」 上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类. ResNet是 Residual Networks 的缩写,是一种经典的神经网络,用作许多计算 ...

  6. 【小白学习PyTorch教程】十、基于大型电影评论数据集训练第一个LSTM模型

    「@Author:Runsen」 本博客对原始IMDB数据集进行预处理,建立一个简单的深层神经网络模型,对给定数据进行情感分析. 数据集下载 here. 原始数据集,没有进行处理here. impor ...

  7. 【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

    「@Author:Runsen」 GAN 是使用两个神经网络模型训练的生成模型.一种模型称为生成网络模型,它学习生成新的似是而非的样本.另一个模型被称为判别网络,它学习区分生成的例子和真实的例子. 生 ...

  8. 【小白学习PyTorch教程】九、基于Pytorch训练第一个RNN模型

    「@Author:Runsen」 当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词.这可以称为记忆.卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网 ...

  9. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...

    「@Author:Runsen」 上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. imp ...

最新文章

  1. 直播APP常用动画效果
  2. 多线程-单生产单消费模型
  3. python tkinter布局混用_[宜配屋]听图阁
  4. 微信公众嵌套页面里再嵌入其他页面的一些问题
  5. Spring MVC学习总结(15)——SpringMVC之国际化简单实现
  6. 事业单位资产管理系统广西某单位案例:实现资产动态全过程管理
  7. Markdown中如何添加特殊符号
  8. 使用BOOTICE 恢复系统启动项
  9. Store generated project files externally
  10. 东南亚电商龙头 shopee 社招,校招 内推(长期有效)
  11. QlikView介绍
  12. python爬虫与java爬虫的区别_java爬虫(一)主流爬虫框架的基本介绍
  13. pow函数以及math.h的一些坑
  14. oracle中treat函数,PL/SQL Challenge 每日一题:2016-3-24 面向对象编程:向下转型TREAT...
  15. Qt 串口获取串口设备名称
  16. UG快捷键使用技巧总结(补充中....)
  17. SVM支持向量机——MATLAB在数学建模中的应用
  18. ◆聚会时可以玩的游戏◆
  19. 什么是我的java.net.SocketException:连接重置?
  20. Python爬虫-面向对象-《传闻中的陈芊芊》豆瓣热评

热门文章

  1. 触摸屏通常接在微型计算机,计算机应用基础习题答案.doc
  2. TVS二极管电压6V,有哪些型号?
  3. html5手机网站照片查看器,HTMLayout仿Picasa照片查看器效果
  4. 简述人工智能的发展历程图_简述墙体彩绘机发展历程
  5. python post请求 415_接收错误415:使用REST API发送GET请求时不支持媒体类型
  6. 剪切粘贴时总是上次的内容_如何关闭 iOS 14 的粘贴通知
  7. 程序员如何做好技术规划?
  8. XGB 调参基本方法
  9. 小白安装eclipse插件—testNG
  10. Kettle 合并记录报错!