《原始论文:Attention-based bidirectional long short-term memory networks for relation classification》

一、概述

1、本文idea提出原因

传统的方法中,大多数研究依赖于一些现有的词汇资源(例如WordNet)、NLP系 统或一些手工提取的特征。这样的方法可能导致计算复杂度的增加,并且特征提取工作本身会耗费大量的时间和精力,特征提取质量的对于实验的结果也有很大的影响。

提出了 ATT-BLSTM的网络结构解决关系端对端识别问题

这篇论文从这一角度出发,提出一个基于Attention机制的双向 LSTM神经网络模型进行关系抽取研究,Attention机制能够自动 发现那些对于分类起到关键作用的词,使得这个模型可以从每个句子中捕获最重要的语义信息,它不依赖于任何外部的知识或者NLP系统

2、本论文历史意义

巧妙地在双向LSTM模型中加入Attention机制,用于关系抽取任务,避免了传统的 任务中复杂的特征工程,大大简化了实验过程并得到相当不错的结果,也为相关的研究提供了可操作性的思路

这篇论文的整体的逻辑十分清晰,紧紧围绕研究动机.整篇论文的思路十分简单,模型也一目了然,但是结果表现优秀

3、摘要核心

  1. 目前关系识别依赖于Mp工具提取特征;
  2. 提出一种不需要复杂预处理的关系识别方法att-blstm;
  3. 实验结果表明该方法是有效的,达到the state-of-the-art的效果

二、Attention-BiLSTM模型结构

1、模型结构


ATT-BLSTM网络结构以word embeding为基础,加入实体标识位,通过ATT-BLSTM的结构让模型动态区分关系分类的重要词汇。
As shown in Figure 1, the model proposed in this paper contains five components:

  1. 输入句子:Input layer: input sentence to this model;
  2. Embedding layer: map each word into a low dimension vector;
  3. BiLSTM:LSTM layer: utilize BLSTM to get high level features from step (2);
  4. Attention layer: produce a weight vector, and merge word-level features from each time step into a sentence-level feature vector, by multiplying the weight vector;
  5. Output layer: the sentence-level feature vec- tor is finally used for relation classification.

2、Attention 原理

Attention 原理:Attention Mechanism可以帮助模型对输入的X每个部分赋予不同的权重,抽取出更加关键及重要的信息,使模型做出更加准确的判断,同时不会对模型的计算和存储带来更大的开销。


根据Attention的计算区域,可以分成以下几种:

  1. Soft-Attention/Global Attention:这是比较常见的Attention方式,对所有key求权重概率,每个key都有一个对应的权重,是一种全局的计算方式(也可以叫Global Attention).
  2. Hard-Attention:这种方式是直接精准定位到某个key,其余key就都不管了,相当于这个key的 概率是1 ,其余key的概率全部是0。因此这种对齐方式要求很高,要求一步到位,如果没有正确对齐, 会带来很大的影响。另一方面,因为不可导,一般需要用强化学习的方法进行训练
  3. Local-Attention:这种方式其实是以上两种方式的一个折中,对一个窗口区域进行计算。先用 Hard方式定位到某个地方,以这个点为中心可以得到一个窗口区域,在这个小区域内用Soft方式来
    算 Attention。

3、小技巧

对实体前后添加特定标识符标明实体位置

采用带约束的正则损失

三、实验结果

compare various model configurations on the SemEval-2010 Task 8 dataset

四、论文结论

1、关键点

不依赖任何其他NLP工具

2、创新点

引入Attention-BiLSTM结构

3、启发点

网格结构完全不依何nlp工具或词法资源,只需要带位置标识的原始文本作为输入。

This model does not rely on NLP tools or lexical resources to get, it uses raw text with position indicators as input.

五、论文代码

1、数据集

1.1 原始数据集

train_file.txt【样本1-8000】

1    "The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>."
Component-Whole(e2,e1)
Comment: Not a collection: there is structure here, organisation.2 "The <e1>child</e1> was carefully wrapped and bound into the <e2>cradle</e2> by means of a cord."
Other
Comment:3  "The <e1>author</e1> of a keygen uses a <e2>disassembler</e2> to look at the raw assembly code."
Instrument-Agency(e2,e1)
Comment:4  "A misty <e1>ridge</e1> uprises from the <e2>surge</e2>."
Other
Comment:5  "The <e1>student</e1> <e2>association</e2> is the voice of the undergraduate student population of the State University of New York at Buffalo."
Member-Collection(e1,e2)
Comment:6  "This is the sprawling <e1>complex</e1> that is Peru's largest <e2>producer</e2> of silver."
Other
Comment:7  "The current view is that the chronic <e1>inflammation</e1> in the distal part of the stomach caused by Helicobacter pylori <e2>infection</e2> results in an increased acid production from the non-infected upper corpus region of the stomach."
Cause-Effect(e2,e1)
Comment:8  "<e1>People</e1> have been moving back into <e2>downtown</e2>."
Entity-Destination(e1,e2)
Comment:9  "The <e1>lawsonite</e1> was contained in a <e2>platinum crucible</e2> and the counter-weight was a plastic crucible with metal pieces."
Content-Container(e1,e2)
Comment: prototypical example10    "The solute was placed inside a beaker and 5 mL of the <e1>solvent</e1> was pipetted into a 25 mL glass <e2>flask</e2> for each trial."
Entity-Destination(e1,e2)
Comment:
......

