文章目录

  • 本文内容
  • 环境配置
  • 全局变量
  • 模型构建
  • 损失函数
  • 模型训练
    • 构造Dataset
    • 构造Dataloader
    • 训练
  • 模型评估
  • 模型使用
  • 参考文献

代码地址 :https://github.com/iioSnail/MDCSpell_pytorch


本文内容

本文为MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文的Pytorch实现。

论文地址: https://aclanthology.org/2022.findings-acl.98/

论文年份:2022

论文笔记:https://blog.csdn.net/zhaohongfei_358/article/details/126973451

论文大致内容:作者基于Transformer和BERT设计了一个多任务的网络来进行CSC(Chinese Spell Checking)任务(中文拼写纠错)。多任务分别是找出哪个字是错的和对错字进行纠正。

由于作者并没有公开代码,所以我就尝试自己实现一个,最终我的实验结果如下表:

Dataset Model D_Precision D_Recall D_F1 C_Prec C_Rec C_F1
SIGHAN 13 MDCSpell 89.1 78.3 83.4 87.5 76.8 81.8
SIGHAN 13 MDCSpell(复现) 80.2 79.9 80.0 77.2 76.9 77.1
SIGHAN 14 MDCSpell 70.2 68.8 69.5 69.0 67.7 68.3
SIGHAN 14 MDCSpell(复现) 82.8 66.6 73.8 79.9 64.3 71.2
SIGHAN 15 MDCSpell 80.8 80.6 80.7 78.4 78.2 78.3
SIGHAN 15 MDCSpell(复现) 86.7 76.1 81.1 72.5 82.7 77.3

这里是我训练了2个epoch的结果,与作者的结论相差不大。如果我增加训练次数的话,也许可以和作者的结果达到一致。

补充:这里有问题,论文使用的是Sentence-level,而我的是Character-level,所以我并没有复现出作者的效果。待后续有时间再尝试一下。

环境配置

try:import transformers
except:!pip install transformers
import os
import copy
import pickleimport torch
import transformersfrom torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
torch.__version__
'1.12.1+cu113'
transformers.__version__
'4.21.3'

全局变量

# 句子的长度,作者并没有说明。我这里就按经验取一个
max_length = 128
# 作者使用的batch_size
batch_size = 32
# epoch数,作者并没有具体说明,按经验取一个
epochs = 10# 每${log_after_step}步,打印一次日志
log_after_step = 20# 模型存放的位置。
model_path = './drive/MyDrive/models/MDCSpell/'
os.makedirs(model_path, exist_ok=True)
model_path = model_path + 'MDCSpell-model.pt'device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
Device: cuda

模型构建


Correction Network 的数据流向如下:

1.将token序列 [CLS] 遇 到 逆 竟 [SEP] 送给Word Embedding模块进行embeddings,得到向量 { e C L S w , e 1 w , e 2 w , e 3 w , e 4 w , e S E P w } \{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\} {eCLSw,e1w,e2w,e3w,e4w,eSEPw}

个人认为此时的embedding仅仅是Word Embeding,并不包含Position Embedding和Segment Embedding。

2.之后将 { e C L S w , e 1 w , e 2 w , e 3 w , e 4 w , e S E P w } \{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\} {eCLSw,e1w,e2w,e3w,e4w,eSEPw}向量送入BERT,增加Position Embedding和Segment Embedding,得到 { e C , e 1 , e 2 , e 3 , e 4 , e S } \{e_C, e_1, e_2, e_3, e_4, e_S\} {eC,e1,e2,e3,e4,eS}

3.在BERT内部,会经历多层的TransformerEncoder,最终的得到输出向量 H c = { h C c , h 1 c , h 2 c , h 3 c , h 4 c , h S c } H^c=\{h_C^c, h_1^c, h_2^c, h_3^c, h_4^c, h_S^c\} Hc={hCc,h1c,h2c,h3c,h4c,hSc}.

4.将BERT的输出 H c H^c Hc 和 隔壁Detection Network输出的 H d H^d Hd 进行融合,得到 H = H d + H c H = H^d+H^c H=Hd+Hc

融合时并不对[CLS][SEP]进行融合

5.将 H H H送给全连接层(Dense Layer)做最后的预测。

