中文短文本的实体链指任务

1. 任务描述

本评测任务围绕实体链指技术,结合其对应的AI智能应用需求,在CCKS 2019面向中文短文本的实体链指任务的基础上进行了拓展与改进,主要改进包括以下几部分:
(1)去掉实体识别,专注于中文短文本场景下的多歧义实体消歧技术;
(2)增加对新实体(NIL实体)的上位概念类型判断;
(3)对标注文本数据调整,增加多模任务场景下的文本源,同时调整了多歧义实体比例。
面向中文短文本的实体链指,简称EL(Entity Linking)。即对于给定的一个中文短文本(如搜索Query、微博、对话内容、文章/视频/图片的标题等),EL将其中的实体与给定知识库中对应的实体进行关联。
传统的实体链指任务主要针对长文本,长文本拥有丰富的上下文信息,能辅助实体进行歧义消解并完成实体链指,相比之下,针对中文短文本的实体链指存在很大的挑战,主要原因如下:
(1)口语化严重,导致实体歧义消解困难;
(2)短文本上下文语境不丰富,须对上下文语境进行精准理解;
(3)相比英文,中文由于语言自身的特点,在短文本的链指问题上更有挑战。
此次任务的输入输出定义如下:
输入:
中文短文本以及该短文本中的实体集合。
输出:
输出文本此中文短文本的实体链指结果。每个结果包含:实体mention、在中文短文本中的位置偏移、其在给定知识库中的id,如果为NIL情况,需要再给出实体的上位概念类型(封闭体系的概念详见附件)
示例输入:

示例输出:

说明:对于实体有歧义的查询,系统应该有能力来区分知识库中链接的候选实体中哪个实体为正确链指的实体结果。例如,知识库中有8个不同的实体都可能是『琅琊榜』的正确链指结果,因为知识库中的这8个实体都可以通过『琅琊榜』的字面表达查找到,但是我们在给定的上下文中(『海燕』、『原创小说』、『权谋小说』),有足够的信息去区分这些候选实体中,哪个才是应该被关联上的结果。

2. 数据描述

2.1. 知识库

该任务知识库来自百度百科知识库。知识库中的每个实体都包含一个subject_id(知识库id),一个subject名称,实体的别名,对应的概念类型,以及与此实体相关的一系列二元组< predicate,object>(<属性,属性值>)信息形式。知识库中每行代表知识库的一条记录(一个实体信息),每条记录为json数据格式。
示例如下所示:

2.2 标准数据集

标注数据集由训练集、验证集和测试集组成,整体标注数据大约10万条左右,数据均通过百度众包标注生成,详细标注质量将会在数据发布时一并给出。
标注数据集中每条数据的格式为:

标注数据集主要来自于:真实的互联网网页标题数据、视频标题数据、搜索Query
标注文本对象的示例数据如下:

3. 评价指标

import torch
if torch.cuda.is_available():# Tell PyTorch to use the GPU.device = torch.device("cuda")print('There are %d GPU(s) available.' % torch.cuda.device_count())print('We will use the GPU:', torch.cuda.get_device_name(0))
else:print('No GPU available, using the CPU instead.')device = torch.device("cpu")
There are 1 GPU(s) available.
We will use the GPU: GeForce GTX 1070

编写路径

import os
import json
import logging
import random
from collections import defaultdictimport numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from transformers import (DataProcessor,InputExample,BertConfig,BertTokenizer,BertForSequenceClassification,glue_convert_examples_to_features,
)DEVICE = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')# 预训练模型路径
PRETRAINED_PATH = './chinese_roberta_wwm_ext_pytorch/'
# 实体链接训练路径
EL_SAVE_PATH = './pytorch-lightning-checkpoints/EntityLinking/'
# 实体类别推断训练路径
ET_SAVE_PATH = './pytorch-lightning-checkpoints/EntityTyping/'# 项目数据路径
DATA_PATH = './data/'# CCKS2020实体链指竞赛原始路径
RAW_PATH = DATA_PATH + 'ccks2020_el_data_v1/'# 预处理后导出的pickle文件路径
PICKLE_PATH = DATA_PATH + 'pickle/'
if not os.path.exists(PICKLE_PATH):os.mkdir(PICKLE_PATH)# 预测结果的文件路径
RESULT_PATH = DATA_PATH + 'result/'
if not os.path.exists(RESULT_PATH):os.mkdir(RESULT_PATH)# 训练、验证、推断所需的tsv文件路径
TSV_PATH = DATA_PATH + 'tsv/'
if not os.path.exists(TSV_PATH):os.mkdir(TSV_PATH)# 训练结果的CheckPoint文件路径
CKPT_PATH = './ckpt/'PICKLE_DATA = {# 实体名称对应的KBID列表'ENTITY_TO_KBIDS': None,# KBID对应的实体名称列表'KBID_TO_ENTITIES': None,# KBID对应的属性文本'KBID_TO_TEXT': None,# KBID对应的实体类型列表(注意:一个实体可能对应'|'分割的多个类型)'KBID_TO_TYPES': None,# KBID对应的属性列表'KBID_TO_PREDICATES': None,# 索引类型映射列表'IDX_TO_TYPE': None,# 类型索引映射字典'TYPE_TO_IDX': None,
}for k in PICKLE_DATA:filename = k + '.pkl'if os.path.exists(PICKLE_PATH + filename):PICKLE_DATA[k] = pd.read_pickle(PICKLE_PATH + filename)else:print(f'File {filename} not Exist!')def set_random_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = True

