使用ERNIE在DuReader_robust上进行阅读理解

1. 实验内容

机器阅读理解 (Machine Reading Comprehension) 是指让机器阅读文本,然后回答和阅读内容相关的问题。阅读理解是自然语言处理和人工智能领域的重要前沿课题,对于提升机器的智能水平、使机器具有持续知识获取的能力等具有重要价值,近年来受到学术界和工业界的广泛关注。

阅读理解模型的鲁棒性是衡量该技术能否在实际应用中大规模落地的重要指标之一。随着当前技术的进步,模型虽然能够在一些阅读理解测试集上取得较好的性能,但在实际应用中,这些模型所表现出的鲁棒性仍然难以令人满意。

Dureaderrobust数据集是首个关注阅读理解模型鲁棒性的中文数据集,旨在考察模型在真实应用场景中的过敏感性、过稳定性以及泛化能力等问题。本实验讲基于该数据集进行阅读理解任务。对于一个给定的问题q和一个篇章p,根据篇章内容,给出该问题的答案a。数据集中的每个样本,是一个三元组<q, p, a>,例如:

问题 q: 乔丹打了多少个赛季

篇章 p: 迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役…

参考答案 a: [‘15个’,‘15个赛季’]

2. 实验环境

  • PaddlePaddle 安装

    本项目依赖于 PaddlePaddle 2.0 及以上版本,请参考 安装指南 进行安装

  • PaddleNLP 安装

    pip install --upgrade paddlenlp -i https://pypi.org/simple
    
  • 环境依赖

    Python的版本要求 3.6+

3. 实验设计

本实验将按照这样几个阶段进行:数据处理–>模型结构–>训练配置–>模型训练和评估

3.1 数据处理

3.1.1 加载Dureaderrobust数据集

Dureaderrobust是paddleNLP内置数据集,因此可以通过PaddleNLP提供的load_datasetAPI,即可一键完成数据集加载。

import paddle
from paddlenlp.data import Stack, Dict, Pad
import paddlenlp
from paddlenlp.datasets import load_dataset
from utils import prepare_train_features, prepare_validation_features
from functools import partialtrain_ds, dev_ds, test_ds = load_dataset('dureader_robust', splits=('train', 'dev', 'test'))for idx in range(2):print(train_ds[idx]['question'])print(train_ds[idx]['context'])print(train_ds[idx]['answers'])print(train_ds[idx]['answer_starts'])print()
100%|██████████| 20038/20038 [00:00<00:00, 60151.14it/s]仙剑奇侠传3第几集上天界
第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。
['第35集']
[0]燃气热水器哪个牌子好
选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。
['方太']
[110]

3.1.2 将数据转换成特征形式

DuReaderrubust数据集采用SQuAD数据格式,InputFeature使用滑动窗口的方法生成,即一个example可能对应多个InputFeature。

由于文章加问题的文本长度可能大于max_seq_length,答案出现的位置有可能出现在文章最后,所以不能简单的对文章进行截断。

那么对于过长的文章,则采用滑动窗口将文章分成多段,分别与问题组合。再用对应的tokenizer转化为模型可接受的feature。doc_stride参数就是每次滑动的距离。滑动窗口生成InputFeature的过程如下图:

本实验中,我们使用的预训练模型是ERNIE,ERNIE对中文数据的处理是以字为单位。PaddleNLP对于各种预训练模型已经内置了相应的tokenizer,指定想要使用的模型名字即可加载对应的tokenizer。

tokenizer的作用是将原始输入文本转化成模型可以接受的输入数据形式,但这里需要注意一下,使用tokenizer处理数据之前,需要将数据转换为id形式。