Correction Network模型细节

  1. BERT:作者使用的是具有12层Transformer Block的BERT-base版。
  2. Dense Layer:Dense Layer的输入通道为词向量维度,输出通道为词典大小。例如:词向量维度为768,词典大小为20000,则Dense Layer则为nn.Linear(768, 20000)
  3. Dense Layer的初始化:Dense Layer的权重使用的是Word Embedding的参数。因为word Embedding是将词index转成词向量,所以其参数刚好是Dense Layer的转置,即Word Embedding是nn.Linear(20000, 768),所以作者就是用Word Embedding的转置来初始化Dense Layer的参数。因为这样可以加速训练,且使模型变的稳定。

Detection Network的数据流向如下:

1.输入为使用BERT得到的word Embedding { e 1 w , e 2 w , e 3 w , e 4 w } \{e_1^w, e_2^w, e_3^w, e_4^w\} {e1w,e2w,e3w,e4w}。虽然图里并不包含[CLS][SEP]的词向量,但个人认为不需要对其特殊处理,因为最后的预测也用不到这两个token.

2.将 { e 1 w , e 2 w , e 3 w , e 4 w } \{e_1^w, e_2^w, e_3^w, e_4^w\} {e1w,e2w,e3w,e4w}增加Position Embedding信息,得到 { e 1 ′ , e 2 ′ , e 3 ′ , e 4 ′ } \{e_1', e_2', e_3', e_4'\} {e1,e2,e3,e4}

在论文中说Detection Network使用的是向量 { e 1 , e 2 , e 3 , e 4 } \{e_1, e_2, e_3, e_4\} {e1,e2,e3,e4},其是word embedding+position embedding+segment embedding。这与图上是矛盾的,这里以图为准了。

3.将 { e 1 ′ , e 2 ′ , e 3 ′ , e 4 ′ } \{e_1', e_2', e_3', e_4'\} {e1,e2,e3,e4}向量送入Transformer Block,得到输出向量 H d = { h 1 d , h 2 d , h 3 d , h 4 d } H^d=\{h_1^d, h_2^d, h_3^d, h_4^d\} Hd={h1d,h2d,h3d,h4d}

4.一方面,将输出向量 H d H^d Hd送给隔壁的Correction Network进行融合;另一方面,将 H d H^d Hd送给后续的全连接层(Dense Layer)来判断哪个token是错误的.

Detection Network的细节:

  1. Transformer Block:Transformer Block是2层的TransformerEncoder。
  2. Transformer Block参数初始化:Transformer Block参数初始化使用的是BERT的权重。
  3. Dense Layer:Dense Layer的输入通道为词向量大小,输出通道为1。使用Sigmoid来判别该token为错字的概率。