处理实体数据,把空数据去掉

logger = logging.getLogger(__name__)class PicklePreprocessor:"""生成全局变量Pickle文件的预处理器"""def __init__(self):# 实体名称对应的KBID列表  {"张健"} -> "10001"self.entity_to_kbids = defaultdict(set)# KBID对应的实体名称列表 "10001" -> {"张健"}self.kbid_to_entities = dict()# KBID对应的属性文本  "10001" -> {"政治面貌:中共党员","义项描述:潜山县塔畈乡副主任科员、纪委副书记","性别:男",# "学历:大专","中文名:张健"} self.kbid_to_text = dict()# KBID对应的实体类型列表 "10001" -> {"Person"}self.kbid_to_types = dict()# KBID对应的属性列表 "10001" -> {"政治面貌","义项描述","性别","学历","中文名"} self.kbid_to_predicates = dict()# 索引类型映射列表 ["Person"]self.idx_to_type = list()# 类型索引映射字典 {"Person":0}self.type_to_idx = dict()def run(self, shuffle_text=True):with open(RAW_PATH + 'kb.json', 'r',encoding='utf-8') as f:for line in tqdm(f):line = json.loads(line)kbid = line['subject_id']# 将实体名与别名合并entities = set(line['alias'])entities.add(line['subject'])for entity in entities:self.entity_to_kbids[entity].add(kbid)self.kbid_to_entities[kbid] = entitiestext_list, predicate_list = [], []for x in line['data']:# 简单拼接predicate与object,这部分可以考虑别的方法尝试text_list.append(':'.join([x['predicate'].strip(), x['object'].strip()]))predicate_list.append(x['predicate'].strip())if shuffle_text:  # 对属性文本随机打乱顺序random.shuffle(text_list)self.kbid_to_predicates[kbid] = predicate_listself.kbid_to_text[kbid] = ' '.join(text_list)# 删除文本中的特殊字符for c in ['\r', '\t', '\n']:self.kbid_to_text[kbid] = self.kbid_to_text[kbid].replace(c, '')type_list = line['type'].split('|')self.kbid_to_types[kbid] = type_listfor t in type_list:if t not in self.type_to_idx:self.type_to_idx[t] = len(self.idx_to_type)self.idx_to_type.append(t)# 保存pickle文件pd.to_pickle(self.entity_to_kbids, PICKLE_PATH + 'ENTITY_TO_KBIDS.pkl')pd.to_pickle(self.kbid_to_entities, PICKLE_PATH + 'KBID_TO_ENTITIES.pkl')pd.to_pickle(self.kbid_to_text, PICKLE_PATH + 'KBID_TO_TEXT.pkl')pd.to_pickle(self.kbid_to_types, PICKLE_PATH + 'KBID_TO_TYPES.pkl')pd.to_pickle(self.kbid_to_predicates, PICKLE_PATH + 'KBID_TO_PREDICATES.pkl')pd.to_pickle(self.idx_to_type, PICKLE_PATH + 'IDX_TO_TYPE.pkl')pd.to_pickle(self.type_to_idx, PICKLE_PATH + 'TYPE_TO_IDX.pkl')logger.info('Process Pickle File Finish.')

生成模 训练,验证集,把数据保存在tsv文件中