# 设置模型名称
MODEL_NAME = 'ernie-1.0'
tokenizer = paddlenlp.transformers.ErnieTokenizer.from_pretrained(MODEL_NAME)max_seq_length = 512
doc_stride = 128train_trans_func = partial(prepare_train_features, max_seq_length=max_seq_length, doc_stride=doc_stride,tokenizer=tokenizer)train_ds.map(train_trans_func, batched=True, num_workers=4)dev_trans_func = partial(prepare_validation_features, max_seq_length=max_seq_length, doc_stride=doc_stride,tokenizer=tokenizer)dev_ds.map(dev_trans_func, batched=True, num_workers=4)
test_ds.map(dev_trans_func, batched=True, num_workers=4)
[2021-06-09 19:35:51,201] [    INFO] - Downloading vocab.txt from https://paddlenlp.bj.bcebos.com/models/transformers/ernie/vocab.txt
100%|██████████| 90/90 [00:00<00:00, 3146.30it/s]<paddlenlp.datasets.dataset.MapDataset at 0x7f3d298486d0>

经过以上数据转换之后,从以上结果可以看出,数据集中的example已经被转换成了模型可以接收的feature,包括input_ids、token_type_ids、答案的起始位置等信息。
其中:

  • input_ids: 表示输入文本的token ID。
  • token_type_ids: 表示对应的token属于输入的问题还是答案。(Transformer类预训练模型支持单句以及句对输入)。
  • overflow_to_sample: feature对应的example的编号。
  • offset_mapping: 每个token的起始字符和结束字符在原文中对应的index(用于生成答案文本)。
  • start_positions: 答案在这个feature中的开始位置。
  • end_positions: 答案在这个feature中的结束位置。

数据处理的详细过程请参见utils.py,下边打印一些结果进行展示。