class CorrectionNetwork(nn.Module):def __init__(self):super(CorrectionNetwork, self).__init__()# BERT分词器,作者并没提到自己使用的是哪个中文版的bert,我这里就使用一个比较常用的self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")# BERTself.bert = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")# BERT的word embedding,本质就是个nn.Embeddingself.word_embedding_table = self.bert.get_input_embeddings()# 预测层。hidden_size是词向量的大小,len(self.tokenizer)是词典大小self.dense_layer = nn.Linear(self.bert.config.hidden_size, len(self.tokenizer))def forward(self, inputs, word_embeddings, detect_hidden_states):"""Correction Network的前向传递:param inputs: inputs为tokenizer对中文文本的分词结果,里面包含了token对一个的index,attention_mask等:param word_embeddings: 使用BERT的word_embedding对token进行embedding后的结果:param detect_hidden_states: Detection Network输出hidden state:return: Correction Network对个token的预测结果。"""# 1. 使用bert进行前向传递bert_outputs = self.bert(token_type_ids=inputs['token_type_ids'],attention_mask=inputs['attention_mask'],inputs_embeds=word_embeddings)# 2. 将bert的hidden_state和Detection Network的hidden state进行融合。hidden_states = bert_outputs['last_hidden_state'] + detect_hidden_states# 3. 最终使用全连接层进行token预测return self.dense_layer(hidden_states)def get_inputs_and_word_embeddings(self, sequences, max_length=128):"""对中文序列进行分词和word embeddings处理:param sequences: 中文文本序列。例如: ["鸡你太美", "哎呦,你干嘛!"]:param max_length: 文本的最大长度,不足则进行填充,超出进行裁剪。:return: tokenizer的输出和word embeddings."""inputs = self.tokenizer(sequences, padding='max_length', max_length=max_length, return_tensors='pt',truncation=True).to(device)# 使用BERT的work embeddings对token进行embedding,这里得到的embedding并不包含position embedding和segment embeddingword_embeddings = self.word_embedding_table(inputs['input_ids'])return inputs, word_embeddings
class DetectionNetwork(nn.Module):def __init__(self, position_embeddings, transformer_blocks, hidden_size):""":param position_embeddings: bert的position_embeddings,本质是一个nn.Embedding:param transformer: BERT的前两层transformer_block,其是一个ModuleList对象"""super(DetectionNetwork, self).__init__()self.position_embeddings = position_embeddingsself.transformer_blocks = transformer_blocks# 定义最后的预测层,预测哪个token是错误的self.dense_layer = nn.Sequential(nn.Linear(hidden_size, 1),nn.Sigmoid())def forward(self, word_embeddings):# 获取token序列的长度,这里为128sequence_length = word_embeddings.size(1)# 生成position embeddingposition_embeddings = self.position_embeddings(torch.LongTensor(range(sequence_length)).to(device))# 融合work_embedding和position_embeddingx = word_embeddings + position_embeddings# 将x一层一层的使用transformer encoder进行向后传递for transformer_layer in self.transformer_blocks:x = transformer_layer(x)[0]# 最终返回Detection Network输出的hidden states和预测结果hidden_states = xreturn hidden_states, self.dense_layer(hidden_states)
class MDCSpellModel(nn.Module):def __init__(self):super(MDCSpellModel, self).__init__()# 构造Correction Networkself.correction_network = CorrectionNetwork()self._init_correction_dense_layer()# 构造Detection Network# position embedding使用BERT的position_embeddings = self.correction_network.bert.embeddings.position_embeddings# 作者在论文中提到的,Detection Network的Transformer使用BERT的权重# 所以我这里直接克隆BERT的前两层Transformer来完成这个动作transformer = copy.deepcopy(self.correction_network.bert.encoder.layer[:2])# 提取BERT的词向量大小hidden_size = self.correction_network.bert.config.hidden_size# 构造Detection Networkself.detection_network = DetectionNetwork(position_embeddings, transformer, hidden_size)def forward(self, sequences, max_length=128):# 先获取word embedding,Correction Network和Detection Network都要用inputs, word_embeddings = self.correction_network.get_inputs_and_word_embeddings(sequences, max_length)# Detection Network进行前向传递,获取输出的Hidden State和预测结果hidden_states, detection_outputs = self.detection_network(word_embeddings)# Correction Network进行前向传递,获取其预测结果correction_outputs = self.correction_network(inputs, word_embeddings, hidden_states)# 返回Correction Network 和 Detection Network 的预测结果。# 在计算损失时`[PAD]`token不需要参与计算,所以这里将`[PAD]`部分全都变为0return correction_outputs, detection_outputs.squeeze(2) * inputs['attention_mask']def _init_correction_dense_layer(self):"""原论文中提到,使用Word Embedding的weight来对Correction Network进行初始化"""self.correction_network.dense_layer.weight.data = self.correction_network.word_embedding_table.weight.data

定义好模型后,我们来简单的尝试一下:

model = MDCSpellModel().to(device)
correction_outputs, detection_outputs = model(["鸡你太美", "哎呦,你干嘛!"])
print("correction_outputs shape:", correction_outputs.size())
print("detection_outputs shape:", detection_outputs.size())
correction_outputs shape: torch.Size([2, 128, 21128])
detection_outputs shape: torch.Size([2, 128])

损失函数

Correction Network和Detection Network使用的都是Cross Entropy。之后进行相加即可:

L = λ L c + ( 1 − λ ) L d L = \lambda L^c + (1-\lambda) L^d L=λLc+(1λ)Ld

其中 λ ∈ [ 0 , 1 ] \lambda \in [0,1] λ[0,1] 。作者通过实验得出 λ = 0.85 \lambda=0.85 λ=0.85 时效果最好。