test_file.txt【样本8001-10717】

8001 "The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."
Message-Topic(e1,e2)
Comment: Assuming an audit = an audit document.8002 "The <e1>company</e1> fabricates plastic <e2>chairs</e2>."
Product-Producer(e2,e1)
Comment: (a) is satisfied8003 "The school <e1>master</e1> teaches the lesson with a <e2>stick</e2>."
Instrument-Agency(e2,e1)
Comment:8004   "The suspect dumped the dead <e1>body</e1> into a local <e2>reservoir</e2>."
Entity-Destination(e1,e2)
Comment:8005   "Avian <e1>influenza</e1> is an infectious disease of birds caused by type A strains of the influenza <e2>virus</e2>."
Cause-Effect(e2,e1)
Comment:8006   "The <e1>ear</e1> of the African <e2>elephant</e2> is significantly larger--measuring 183 cm by 114 cm in the bush elephant."
Component-Whole(e1,e2)
Comment:8007   "A child is told a <e1>lie</e1> for several years by their <e2>parents</e2> before he/she realizes that a Santa Claus does not exist."
Product-Producer(e1,e2)
Comment: (a) is satisfied; negation is outside8008  "Skype, a free software, allows a <e1>hookup</e1> of multiple computer <e2>users</e2> to join in an online conference call without incurring any telephone costs."
Member-Collection(e2,e1)
Comment:8009   "The disgusting scene was retaliation against her brother Philip who rents the <e1>room</e1> inside this apartment <e2>house</e2> on Lombard street."
Component-Whole(e1,e2)
Comment:8010   "This <e1>thesis</e1> defines the <e2>clinical characteristics</e2> of amyloid disease."
Message-Topic(e1,e2)
Comment: may be we could leave clinical out of e2.

1.2 处理后的数据

preprocess.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6import json
import re
from nltk.tokenize import word_tokenizedef search_entity(sentence):e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)sentence = word_tokenize(sentence)sentence = ' '.join(sentence)sentence = sentence.replace('< e1 >', '<e1>')sentence = sentence.replace('< e2 >', '<e2>')sentence = sentence.replace('< /e1 >', '</e1>')sentence = sentence.replace('< /e2 >', '</e2>')sentence = sentence.split()assert '<e1>' in sentenceassert '<e2>' in sentenceassert '</e1>' in sentenceassert '</e2>' in sentencereturn sentencedef convert(path_src, path_des):with open(path_src, 'r', encoding='utf-8') as fr:data = fr.readlines()with open(path_des, 'w', encoding='utf-8') as fw:for i in range(0, len(data), 4):id_s, sentence = data[i].strip().split('\t')sentence = sentence[1:-1]sentence = search_entity(sentence)meta = dict(id=id_s,relation=data[i+1].strip(),sentence=sentence,comment=data[i+2].strip()[8:])json.dump(meta, fw, ensure_ascii=False)fw.write('\n')if __name__ == '__main__':path_train = './SemEval2010_task8_all_data/SemEval2010_task8_training/TRAIN_FILE.TXT'path_test = './SemEval2010_task8_all_data/SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT'convert(path_train, 'train.json')convert(path_test, 'test.json')

train.json

{"id": "1", "relation": "Component-Whole(e2,e1)", "sentence": ["The", "system", "as", "described", "above", "has", "its", "greatest", "application", "in", "an", "arrayed", "<e1>", "configuration", "</e1>", "of", "antenna", "<e2>", "elements", "</e2>", "."], "comment": " Not a collection: there is structure here, organisation."}
{"id": "2", "relation": "Other", "sentence": ["The", "<e1>", "child", "</e1>", "was", "carefully", "wrapped", "and", "bound", "into", "the", "<e2>", "cradle", "</e2>", "by", "means", "of", "a", "cord", "."], "comment": ""}
{"id": "3", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "<e1>", "author", "</e1>", "of", "a", "keygen", "uses", "a", "<e2>", "disassembler", "</e2>", "to", "look", "at", "the", "raw", "assembly", "code", "."], "comment": ""}
{"id": "4", "relation": "Other", "sentence": ["A", "misty", "<e1>", "ridge", "</e1>", "uprises", "from", "the", "<e2>", "surge", "</e2>", "."], "comment": ""}
{"id": "5", "relation": "Member-Collection(e1,e2)", "sentence": ["The", "<e1>", "student", "</e1>", "<e2>", "association", "</e2>", "is", "the", "voice", "of", "the", "undergraduate", "student", "population", "of", "the", "State", "University", "of", "New", "York", "at", "Buffalo", "."], "comment": ""}
......