class DataFramePreprocessor:"""生成模型训练、验证、推断所需的tsv文件"""def __init__(self):passdef process_link_data(self, input_path, output_path, max_negs=-1):entity_to_kbids = PICKLE_DATA['ENTITY_TO_KBIDS']#print("entity_to_kbids")kbid_to_text = PICKLE_DATA['KBID_TO_TEXT']#print(kbid_to_text)kbid_to_predicates = PICKLE_DATA['KBID_TO_PREDICATES']link_dict = defaultdict(list)with open(input_path, 'r',encoding='utf-8') as f:for line in tqdm(f):line = json.loads(line)for data in line['mention_data']:# 对测试集特殊处理if 'kb_id' not in data:data['kb_id'] = '0'# KB中不存在的实体不进行链接if not data['kb_id'].isdigit():continueentity = data['mention']kbids = list(entity_to_kbids[entity])random.shuffle(kbids)num_negs = 0for kbid in kbids:if num_negs >= max_negs > 0 and kbid != data['kb_id']:continuelink_dict['text_id'].append(line['text_id'])link_dict['entity'].append(entity)link_dict['offset'].append(data['offset'])link_dict['short_text'].append(line['text'])link_dict['kb_id'].append(kbid)link_dict['kb_text'].append(kbid_to_text[kbid])link_dict['kb_predicate_num'].append(len(kbid_to_predicates[kbid]))if kbid != data['kb_id']:link_dict['predict'].append(0)num_negs += 1else:link_dict['predict'].append(1)link_data = pd.DataFrame(link_dict)link_data.to_csv(output_path, index=False, sep='\t')def process_type_data(self, input_path, output_path):kbid_to_types = PICKLE_DATA['KBID_TO_TYPES']type_dict = defaultdict(list)with open(input_path, 'r',encoding='utf-8') as f:for line in tqdm(f):line = json.loads(line)for data in line['mention_data']:entity = data['mention']# 测试集特殊处理if 'kb_id' not in data:entity_type = ['Other']elif data['kb_id'].isdigit():entity_type = kbid_to_types[data['kb_id']]else:entity_type = data['kb_id'].split('|')for x in range(len(entity_type)):entity_type[x] = entity_type[x][4:]for e in entity_type:type_dict['text_id'].append(line['text_id'])type_dict['entity'].append(entity)type_dict['offset'].append(data['offset'])type_dict['short_text'].append(line['text'])type_dict['type'].append(e)type_data = pd.DataFrame(type_dict)type_data.to_csv(output_path, index=False, sep='\t')def run(self):self.process_link_data(input_path=RAW_PATH + 'train.json',output_path=TSV_PATH + 'EL_TRAIN.tsv',max_negs=2,)logger.info('Process EL_TRAIN Finish.')self.process_link_data(input_path=RAW_PATH + 'dev.json',output_path=TSV_PATH + 'EL_VALID.tsv',max_negs=-1,)logger.info('Process EL_VALID Finish.')self.process_link_data(input_path=RAW_PATH + 'test.json',output_path=TSV_PATH + 'EL_TEST.tsv',max_negs=-1,)logger.info('Process EL_TEST Finish.')self.process_type_data(input_path=RAW_PATH + 'train.json',output_path=TSV_PATH + 'ET_TRAIN.tsv',)logger.info('Process ET_TRAIN Finish.')self.process_type_data(input_path=RAW_PATH + 'dev.json',output_path=TSV_PATH + 'ET_VALID.tsv',)logger.info('Process ET_VALID Finish.')self.process_type_data(input_path=RAW_PATH + 'test.json',output_path=TSV_PATH + 'ET_TEST.tsv',)logger.info('Process ET_TEST Finish.')

简单查看一下训练数据

train_data = pd.read_csv(TSV_PATH + 'ET_TRAIN.tsv', sep='\t')
train_data.head()
text_id entity offset short_text type
0 1 小品 0 小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人 Other
1 1 战狼故事 3 小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人 Work
2 1 吴京 10 小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人 Person
3 1 障碍 16 小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人 Other
4 1 爱人 20 小品《战狼故事》中,吴京突破重重障碍解救爱人,深情告白太感人 Other

简单查看一下效验数据

valid_data = pd.read_csv(TSV_PATH + 'ET_VALID.tsv', sep='\t')
valid_data.head()
text_id entity offset short_text type
0 1 天下没有不散的宴席 0 天下没有不散的宴席 - ╰つ雲中帆╰つ Work
1 1 ╰つ雲中帆╰つ 12 天下没有不散的宴席 - ╰つ雲中帆╰つ Other
2 2 永嘉 0 永嘉厂房出租 Location
3 2 厂房 2 永嘉厂房出租 Location
4 2 出租 4 永嘉厂房出租 Other

简单查看一下测试数据

test_data = pd.read_csv(TSV_PATH + 'ET_TEST.tsv', sep='\t')
test_data.head()
text_id entity offset short_text type
0 1 林平之 0 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other
1 1 岳灵珊 5 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other
2 1 师娘 18 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other
3 1 令狐冲 21 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other
4 2 思追 0 思追原来是个超级妹控,不愿妹妹嫁人,然而妹妹却喜欢一博老师 Other

实体链接数据处理

class EntityTypingProcessor(DataProcessor):"""实体链接数据处理"""def get_train_examples(self, file_path):return self._create_examples(self._read_tsv(file_path),set_type='train',)def get_dev_examples(self, file_path):return self._create_examples(self._read_tsv(file_path),set_type='valid',)def get_test_examples(self, file_path):return self._create_examples(self._read_tsv(file_path),set_type='test',)def get_labels(self):return PICKLE_DATA['IDX_TO_TYPE']def _create_examples(self, lines, set_type):examples = []for i, line in enumerate(lines):if i == 0:continueguid = f'{set_type}-{i}'text_a = line[1]text_b = line[3]label = line[-1]examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label,))return examplesdef create_dataloader(self, examples, tokenizer, max_length=64,shuffle=False, batch_size=64, use_pickle=False):pickle_name = 'ET_FEATURE_' + examples[0].guid.split('-')[0].upper() + '.pkl'if use_pickle:features = pd.read_pickle(PICKLE_PATH + pickle_name)else:features = glue_convert_examples_to_features(examples,tokenizer,label_list=self.get_labels(),max_length=max_length,output_mode='classification',)pd.to_pickle(features, PICKLE_PATH + pickle_name)dataset = torch.utils.data.TensorDataset(torch.LongTensor([f.input_ids for f in features]),torch.LongTensor([f.attention_mask for f in features]),torch.LongTensor([f.token_type_ids for f in features]),torch.LongTensor([f.label for f in features]),)dataloader = torch.utils.data.DataLoader(dataset,shuffle=shuffle,batch_size=batch_size,num_workers=2,)return dataloaderdef generate_feature_pickle(self, max_length):tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")train_examples = self.get_train_examples(TSV_PATH + 'ET_TRAIN.tsv')valid_examples = self.get_dev_examples(TSV_PATH + 'ET_VALID.tsv')test_examples = self.get_test_examples(TSV_PATH + 'ET_TEST.tsv')self.create_dataloader(examples=train_examples,tokenizer=tokenizer,max_length=max_length,shuffle=True,batch_size=32,use_pickle=False,)self.create_dataloader(examples=valid_examples,tokenizer=tokenizer,max_length=max_length,shuffle=False,batch_size=32,use_pickle=False,)self.create_dataloader(examples=test_examples,tokenizer=tokenizer,max_length=max_length,shuffle=False,batch_size=32,use_pickle=False,)

