• 使用paddlenlp中与训练好的语言模型来进行实体抽取:
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import argparse
import os
from functools import partialimport paddle
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.transformers import AutoTokenizer, AutoModelForTokenClassification
from paddlenlp.metrics import ChunkEvaluatorfrom model import ErnieCrfForTokenClassification
from data import load_dict, load_dataset, parse_decodesparser = argparse.ArgumentParser()# yapf: disable
parser.add_argument("--save_dir", default='./ernie_crf_ckpt', type=str, help="The output directory where the model checkpoints will be written.")
parser.add_argument("--epochs", default=2, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--batch_size", default=200, type=int, help="Batch size per GPU/CPU for training.")  #default=200
parser.add_argument("--device", default="cpu", type=str, choices=["cpu", "gpu"] ,help="The device to select to train the model, is must be cpu/gpu.")
parser.add_argument("--data_dir", default='./waybill/data', type=str, help="The folder where the dataset is located.")args = parser.parse_args()
# yapf: enabledef convert_to_features(example, tokenizer, label_vocab):tokens, labels = exampletokenized_input = tokenizer(tokens,return_length=True,is_split_into_words=True)# Token '[CLS]' and '[SEP]' will get label 'O'labels = ['O'] + labels + ['O']tokenized_input['labels'] = [label_vocab[x] for x in labels]return tokenized_input['input_ids'], tokenized_input['token_type_ids'], tokenized_input['seq_len'], tokenized_input['labels']@paddle.no_grad()
def evaluate(model, metric, data_loader):model.eval()metric.reset()for input_ids, seg_ids, lens, labels in data_loader:preds = model(input_ids, seg_ids, lengths=lens)n_infer, n_label, n_correct = metric.compute(lens, preds, labels)metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy())precision, recall, f1_score = metric.accumulate()print("[EVAL] Precision: %f - Recall: %f - F1: %f" %(precision, recall, f1_score))model.train()@paddle.no_grad()
def predict(model, data_loader, ds, label_vocab):all_preds = []all_lens = []for input_ids, seg_ids, lens, labels in data_loader:preds = model(input_ids, seg_ids, lengths=lens)# Drop CLS predictionpreds = [pred[1:] for pred in preds.numpy()]all_preds.append(preds)all_lens.append(lens)sentences = [example[0] for example in ds.data]results = parse_decodes(sentences, all_preds, all_lens, label_vocab)return resultsif __name__ == '__main__':paddle.set_device(args.device)# Create dataset, tokenizer and dataloader.train_ds, dev_ds, test_ds = load_dataset(datafiles=(os.path.join(args.data_dir, 'train.txt'),os.path.join(args.data_dir, 'dev.txt'),os.path.join(args.data_dir, 'test.txt')))label_vocab = load_dict(os.path.join(args.data_dir, 'tag.dic'))tokenizer = AutoTokenizer.from_pretrained('ernie-3.0-medium-zh')trans_func = partial(convert_to_features,tokenizer=tokenizer,label_vocab=label_vocab)train_ds.map(trans_func)dev_ds.map(trans_func)test_ds.map(trans_func)batchify_fn = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int32'),  # input_idsPad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int32'),  # token_type_idsStack(dtype='int64'),  # seq_lenPad(axis=0, pad_val=label_vocab.get("O", 0), dtype='int64')  # labels): fn(samples)train_loader = paddle.io.DataLoader(dataset=train_ds,batch_size=args.batch_size,return_list=True,collate_fn=batchify_fn)dev_loader = paddle.io.DataLoader(dataset=dev_ds,batch_size=args.batch_size,return_list=True,collate_fn=batchify_fn)test_loader = paddle.io.DataLoader(dataset=test_ds,batch_size=args.batch_size,return_list=True,collate_fn=batchify_fn)# Define the model netword and its lossernie = AutoModelForTokenClassification.from_pretrained("ernie-3.0-medium-zh", num_classes=len(label_vocab))model = ErnieCrfForTokenClassification(ernie)metric = ChunkEvaluator(label_list=label_vocab.keys(), suffix=True)optimizer = paddle.optimizer.AdamW(learning_rate=2e-5,parameters=model.parameters())step = 0for epoch in range(args.epochs):for input_ids, token_type_ids, lengths, labels in train_loader:loss = model(input_ids,token_type_ids,lengths=lengths,labels=labels)avg_loss = paddle.mean(loss)avg_loss.backward()optimizer.step()optimizer.clear_grad()step += 1print("[TRAIN] Epoch:%d - Step:%d - Loss: %f" %(epoch, step, avg_loss))evaluate(model, metric, dev_loader)paddle.save(model.state_dict(),os.path.join(args.save_dir, 'model_%d' % step))preds = predict(model, test_loader, test_ds, label_vocab)file_path = "ernie_crf_results.txt"with open(file_path, "w", encoding="utf8") as fout:fout.write("\n".join(preds))# Print some examplesprint("The results have been saved in the file: %s, some examples are shown below: "% file_path)print("\n".join(preds[:10]))#损失值为0:出现梯度消失。