test.json

{"id": "8001", "relation": "Message-Topic(e1,e2)", "sentence": ["The", "most", "common", "<e1>", "audits", "</e1>", "were", "about", "<e2>", "waste", "</e2>", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}
{"id": "8002", "relation": "Product-Producer(e2,e1)", "sentence": ["The", "<e1>", "company", "</e1>", "fabricates", "plastic", "<e2>", "chairs", "</e2>", "."], "comment": " (a) is satisfied"}
{"id": "8003", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "school", "<e1>", "master", "</e1>", "teaches", "the", "lesson", "with", "a", "<e2>", "stick", "</e2>", "."], "comment": ""}
{"id": "8004", "relation": "Entity-Destination(e1,e2)", "sentence": ["The", "suspect", "dumped", "the", "dead", "<e1>", "body", "</e1>", "into", "a", "local", "<e2>", "reservoir", "</e2>", "."], "comment": ""}
{"id": "8005", "relation": "Cause-Effect(e2,e1)", "sentence": ["Avian", "<e1>", "influenza", "</e1>", "is", "an", "infectious", "disease", "of", "birds", "caused", "by", "type", "A", "strains", "of", "the", "influenza", "<e2>", "virus", "</e2>", "."], "comment": ""}
......

1.3 relation2id

Other  0
Cause-Effect(e1,e2)    1
Cause-Effect(e2,e1)    2
Component-Whole(e1,e2) 3
Component-Whole(e2,e1) 4
Content-Container(e1,e2)   5
Content-Container(e2,e1)   6
Entity-Destination(e1,e2)  7
Entity-Destination(e2,e1)  8
Entity-Origin(e1,e2)   9
Entity-Origin(e2,e1)   10
Instrument-Agency(e1,e2)   11
Instrument-Agency(e2,e1)   12
Member-Collection(e1,e2)   13
Member-Collection(e2,e1)   14
Message-Topic(e1,e2)   15
Message-Topic(e2,e1)   16
Product-Producer(e1,e2)    17
Product-Producer(e2,e1)    18

2、预训练词向量:静态词向量HLBL

hlbl-embeddings-scaled.EMBEDDING_SIZE=50