实体链接数据处理

class EntityLinkingProcessor(DataProcessor):"""实体链接数据处理"""def get_train_examples(self, file_path):return self._create_examples(self._read_tsv(file_path),set_type='train',)def get_dev_examples(self, file_path):return self._create_examples(self._read_tsv(file_path),set_type='valid',)def get_test_examples(self, file_path):return self._create_examples(self._read_tsv(file_path),set_type='test',)def get_labels(self):return ['0', '1']def _create_examples(self, lines, set_type):examples = []for i, line in enumerate(lines):if i == 0:continueguid = f'{set_type}-{i}'text_a = line[1] + ' ' + line[3]text_b = line[5]label = line[-1]examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label,))return examplesdef create_dataloader(self, examples, tokenizer, max_length=384,shuffle=False, batch_size=32, use_pickle=False):pickle_name = 'EL_FEATURE_' + examples[0].guid.split('-')[0].upper() + '.pkl'if use_pickle:features = pd.read_pickle(PICKLE_PATH + pickle_name)else:features = glue_convert_examples_to_features(examples,tokenizer,label_list=self.get_labels(),max_length=max_length,output_mode='classification',                )pd.to_pickle(features, PICKLE_PATH + pickle_name)dataset = torch.utils.data.TensorDataset(torch.LongTensor([f.input_ids for f in features]),torch.LongTensor([f.attention_mask for f in features]),torch.LongTensor([f.token_type_ids for f in features]),torch.LongTensor([f.label for f in features]),)dataloader = torch.utils.data.DataLoader(dataset,shuffle=shuffle,batch_size=batch_size,num_workers=2,)return dataloaderdef generate_feature_pickle(self, max_length):tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")print("asdgfsdfg")train_examples = self.get_train_examples(TSV_PATH + 'EL_TRAIN.tsv')valid_examples = self.get_dev_examples(TSV_PATH + 'EL_VALID.tsv')test_examples = self.get_test_examples(TSV_PATH + 'EL_TEST.tsv')self.create_dataloader(examples=train_examples,tokenizer=tokenizer,max_length=max_length,shuffle=True,batch_size=32,use_pickle=False,)self.create_dataloader(examples=valid_examples,tokenizer=tokenizer,max_length=max_length,shuffle=False,batch_size=32,use_pickle=False,)self.create_dataloader(examples=test_examples,tokenizer=tokenizer,max_length=max_length,shuffle=False,batch_size=32,use_pickle=False,)

实体链接模型

class EntityLinkingModel(pl.LightningModule):"""实体链接模型"""def __init__(self, max_length=384, batch_size=32, use_pickle=True):super(EntityLinkingModel, self).__init__()# 输入最大长度self.max_length = max_lengthself.batch_size = batch_sizeself.use_pickle = use_pickleself.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")       self.bert = BertForSequenceClassification.from_pretrained("hfl/chinese-roberta-wwm-ext",num_labels = 1,)# 二分类损失函数self.criterion = nn.BCEWithLogitsLoss()def forward(self, input_ids, attention_mask, token_type_ids):logits = self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,)[0]return logits.squeeze()def prepare_data(self):self.processor = EntityLinkingProcessor()self.train_examples = self.processor.get_train_examples(TSV_PATH + 'EL_TRAIN.tsv')self.valid_examples = self.processor.get_dev_examples(TSV_PATH + 'EL_VALID.tsv')self.test_examples = self.processor.get_test_examples(TSV_PATH + 'EL_TEST.tsv')self.train_loader = self.processor.create_dataloader(examples=self.train_examples,tokenizer=self.tokenizer,max_length=self.max_length,shuffle=True,batch_size=self.batch_size,use_pickle=self.use_pickle,)self.valid_loader = self.processor.create_dataloader(examples=self.valid_examples,tokenizer=self.tokenizer,max_length=self.max_length,shuffle=False,batch_size=self.batch_size,use_pickle=self.use_pickle,)self.test_loader = self.processor.create_dataloader(examples=self.test_examples,tokenizer=self.tokenizer,max_length=self.max_length,shuffle=False,batch_size=self.batch_size,use_pickle=self.use_pickle,)print("finish")def training_step(self, batch, batch_idx):input_ids, attention_mask, token_type_ids, labels = batchlogits = self(input_ids, attention_mask, token_type_ids)loss = self.criterion(logits, labels.float())preds = (logits > 0).int()acc = (preds == labels).float().mean()tensorboard_logs = {'train_loss': loss, 'train_acc': acc}return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}def validation_step(self, batch, batch_idx):input_ids, attention_mask, token_type_ids, labels = batchlogits = self(input_ids, attention_mask, token_type_ids)loss = self.criterion(logits, labels.float())preds = (logits > 0).int()acc = (preds == labels).float().mean()return {'val_loss': loss, 'val_acc': acc}def validation_epoch_end(self, outputs):val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}def configure_optimizers(self):return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)def train_dataloader(self):return self.train_loaderdef val_dataloader(self):return self.valid_loader