这里使用的是cpu,可以根据需要,自己修改一下默认的参数即可。
预测的最终结果:

('黑龙江省', 'A1')('双鸭山市', 'A2')('尖山区八马路与东平行路交叉口北40米韦业涛18600009172', 'A3')
('广西壮族自治区桂林市', 'A1')('雁山区', 'A3')('雁山镇西龙村老年活动中心17610348888羊卓卫', 'A4')
('15652864561河南省', 'A1')('开', 'A2')('封市', 'A3')
('河北省', 'A1')('唐山市', 'A2')('玉田县无终大街159号18614253058尚汉生', 'A2')
('台湾', 'A1')('台中市北区北区锦新街18号18511226708', 'A3')('蓟丽', 'A3')
('廖梓琪18514743222', 'A3')('湖北省宜昌市', 'A1')('长阳土家族自治县', 'A3')('贺家坪镇贺家坪村一组临河1号', 'A4')
('江苏省南通市', 'A1')('海门市孝威村孝威路88号18611840623计星仪', 'A3')
('17601674746', 'P')('赵春丽内蒙古自治区', 'A3')('乌兰', 'A3')('察布市', 'A3')
('云南省临沧市', 'A1')('耿马傣族佤族自治县鑫源路法院对面许贞爱18510566685', 'A2')
('四川省', 'A2')('成都市', 'A3')
('湖南省娄底市', 'A1')('娄星区乐坪大道48号潘苹18500039123', 'A3')
('韩泽涛', 'A1')('山东省威海市', 'A1')('文登区', 'A3')('横山路天润小区39号18600274912', 'A4')
('15712917351', 'P')('宗忆珍山东省青岛市', 'A3')
('程云燊河南省', 'A2')('商丘市', 'A2')
('秋庆龙13051428666', 'A1')('台湾', 'A1')('新北市', 'A3')
('江西省', 'A1')('萍乡市', 'A2')('湘东区泉湖北路18511082664焦博', 'A3')
('邵春雨15611120818四川省', 'A3')('德阳市', 'A2')('旌阳区岷江西路550号', 'A3')
('13520006391', 'A1')('海南省', 'A2')
('勾雪睿17610370000', 'A3')('黑龙江省', 'A1')('伊春市', 'A1')('乌', 'A3')('伊岭区学府路', 'A3')
('林绰刈湖南省郴州市', 'A3')
  • 结果中,大部分还是比较正确,但是还是有一些错误的切分情况。例如:云南省临沧市,就没有被很好的区分开,这应该是两个实体才对。
  • 数据集下载地址:
  • URL =数据集

