最近因为需要用BERT-CRF模型做一个英文数据的实体抽取模型训练,因为github上BERT-CRF大多是对中文数据做NER, 这里特此记录一下处理过程中的解决方法与思路,废话不多说直接上代码,这里的代码模版参考的是 CLUENER2020项目下的BERT-CRF模型代码, 主要修改部分在 collate_fn 部分的 batch数据的 padding与aligning处理

首先,说一下我遇到的主要问题:

因为模型的数据padding, aligning的预处理是针对中文数据的,而在用英文数据调试训练过程发现padding和aligning总是出现越界溢出问题,经过网上多方调研,发现是由于中文分词与英文分词的方法不同,中文是单纯的词组切分,而英文是依据词源词根进行切分,导致分词的序列长与原句次数不一致,故原模型的padding与aligning处理方法已不适于英文数据,需要依据英文分词特点进行padding与aligning处理。

import os
import json
import torch
import numpy as np
from transformers import BertTokenizer, BertTokenizerFast
from transformers import RobertaTokenizer, RobertaModel
from torch.utils.data import Datasetclass NERDataset(Dataset):def __init__(self, words, labels, config, word_pad_idx=0, label_pad_idx=-1):self.tokenizer = BertTokenizerFast.from_pretrained(config.bert_model, do_lower_case=True, add_special_tokens=True)# self.tokenizer = BertTokenizerFast.from_pretrained(config.bert_model, do_lower_case=False, add_special_tokens=True)self.label2id = config.label2idself.id2label = {_id: _label for _label, _id in list(config.label2id.items())}self.dataset = self.preprocess(words, labels)self.word_pad_idx = word_pad_idxself.label_pad_idx = label_pad_idxself.device = config.devicedef preprocess(self, origin_sentences, origin_labels):"""Maps tokens and tags to their indices and stores them in the dict data.examples: word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956], array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]))label:[3, 13, 13, 13, 13, 0, 0, 0, 0, 0]"""data = []sentences = []labels = []for line in origin_sentences:# replace each token by its index# we can not use encode_plus because our sentences are aligned to labels in list typewords = []word_lens = []for token in line:words.append(self.tokenizer.tokenize(token))word_lens.append(len(token))# 变成单个字的列表,开头加上[CLS]words = ['[CLS]'] + [item for token in words for item in token] + ['SEP']token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])sentences.append((self.tokenizer.convert_tokens_to_ids(words), token_start_idxs))for tag in origin_labels:label_id = [self.label2id.get(t) for t in tag]labels.append(label_id)for sentence, label in zip(sentences, labels):data.append((sentence, label))return datadef __getitem__(self, idx):"""sample data to get batch"""word = self.dataset[idx][0]label = self.dataset[idx][1]return [word, label]def __len__(self):"""get dataset size"""return len(self.dataset)def collate_fn(self, batch):""""""sentences = [x[0] for x in batch]labels = [x[1] for x in batch]# batch lengthbatch_len = len(sentences)  # batch sizebatch_max_subwords_len = max([len(s[0]) for s in sentences])max_subword_len = min(batch_max_subwords_len, 512)max_token_len = 0# padding data 初始化batch_data = self.word_pad_idx * np.ones((batch_len, max_subword_len))  # 初始化标注数据默认为0 64 * max_len =batch_token_starts = []# padding and aligningfor j in range(batch_len):cur_subwords_len = len(sentences[j][0])  # word_id listif cur_subwords_len <= max_subword_len:batch_data[j][:cur_subwords_len] = sentences[j][0]else:batch_data[j] = sentences[j][0][:max_subword_len]token_start_ids = sentences[j][-1]token_starts = np.zeros(max_subword_len)token_starts[[idx for idx in token_start_ids if idx < max_subword_len]] = 1batch_token_starts.append(token_starts)max_token_len = max(int(sum(token_starts)), max_token_len)batch_labels = self.label_pad_idx * np.ones((batch_len, max_token_len))for j in range(batch_len):cur_labels_len = len(labels[j])if cur_labels_len <= max_token_len:batch_labels[j][:cur_labels_len] = labels[j]else:batch_labels[j] = labels[j][:max_token_len]# convert data to torch LongTensorsbatch_data = torch.tensor(batch_data, dtype=torch.long)batch_token_starts = torch.tensor(batch_token_starts, dtype=torch.long)batch_labels = torch.tensor(batch_labels, dtype=torch.long)# print(batch_data.size())# print(batch_token_starts.size())# print(batch_labels.size())# shift tensors to GPU if availablebatch_data, batch_label_starts = batch_data.to(self.device ), batch_token_starts.to(self.device )batch_labels = batch_labels.to(self.device)return [batch_data, batch_label_starts, batch_labels]