实体链接推断

class EntityLinkingPredictor:def __init__(self, ckpt_name, batch_size=8, use_pickle=True):self.ckpt_name = ckpt_nameself.batch_size = batch_sizeself.use_pickle = use_pickledef generate_tsv_result(self, tsv_name, tsv_type='Valid'):processor = EntityLinkingProcessor()tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")if tsv_type == 'Valid':examples = processor.get_dev_examples(TSV_PATH + tsv_name)elif tsv_type == 'Test':examples = processor.get_test_examples(TSV_PATH + tsv_name)else:raise ValueError('tsv_type error')dataloader = processor.create_dataloader(examples=examples,tokenizer=tokenizer,max_length=384,shuffle=False,batch_size=self.batch_size,use_pickle=self.use_pickle,)model = EntityLinkingModel.load_from_checkpoint(checkpoint_path=CKPT_PATH + self.ckpt_name,)model.to(DEVICE)model = nn.DataParallel(model)model.eval()result_list, logit_list = [], []for batch in tqdm(dataloader):for i in range(len(batch)):batch[i] = batch[i].to(DEVICE)input_ids, attention_mask, token_type_ids, labels = batchlogits = model(input_ids, attention_mask, token_type_ids)preds = (logits > 0).int()result_list.extend(preds.tolist())logit_list.extend(logits.tolist())tsv_data = pd.read_csv(TSV_PATH + tsv_name, sep='\t')tsv_data['logits'] = logit_listtsv_data['result'] = result_listresult_name = tsv_name.split('.')[0] + '_RESULT.tsv'tsv_data.to_csv(RESULT_PATH + result_name, index=False, sep='\t')

实体类型推断模型

import torch.nn as nnclass EntityTypingModel(pl.LightningModule):"""实体类型推断模型"""def __init__(self, max_length=64, batch_size=64, use_pickle=True):super(EntityTypingModel, self).__init__()# 输入最大长度self.max_length = max_lengthself.batch_size = batch_sizeself.use_pickle = use_pickle# 二分类损失函数self.criterion = nn.CrossEntropyLoss()self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")# 预训练模型self.bert = BertForSequenceClassification.from_pretrained("hfl/chinese-roberta-wwm-ext",num_labels=len(PICKLE_DATA['IDX_TO_TYPE']),)def forward(self, input_ids, attention_mask, token_type_ids):return self.bert(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,)[0]def prepare_data(self):self.processor = EntityTypingProcessor()self.train_examples = self.processor.get_train_examples(TSV_PATH + 'ET_TRAIN.tsv')self.valid_examples = self.processor.get_dev_examples(TSV_PATH + 'ET_VALID.tsv')self.test_examples = self.processor.get_test_examples(TSV_PATH + 'ET_TEST.tsv')self.train_loader = self.processor.create_dataloader(examples=self.train_examples,tokenizer=self.tokenizer,max_length=self.max_length,shuffle=True,batch_size=self.batch_size,use_pickle=self.use_pickle,)self.valid_loader = self.processor.create_dataloader(examples=self.valid_examples,tokenizer=self.tokenizer,max_length=self.max_length,shuffle=False,batch_size=self.batch_size,use_pickle=self.use_pickle,)self.test_loader = self.processor.create_dataloader(examples=self.test_examples,tokenizer=self.tokenizer,max_length=self.max_length,shuffle=False,batch_size=self.batch_size,use_pickle=self.use_pickle,)def training_step(self, batch, batch_idx):input_ids, attention_mask, token_type_ids, labels = batchoutputs = self(input_ids, attention_mask, token_type_ids)loss = self.criterion(outputs, labels)_, preds = torch.max(outputs, dim=1)acc = (preds == labels).float().mean()tensorboard_logs = {'train_loss': loss, 'train_acc': acc}return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}def validation_step(self, batch, batch_idx):input_ids, attention_mask, token_type_ids, labels = batchoutputs = self(input_ids, attention_mask, token_type_ids)loss = self.criterion(outputs, labels)_, preds = torch.max(outputs, dim=1)acc = (preds == labels).float().mean()return {'val_loss': loss, 'val_acc': acc}def validation_epoch_end(self, outputs):val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}def configure_optimizers(self):return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)def train_dataloader(self):return self.train_loaderdef val_dataloader(self):return self.valid_loader

实体类型推断

