信息抽取(三)三元关系抽取——改良后的层叠式指针网络

  • 前言
  • 优化在验证集上的模型推理结果的SPO抽取方法
  • 不随机选择S(subject),⽽是遍历所有不同主语的标注样本构建训练集。
  • 模型优化
  • 加入对抗训练FGM
  • 总结

前言

基于我上一篇的博客:信息抽取(二)花了一个星期走了无数条弯路终于用TF复现了苏神的《Bert三元关系抽取模型》,我到底悟到了什么?

复现后的模型在百度2019年语言竞赛三元关系抽取的数据集上F1值仅达到77%,我在博文总结了几点可以优化的方向,并实现一系列层叠式指针网络的改良。在此贴出代码和提升结果。


优化在验证集上的模型推理结果的SPO抽取方法

原方案是将token decode后组成结果,改良后的方案通过在token上的index返回到原文本中切割出答案,这避免了token无法识别一些特殊文字和符号亦或是空格。

def rematch_text_word(tokenizer,text,enc_context,enc_start,enc_end):span = [a.span()[0] for a in re.finditer(' ', text)]decode_list = [tokenizer.decode([i]) for i in enc_context][1:]start = 0end = 0len_start = 0for i in range(len(decode_list)):if i ==  enc_start - 1:start = len_startj = decode_list[i]if '#' in j and len(j)>1:j = j.replace('#','')if j == '[UNK]':j = '。'len_start += len(j)if i == enc_end - 1:end = len_startbreakfor span_index in span:if start >= span_index:start += 1end += 1if end > span_index and span_index>start:end += 1return text[start:end]

不随机选择S(subject),⽽是遍历所有不同主语的标注样本构建训练集。

原方案是对于每组文本数据,仅随机抽取一个S以及其相关的PO构建成一组数据。
改良后,对于每组文本数据,分别抽取其所有不同的S以及其相关的PO组成多组数据。

尽管对不不同样本来说S是相同的,但在实验中发现,模型对于S的推理往往比PO关系优秀太多,因此S的可能过拟合来提升模型在PO上的表现是值得的。

以上两个优化方案的提升效果:F1:0.7719 —> 0.7979

def proceed_data(text_list,spo_list,p2id,id2p,tokenizer,MAX_LEN,sop_count):id_label = {}ct = len(text_list)MAX_LEN = MAX_LENprint(sop_count)input_ids = np.zeros((sop_count,MAX_LEN),dtype='int32')attention_mask = np.zeros((sop_count,MAX_LEN),dtype='int32')start_tokens = np.zeros((sop_count,MAX_LEN),dtype='int32')end_tokens = np.zeros((sop_count,MAX_LEN),dtype='int32')send_s_po = np.zeros((sop_count,2),dtype='int32')object_start_tokens = np.zeros((sop_count,MAX_LEN,len(p2id)),dtype='int32')object_end_tokens = np.zeros((sop_count,MAX_LEN,len(p2id)),dtype='int32')index_vaild = -1for k in range(ct):context_k = text_list[k].lower().replace(' ','')enc_context = tokenizer.encode(context_k,max_length=MAX_LEN,truncation=True) if len(spo_list[k])==0:continue          start = []S_index = []for j in range(len(spo_list[k])):answers_text_k = spo_list[k][j]['subject'].lower().replace(' ','')chars = np.zeros((len(context_k)))index = context_k.find(answers_text_k)chars[index:index+len(answers_text_k)]=1offsets = []idx=0for t in enc_context[1:]:w = tokenizer.decode([t])if '#' in w and len(w)>1:w = w.replace('#','')if w == '[UNK]':w = '。'offsets.append((idx,idx+len(w)))idx += len(w)toks = []for i,(a,b) in enumerate(offsets):sm = np.sum(chars[a:b])if sm>0: toks.append(i) if len(toks)>0:S_start = toks[0]+1S_end = toks[-1]+1if (S_start,S_end) not in start:index_vaild += 1start.append((S_start,S_end))input_ids[index_vaild,:len(enc_context)] = enc_contextattention_mask[index_vaild,:len(enc_context)] = 1start_tokens[index_vaild,S_start] = 1end_tokens[index_vaild,S_end] = 1send_s_po[index_vaild,0] = S_startsend_s_po[index_vaild,1] = S_endS_index.append([j,index_vaild])else:S_index.append([j,index_vaild])if len(S_index) > 0:for index_ in range(len(S_index)):#随机选取object的首位,如果选取错误,则作为负样本object_text_k = spo_list[k][S_index[index_][0]]['object'].lower().replace(' ','')predicate = spo_list[k][S_index[index_][0]]['predicate']p_id = p2id[predicate]chars = np.zeros((len(context_k)))index = context_k.find(object_text_k)chars[index:index+len(object_text_k)]=1offsets = [] idx = 0for t in enc_context[1:]:w = tokenizer.decode([t])if '#' in w and len(w)>1:w = w.replace('#','')if w == '[UNK]':w = '。'offsets.append((idx,idx+len(w)))idx += len(w)toks = []for i,(a,b) in enumerate(offsets):sm = np.sum(chars[a:b])if sm>0: toks.append(i) if len(toks)>0:id_label[p_id] = predicateP_start = toks[0]+1P_end = toks[-1]+1object_start_tokens[S_index[index_][1]][P_start,p_id] = 1object_end_tokens[S_index[index_][1]][P_end,p_id] = 1return input_ids[:index_vaild],attention_mask[:index_vaild],start_tokens[:index_vaild],\
end_tokens[:index_vaild],send_s_po[:index_vaild],object_start_tokens[:index_vaild],\
object_end_tokens[:index_vaild],id_label