文本关键信息抽取——实体抽取代码实现相关推荐

  1. lstm+crf 信息抽取 实体识别 代码

    目录 可以作为毕业设计 可以用来练手 可以用作论文基础模型 任务描述: 数据集: 运行环境: 数据说明 数据处理 处理数据集

  2. 徐阿衡 | 知识抽取-实体及关系抽取(一)

    本文转载自公众号:徐阿衡. 这一篇是关于知识抽取,整理并补充了上学时的两篇笔记 NLP笔记 - Information Extraction 和 NLP笔记 - Relation Extraction ...

  3. Python爬取百度百科,BeautifulSoup提取关键信息

    本文主要爬取演员杨幂的百度百科,用到的python库有:requests和BeautifulSoup 主要内容共分为以下两个方面: 1. 用requests爬取网页内容 2. 用BeautifulSo ...

  4. PaddleNLP通用信息抽取技术UIE【一】产业应用实例:信息抽取{实体关系抽取、中文分词、精准实体标。情感分析等}、文本纠错、问答系统、闲聊机器人、定制训练

    相关文章: 1.快递单中抽取关键信息[一]----基于BiGRU+CR+预训练的词向量优化 2.快递单信息抽取[二]基于ERNIE1.0至ErnieGram + CRF预训练模型 3.快递单信息抽取[ ...

  5. 【PaddleOCR-kie】关键信息抽取1:使用VI-LayoutXLM模型推理预测(SER+RE)

    背景:在训练自己数据集进行kie之前,想跑一下md里面的例程,但md教程内容混乱,而且同一个内容有多个手册,毕竟是多人合作的项目,可能是为了工程解耦,方便更新考虑--需要运行的模型和运行步骤散落在不用 ...

  6. 在线文本实体抽取能力,助力应用解析海量文本数据

    随着信息化的发展,很多具有重要价值的知识隐藏分布在海量数据中,影响了人们获取知识的效率,如何处理繁杂的非结构化文本数据成为难题. 近日,HMS Core机器学习服务6.5.0版本新增在线文本实体抽取能 ...

  7. 文字表格信息抽取模型介绍——实体抽取方法:NER模型(上)

    导读: 将深度学习技术应用于NER有三个核心优势.首先,NER受益于非线性转换,它生成从输入到输出的非线性映射.与线性模型(如对数线性HMM和线性链CRF)相比,基于DL的模型能够通过非线性激活函数从 ...

  8. 【PaddleNLP-kie】关键信息抽取2:UIE模型做图片信息提取全流程

    文章目录 本文参考 UIE理论部分 step0.UIEX原始模型使用 网页体验 本机安装使用 环境安装 使用docker的环境安装 快速开始 step1.UIEX模型微调(小样本学习) 数据标注(la ...

  9. ####好好好######信息抽取——实体关系联合抽取

    信息抽取--实体关系联合抽取 目录 简介 实体关系联合抽取 Model 1: End-to-End Relation Extraction using LSTMs on Sequences and T ...

最新文章

  1. AD7705 16-bit Delta-Sigma AD 转换器
  2. NABCD项目需求分析
  3. 运行 composer update,提示 Allowed memory size of bytes exhausted
  4. 如何通过标签体系,打造精细化运营?
  5. SQLyog连接Mysql8.0提示 Authentication plugin ‘caching_sha2_password‘ cannot be loaded
  6. IBASE deletion timestamp verification
  7. java.sql.SQLException: Access denied for user ‘root‘@‘hadoop001‘ (using password: YES)
  8. 一次使用BeanPostProcessor疏漏引起的重大bug
  9. Kotlin从入门到放弃(三)——协程
  10. 网络收藏夹--用来收藏我经常访问的网站
  11. 什么是弹性(display: flex)布局 ?
  12. dns和私人dns是什么意思?企业如何预防dns劫持?
  13. js/vue:video 视频播放器
  14. Closures in OOC
  15. 06-SparkSQL
  16. w7电脑蓝屏怎么解决_详解win7电脑蓝屏怎么办
  17. 关于数据清洗的常见方式
  18. 明月镜片官宣品牌代言人刘昊然;Crocs与欧阳娜娜打造全新联名系列 | 美通企业日报...
  19. Proguard 常用规则
  20. 我的CSDN现在没有C币,没办法下载

热门文章

  1. 搜狗与360加紧抢食百度份额
  2. 商店购物 (shopping.c/cpp/pas)
  3. html中 media的作用
  4. 我们前端跟后端是怎么合作的
  5. uos系统中windows格utf8编码文件转GBK
  6. 【NOIP2017】时间复杂度
  7. 【HDU 5945】Fxx and game
  8. 车牌识别中车牌信息以及如何做车牌识别的思路整理
  9. C++语言基础:计算圆的周长与面积
  10. 云顶之弈法机器人_云顶之弈机器人怎么用最好 云顶之弈机器人使用技巧