for idx in range(2):print(train_ds[idx]['input_ids'])print(train_ds[idx]['token_type_ids'])print(train_ds[idx]['overflow_to_sample'])print(train_ds[idx]['offset_mapping'])print(train_ds[idx]['start_positions'])print(train_ds[idx]['end_positions'])print()
[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 12043, 37, 10, 561, 125, 43, 8, 445, 86, 576, 65, 1448, 2969, 4, 469, 1586, 118, 776, 5, 1993, 2602, 4, 108, 25, 179, 51, 1993, 2602, 498, 1052, 122, 12043, 1082, 1994, 1616, 11, 262, 4, 518, 171, 813, 109, 1084, 270, 12043, 539, 8, 3006, 580, 11, 31, 4, 2473, 306, 34, 87, 889, 280, 846, 573, 12043, 561, 125, 14, 539, 889, 810, 276, 182, 4, 67, 351, 14, 889, 1182, 118, 776, 156, 952, 4, 539, 889, 16, 38, 4, 445, 15, 200, 61, 12043, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
0
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 261), (261, 262), (262, 263), (263, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 277), (277, 278), (278, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (325, 326), (326, 327), (327, 328), (328, 329), (329, 330), (330, 331), (331, 332), (0, 0)]
14
16[1, 1404, 266, 506, 101, 361, 1256, 27, 664, 85, 170, 2, 352, 790, 1404, 266, 506, 101, 361, 36, 4, 7, 91, 41, 129, 490, 47, 553, 27, 358, 281, 74, 208, 6, 39, 101, 862, 91, 92, 41, 170, 4, 16, 52, 39, 87, 1745, 506, 1745, 888, 5, 87, 528, 249, 6, 532, 537, 45, 302, 94, 91, 5, 413, 323, 101, 565, 284, 6, 868, 25, 41, 826, 52, 6, 58, 518, 397, 6, 204, 62, 92, 41, 170, 4, 41, 371, 9, 204, 62, 337, 1023, 371, 521, 99, 191, 28, 1404, 266, 506, 101, 361, 100, 664, 539, 65, 4, 817, 1042, 36, 201, 413, 65, 120, 51, 277, 14, 2081, 541, 1190, 348, 12043, 58, 512, 508, 17, 57, 445, 5, 1512, 73, 1664, 565, 506, 101, 361, 11, 175, 29, 82, 412, 58, 76, 388, 15, 62, 76, 658, 222, 74, 701, 1866, 537, 506, 4, 48, 532, 537, 71, 109, 1123, 1600, 469, 220, 12048, 101, 565, 303, 876, 862, 91, 4, 16, 32, 39, 87, 1745, 506, 1745, 888, 5, 87, 528, 4, 145, 124, 93, 101, 150, 3466, 231, 164, 133, 174, 39, 101, 565, 130, 326, 524, 586, 108, 11, 17963, 42, 17963, 4, 48, 596, 581, 50, 155, 707, 1358, 1443, 345, 1455, 1411, 1123, 455, 413, 323, 12048, 483, 366, 4850, 14, 6215, 9488, 653, 266, 82, 337, 1023, 371, 521, 263, 204, 62, 78, 99, 191, 28, 7, 689, 65, 13, 4850, 269, 266, 82, 337, 1023, 77, 12043, 770, 137, 4, 47, 699, 506, 101, 361, 201, 9, 826, 52, 4177, 756, 387, 369, 52, 4, 297, 413, 86, 763, 27, 247, 98, 3887, 444, 48, 29, 247, 98, 629, 163, 868, 25, 506, 101, 361, 4, 79, 87, 326, 378, 290, 377, 101, 565, 4, 596, 581, 50, 8, 65, 314, 73, 5, 1123, 1600, 413, 323, 12043, 153, 187, 58, 512, 5, 1512, 73, 1664, 565, 135, 517, 57, 41, 5, 10, 385, 120, 1512, 73, 369, 52, 4, 48, 22, 9, 344, 813, 912, 101, 12, 5, 754, 2337, 6, 754, 2880, 43, 702, 96, 792, 207, 4, 510, 735, 541, 1101, 1989, 21, 4, 175, 2873, 1600, 101, 207, 263, 1308, 1158, 4, 84, 195, 175, 29, 1512, 73, 101, 2873, 1600, 263, 217, 37, 262, 82, 691, 736, 12043, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
1
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 193), (193, 194), (194, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 217), (217, 218), (218, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 241), (241, 242), (242, 243), (243, 244), (244, 245), (245, 246), (246, 247), (247, 248), (248, 249), (249, 250), (250, 251), (251, 252), (252, 253), (253, 254), (254, 255), (255, 256), (256, 257), (257, 258), (258, 259), (259, 260), (260, 264), (264, 265), (265, 266), (266, 267), (267, 268), (268, 269), (269, 270), (270, 271), (271, 272), (272, 273), (273, 274), (274, 275), (275, 276), (276, 279), (279, 280), (280, 281), (281, 282), (282, 283), (283, 284), (284, 285), (285, 286), (286, 287), (287, 288), (288, 289), (289, 290), (290, 291), (291, 292), (292, 293), (293, 294), (294, 295), (295, 296), (296, 297), (297, 298), (298, 299), (299, 300), (300, 301), (301, 302), (302, 303), (303, 304), (304, 305), (305, 306), (306, 307), (307, 308), (308, 309), (309, 310), (310, 311), (311, 312), (312, 313), (313, 314), (314, 315), (315, 316), (316, 317), (317, 318), (318, 319), (319, 320), (320, 321), (321, 322), (322, 323), (323, 324), (324, 325), (325, 326), (326, 327), (327, 328), (328, 329), (329, 330), (330, 331), (331, 332), (332, 333), (333, 334), (334, 335), (335, 336), (336, 337), (337, 338), (338, 339), (339, 340), (340, 341), (341, 342), (342, 343), (343, 344), (344, 345), (345, 346), (346, 347), (347, 348), (348, 349), (349, 350), (350, 351), (351, 352), (352, 353), (353, 354), (354, 355), (355, 356), (356, 357), (357, 358), (358, 359), (359, 360), (360, 361), (361, 362), (362, 363), (363, 364), (364, 365), (365, 366), (366, 367), (367, 368), (368, 369), (369, 370), (370, 371), (371, 372), (372, 373), (373, 374), (374, 375), (375, 376), (376, 377), (377, 378), (378, 379), (379, 380), (380, 381), (381, 382), (382, 383), (383, 384), (384, 385), (385, 386), (386, 387), (387, 388), (388, 389), (0, 0)]
121
122