模型优化

图中并没有表示如何加入实体的position embedding,这部分也是本人自己摸索出来的,通过抽取Bert的position embedding加入到hidden state中,然后作self_attention,这样能给模型的f1带来0.06的提升。
也尝试过重新构建一个position embedding让模型自己学习,但并没有提升。

def build_model(pretrained_path,config,MAX_LEN,p2id):ids = tf.keras.layers.Input((MAX_LEN,), dtype=tf.int32)att = tf.keras.layers.Input((MAX_LEN,), dtype=tf.int32)s_po_index =  tf.keras.layers.Input((2,), dtype=tf.int32)config.output_hidden_states = Truebert_model = TFBertModel.from_pretrained(pretrained_path,config=config,from_pt=True)x, _, hidden_states = bert_model(ids,attention_mask=att)layer_1 = hidden_states[-1]start_logits = tf.keras.layers.Dense(1,activation = 'sigmoid')(layer_1)start_logits = tf.keras.layers.Lambda(lambda x: x**2)(start_logits)end_logits = tf.keras.layers.Dense(1,activation = 'sigmoid')(layer_1)end_logits = tf.keras.layers.Lambda(lambda x: x**2)(end_logits)subject_1 = extract_subject([layer_1,s_po_index])Normalization_1 = LayerNormalization(conditional=True)([layer_1, subject_1])'''图中并没没有表示如何加入实体的position embedding,这部分也是本人自己摸索出来的。仅供参考'''position_emb_s = bert_model.bert.get_input_embeddings().position_embeddings(s_po_index[:,0])position_emb_e = bert_model.bert.get_input_embeddings().position_embeddings(s_po_index[:,1])position_embedding = position_emb_s + position_emb_eposition_embedding = position_embedding[:,tf.newaxis,:]add_position = Normalization_1 + position_embeddingself_attenion = TFBertSelfAttention(768,1)(add_position,att,head_mask=None,output_attentions=False)dense = tf.keras.layers.Dense(768,activation='relu')(self_attenion)dense = tf.keras.layers.Dropout(0.2)(dense)dense = tf.keras.layers.Dense(512,activation='relu')(dense)op_out_put_start = tf.keras.layers.Dense(len(p2id),activation = 'sigmoid')(dense)op_out_put_start = tf.keras.layers.Lambda(lambda x: x**4)(op_out_put_start)op_out_put_end = tf.keras.layers.Dense(len(p2id),activation = 'sigmoid')(dense)op_out_put_end = tf.keras.layers.Lambda(lambda x: x**4)(op_out_put_end)model = tf.keras.models.Model(inputs=[ids,att,s_po_index], outputs=[start_logits,end_logits,op_out_put_start,op_out_put_end])model_2 = tf.keras.models.Model(inputs=[ids,att], outputs=[start_logits,end_logits])model_3 = tf.keras.models.Model(inputs=[ids,att,s_po_index], outputs=[op_out_put_start,op_out_put_end])return model,model_2,model_3