*UNKNOWN* -0.166038776479 0.104395984608 0.163119732357 0.0899594154863 -0.0192271099805 -0.0417631572501 -0.0163376687927 0.0357616216019 0.0536077591673 0.0127688536503 -0.00284508433021 -0.0626207031228 -0.0379452734015 -0.103548297666 0.0381169119981 0.00199421074321 -0.0474636488659 -0.0127526851513 0.016404178535 -0.12759853361 -0.0292937037717 -0.0512566352549 0.0233097445983 0.0360505083995 0.00229317984472 -0.0771565284227 0.0071461584378 -0.051608090196 -0.0267547654304 0.0492994451068 -0.0531630844999 0.00787191810391 0.082280106873 0.066908641868 -0.0283930612982 0.216840166248 0.164923151267 0.00188498983723 0.0328679039324 -0.00175432516758 0.0614261774935 0.0987773071377 0.0548423375506 -0.0307057922059 0.053074241476 0.04982054279 -0.0572485864016 0.132236444766 -0.0379717035014 -0.120915939814
the -0.0841015569168 0.145263825738 0.116945121935 -0.0754618634155 0.17901499611 -0.000652852605208 -0.0713783879233 0.207273704502 0.060711721477 0.0366727701165 -0.0269791566731 -0.156993473526 -0.0393947453024 0.00749161628231 -0.332851634057 -0.1708430781 -0.275163605231 -0.266592614101 0.43349041466 -0.00779248211778 0.031101796379 -0.0257114150838 0.174856713352 -0.0543054233622 -0.0846669459476 -0.006234398456 0.00414488584462 0.119738648443 -0.0914876936952 -0.317381121871 -0.27471439742 0.234269597998 0.170305945138 -0.0282815073325 -0.10127814458 0.156451476203 0.154703520781 -0.0014827085612 0.164287521114 0.0328582913203 0.0356570354049 -0.190254406793 -0.112029936115 -0.198875312619 0.00102875631152 -0.00161517169984 -0.125210890327 0.196903181061 -0.112017915766 -0.00838804375065
. -0.0875932389444 -0.0586365253633 0.0729727126603 0.32072000431 0.0745620569276 -0.0494709138174 0.208708067552 -0.025035364294 -0.197531050237 0.177318202028 0.297077745222 -0.0256369072571 0.182364658364 0.189089099105 0.0589179494006 -0.0627276310572 0.0682898379459 0.241161712515 0.253510796291 -0.0325139691451 -0.0129081882483 -0.083367340352 0.0276167362372 -0.00757124183183 -0.0905801885623 0.305015208385 0.0755474920504 -0.00516459185438 -0.0412876867803 0.105047372601 -0.718674456034 0.184682477295 0.232732814491 0.0929975692214 0.0999329447708 -0.0968008990987 0.421525505372 -0.136460066398 -0.323294448817 0.118318915141 0.415411774103 -0.135770867168 0.0404792691614 0.264279769529 -0.133076243622 0.195087919022 -0.087589323012 0.0335223022065 -0.0365650611956 -0.0163760300203
, -0.023019838485 0.277215570968 0.241932261453 -0.105403438907 0.247316949736 0.0859618436243 -0.0130132156599 0.123988163629 -0.150741462418 0.129993766762 0.0766431623839 0.0547135456598 0.187342182554 0.176303102861 -0.121401723217 0.0458278230666 0.0339804870854 -0.0619606057248 0.0514787739809 0.00732501266557 0.0879996990484 -0.369288823679 0.235222707122 -0.0528783055204 0.0121891472663 -0.165169815904 -0.136829953355 -0.0750751223049 -0.0503433833321 0.0782539868365 -0.400940778018 -0.099745222007 -0.152448498545 -0.0815002789835 -0.010575616616 0.331604536668 -0.0124179474775 0.00173559407939 -0.230971231526 0.0162523457081 0.213848645598 0.184698023693 0.158368229826 0.0975422545404 -0.0307127563081 0.093420146492 -0.0377856184872 -0.0181716170654 0.43322993915 -0.113289957059
to 0.134693667961 0.392203653086 0.0346151199225 0.135354475458 0.0719918082372 0.118667933013 -0.0698386234679 -0.0139927084407 0.144452931939 0.0383223273458 -0.0491954394553 -0.126435975874 0.23979196724 -0.186550477314 0.0602616605691 -0.0875395769807 0.0788848675161 0.132691898026 0.155618778336 0.00680378469567 -0.126513561203 -0.436124771467 0.132675129426 -0.0946286638801 0.0986847070674 -0.354397304845 -0.196909463175 -0.0911408611189 0.134975690877 0.0625931974859 0.0108112360985 -0.107933544401 -0.166545488854 0.0137397678012 -0.0268394211932 -0.260328038765 0.0745185746772 0.020864049205 0.133485534344 -0.0479098207297 0.145382061477 -0.116284346216 0.0822848147919 -0.00621959258902 0.0135679910959 -0.0723116375013 -0.422793539068 0.144456402991 -0.119019192402 0.0659297394103
......

3、config.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6import argparse
import torch
import os
import random
import json
import numpy as npclass Config(object):def __init__(self):# get init configargs = self.__get_config()for key in args.__dict__:setattr(self, key, args.__dict__[key])# select deviceself.device = Noneif self.cuda >= 0 and torch.cuda.is_available():self.device = torch.device('cuda:{}'.format(self.cuda))else:self.device = torch.device('cpu')# determine the model name and model dirif self.model_name is None:self.model_name = 'Att_BLSTM'self.model_dir = os.path.join(self.output_dir, self.model_name)if not os.path.exists(self.model_dir):os.makedirs(self.model_dir)# backup dataself.__config_backup(args)# set the random seedself.__set_seed(self.seed)def __get_config(self):parser = argparse.ArgumentParser()parser.description = 'config for models'# several key selective parametersparser.add_argument('--data_dir', type=str,default='./data',help='dir to load data')parser.add_argument('--output_dir', type=str,default='./output',help='dir to save output')# word embeddingparser.add_argument('--embedding_path', type=str,default='./embedding/glove.6B.100d.txt',help='pre_trained word embedding')parser.add_argument('--word_dim', type=int,default=100,help='dimension of word embedding')# train settingsparser.add_argument('--model_name', type=str,default=None,help='model name')parser.add_argument('--mode', type=int,default=1,choices=[0, 1],help='running mode: 1 for training; otherwise testing')parser.add_argument('--seed', type=int,default=5782,help='random seed')parser.add_argument('--cuda', type=int,default=0,help='num of gpu device, if -1, select cpu')parser.add_argument('--epoch', type=int,default=30,help='max epoches during training')# hyper parametersparser.add_argument('--batch_size', type=int,default=10,help='batch size')parser.add_argument('--lr', type=float,default=1.0,help='learning rate')parser.add_argument('--max_len', type=int,default=100,help='max length of sentence')parser.add_argument('--emb_dropout', type=float,default=0.3,help='the possiblity of dropout in embedding layer')parser.add_argument('--lstm_dropout', type=float,default=0.3,help='the possiblity of dropout in (Bi)LSTM layer')parser.add_argument('--linear_dropout', type=float,default=0.5,help='the possiblity of dropout in liner layer')parser.add_argument('--hidden_size', type=int,default=100,help='the dimension of hidden units in (Bi)LSTM layer')parser.add_argument('--layers_num', type=int,default=1,help='num of RNN layers')parser.add_argument('--L2_decay', type=float, default=1e-5,help='L2 weight decay')args = parser.parse_args()return argsdef __set_seed(self, seed=1234):os.environ['PYTHONHASHSEED'] = '{}'.format(seed)random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)  # set seed for cputorch.cuda.manual_seed(seed)  # set seed for current gputorch.cuda.manual_seed_all(seed)  # set seed for all gpudef __config_backup(self, args):config_backup_path = os.path.join(self.model_dir, 'config.json')with open(config_backup_path, 'w', encoding='utf-8') as fw:json.dump(vars(args), fw, ensure_ascii=False)def print_config(self):for key in self.__dict__:print(key, end=' = ')print(self.__dict__[key])if __name__ == '__main__':config = Config()config.print_config()