class EntityTypingPredictor:def __init__(self, ckpt_name, batch_size=8, use_pickle=True):self.ckpt_name = ckpt_nameself.batch_size = batch_sizeself.use_pickle = use_pickledef generate_tsv_result(self, tsv_name, tsv_type='Valid'):processor = EntityTypingProcessor()tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")if tsv_type == 'Valid':examples = processor.get_dev_examples(TSV_PATH + tsv_name)elif tsv_type == 'Test':examples = processor.get_test_examples(TSV_PATH + tsv_name)else:raise ValueError('tsv_type error')dataloader = processor.create_dataloader(examples=examples,tokenizer=tokenizer,max_length=64,shuffle=False,batch_size=self.batch_size,use_pickle=self.use_pickle,)model = EntityTypingModel.load_from_checkpoint(checkpoint_path=CKPT_PATH + self.ckpt_name,)model.to(DEVICE)model = nn.DataParallel(model)model.eval()result_list = []for batch in tqdm(dataloader):for i in range(len(batch)):batch[i] = batch[i].to(DEVICE)input_ids, attention_mask, token_type_ids, labels = batchoutputs = model(input_ids, attention_mask, token_type_ids)_, preds = torch.max(outputs, dim=1)result_list.extend(preds.tolist())idx_to_type = PICKLE_DATA['IDX_TO_TYPE']result_list = [idx_to_type[x] for x in result_list]tsv_data = pd.read_csv(TSV_PATH + tsv_name, sep='\t')tsv_data['result'] = result_listresult_name = tsv_name.split('.')[0] + '_RESULT.tsv'tsv_data.to_csv(RESULT_PATH + result_name, index=False, sep='\t')

导入数据

def preprocess_pickle_file():processor = PicklePreprocessor()processor.run()def preprocess_tsv_file():processor = DataFramePreprocessor()processor.run()def generate_feature_pickle():processor = EntityLinkingProcessor()processor.generate_feature_pickle(max_length=384)processor = EntityTypingProcessor()processor.generate_feature_pickle(max_length=64)def train_entity_linking_model(ckpt_name):model = EntityLinkingModel(max_length=384, batch_size=32)trainer = pl.Trainer(max_epochs=1,gpus=1,distributed_backend='dp',default_save_path=EL_SAVE_PATH,profiler=True,)trainer.fit(model)trainer.save_checkpoint(CKPT_PATH + ckpt_name)def train_entity_typing_model(ckpt_name):model = EntityTypingModel(max_length=64, batch_size=64)trainer = pl.Trainer(max_epochs=1,gpus=1,distributed_backend='dp',default_save_path=ET_SAVE_PATH,profiler=True,)trainer.fit(model)trainer.save_checkpoint(CKPT_PATH + ckpt_name)def generate_link_tsv_result(ckpt_name):predictor = EntityLinkingPredictor(ckpt_name, batch_size=24, use_pickle=True)predictor.generate_tsv_result('EL_VALID.tsv', tsv_type='Valid')predictor.generate_tsv_result('EL_TEST.tsv', tsv_type='Test')def generate_type_tsv_result(ckpt_name):predictor = EntityTypingPredictor(ckpt_name, batch_size=64, use_pickle=True)predictor.generate_tsv_result('ET_VALID.tsv', tsv_type='Valid')predictor.generate_tsv_result('ET_TEST.tsv', tsv_type='Test')def make_predication_result(input_name, output_name, el_ret_name, et_ret_name):entity_to_kbids = PICKLE_DATA['ENTITY_TO_KBIDS']el_ret = pd.read_csv(RESULT_PATH + el_ret_name, sep='\t', dtype={'text_id': np.str_,'offset': np.str_,'kb_id': np.str_})et_ret = pd.read_csv(RESULT_PATH + et_ret_name, sep='\t', dtype={'text_id': np.str_, 'offset': np.str_})result = []with open(RAW_PATH + input_name, 'r',encoding="utf-8") as f:for line in tqdm(f):line = json.loads(line)for data in line['mention_data']:text_id = line['text_id']offset = data['offset']candidate_data = el_ret[(el_ret['text_id'] == text_id) & (el_ret['offset'] == offset)]# Entity Linkingif len(candidate_data) > 0 and candidate_data['logits'].max() > 0:max_idx = candidate_data['logits'].idxmax()data['kb_id'] = candidate_data.loc[max_idx]['kb_id']# Entity Typingelse:type_data = et_ret[(et_ret['text_id'] == text_id) & (et_ret['offset'] == offset)]data['kb_id'] = 'NIL_' + type_data.iloc[0]['result']result.append(line)with open(RESULT_PATH + output_name, 'w',encoding="utf-8") as f:for r in result:json.dump(r, f, ensure_ascii=False)f.write('\n')
set_random_seed(20200619)
preprocess_pickle_file()
preprocess_tsv_file()
generate_feature_pickle()
train_entity_linking_model('EL_BASE_EPOCH0.ckpt')
generate_link_tsv_result('EL_BASE_EPOCH0.ckpt')
train_entity_typing_model('ET_BASE_EPOCH1.ckpt')
generate_type_tsv_result('ET_BASE_EPOCH1.ckpt')

简单查看一下测试数据的准确率