模型提升效果:F1:0.7979 —> 0.8095


加入对抗训练FGM

@tf.function
def train_step(model,x,y,loss_func,optimizer,train_loss):with tf.GradientTape() as tape:y_pred = model(x,training=True)loss1 = loss_func(y['lambda'],tf.squeeze(y_pred[0]))loss2 = loss_func(y['lambda_1'],tf.squeeze(y_pred[1]))loss3 = loss_func(y['lambda_2'],tf.squeeze(y_pred[2]))loss4 = loss_func(y['lambda_3'],tf.squeeze(y_pred[3]))loss = loss1+loss2+loss3+loss4embedding = model.trainable_variables[0]embedding_gradients = tape.gradient(loss,[model.trainable_variables[0]])[0]embedding_gradients = tf.zeros_like(embedding) + embedding_gradientsdelta = 0.1 * embedding_gradients / (tf.math.sqrt(tf.reduce_sum(embedding_gradients**2)) + 1e-8)  # 计算扰动model.trainable_variables[0].assign_add(delta)with tf.GradientTape() as tape2:y_pred = model(x,training=True)loss1 = loss_func(y['lambda'],tf.squeeze(y_pred[0]))loss2 = loss_func(y['lambda_1'],tf.squeeze(y_pred[1]))loss3 = loss_func(y['lambda_2'],tf.squeeze(y_pred[2]))loss4 = loss_func(y['lambda_3'],tf.squeeze(y_pred[3]))new_loss = loss1+loss2+loss3+loss4gradients = tape2.gradient(new_loss,model.trainable_variables)model.trainable_variables[0].assign_sub(delta)optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_loss.update_state(new_loss)

但由于本人设备配置有限,并没有得到相应的对抗训练结果,大家可以自己尝试~


总结


完整代码地址: https://github.com/zhengyanzhao1997/TF-NLP-model/blob/main/model/train/Three_relation_extract.py

参考文章:
苏剑林. (2020, Jan 03). 《用bert4keras做三元组抽取 》[Blog post]. Retrieved from https://kexue.fm/archives/7161
一人之力,刷爆三路榜单!信息抽取竞赛夺冠经验分享