4、model.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequenceclass Att_BLSTM(nn.Module):def __init__(self, word_vec, class_num, config):super().__init__()self.word_vec = word_vecself.class_num = class_num# hyper parameters and othersself.max_len = config.max_lenself.word_dim = config.word_dimself.hidden_size = config.hidden_sizeself.layers_num = config.layers_numself.emb_dropout_value = config.emb_dropoutself.lstm_dropout_value = config.lstm_dropoutself.linear_dropout_value = config.linear_dropout# net structures and operationsself.word_embedding = nn.Embedding.from_pretrained(embeddings=self.word_vec,freeze=False,)self.lstm = nn.LSTM(input_size=self.word_dim,hidden_size=self.hidden_size,num_layers=self.layers_num,bias=True,batch_first=True,dropout=0,bidirectional=True,)self.tanh = nn.Tanh()self.emb_dropout = nn.Dropout(self.emb_dropout_value)self.lstm_dropout = nn.Dropout(self.lstm_dropout_value)self.linear_dropout = nn.Dropout(self.linear_dropout_value)self.att_weight = nn.Parameter(torch.randn(1, self.hidden_size, 1))self.dense = nn.Linear(in_features=self.hidden_size,out_features=self.class_num,bias=True)# initialize weightinit.xavier_normal_(self.dense.weight)init.constant_(self.dense.bias, 0.)def lstm_layer(self, x, mask):lengths = torch.sum(mask.gt(0), dim=-1)x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)h, (_, _) = self.lstm(x)h, _ = pad_packed_sequence(h, batch_first=True, padding_value=0.0, total_length=self.max_len)h = h.view(-1, self.max_len, 2, self.hidden_size)h = torch.sum(h, dim=2)  # B*L*Hreturn hdef attention_layer(self, h, mask):att_weight = self.att_weight.expand(mask.shape[0], -1, -1)  # B*H*1att_score = torch.bmm(self.tanh(h), att_weight)  # B*L*H  *  B*H*1 -> B*L*1# mask, remove the effect of 'PAD'mask = mask.unsqueeze(dim=-1)  # B*L*1att_score = att_score.masked_fill(mask.eq(0), float('-inf'))  # B*L*1att_weight = F.softmax(att_score, dim=1)  # B*L*1reps = torch.bmm(h.transpose(1, 2), att_weight).squeeze(dim=-1)  # B*H*L *  B*L*1 -> B*H*1 -> B*Hreps = self.tanh(reps)  # B*Hreturn repsdef forward(self, data):token = data[:, 0, :].view(-1, self.max_len)mask = data[:, 1, :].view(-1, self.max_len)emb = self.word_embedding(token)  # B*L*word_dimemb = self.emb_dropout(emb)h = self.lstm_layer(emb, mask)  # B*L*Hh = self.lstm_dropout(h)reps = self.attention_layer(h, mask)  # B*repsreps = self.linear_dropout(reps)logits = self.dense(reps)return logits