class MDCSpellLoss(nn.Module):def __init__(self, coefficient=0.85):super(MDCSpellLoss, self).__init__()# 定义Correction Network的Loss函数self.correction_criterion = nn.CrossEntropyLoss(ignore_index=0)# 定义Detection Network的Loss函数,因为是二分类,所以用Binary Cross Entropyself.detection_criterion = nn.BCELoss()# 权重系数self.coefficient = coefficientdef forward(self, correction_outputs, correction_targets, detection_outputs, detection_targets):""":param correction_outputs: Correction Network的输出,Shape为(batch_size, sequence_length, hidden_size):param correction_targets: Correction Network的标签,Shape为(batch_size, sequence_length):param detection_outputs: Detection Network的输出,Shape为(batch_size, sequence_length):param detection_targets: Detection Network的标签,Shape为(batch_size, sequence_length):return:"""# 计算Correction Network的loss,因为Shape维度为3,所以要把batch_size和sequence_length进行合并才能计算correction_loss = self.correction_criterion(correction_outputs.view(-1, correction_outputs.size(2)),correction_targets.view(-1))# 计算Detection Network的lossdetection_loss = self.detection_criterion(detection_outputs, detection_targets)# 对两个loss进行加权平均return self.coefficient * correction_loss + (1 - self.coefficient) * detection_loss

模型训练

作者的训练方式:

  1. 第一步,首先使用 Wang271K(自己造的假数据) 数据集进行训练。batch size为32, learning rate为2e-5

  2. 第二步,使用SIGHAN训练集进行fine-tune。 batch size为32,learning rate为1e-5

作者并没有提到使用的是什么Optimizer,但看这个学习率,应该是Adam。

在第一步,作者说的是使用了几乎3M个,但作者只提到过Wang271K这个数据集,我猜可能作者看错了,这个是0.3M条数据,而不是3M。

作者首先使用了Wang271K数据集进行对模型进行训练,然后又使用SIGHAN训练集对模型进行fine-tune。这里我就不进行fine-tune了,直接进行训练。我这里使用的是 ReaLiSe论文 处理好的数据集,其就是Wang271K和SIGHAN。

百度网盘链接 :https://pan.baidu.com/s/1x67LPiYAjLKhO1_2CI6aOA?pwd=skda

下载好直接解压即可。

构造Dataset

class CSCDataset(Dataset):def __init__(self):super(CSCDataset, self).__init__()with open("data/trainall.times2.pkl", mode='br') as f:train_data = pickle.load(f)self.train_data = train_datadef __getitem__(self, index):src = self.train_data[index]['src']tgt = self.train_data[index]['tgt']return src, tgtdef __len__(self):return len(self.train_data)
train_data = CSCDataset()
train_data.__getitem__(0)
('纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一豪元以上。','纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一美元以上。')

构造Dataloader

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
def collate_fn(batch):src, tgt = zip(*batch)src, tgt = list(src), list(tgt)src_tokens = tokenizer(src, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']tgt_tokens = tokenizer(tgt, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']correction_targets = tgt_tokensdetection_targets = (src_tokens != tgt_tokens).float()return src, correction_targets, detection_targets, src_tokens  # src_tokens在计算Correction的精准率时要用到
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

训练

criterion = MDCSpellLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
start_epoch = 0  # 从哪个epoch开始
total_step = 0  # 一共更新了多少次参数
# 恢复之前的训练
if os.path.exists(model_path):if not torch.cuda.is_available():checkpoint = torch.load(model_path, map_location='cpu')else:checkpoint = torch.load(model_path)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']total_step = checkpoint['total_step']print("恢复训练,epoch:", start_epoch)
恢复训练,epoch: 2
model = model.to(device)
model = model.train()

训练这里代码量看起来很大,但实际大多都是计算recall和precision的代码。这里对于Detection的recall和precision的计算使用的是Detection Network的预测结果。