3.1.3 Batchify和数据读入

使用paddle.io.BatchSamplerpaddlenlp.data中提供的方法把数据组成batch。

然后使用paddle.io.DataLoader接口多线程异步加载数据。

batchify_fn详解:

batch_size = 12# 定义BatchSampler
train_batch_sampler = paddle.io.DistributedBatchSampler(train_ds, batch_size=batch_size, shuffle=True)dev_batch_sampler = paddle.io.BatchSampler(dev_ds, batch_size=batch_size, shuffle=False)test_batch_sampler = paddle.io.BatchSampler(test_ds, batch_size=batch_size, shuffle=False)# 定义batchify_fn
train_batchify_fn = lambda samples, fn=Dict({"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),"start_positions": Stack(dtype="int64"),"end_positions": Stack(dtype="int64")
}): fn(samples)dev_batchify_fn = lambda samples, fn=Dict({"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)# 构造DataLoader
train_data_loader = paddle.io.DataLoader(dataset=train_ds,batch_sampler=train_batch_sampler,collate_fn=train_batchify_fn,return_list=True)dev_data_loader = paddle.io.DataLoader(dataset=dev_ds,batch_sampler=dev_batch_sampler,collate_fn=dev_batchify_fn,return_list=True)test_data_loader = paddle.io.DataLoader(dataset=test_ds,batch_sampler=test_batch_sampler,collate_fn=dev_batchify_fn,return_list=True)for step, batch in enumerate(train_data_loader, start=1):input_ids, segment_ids, start_positions, end_positions = batchprint(input_ids)break
Tensor(shape=[12, 512], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,[[1   , 585 , 8   , ..., 0   , 0   , 0   ],[1   , 161 , 1538, ..., 0   , 0   , 0   ],[1   , 1138, 1966, ..., 0   , 0   , 0   ],...,[1   , 208 , 1713, ..., 0   , 0   , 0   ],[1   , 323 , 57  , ..., 0   , 0   , 0   ],[1   , 1526, 73  , ..., 73  , 21  , 2   ]])

3.2 模型结构

本节将开始介绍如何使用ERNIE Fine-tune DuReaderrobust阅读理解任务。DuReaderrobust阅读理解任务的本质是答案抽取任务。根据输入的问题和文章,从预训练模型的sequence_output中预测答案在文章中的起始位置和结束位置。原理如下图所示:

目前PaddleNLP已经内置了包括ERNIE在内的多种基于预训练模型的常用任务的下游网络,包括机器阅读理解。这些网络在paddlenlp.transformers下,均可实现一键调用。相应代码如下:

from paddlenlp.transformers import ErnieForQuestionAnsweringmodel = ErnieForQuestionAnswering.from_pretrained(MODEL_NAME)
[2021-06-09 19:40:13,342] [    INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams and saved to /home/aistudio/.paddlenlp/models/ernie-1.0
[2021-06-09 19:40:13,344] [    INFO] - Downloading ernie_v1_chn_base.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams
100%|██████████| 392507/392507 [00:08<00:00, 44543.90it/s]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1297: UserWarning: Skip loading for classifier.weight. classifier.weight is not found in the provided dict.warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1297: UserWarning: Skip loading for classifier.bias. classifier.bias is not found in the provided dict.warnings.warn(("Skip loading for {}. ".format(key) + str(err)))

3.3 训练配置

3.3.1 损失函数设计

模型的网络结构确定后我们就可以设计loss function了。ErineForQuestionAnswering模型对将ErnieModel的sequence_output拆开成start_logits和end_logits输出,所以DuReaderrobust的loss由start_loss和end_loss两部分组成,我们需要自己定义loss function。

对于答案起始位置和结束位置的预测可以分别看成两个分类任务。所以设计的loss function如下:

class CrossEntropyLossForRobust(paddle.nn.Layer):def __init__(self):super(CrossEntropyLossForRobust, self).__init__()def forward(self, y, label):start_logits, end_logits = y   # both shape are [batch_size, seq_len]start_position, end_position = labelstart_position = paddle.unsqueeze(start_position, axis=-1)end_position = paddle.unsqueeze(end_position, axis=-1)start_loss = paddle.nn.functional.softmax_with_cross_entropy(logits=start_logits, label=start_position, soft_label=False)start_loss = paddle.mean(start_loss)end_loss = paddle.nn.functional.softmax_with_cross_entropy(logits=end_logits, label=end_position, soft_label=False)end_loss = paddle.mean(end_loss)loss = (start_loss + end_loss) / 2return loss

3.3.2 参数和模型设定

本节将进行一些超参数,计算资源,优化器等的设定,用来训练模型。

# 训练过程中的最大学习率
learning_rate = 3e-5 # 训练轮次
epochs = 2# 学习率预热比例
warmup_proportion = 0.1# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01num_training_steps = len(train_data_loader) * epochs# 学习率衰减策略
lr_scheduler = paddlenlp.transformers.LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)decay_params = [p.name for n, p in model.named_parameters()if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,parameters=model.parameters(),weight_decay=weight_decay,apply_decay_param_fun=lambda x: x in decay_params)

3.4 模型训练与评估

模型训练的过程通常有以下步骤:

  1. 从dataloader中取出一个batch data。
  2. 将batch data喂给model,做前向计算。
  3. 将前向计算结果传给损失函数,计算loss。
  4. loss反向回传,更新梯度。重复以上步骤。

每训练一个epoch时,程序通过evaluate()调用paddlenlp.metric.squad中的squad_evaluate(), compute_predictions()评估当前模型训练的效果,其中:

  • compute_predictions()用于生成可提交的答案;

  • squad_evaluate()用于返回评价指标。

二者适用于所有符合squad数据格式的答案抽取任务。这类任务使用F1和exact来评估预测的答案和真实答案的相似程度。

from utils import evaluatecriterion = CrossEntropyLossForRobust()
global_step = 0
for epoch in range(1, epochs + 1):for step, batch in enumerate(train_data_loader, start=1):global_step += 1input_ids, segment_ids, start_positions, end_positions = batchlogits = model(input_ids=input_ids, token_type_ids=segment_ids)loss = criterion(logits, (start_positions, end_positions))if global_step % 100 == 0 :print("global step %d, epoch: %d, batch: %d, loss: %.5f" % (global_step, epoch, step, loss))loss.backward()optimizer.step()lr_scheduler.step()optimizer.clear_grad()evaluate(model=model, data_loader=dev_data_loader)

NLP领域中的ERNIE模型在阅读理解中的应用

【实践】NLP领域中的ERNIE模型在阅读理解中的应用相关推荐

  1. 机器学习中训练的模型,通俗理解

    概率统计(建模.学习) 很多新手在初学机器学习/深度学习中,会产生这样的疑问?为什么要训练模型,模型是什么,如何训练- 本人刚开始接触时也产生过类似地疑问,现在为大家排解这些疑问. 1.机器学习中大概 ...

  2. 机器阅读理解中文章和问题的深度学习表示方法

    /*版权声明:可以任意转载,转载时请标明文章原始出处和作者信息.*/ author: 张俊林 注:本文是<深度学习解决机器阅读理解任务的研究进展>节选,该文将于近期在"深度学习大 ...

  3. 论文浅尝 | 机器阅读理解中常识知识的显式利用

    论文笔记整理:吴林娟,天津大学硕士,自然语言处理方向. 链接:https://arxiv.org/pdf/1809.03449.pdf 动机 机器阅读理解(MRC)和人类进行阅读理解之间还存在差距,作 ...

  4. 【自然语言处理(NLP)】基于SQuAD的机器阅读理解

    [自然语言处理(NLP)]基于SQuAD的机器阅读理解 作者简介:在校大学生一枚,华为云享专家,阿里云专家博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专 ...

  5. 英语阅读理解中表示作者态度的词汇汇总

    英语阅读理解中表示作者态度的词汇汇总 来源:    恩波论坛      一 赞同 positive adj.肯定的, 实际的, 积极的, , 确实的 favorable adj.赞成的, 有利的, 赞 ...

  6. 机器学习中激活函数和模型_探索机器学习中的激活和丢失功能

    机器学习中激活函数和模型 In this post, we're going to discuss the most widely-used activation and loss functions ...

  7. 【Unity3D】使用 FBX 格式的外部模型 ( 向 Unity 中添加 FBX 模型 | 向 Scene 场景中添加 FBX 模型 | 3D 物体渲染 | 3D 物体材质设置 )

    文章目录 一.向 Unity 中添加 FBX 模型 二.向 Scene 场景中添加 FBX 模型 三.3D 物体渲染 四.3D 物体材质设置 一.向 Unity 中添加 FBX 模型 Unity 中使 ...

  8. 【NLP】完全解析!Bert Transformer 阅读理解源码详解

    接上一篇: 你所不知道的 Transformer! 超详细的 Bert 文本分类源码解读 | 附源码 中文情感分类单标签 参考论文: https://arxiv.org/abs/1706.03762 ...

  9. 【论文解读】NER任务中的MRC(机器阅读理解)

    论文:https://arxiv.org/pdf/1910.11476v6.pdf 前沿: 在之前的NER任务中常常分为两种:nested NER和 flat NER.从直观的角度来看,nested N ...

最新文章

  1. 用 Chiron 运行 IronPython 编写的 Silverlight 程序
  2. 命令行下操作常用语句
  3. Android测试中被测应用挂了怎么办?
  4. Oracle下的Databse,Instance,Schemas
  5. 使用Python调用ASA rest API进行配置
  6. keil5安装教程简单易上手
  7. 福昕阅读器 自定义注释快捷键
  8. java根据地址解析省市区信息
  9. python面向对象练习题
  10. Java面向对象知识点总结
  11. 公安大数据平台应用与公安大数据建模
  12. 关机一直显示正在关闭服务器,电脑关机后,显示正在关机,但等半天也关不了 怎么办...
  13. SparkLink星闪技术之SLB概述
  14. 关于Flash的几点思考(Thoughts on Flash)
  15. 苹果自助维修服务上线:维修工具租赁价约为321元
  16. openstack框架搭建云计算平台和各组件运维内容包括mysql、keyston、Glance、Nova、Neutron、Dashboard、Heat、Trove、Ceilometer运维
  17. lego-loam学习笔记(一)
  18. 常见电脑故障维修-硬盘篇
  19. 朴素贝叶斯+语言模型
  20. SSM项目-OA报销单管理系统(三)

热门文章

  1. 服务器修改拔刀剑修改数,关于拔刀剑的(求助大佬!)
  2. 小白学Python之爬虫篇(二)——隐式资源链接查找与爬取
  3. [\u4e00-\u9fa5 ]+是什么
  4. 路飞学城-python爬虫密训-第二章
  5. Gin 框架学习笔记(01)— 自定义结构体绑定表单、绑定URI、自定义log、自定义中间件、路由组、解析查询字符串、上传文件、使用HTTP方法
  6. 解决使用FireFox下Flash上传文件时SESSION丢失的问题(swfupload)
  7. 如何开通国际域名网站
  8. win7关闭UAC、安装Telnet、打开远程登陆
  9. 百度文库中在线阅读器的基本思想
  10. SSRF漏洞的攻击与防御