5、train_or_test.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6import os
import torch
import torch.nn as nn
import torch.optim as optimfrom config import Config
from utils import WordEmbeddingLoader, RelationLoader, SemEvalDataLoader
from model import Att_BLSTM
from evaluate import Evaldef print_result(predict_label, id2rel, start_idx=8001):with open('predicted_result.txt', 'w', encoding='utf-8') as fw:for i in range(0, predict_label.shape[0]):fw.write('{}\t{}\n'.format(start_idx+i, id2rel[int(predict_label[i])]))def train(model, criterion, loader, config):train_loader, dev_loader, _ = loaderoptimizer = optim.Adadelta(model.parameters(), lr=config.lr, weight_decay=config.L2_decay)print(model)print('traning model parameters:')for name, param in model.named_parameters():if param.requires_grad:print('%s :  %s' % (name, str(param.data.shape)))print('--------------------------------------')print('start to train the model ...')eval_tool = Eval(config)min_f1 = -float('inf')for epoch in range(1, config.epoch+1):for step, (data, label) in enumerate(train_loader):model.train()data = data.to(config.device)label = label.to(config.device)optimizer.zero_grad()logits = model(data)loss = criterion(logits, label)loss.backward()nn.utils.clip_grad_value_(model.parameters(), clip_value=5)optimizer.step()_, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader)f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader)print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f'% (epoch, train_loss, dev_loss, f1), end=' ')if f1 > min_f1:min_f1 = f1torch.save(model.state_dict(), os.path.join(config.model_dir, 'model.pkl'))print('>>> save models!')else:print()def test(model, criterion, loader, config):print('--------------------------------------')print('start test ...')_, _, test_loader = loadermodel.load_state_dict(torch.load(os.path.join(config.model_dir, 'model.pkl')))eval_tool = Eval(config)f1, test_loss, predict_label = eval_tool.evaluate(model, criterion, test_loader)print('test_loss: %.3f | micro f1 on test:  %.4f' % (test_loss, f1))return predict_labelif __name__ == '__main__':config = Config()print('--------------------------------------')print('some config:')config.print_config()print('--------------------------------------')print('start to load data ...')word2id, word_vec = WordEmbeddingLoader(config).load_embedding()rel2id, id2rel, class_num = RelationLoader(config).get_relation()loader = SemEvalDataLoader(rel2id, word2id, config)train_loader, dev_loader = None, Noneif config.mode == 1:  # train modetrain_loader = loader.get_train()dev_loader = loader.get_dev()test_loader = loader.get_test()loader = [train_loader, dev_loader, test_loader]print('finish!')print('--------------------------------------')model = Att_BLSTM(word_vec=word_vec, class_num=class_num, config=config)model = model.to(config.device)criterion = nn.CrossEntropyLoss()if config.mode == 1:  # train modetrain(model, criterion, loader, config)predict_label = test(model, criterion, loader, config)print_result(predict_label, id2rel)

6、evaluate.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6import numpy as np
import torchdef semeval_scorer(predict_label, true_label, class_num=10):import mathassert true_label.shape[0] == predict_label.shape[0]confusion_matrix = np.zeros(shape=[class_num, class_num], dtype=np.float32)xDIRx = np.zeros(shape=[class_num], dtype=np.float32)for i in range(true_label.shape[0]):true_idx = math.ceil(true_label[i]/2)predict_idx = math.ceil(predict_label[i]/2)if true_label[i] == predict_label[i]:confusion_matrix[predict_idx][true_idx] += 1else:if true_idx == predict_idx:xDIRx[predict_idx] += 1else:confusion_matrix[predict_idx][true_idx] += 1col_sum = np.sum(confusion_matrix, axis=0).reshape(-1)row_sum = np.sum(confusion_matrix, axis=1).reshape(-1)f1 = np.zeros(shape=[class_num], dtype=np.float32)for i in range(0, class_num):  # ignore the 'Other'try:p = float(confusion_matrix[i][i]) / float(col_sum[i] + xDIRx[i])r = float(confusion_matrix[i][i]) / float(row_sum[i] + xDIRx[i])f1[i] = (2 * p * r / (p + r))except:passactual_class = 0total_f1 = 0.0for i in range(1, class_num):if f1[i] > 0.0:  # classes that not in the predict label are not consideredactual_class += 1total_f1 += f1[i]try:macro_f1 = total_f1 / actual_classexcept:macro_f1 = 0.0return macro_f1class Eval(object):def __init__(self, config):self.device = config.devicedef evaluate(self, model, criterion, data_loader):predict_label = []true_label = []total_loss = 0.0with torch.no_grad():model.eval()for _, (data, label) in enumerate(data_loader):data = data.to(self.device)label = label.to(self.device)logits = model(data)loss = criterion(logits, label)total_loss += loss.item() * logits.shape[0]_, pred = torch.max(logits, dim=1)  # replace softmax with max function, same impactspred = pred.cpu().detach().numpy().reshape((-1, 1))label = label.cpu().detach().numpy().reshape((-1, 1))predict_label.append(pred)true_label.append(label)predict_label = np.concatenate(predict_label, axis=0).reshape(-1).astype(np.int64)true_label = np.concatenate(true_label, axis=0).reshape(-1).astype(np.int64)eval_loss = total_loss / predict_label.shape[0]f1 = semeval_scorer(predict_label, true_label)return f1, eval_loss, predict_label