total_loss = 0.  # 记录lossd_recall_numerator = 0  # Detection的Recall的分子
d_recall_denominator = 0  # Detection的Recall的分母
d_precision_numerator = 0  # Detection的precision的分子
d_precision_denominator = 0  # Detection的precision的分母
c_recall_numerator = 0  # Correction的Recall的分子
c_recall_denominator = 0  # Correction的Recall的分母
c_precision_numerator = 0  # Correction的precision的分子
c_precision_denominator = 0  # Correction的precision的分母for epoch in range(start_epoch, epochs):step = 0for sequences, correction_targets, detection_targets, correction_inputs in train_loader:correction_targets, detection_targets = correction_targets.to(device), detection_targets.to(device)correction_inputs = correction_inputs.to(device)correction_outputs, detection_outputs = model(sequences)loss = criterion(correction_outputs, correction_targets, detection_outputs, detection_targets)loss.backward()optimizer.step()optimizer.zero_grad()step += 1total_step += 1total_loss += loss.detach().item()# 计算Detection的recall和precision指标# 大于0.5,认为是错误token,反之为正确tokend_predicts = detection_outputs >= 0.5# 计算错误token中被网络正确预测到的数量d_recall_numerator += d_predicts[detection_targets == 1].sum().item()# 计算错误token的数量d_recall_denominator += (detection_targets == 1).sum().item()# 计算网络预测的错误token的数量d_precision_denominator += d_predicts.sum().item()# 计算网络预测的错误token中,有多少是真错误的tokend_precision_numerator += (detection_targets[d_predicts == 1]).sum().item()# 计算Correction的recall和precision# 将输出映射成index,即将correction_outputs的Shape由(32, 128, 21128)变为(32,128)correction_outputs = correction_outputs.argmax(2)# 对于填充、[CLS]和[SEP]这三个token不校验correction_outputs[(correction_targets == 0) | (correction_targets == 101) | (correction_targets == 102)] = 0# correction_targets的[CLS]和[SEP]也要变为0correction_targets[(correction_targets == 101) | (correction_targets == 102)] = 0# Correction的预测结果,其中True表示预测正确,False表示预测错误或无需预测c_predicts = correction_outputs == correction_targets# 计算错误token中被网络正确纠正的token数量c_recall_numerator += c_predicts[detection_targets == 1].sum().item()# 计算错误token的数量c_recall_denominator += (detection_targets == 1).sum().item()# 计算网络纠正token的数量correction_inputs[(correction_inputs == 101) | (correction_inputs == 102)] = 0c_precision_denominator += (correction_outputs != correction_inputs).sum().item()# 计算在网络纠正的这些token中,有多少是真正被纠正对的c_precision_numerator += c_predicts[correction_outputs != correction_inputs].sum().item()if total_step % log_after_step == 0:loss = total_loss / log_after_stepd_recall = d_recall_numerator / (d_recall_denominator + 1e-9)d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)print("Epoch {}, ""Step {}/{}, ""Total Step {}, ""loss {:.5f}, ""detection recall {:.4f}, ""detection precision {:.4f}, ""correction recall {:.4f}, ""correction precision {:.4f}".format(epoch, step, len(train_loader), total_step,loss,d_recall,d_precision,c_recall,c_precision))total_loss = 0.total_correct = 0total_num = 0d_recall_numerator = 0d_recall_denominator = 0d_precision_numerator = 0d_precision_denominator = 0c_recall_numerator = 0c_recall_denominator = 0c_precision_numerator = 0c_precision_denominator = 0torch.save({'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch + 1,'total_step': total_step,}, model_path)
。。。
Epoch 2, Step 15/8882, Total Step 8900, loss 0.02403, detection recall 0.4118, detection precision 0.8247, correction recall 0.8192, correction precision 0.9485
Epoch 2, Step 35/8882, Total Step 8920, loss 0.03479, detection recall 0.3658, detection precision 0.8055, correction recall 0.8029, correction precision 0.9125
。。。

模型评估

模型评估使用了SIGHAN 2013,2014,2015三个数据集对模型进行评估。对于Detection的Precision和Recall的评估,使用的是Correction Network的结果,这和训练阶段有所不同,这是因为Detection Network只是帮助Correction Network训练的,其结果在使用时不具备参考价值。