el_ret = pd.read_csv("./data/result/ET_TEST_RESULT.tsv", sep='\t')
el_ret.head()
text_id entity offset short_text type result
0 1 林平之 0 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other Person
1 1 岳灵珊 5 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other Person
2 1 师娘 18 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other Other
3 1 令狐冲 21 林平之答应岳灵珊报仇之事,从长计议,师娘想令狐冲了 Other Person
4 2 思追 0 思追原来是个超级妹控,不愿妹妹嫁人,然而妹妹却喜欢一博老师 Other Person
el_ret = pd.read_csv("./data/result/ET_VALID_RESULT.tsv", sep='\t')
el_ret.head()
text_id entity offset short_text type result
0 1 天下没有不散的宴席 0 天下没有不散的宴席 - ╰つ雲中帆╰つ Work Work
1 1 ╰つ雲中帆╰つ 12 天下没有不散的宴席 - ╰つ雲中帆╰つ Other Person
2 2 永嘉 0 永嘉厂房出租 Location Location
3 2 厂房 2 永嘉厂房出租 Location Other
4 2 出租 4 永嘉厂房出租 Other Other
path = "./data/result/ET_VALID_RESULT.tsv"
data = pd.read_csv(path, sep='\t', dtype={'text_id': np.str_,'offset': np.str_,'kb_id': np.str_})
data.head()
text_id entity offset short_text type result
0 1 天下没有不散的宴席 0 天下没有不散的宴席 - ╰つ雲中帆╰つ Work Work
1 1 ╰つ雲中帆╰つ 12 天下没有不散的宴席 - ╰つ雲中帆╰つ Other Person
2 2 永嘉 0 永嘉厂房出租 Location Location
3 2 厂房 2 永嘉厂房出租 Location Other
4 2 出租 4 永嘉厂房出租 Other Other
make_predication_result('dev.json', 'valid_result.json', 'EL_VALID_RESULT.tsv', 'ET_VALID_RESULT.tsv')
10000it [13:17, 12.55it/s]
make_predication_result('test.json', 'test_result.json', 'EL_TEST_RESULT.tsv', 'ET_TEST_RESULT.tsv')
10000it [31:54,  5.22it/s]

计算准确率

# !/bin/env python
# -*- coding: utf-8 -*-
#####################################################################################
#
#  Copyright (c) CCKS 2020 Entity Linking Organizing Committee.
#  All Rights Reserved.
#
#####################################################################################
"""
@version 2020-03-30
@brief:Entity Linking效果评估脚本,评价指标Micro-F1
"""
# import sys# reload(sys)
# sys.setdefaultencoding('utf-8')
import json
from collections import defaultdictclass Eval(object):"""Entity Linking Evaluation"""def __init__(self, golden_file_path, user_file_path):self.golden_file_path = golden_file_pathself.user_file_path = user_file_pathself.tp = 0self.fp = 0self.total_recall = 0self.errno = Nonedef format_check(self, file_path):"""文件格式验证:param file_path: 文件路径:return: Bool类型:是否通过格式检查,通过为True,反之False"""flag = Truefor line in open(file_path,encoding='utf-8'):json_info = json.loads(line.strip())if 'text_id' not in json_info:flag = Falseself.errno = 1breakif 'text' not in json_info:flag = Falseself.errno = 2breakif 'mention_data' not in json_info:flag = Falseself.errno = 3breakif not json_info['text_id'].isdigit():flag = Falseself.errno = 5break           if not isinstance(json_info['mention_data'], list):flag = Falseself.errno = 7breakfor mention_info in json_info['mention_data']:if 'kb_id' not in mention_info:flag = Falseself.errno = 7breakif 'mention' not in mention_info:flag = Falseself.errno = 8breakif 'offset' not in mention_info:flag = Falseself.errno = 9break                if not mention_info['offset'].isdigit():flag = Falseself.errno = 13breakreturn flagdef micro_f1(self):""":return: float类型:精确率,召回率,Micro-F1值"""# 文本格式验证flag_golden = self.format_check(self.golden_file_path)flag_user = self.format_check(self.user_file_path)# 格式验证失败直接返回Noneif not flag_golden or not flag_user:return None, None, Noneprecision = 0recall = 0self.tp = 0self.fp = 0self.total_recall = 0golden_dict = defaultdict(list)for line in open(self.golden_file_path,encoding='utf-8'):golden_info = json.loads(line.strip())text_id = golden_info['text_id']text = golden_info['text']mention_data = golden_info['mention_data']for mention_info in mention_data:kb_id = mention_info['kb_id']mention = mention_info['mention']offset = mention_info['offset']key = '\1'.join([text_id, text, mention, offset]).encode('utf8')# value的第二个元素表示标志位,用于判断是否已经进行了统计golden_dict[key] = [kb_id, 0]self.total_recall += 1# 进行评估for line in open(self.user_file_path,encoding='utf-8'):golden_info = json.loads(line.strip())text_id = golden_info['text_id']text = golden_info['text']mention_data = golden_info['mention_data']for mention_info in mention_data:kb_id = mention_info['kb_id']mention = mention_info['mention']offset = mention_info['offset']key = '\1'.join([text_id, text, mention, offset]).encode('utf8')if key in golden_dict:kb_result_golden = golden_dict[key]if kb_id.isdigit():if kb_id in [kb_result_golden[0]] and kb_result_golden[1] in [0]:self.tp += 1else:self.fp += 1else:# nil golden结果nil_res = kb_result_golden[0].split('|')if kb_id in nil_res and kb_result_golden[1] in [0]:self.tp += 1else:self.fp += 1golden_dict[key][1] = 1else:self.fp += 1if self.tp + self.fp > 0:precision = float(self.tp) / (self.tp + self.fp)if self.total_recall > 0:recall = float(self.tp) / self.total_recalla = 2 * precision * recallb = precision + recallif b == 0:return 0, 0, 0f1 = a / breturn precision, recall, f1
eval = Eval('./data/ccks2020_el_data_v1/dev.json', './data/result/valid_result.json')prec, recall, f1 = eval.micro_f1()
print(prec, recall, f1)
if eval.errno:print(eval.errno)
0.8488185641242852 0.8488185641242852 0.8488185641242852