7、util.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoaderclass WordEmbeddingLoader(object):"""A loader for pre-trained word embedding"""def __init__(self, config):self.path_word = config.embedding_path  # path of pre-trained word embeddingself.word_dim = config.word_dim  # dimension of word embeddingdef load_embedding(self):word2id = dict()  # word to wordIDword_vec = list()  # wordID to word embeddingword2id['PAD'] = len(word2id)  # PAD characterword2id['UNK'] = len(word2id)  # out of vocabularyword2id['<e1>'] = len(word2id)word2id['<e2>'] = len(word2id)word2id['</e1>'] = len(word2id)word2id['</e2>'] = len(word2id)with open(self.path_word, 'r', encoding='utf-8') as fr:for line in fr:line = line.strip().split()if len(line) != self.word_dim + 1:continueword2id[line[0]] = len(word2id)word_vec.append(np.asarray(line[1:], dtype=np.float32))word_vec = np.stack(word_vec)vec_mean, vec_std = word_vec.mean(), word_vec.std()special_emb = np.random.normal(vec_mean, vec_std, (6, self.word_dim))special_emb[0] = 0  # <pad> is initialize as zeroword_vec = np.concatenate((special_emb, word_vec), axis=0)word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim)word_vec = torch.from_numpy(word_vec)return word2id, word_vecclass RelationLoader(object):def __init__(self, config):self.data_dir = config.data_dirdef __load_relation(self):relation_file = os.path.join(self.data_dir, 'relation2id.txt')rel2id = {}id2rel = {}with open(relation_file, 'r', encoding='utf-8') as fr:for line in fr:relation, id_s = line.strip().split()id_d = int(id_s)rel2id[relation] = id_did2rel[id_d] = relationreturn rel2id, id2rel, len(rel2id)def get_relation(self):return self.__load_relation()class SemEvalDateset(Dataset):def __init__(self, filename, rel2id, word2id, config):self.filename = filenameself.rel2id = rel2idself.word2id = word2idself.max_len = config.max_lenself.data_dir = config.data_dirself.dataset, self.label = self.__load_data()def __symbolize_sentence(self, sentence):"""Args:sentence (list)"""mask = [1] * len(sentence)words = []length = min(self.max_len, len(sentence))mask = mask[:length]for i in range(length):words.append(self.word2id.get(sentence[i].lower(), self.word2id['UNK']))if length < self.max_len:for i in range(length, self.max_len):mask.append(0)  # 'PAD' mask is zerowords.append(self.word2id['PAD'])unit = np.asarray([words, mask], dtype=np.int64)unit = np.reshape(unit, newshape=(1, 2, self.max_len))return unitdef __load_data(self):path_data_file = os.path.join(self.data_dir, self.filename)data = []labels = []with open(path_data_file, 'r', encoding='utf-8') as fr:for line in fr:line = json.loads(line.strip())label = line['relation']sentence = line['sentence']label_idx = self.rel2id[label]one_sentence = self.__symbolize_sentence(sentence)data.append(one_sentence)labels.append(label_idx)return data, labelsdef __getitem__(self, index):data = self.dataset[index]label = self.label[index]return data, labeldef __len__(self):return len(self.label)class SemEvalDataLoader(object):def __init__(self, rel2id, word2id, config):self.rel2id = rel2idself.word2id = word2idself.config = configdef __collate_fn(self, batch):data, label = zip(*batch)  # unzip the batch datadata = list(data)label = list(label)data = torch.from_numpy(np.concatenate(data, axis=0))label = torch.from_numpy(np.asarray(label, dtype=np.int64))return data, labeldef __get_data(self, filename, shuffle=False):dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config)loader = DataLoader(dataset=dataset,batch_size=self.config.batch_size,shuffle=shuffle,num_workers=2,collate_fn=self.__collate_fn)return loaderdef get_train(self):return self.__get_data('train.json', shuffle=True)def get_dev(self):return self.__get_data('test.json', shuffle=False)def get_test(self):return self.__get_data('test.json', shuffle=False)if __name__ == '__main__':from config import Configconfig = Config()word2id, word_vec = WordEmbeddingLoader(config).load_embedding()rel2id, id2rel, class_num = RelationLoader(config).get_relation()loader = SemEvalDataLoader(rel2id, word2id, config)test_loader = loader.get_train()for step, (data, label) in enumerate(test_loader):print(type(data), data.shape)print(type(label), label.shape)break