model = model.eval()
def evaluation(test_data):d_recall_numerator = 0  # Detection的Recall的分子d_recall_denominator = 0  # Detection的Recall的分母d_precision_numerator = 0  # Detection的precision的分子d_precision_denominator = 0  # Detection的precision的分母c_recall_numerator = 0  # Correction的Recall的分子c_recall_denominator = 0  # Correction的Recall的分母c_precision_numerator = 0  # Correction的precision的分子c_precision_denominator = 0  # Correction的precision的分母prograss = tqdm(range(len(test_data)))for i in prograss:src, tgt = test_data[i]['src'], test_data[i]['tgt']src_tokens = tokenizer(src, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]tgt_tokens = tokenizer(tgt, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]# 正常情况下,src和tgt的长度应该是一致的if len(src_tokens) != len(tgt_tokens):print("第%d条数据异常" % i)continuecorrection_outputs, _ = model(src)predict_tokens = correction_outputs[0][1:len(src_tokens) + 1].argmax(1).detach().cpu()# 计算错误token的数量d_recall_denominator += (src_tokens != tgt_tokens).sum().item()# 计算在这些错误token,有多少网络也认为它是错误的d_recall_numerator += (predict_tokens != src_tokens)[src_tokens != tgt_tokens].sum().item()# 计算网络找出的错误token的数量d_precision_denominator += (predict_tokens != src_tokens).sum().item()# 计算在网络找出的这些错误token中,有多少是真正错误的d_precision_numerator += (src_tokens != tgt_tokens)[predict_tokens != src_tokens].sum().item()# 计算Detection的recall、precision和f1-scored_recall = d_recall_numerator / (d_recall_denominator + 1e-9)d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)d_f1_score = 2 * (d_recall * d_precision) / (d_recall + d_precision + 1e-9)# 计算错误token的数量c_recall_denominator += (src_tokens != tgt_tokens).sum().item()# 计算在这些错误token中,有多少网络预测对了c_recall_numerator += (predict_tokens == tgt_tokens)[src_tokens != tgt_tokens].sum().item()# 计算网络找出的错误token的数量c_precision_denominator += (predict_tokens != src_tokens).sum().item()# 计算网络找出的错误token中,有多少是正确修正的c_precision_numerator += (predict_tokens == tgt_tokens)[predict_tokens != src_tokens].sum().item()# 计算Correction的recall、precision和f1-scorec_recall = c_recall_numerator / (c_recall_denominator + 1e-9)c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)c_f1_score = 2 * (c_recall * c_precision) / (c_recall + c_precision + 1e-9)prograss.set_postfix({'d_recall': d_recall,'d_precision': d_precision,'d_f1_score': d_f1_score,'c_recall': c_recall,'c_precision': c_precision,'c_f1_score': c_f1_score,})
with open("data/test.sighan13.pkl", mode='br') as f:sighan13 = pickle.load(f)
evaluation(sighan13)
100%|██████████| 1000/1000 [00:11<00:00, 90.12it/s, d_recall=0.799, d_precision=0.802, d_f1_score=0.8, c_recall=0.769, c_precision=0.772, c_f1_score=0.771]
with open("data/test.sighan14.pkl", mode='br') as f:sighan14 = pickle.load(f)
evaluation(sighan14)
100%|██████████| 1062/1062 [00:12<00:00, 85.48it/s, d_recall=0.666, d_precision=0.828, d_f1_score=0.738, c_recall=0.643, c_precision=0.799, c_f1_score=0.712]
with open("data/test.sighan15.pkl", mode='br') as f:sighan15 = pickle.load(f)
evaluation(sighan15)
100%|██████████| 1100/1100 [00:11<00:00, 92.04it/s, d_recall=0.761, d_precision=0.867, d_f1_score=0.811, c_recall=0.725, c_precision=0.827, c_f1_score=0.773]

模型使用

最后,我们来真正的使用一下该模型,看下效果:

def predict(text):sequences = [text]correction_outputs, _ = model(sequences)tokens = correction_outputs[0][1:len(text) + 1].argmax(1)return ''.join(tokenizer.convert_ids_to_tokens(tokens))
predict("今天早上我吃了以个火聋果")
'今天早上我吃了一个火聋果'
predict("我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳RAP蓝球")
'我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳ra##p蓝球[SEP]'

虽然在数据上模型表现还不错,但在真正使用场景上,效果还是不够好。中文文本纠错果然是一个比较难的任务 T_T !


参考文献

MDCSpell论文: https://aclanthology.org/2022.findings-acl.98/

MDCSpell论文笔记:https://blog.csdn.net/zhaohongfei_358/article/details/126973451