相关链接:2020全国知识图谱与语义计算大会 http://sigkg.cn/ccks2020/?page_id=69

中文短文本的实体链指相关推荐

  1. 面向中文短文本的实体链指任务竞赛亚军DeepBlueAI团队技术分享

    ©PaperWeekly 原创 · 作者|罗志鹏 学校|深兰北京AI研发中心 研究方向|物体检测 全国知识图谱与语义计算大会(CCKS 2020)11 月 12 日至 15 日在江西南昌举行,CCKS ...

  2. 小布助手在面向中文短文本的实体链指比赛中的实践应用

    背景介绍 实体链指是指对于给定的一个文本(如搜索Query.微博.对话内容.文章.视频.图片的标题等),将其中的实体与给定知识库中对应的实体进行关联.实体链指一般有两种任务设计方式:Pipeline式 ...

  3. 小米知识图谱团队斩获CCKS 2020实体链指比赛冠军

    "CCKS 2020:面向中文短文本的实体链指任务"是由中国中文信息学会语言与知识计算专业委员会主办,该比赛主要面向中文短文本的实体链指,简称 EL(Entity Linking) ...

  4. 实体链指比赛方案分享

    实体链指比赛方案分享:https://aistudio.baidu.com/aistudio/projectdetail/1331020?channelType=0&channel=0 实体链 ...

  5. 小布助手在百度飞桨实体链指比赛中的实践应用

    本文由百度飞桨举办的千言数据集:面向中文短文本的实体链指任务比赛中取得优异成绩的小布助手算法工程师樊乘源投稿. 背景介绍 实体链指是指对于给定的一个文本(如搜索Query.微博.对话内容.文章.视频. ...

  6. 容联云AI问鼎“千言数据集—实体链指评测“,持续打造知识语义计算能力

    容联云研发并积累了面向业务知识图谱敏捷构建与应用的核心算法能力,可快速响应业务需求,并利用"知识"增强现有的语义理解技术,解决用户的知识查询等问题. 近日,容联云知识及语义计算技术 ...

  7. 实体list 查找一个符合条件的实体并返回其中一个字段_小米知识图谱团队斩获CCKS 2020实体链指比赛冠军...

    "CCKS 2020:面向中文短文本的实体链指任务"是由中国中文信息学会语言与知识计算专业委员会主办,该比赛主要面向中文短文本的实体链指,简称 EL(Entity Linking) ...

  8. 千言实体链指赛事登顶,冠军团队经验独家分享

    点击左上方蓝字关注我们 本文由小米公司算法工程师吕荣荣.王鹏程投稿,该团队目前在千言&百度飞桨实体链指常规赛leaderboard中排名第一. 大数据时代,信息爆炸性增长,直接导致了信息过载. ...

  9. 实体链指(2)EL:Disambiguation-Only

    实体链指(2)EL:Disambiguation-Only 1 Neural cross-lingual entity linking Embeddings Modeling Contexts Mod ...

最新文章

  1. R语言绘制不一样的条形图
  2. 无法打开包括文件:“mysql..h”: No such file or directory
  3. Conda solving environment一晚上还不能完成有解吗?
  4. zoj3715 Kindergarten Election
  5. (组合数学笔记)Pólya计数理论_Part.7_Pólya定理的母函数形式
  6. 产品创新及内容多元化 推动腾讯音乐第一季度财务和运营表现强劲
  7. 私有变量访问/延后执行代码块
  8. GRASP设计原则(职责分配原则)
  9. 如何解决sql server 存储过程在查询分析器快,但程序调用存储过程执行慢的问题?
  10. 晨读-如何打造出有效的“人脉关系”
  11. Tomcat 修改端口号
  12. tsconfig.json详细配置
  13. LocalDate 获取英文星期
  14. 模仿的网易和钱钱钱的腾讯
  15. 鸿蒙 什么意思,鸿蒙是什么意思和拼音怎么读
  16. pytorch训练网络 程序未报错 但是训练不动
  17. 我们认为2是第一个素数,3是第二个素数,5是第三个素数,依次类推。现在,给定两个整数n和m,0<n<=m<=200,你的程序要计算第n个素数到第m个素数之间所有的素数的和,包括第n个素数和第m个素数
  18. 笑死:Welcome to Skip Thompson's Homepage
  19. c语言静态变量与局部变量,C语言 全局变量、静态全局变量、局部变量、静态局部变量...
  20. 【Linux 内核】编译 Linux 内核 ⑦ ( 安装内核模块 | 安装内核 | 重启系统 | 查看当前内核版本 )

热门文章

  1. 使用HttpClient登录微博开放平台,获取授权code
  2. 2023年高新技术企业怎么申报认定
  3. BZOJ3572.【HNOI2014】世界树(worldtree)
  4. 做聊天机器人平台就是赌博
  5. null pointer
  6. c# SQLite下载和安装
  7. 《Maven实战》百度网盘
  8. Kubernetes command、args的区别
  9. 开心网试水房产汽车虚拟广告营销
  10. 如何清除IE Cache(緩存)