NLP-信息抽取-关系抽取-2016:Attention-BiLSTM实体关系分类器【基于双向LSTM及注意力机制的关系分类】【数据集:SemEval-2010 Task 8】相关推荐

  1. NLP中的关系抽取方法归纳

    文章目录 前言 命名实体识别任务 Softmax和CRF 指针网络 span排列 关系分类任务 半监督学习方法 基于远程监督的优化 多示例学习 强化学习 预训练 监督学习方法 联合抽取 共享参数的联合 ...

  2. 必读!信息抽取(Information Extraction)【关系抽取】

    来源: AINLPer 微信公众号(每日给你好看-) 编辑: ShuYini 校稿: ShuYini 时间: 2020-08-11 引言     信息抽取(information extraction ...

  3. ACL 2021 | 基于词依存信息类型映射记忆神经网络的关系抽取

    ©作者 | 陈桂敏 来源 | QTrade AI研究中心 QTrade AI 研究中心是一支将近 30 人的团队,主要研究方向包括:预训练模型.信息抽取.对话机器人.内容推荐等.本文介绍的是一篇信息抽 ...

  4. 实体关系抽取学习笔记

    1 关系抽取概述 1.1 简介 信息抽取旨在从大规模非结构或半结构的自然语言文本中抽取结构化信息.关系抽取是其中的重要子任务之一,主要目的是从文本中识别实体并抽取实体之间的语义关系. 关系抽取对于很多 ...

  5. 论文阅读课4-Long-tail Relation Extraction via Knowledge Graph Embeddings(GCN,关系抽取,2019,远程监督,少样本不平衡,2注意

    文章目录 abstract 1.introduction 2.相关工作 2.1 关系提取 2.2 KG embedding 2.3 GCNN 3. 方法 3.1符号 3.2框架 3.2.1 Insta ...

  6. 知识图谱课程报告-关系抽取文献综述

    关系抽取文献综述 引言: ​ 随着大数据的不断发展,在海量的结构化数据或非结构化数据中更低成本的抽取出有价值的信息越来越重要,可以说信息抽取是自然语言处理领域的一项最基本任务,信息抽取进而可被分成三个 ...

  7. 基于监督学习和远程监督的神经关系抽取

    基于监督学习和远程监督的神经关系抽取 作者:王嘉宁  QQ:851019059  Email:lygwjn@126.com 最新:博主发表在华东师范大学学报(自然科学版)的<基于远程监督的关系抽 ...

  8. 知识图谱最新权威综述论文解读:关系抽取

    上期我们介绍了2020年知识图谱最新权威综述论文<A Survey on Knowledge Graphs: Representation, Acquisition and Applicatio ...

  9. 【论文】Awesome Relation Extraction Paper(关系抽取)(PART III)

    0. 写在前面 回头看了一遍之前的博客,好些介绍的论文主要是属于关系分类的领域,于是就把前几篇的标题给修改了一下哈哈.关系分类和之前的文本分类,基于目标词的情感识别还都挺像的,baseline模型也都 ...

最新文章

  1. 购物搜索引擎架构的变与不变——淘宝网曲琳
  2. 魔改ResNet反超Transformer再掀架构之争!作者说“没一处是创新”,这些优化trick值得学...
  3. C++迭代器失效的几种情况总结
  4. android shell用户界面,shell界面下安装和卸载Android应用程序(apk包)
  5. vue使用element ui实现下拉列表分页的功能!!!
  6. html bootstrap复选框全选,javascript+bootstrap+html实现层级多选框全层全选和多选功能代码实例...
  7. JAVA中数字格式异常,java - Java数字格式异常 - 堆栈内存溢出
  8. nginx负载均衡器处理session共享的几种方法(转)
  9. arm linux vlc移值,vlc-3.0.8在飞凌开发板i.mx6q上移植
  10. java贪吃蛇总结报告_java贪吃蛇开发总结
  11. 芯片AD库转换之贸泽 Library Loader使用
  12. ESP32 系统篇: 优化系统启动时间
  13. 腾讯TIM实现即时通信 v3+ts实践
  14. mc服务器怎么回到床的位置,《我的世界》MC床的功能居然跟这四个指令有关系?很多人不知道!...
  15. 基于JAVA Frame的太阳系行星运转系统
  16. 51单片机仿真——中断系统(2)
  17. 《滚雪球:巴菲特和他的财富人生》【美】艾丽斯·施罗德著
  18. windows10家庭版下找不到gpedit.msc
  19. 【chrome插件】公众号后台,固定侧边栏,自动定位菜单位置。
  20. 独立开发变现周刊(第83期):建在Stripe上的应用,年收入70万美元

热门文章

  1. Python使用正则表达式识别代码中的中文、英文和数字实例演示
  2. 微信和支付宝扫码之后,需要加载各种业务模块:
  3. UserWarning: image file could not be identified because WEBP support not install
  4. 视频、音乐播放器大家都听说过,那么图片播放器呢
  5. 最新汽车娱乐系统测试,你了解吗?
  6. IT人物之《Netty权威指南》中文作者 专访华为李林锋:我与Netty那些不得不说的事
  7. Spring整合AMQ
  8. WIN7使用过360系统急救箱后出现的任务计划程序文件夹删除的办法
  9. 给网站添加一个初音未来代码
  10. 小米潘多拉路由器添加节点_腾讯网游加速器联手小米路由器,共创全场景游戏加速体验!...