【论文复现】MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文复现相关推荐

  1. 【论文笔记】MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction

    文章目录 论文内容 论文思路 模型架构 损失函数 训练细节 实验结果 个人总结 论文复现 : https://blog.csdn.net/zhaohongfei_358/article/details ...

  2. 论文解读:DCSpell:A Detector-Corrector Framework for Chinese Spelling Error Correction

    论文解读:DCSpell:A Detector-Corrector Framework for Chinese Spelling Error Correction 简要信息: 序号 属性 值 1 模型 ...

  3. 论文解读:SpellBERT:A Lightweight Pretrained Model for Chinese Spelling Checking

    论文解读:SpellBERT:A Lightweight Pretrained Model for Chinese Spelling Checking 简要信息: 序号 属性 值 1 模型名称 Spe ...

  4. 多智能体强化学习Multi agent,多任务强化学习Multi task以及多智能体多任务强化学习Multi agent Multi task概述

    概述 在我之前的工作中,我自己总结了一些多智能体强化学习的算法和通俗的理解. 首先,关于题目中提到的这三个家伙,大家首先想到的就是强化学习的五件套: 状态:s 奖励:r 动作值:Q 状态值:V 策略: ...

  5. multi task训练torch_采用single task模型蒸馏到Multi-Task Networks

    论文地址. 这篇论文主要研究利用各个single task model来分别作为teacher model,用knowledge distillation的方法指导一个multi task model ...

  6. Multi task learning多任务学习背景简介

    2020-06-16 23:22:33 本篇文章将介绍在机器学习中效果比较好的一种模式,多任务学习(Multi task Learning,MTL).已经有一篇机器之心翻译的很好的博文介绍多任务学习了 ...

  7. 【论文相关】1.1 T 的 arXiv 数据集:170 万篇论文,可以看到下辈子

    By 超神经 内容提要:近日,arXiv 将 170 万+ 篇的论文,打包成数据集,放在了 kaggle 平台,以后访问和下载论文,就更方便了.该数据集目前大小 1.1 TB 左右,而且之后还会随着每 ...

  8. 论文投稿新规则,不用跑出SOTA,还能“内定”发论文?!

    文 | Sheryc_王苏 从5月初开始,CV圈似乎开始了一阵MLP"文艺复兴"的热潮:在短短4天时间里,来自谷歌.清华.牛津.Facebook四个顶级研究机构的研究者分别独立发布 ...

  9. 【学术相关】顶级论文创新点怎么找?中国高校首次获CVPR最佳学生论文奖有感...

    几天前,同济大学公布了一条重磅消息:本校学生陈涵晟获得CVPR2022最佳学生论文奖,这也是CVPR自2001年设立最佳学生论文奖以来,获奖论文的第一作者首次来自中国高校. 华人在CV领域崛起 最近几 ...

最新文章

  1. iOS MMDrawerController源码解读(一)
  2. python编写篮球_Python编程2——Python实现计算篮球比赛是否领先安全的程序
  3. python unpack_ip地址处理每天10行python代码系列
  4. nginx+web.py+fastcgi(spawn-fcgi)的session失效問題
  5. Linux文件系统中的inode
  6. Wi-Fi 真的安全吗?一行代码就可让周边无线网络全部瘫痪!| 原力计划
  7. 自学android编程教程,安卓编程入门教程 安卓编程如何自学
  8. lg g2 android 5.0 rom,LG G2(D802)升级Flyme4.5图文教程
  9. Delphi中使用Imageen控件将图像文件转换成PDF
  10. 解决Android部分手机图片剪切返回崩溃问题
  11. 基于STM32单片机的FM调频TEA5767功放收音机方案原理图设计
  12. 职称英语职称计算机如何折算为学时,发表论文算继续教育多少学时
  13. 讯飞语音报错:未经授权的语音应用.(错误码:11210)
  14. PC常见故障及解决思路汇总(网络方面)
  15. 2020-2021前端面试题合集,面试题附答案
  16. 教你用 Python 快速批量转换 HEIC 文件
  17. deepin外置键盘无法打开键盘背光灯的解决方法
  18. 小米路由器的服务器无响应怎么回事,小米路由器常见问题与解决方法(高级功能) | 192路由网...
  19. 从51开始的单片机之旅(二)----LCD1602液晶、ADC0809、DAC0832
  20. `Computer-Algorithm` 最小生成树MST,Prim,Kruskal,次小生成树

热门文章

  1. 目前已确定转行开个淘宝店,想征集一个淘宝店名。
  2. 超链接做按钮 禁止跳转 submit 或 location 导致 return false 不起作用
  3. OleDbDataReader的一点属性和方法
  4. iOS精仿唱吧下载按钮、仿知乎日报、自定义提示视图、过渡动画、记录应用等源码...
  5. 【万字总结】以插排和分治为例来看如何分析与设计算法
  6. win10休眠按钮在“选择电源按钮功能”配置中找不到
  7. 服装创业未来的出路在哪里?
  8. mysql declare 语法_sql_declare等语法 | 学步园
  9. 快乐是福二级域名分发网美化版源码
  10. ldd命令 ubuntu_技术|简单介绍 ldd 命令