信息抽取(三)三元关系抽取——改良后的层叠式指针网络,让我的模型F1提升近4%(接上篇)相关推荐

  1. 信息抽取(二)花了一个星期走了无数条弯路终于用TF复现了苏神的《Bert三元关系抽取模型》,我到底悟到了什么?

    信息抽取(二)花了一个星期走了无数条弯路终于用TF复现了苏神的<Bert三元关系抽取>模型,我到底悟到了什么? 前言 数据格式与任务目标 模型整体思路 复现代码 数据处理 数据读取 训练数 ...

  2. A.2【数据标注】:基于Label studio的训练数据标注指南:信息抽取(实体关系抽取)、文本分类等

    NLP专栏简介:数据增强.智能标注.意图识别算法|多分类算法.文本信息抽取.多模态信息抽取.可解释性分析.性能调优.模型压缩算法等 专栏详细介绍:NLP专栏简介:数据增强.智能标注.意图识别算法|多分 ...

  3. 徐阿衡 | 知识抽取-实体及关系抽取(一)

    本文转载自公众号:徐阿衡. 这一篇是关于知识抽取,整理并补充了上学时的两篇笔记 NLP笔记 - Information Extraction 和 NLP笔记 - Relation Extraction ...

  4. python实体关系抽取_【关系抽取】从文本中进行关系抽取的几种不同的方法

    关系提取是指从文本中提取语义关系,这种语义关系通常发生在两个或多个实体之间.这些关系可以是不同类型的." Paris is in France "表示巴黎与法国之间的" ...

  5. 信息抽取(四)【NLP论文复现】Multi-head Selection和Deep Biaffine Attention在关系抽取中的实现和效果

    Multi-head Selection和Deep Biaffine Attention在关系抽取中的应用 前言 Multi-head Selection 一.Joint entity recogni ...

  6. 【文本信息抽取与结构化】深入了解关系抽取你需要知道的东西

    常常在想,自然语言处理到底在做的是一件什么样的事情?到目前为止,我所接触到的NLP其实都是在做一件事情,即将自然语言转化为一种计算机能够理解的形式.这一点在知识图谱.信息抽取.文本摘要这些任务中格外明 ...

  7. 信息抽取——关系抽取

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 简介信息抽取(information extraction),即从自然语言文本中,抽取出特定的 ...

  8. 必读!信息抽取(Information Extraction)【关系抽取】

    来源: AINLPer 微信公众号(每日给你好看-) 编辑: ShuYini 校稿: ShuYini 时间: 2020-08-11 引言     信息抽取(information extraction ...

  9. AGGCN | 基于图神经网络的关系抽取模型

    今天给大家介绍2019年6月发表在ACL上的论文"Attention Guided Graph Convolutional Networks for Relation Extraction& ...

最新文章

  1. Redis中有序集合zset数据类型(增加(添加元素)、获取(获取指定范围元素、返回权值在min和max之间的成员、返回成员member的score值)、删除(删除指定元素和指定权值范围的元素))
  2. 万字干货 | 一文揭秘Presto在腾讯资讯业务中的应用
  3. win7系统, vim的_vimrc文件无法修改
  4. 查询存在表1但不存在表2的所有数据
  5. springboot 创建地址_这可能是史上最易懂的 Spring Boot 入门教程
  6. hyperstudy联合matlab,HyperStudy对后处理排气管道参数的灵敏度分析及优化设计
  7. python安装redis库
  8. makefile的两个变量(自动变量和普通变量)
  9. excel查找在哪里_Excel办公自动化,让低值费时的工作自动进行
  10. JAVA声明圆锥体类,实现Area和Volume接口,计算表面积和体积,按体积比较大小
  11. 如何利用导数推导向心加速度公式? + 开普勒 第三定律的推导过程
  12. 计算机系统中引入虚拟内存的好处是什么,虚拟内存的作用是什么
  13. OS学习笔记-1(清华大学慕课)操作系统概述
  14. android记账app开发全过程,android开发实战-记账本APP(一)
  15. amd为什么还用针脚_闲聊CPU针脚 一年一换都怪AMD不给力?
  16. struts1 使用poi组件 读取excel文件,创建excel ,输出excel文件
  17. 【2018-11-15】中证1000指数的估值详情
  18. shell版本爬取NVD网站信息
  19. TP5.1 支付宝app支付 (沙箱本地测试)
  20. UE5 官方案例Lyra 全特性详解 7.资源管理

热门文章

  1. 你不知道的Node.js性能优化,读了之后水平直线上升
  2. 聊聊前段插件之Datatables
  3. 如何在程序中添加iAd广告
  4. erlang分布式编程模型
  5. 数据结构——HDU1312:Red and Black(DFS)
  6. js 调用父窗口的方法
  7. [28期] lamp兄弟连28期学员手册,请大家务必看一下
  8. java f.add()_f.add(p1,First); 那个“First”是什么意思呀?
  9. KEIL-MDK编译错误问题解决办法
  10. 国外在线学习网站+慕课平台