NER任务中BERT-CRF 模型的英文数据padding与aligning相关推荐

  1. Django框架(11.Django中的通过模型类查询数据以及相关函数和条件)

     Django中的查询函数 通过模型类.objects属性可以调用如下函数,实现对模型类对应的数据表的查询.    不管哪个函数注意返回值的类型 函数名 功能 返回值 说明 get 返回表中满足条件的 ...

  2. 水环境模型与大数据技术融合研究

    点击上方蓝字关注我们 水环境模型与大数据技术融合研究 马金锋1, 饶凯锋1, 李若男1,2, 张京1, 郑华1,2 1 中国科学院生态环境研究中心城市与区域生态国家重点实验室,北京 100085 2  ...

  3. 信息抽取实战:命名实体识别NER【ALBERT+Bi-LSTM模型 vs. ALBERT+Bi-LSTM+CRF模型】(附代码)

    实战:命名实体识别NER 目录 实战:命名实体识别NER 一.命名实体识别(NER) 二.BERT的应用 NLP基本任务 查找相似词语 提取文本中的实体 问答中的实体对齐 三.ALBERT ALBER ...

  4. bert模型简介、transformers中bert模型源码阅读、分类任务实战和难点总结

    bert模型简介.transformers中bert模型源码阅读.分类任务实战和难点总结:https://blog.csdn.net/HUSTHY/article/details/105882989 ...

  5. NLP中BERT模型详解

    标题NLP中BERT模型详解 谷歌发表的论文为: Attention Is ALL You Need 论文地址:[添加链接描述](https://arxiv.org/pdf/1706.03762.pd ...

  6. BiLSTM-CRF模型中的CRF层讲解

    参考:最通俗易懂的BiLSTM-CRF模型中的CRF层讲解 代码:Bert-BiLSTM-CRF

  7. huggingface中Bert模型的简单使用

    因为项目和毕设的缘故,做了挺多关于Bert分类的实际操作的,本文主要记录下transformers库中使用较多的类. 在本文中,你将看到 huggingface(hf)中Bert模型的简单介绍 Ber ...

  8. bert+crf可以做NER,那么为什么还有bert+bi-lstm+crf ?

    我在自己人工标注的一份特定领域的数据集上跑过,加上bert确实会比只用固定的词向量要好一些,即使只用BERT加一个softmax层都比不用bert的bilstm+crf强.而bert+bilstm+c ...

  9. NLP(二十五)实现ALBERT+Bi-LSTM+CRF模型

      在文章NLP(二十四)利用ALBERT实现命名实体识别中,笔者介绍了ALBERT+Bi-LSTM模型在命名实体识别方面的应用.   在本文中,笔者将介绍如何实现ALBERT+Bi-LSTM+CRF ...

最新文章

  1. 无事“自动驾驶”,有事“辅助驾驶”?
  2. OpenCV寻找复杂背景下物体的轮廓
  3. 重磅消息:Redis 6.0.0 稳定版发布
  4. Android中的表格布局
  5. Lua === Lua 十分钟基础入门上手
  6. python excel操作xlrd_python操作Excel读写--使用xlrd
  7. R绘图 vs Python绘图(散点图、折线图、直方图、条形图、箱线图、饼图、热力图、蜘蛛图)
  8. 粒度过粗_这些书帮助我度过了第一次成为技术主管的经历
  9. 微型计算机使用字符编码,微型计算机系统中普遍使用的字符编码是( )
  10. python 遍历_Python遍历字典
  11. haproxy5-ssl
  12. bash取得相应行的数据
  13. AFNetworking2.0源代码解析
  14. Ubuntu14.04安装中文输入法以及解决Gedit中文乱码问题
  15. spark 读取ftp_scala – 使用ftp在Apache Spark中的远程计算机上读取文件
  16. 大学生体育运动网页设计模板代码 校园篮球网页作业成品 学校篮球网页制作模板 学生简单体育运动网站设计成品...
  17. linux 移动硬盘 mac,Mac下使用NTFS格式的移动硬盘
  18. 没有装php可以用phpmyadmin,phpMyAdmin 安装及问题总结
  19. 临时手机短信云接收(防骚扰)
  20. Python使用腾讯云-短信服务发送手机短信

热门文章

  1. 中职计算机高考试题卷,中职高中高考计算机试卷试题及含答案.doc
  2. 完整!贪吃蛇游戏c语言代码分享(包括界面,计数,提示)
  3. Android 花里胡哨的加载Loading动画
  4. python循环控制--for-else循环
  5. 基于easyui 1.3.6设计的后台管理系统模板界面
  6. php页游sf源码,开源程序 PHP源码 页游联运系统 CPA+CPS
  7. 不同调制方式对信道容量影响的分析
  8. Android反编译修改apk并重新打包
  9. uboot启动过程教程详解
  10. JS基础_Unicode编码表