【论文复现】MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文复现
文章目录
- 本文内容
- 环境配置
- 全局变量
- 模型构建
- 损失函数
- 模型训练
- 构造Dataset
- 构造Dataloader
- 训练
- 模型评估
- 模型使用
- 参考文献
代码地址 :https://github.com/iioSnail/MDCSpell_pytorch
本文内容
本文为MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文的Pytorch实现。
论文地址: https://aclanthology.org/2022.findings-acl.98/
论文年份:2022
论文笔记:https://blog.csdn.net/zhaohongfei_358/article/details/126973451
论文大致内容:作者基于Transformer和BERT设计了一个多任务的网络来进行CSC(Chinese Spell Checking)任务(中文拼写纠错)。多任务分别是找出哪个字是错的和对错字进行纠正。
由于作者并没有公开代码,所以我就尝试自己实现一个,最终我的实验结果如下表:
Dataset | Model | D_Precision | D_Recall | D_F1 | C_Prec | C_Rec | C_F1 |
---|---|---|---|---|---|---|---|
SIGHAN 13 | MDCSpell | 89.1 | 78.3 | 83.4 | 87.5 | 76.8 | 81.8 |
SIGHAN 13 | MDCSpell(复现) | 80.2 | 79.9 | 80.0 | 77.2 | 76.9 | 77.1 |
SIGHAN 14 | MDCSpell | 70.2 | 68.8 | 69.5 | 69.0 | 67.7 | 68.3 |
SIGHAN 14 | MDCSpell(复现) | 82.8 | 66.6 | 73.8 | 79.9 | 64.3 | 71.2 |
SIGHAN 15 | MDCSpell | 80.8 | 80.6 | 80.7 | 78.4 | 78.2 | 78.3 |
SIGHAN 15 | MDCSpell(复现) | 86.7 | 76.1 | 81.1 | 72.5 | 82.7 | 77.3 |
这里是我训练了2个epoch的结果,与作者的结论相差不大。如果我增加训练次数的话,也许可以和作者的结果达到一致。
补充:这里有问题,论文使用的是Sentence-level,而我的是Character-level,所以我并没有复现出作者的效果。待后续有时间再尝试一下。
环境配置
try:import transformers
except:!pip install transformers
import os
import copy
import pickleimport torch
import transformersfrom torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
torch.__version__
'1.12.1+cu113'
transformers.__version__
'4.21.3'
全局变量
# 句子的长度,作者并没有说明。我这里就按经验取一个
max_length = 128
# 作者使用的batch_size
batch_size = 32
# epoch数,作者并没有具体说明,按经验取一个
epochs = 10# 每${log_after_step}步,打印一次日志
log_after_step = 20# 模型存放的位置。
model_path = './drive/MyDrive/models/MDCSpell/'
os.makedirs(model_path, exist_ok=True)
model_path = model_path + 'MDCSpell-model.pt'device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
Device: cuda
模型构建
Correction Network 的数据流向如下:
1.将token序列 [CLS] 遇 到 逆 竟 [SEP]
送给Word Embedding模块进行embeddings,得到向量 { e C L S w , e 1 w , e 2 w , e 3 w , e 4 w , e S E P w } \{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\} {eCLSw,e1w,e2w,e3w,e4w,eSEPw}。
个人认为此时的embedding仅仅是Word Embeding,并不包含Position Embedding和Segment Embedding。
2.之后将 { e C L S w , e 1 w , e 2 w , e 3 w , e 4 w , e S E P w } \{e_{CLS}^w, e_1^w, e_2^w, e_3^w, e_4^w,e_{SEP}^w\} {eCLSw,e1w,e2w,e3w,e4w,eSEPw}向量送入BERT,增加Position Embedding和Segment Embedding,得到 { e C , e 1 , e 2 , e 3 , e 4 , e S } \{e_C, e_1, e_2, e_3, e_4, e_S\} {eC,e1,e2,e3,e4,eS}。
3.在BERT内部,会经历多层的TransformerEncoder,最终的得到输出向量 H c = { h C c , h 1 c , h 2 c , h 3 c , h 4 c , h S c } H^c=\{h_C^c, h_1^c, h_2^c, h_3^c, h_4^c, h_S^c\} Hc={hCc,h1c,h2c,h3c,h4c,hSc}.
4.将BERT的输出 H c H^c Hc 和 隔壁Detection Network输出的 H d H^d Hd 进行融合,得到 H = H d + H c H = H^d+H^c H=Hd+Hc
融合时并不对
[CLS]
和[SEP]
进行融合
5.将 H H H送给全连接层(Dense Layer)做最后的预测。
Correction Network模型细节:
- BERT:作者使用的是具有12层Transformer Block的BERT-base版。
- Dense Layer:Dense Layer的输入通道为词向量维度,输出通道为词典大小。例如:词向量维度为768,词典大小为20000,则Dense Layer则为
nn.Linear(768, 20000)
- Dense Layer的初始化:Dense Layer的权重使用的是Word Embedding的参数。因为word Embedding是将词index转成词向量,所以其参数刚好是Dense Layer的转置,即Word Embedding是
nn.Linear(20000, 768)
,所以作者就是用Word Embedding的转置来初始化Dense Layer的参数。因为这样可以加速训练,且使模型变的稳定。
Detection Network的数据流向如下:
1.输入为使用BERT得到的word Embedding { e 1 w , e 2 w , e 3 w , e 4 w } \{e_1^w, e_2^w, e_3^w, e_4^w\} {e1w,e2w,e3w,e4w}。虽然图里并不包含[CLS]
和[SEP]
的词向量,但个人认为不需要对其特殊处理,因为最后的预测也用不到这两个token.
2.将 { e 1 w , e 2 w , e 3 w , e 4 w } \{e_1^w, e_2^w, e_3^w, e_4^w\} {e1w,e2w,e3w,e4w}增加Position Embedding信息,得到 { e 1 ′ , e 2 ′ , e 3 ′ , e 4 ′ } \{e_1', e_2', e_3', e_4'\} {e1′,e2′,e3′,e4′}
在论文中说Detection Network使用的是向量 { e 1 , e 2 , e 3 , e 4 } \{e_1, e_2, e_3, e_4\} {e1,e2,e3,e4},其是word embedding+position embedding+segment embedding。这与图上是矛盾的,这里以图为准了。
3.将 { e 1 ′ , e 2 ′ , e 3 ′ , e 4 ′ } \{e_1', e_2', e_3', e_4'\} {e1′,e2′,e3′,e4′}向量送入Transformer Block,得到输出向量 H d = { h 1 d , h 2 d , h 3 d , h 4 d } H^d=\{h_1^d, h_2^d, h_3^d, h_4^d\} Hd={h1d,h2d,h3d,h4d}
4.一方面,将输出向量 H d H^d Hd送给隔壁的Correction Network进行融合;另一方面,将 H d H^d Hd送给后续的全连接层(Dense Layer)来判断哪个token是错误的.
Detection Network的细节:
- Transformer Block:Transformer Block是2层的TransformerEncoder。
- Transformer Block参数初始化:Transformer Block参数初始化使用的是BERT的权重。
- Dense Layer:Dense Layer的输入通道为词向量大小,输出通道为1。使用Sigmoid来判别该token为错字的概率。
class CorrectionNetwork(nn.Module):def __init__(self):super(CorrectionNetwork, self).__init__()# BERT分词器,作者并没提到自己使用的是哪个中文版的bert,我这里就使用一个比较常用的self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")# BERTself.bert = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")# BERT的word embedding,本质就是个nn.Embeddingself.word_embedding_table = self.bert.get_input_embeddings()# 预测层。hidden_size是词向量的大小,len(self.tokenizer)是词典大小self.dense_layer = nn.Linear(self.bert.config.hidden_size, len(self.tokenizer))def forward(self, inputs, word_embeddings, detect_hidden_states):"""Correction Network的前向传递:param inputs: inputs为tokenizer对中文文本的分词结果,里面包含了token对一个的index,attention_mask等:param word_embeddings: 使用BERT的word_embedding对token进行embedding后的结果:param detect_hidden_states: Detection Network输出hidden state:return: Correction Network对个token的预测结果。"""# 1. 使用bert进行前向传递bert_outputs = self.bert(token_type_ids=inputs['token_type_ids'],attention_mask=inputs['attention_mask'],inputs_embeds=word_embeddings)# 2. 将bert的hidden_state和Detection Network的hidden state进行融合。hidden_states = bert_outputs['last_hidden_state'] + detect_hidden_states# 3. 最终使用全连接层进行token预测return self.dense_layer(hidden_states)def get_inputs_and_word_embeddings(self, sequences, max_length=128):"""对中文序列进行分词和word embeddings处理:param sequences: 中文文本序列。例如: ["鸡你太美", "哎呦,你干嘛!"]:param max_length: 文本的最大长度,不足则进行填充,超出进行裁剪。:return: tokenizer的输出和word embeddings."""inputs = self.tokenizer(sequences, padding='max_length', max_length=max_length, return_tensors='pt',truncation=True).to(device)# 使用BERT的work embeddings对token进行embedding,这里得到的embedding并不包含position embedding和segment embeddingword_embeddings = self.word_embedding_table(inputs['input_ids'])return inputs, word_embeddings
class DetectionNetwork(nn.Module):def __init__(self, position_embeddings, transformer_blocks, hidden_size):""":param position_embeddings: bert的position_embeddings,本质是一个nn.Embedding:param transformer: BERT的前两层transformer_block,其是一个ModuleList对象"""super(DetectionNetwork, self).__init__()self.position_embeddings = position_embeddingsself.transformer_blocks = transformer_blocks# 定义最后的预测层,预测哪个token是错误的self.dense_layer = nn.Sequential(nn.Linear(hidden_size, 1),nn.Sigmoid())def forward(self, word_embeddings):# 获取token序列的长度,这里为128sequence_length = word_embeddings.size(1)# 生成position embeddingposition_embeddings = self.position_embeddings(torch.LongTensor(range(sequence_length)).to(device))# 融合work_embedding和position_embeddingx = word_embeddings + position_embeddings# 将x一层一层的使用transformer encoder进行向后传递for transformer_layer in self.transformer_blocks:x = transformer_layer(x)[0]# 最终返回Detection Network输出的hidden states和预测结果hidden_states = xreturn hidden_states, self.dense_layer(hidden_states)
class MDCSpellModel(nn.Module):def __init__(self):super(MDCSpellModel, self).__init__()# 构造Correction Networkself.correction_network = CorrectionNetwork()self._init_correction_dense_layer()# 构造Detection Network# position embedding使用BERT的position_embeddings = self.correction_network.bert.embeddings.position_embeddings# 作者在论文中提到的,Detection Network的Transformer使用BERT的权重# 所以我这里直接克隆BERT的前两层Transformer来完成这个动作transformer = copy.deepcopy(self.correction_network.bert.encoder.layer[:2])# 提取BERT的词向量大小hidden_size = self.correction_network.bert.config.hidden_size# 构造Detection Networkself.detection_network = DetectionNetwork(position_embeddings, transformer, hidden_size)def forward(self, sequences, max_length=128):# 先获取word embedding,Correction Network和Detection Network都要用inputs, word_embeddings = self.correction_network.get_inputs_and_word_embeddings(sequences, max_length)# Detection Network进行前向传递,获取输出的Hidden State和预测结果hidden_states, detection_outputs = self.detection_network(word_embeddings)# Correction Network进行前向传递,获取其预测结果correction_outputs = self.correction_network(inputs, word_embeddings, hidden_states)# 返回Correction Network 和 Detection Network 的预测结果。# 在计算损失时`[PAD]`token不需要参与计算,所以这里将`[PAD]`部分全都变为0return correction_outputs, detection_outputs.squeeze(2) * inputs['attention_mask']def _init_correction_dense_layer(self):"""原论文中提到,使用Word Embedding的weight来对Correction Network进行初始化"""self.correction_network.dense_layer.weight.data = self.correction_network.word_embedding_table.weight.data
定义好模型后,我们来简单的尝试一下:
model = MDCSpellModel().to(device)
correction_outputs, detection_outputs = model(["鸡你太美", "哎呦,你干嘛!"])
print("correction_outputs shape:", correction_outputs.size())
print("detection_outputs shape:", detection_outputs.size())
correction_outputs shape: torch.Size([2, 128, 21128])
detection_outputs shape: torch.Size([2, 128])
损失函数
Correction Network和Detection Network使用的都是Cross Entropy。之后进行相加即可:
L = λ L c + ( 1 − λ ) L d L = \lambda L^c + (1-\lambda) L^d L=λLc+(1−λ)Ld
其中 λ ∈ [ 0 , 1 ] \lambda \in [0,1] λ∈[0,1] 。作者通过实验得出 λ = 0.85 \lambda=0.85 λ=0.85 时效果最好。
class MDCSpellLoss(nn.Module):def __init__(self, coefficient=0.85):super(MDCSpellLoss, self).__init__()# 定义Correction Network的Loss函数self.correction_criterion = nn.CrossEntropyLoss(ignore_index=0)# 定义Detection Network的Loss函数,因为是二分类,所以用Binary Cross Entropyself.detection_criterion = nn.BCELoss()# 权重系数self.coefficient = coefficientdef forward(self, correction_outputs, correction_targets, detection_outputs, detection_targets):""":param correction_outputs: Correction Network的输出,Shape为(batch_size, sequence_length, hidden_size):param correction_targets: Correction Network的标签,Shape为(batch_size, sequence_length):param detection_outputs: Detection Network的输出,Shape为(batch_size, sequence_length):param detection_targets: Detection Network的标签,Shape为(batch_size, sequence_length):return:"""# 计算Correction Network的loss,因为Shape维度为3,所以要把batch_size和sequence_length进行合并才能计算correction_loss = self.correction_criterion(correction_outputs.view(-1, correction_outputs.size(2)),correction_targets.view(-1))# 计算Detection Network的lossdetection_loss = self.detection_criterion(detection_outputs, detection_targets)# 对两个loss进行加权平均return self.coefficient * correction_loss + (1 - self.coefficient) * detection_loss
模型训练
作者的训练方式:
第一步,首先使用 Wang271K(自己造的假数据) 数据集进行训练。batch size为32, learning rate为2e-5
第二步,使用SIGHAN训练集进行fine-tune。 batch size为32,learning rate为1e-5
作者并没有提到使用的是什么Optimizer,但看这个学习率,应该是Adam。
在第一步,作者说的是使用了几乎3M个,但作者只提到过Wang271K这个数据集,我猜可能作者看错了,这个是0.3M条数据,而不是3M。
作者首先使用了Wang271K数据集进行对模型进行训练,然后又使用SIGHAN训练集对模型进行fine-tune。这里我就不进行fine-tune了,直接进行训练。我这里使用的是 ReaLiSe论文 处理好的数据集,其就是Wang271K和SIGHAN。
百度网盘链接 :https://pan.baidu.com/s/1x67LPiYAjLKhO1_2CI6aOA?pwd=skda
下载好直接解压即可。
构造Dataset
class CSCDataset(Dataset):def __init__(self):super(CSCDataset, self).__init__()with open("data/trainall.times2.pkl", mode='br') as f:train_data = pickle.load(f)self.train_data = train_datadef __getitem__(self, index):src = self.train_data[index]['src']tgt = self.train_data[index]['tgt']return src, tgtdef __len__(self):return len(self.train_data)
train_data = CSCDataset()
train_data.__getitem__(0)
('纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一豪元以上。','纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一美元以上。')
构造Dataloader
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
def collate_fn(batch):src, tgt = zip(*batch)src, tgt = list(src), list(tgt)src_tokens = tokenizer(src, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']tgt_tokens = tokenizer(tgt, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']correction_targets = tgt_tokensdetection_targets = (src_tokens != tgt_tokens).float()return src, correction_targets, detection_targets, src_tokens # src_tokens在计算Correction的精准率时要用到
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
训练
criterion = MDCSpellLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
start_epoch = 0 # 从哪个epoch开始
total_step = 0 # 一共更新了多少次参数
# 恢复之前的训练
if os.path.exists(model_path):if not torch.cuda.is_available():checkpoint = torch.load(model_path, map_location='cpu')else:checkpoint = torch.load(model_path)model.load_state_dict(checkpoint['model'])optimizer.load_state_dict(checkpoint['optimizer'])start_epoch = checkpoint['epoch']total_step = checkpoint['total_step']print("恢复训练,epoch:", start_epoch)
恢复训练,epoch: 2
model = model.to(device)
model = model.train()
训练这里代码量看起来很大,但实际大多都是计算recall和precision的代码。这里对于Detection的recall和precision的计算使用的是Detection Network的预测结果。
total_loss = 0. # 记录lossd_recall_numerator = 0 # Detection的Recall的分子
d_recall_denominator = 0 # Detection的Recall的分母
d_precision_numerator = 0 # Detection的precision的分子
d_precision_denominator = 0 # Detection的precision的分母
c_recall_numerator = 0 # Correction的Recall的分子
c_recall_denominator = 0 # Correction的Recall的分母
c_precision_numerator = 0 # Correction的precision的分子
c_precision_denominator = 0 # Correction的precision的分母for epoch in range(start_epoch, epochs):step = 0for sequences, correction_targets, detection_targets, correction_inputs in train_loader:correction_targets, detection_targets = correction_targets.to(device), detection_targets.to(device)correction_inputs = correction_inputs.to(device)correction_outputs, detection_outputs = model(sequences)loss = criterion(correction_outputs, correction_targets, detection_outputs, detection_targets)loss.backward()optimizer.step()optimizer.zero_grad()step += 1total_step += 1total_loss += loss.detach().item()# 计算Detection的recall和precision指标# 大于0.5,认为是错误token,反之为正确tokend_predicts = detection_outputs >= 0.5# 计算错误token中被网络正确预测到的数量d_recall_numerator += d_predicts[detection_targets == 1].sum().item()# 计算错误token的数量d_recall_denominator += (detection_targets == 1).sum().item()# 计算网络预测的错误token的数量d_precision_denominator += d_predicts.sum().item()# 计算网络预测的错误token中,有多少是真错误的tokend_precision_numerator += (detection_targets[d_predicts == 1]).sum().item()# 计算Correction的recall和precision# 将输出映射成index,即将correction_outputs的Shape由(32, 128, 21128)变为(32,128)correction_outputs = correction_outputs.argmax(2)# 对于填充、[CLS]和[SEP]这三个token不校验correction_outputs[(correction_targets == 0) | (correction_targets == 101) | (correction_targets == 102)] = 0# correction_targets的[CLS]和[SEP]也要变为0correction_targets[(correction_targets == 101) | (correction_targets == 102)] = 0# Correction的预测结果,其中True表示预测正确,False表示预测错误或无需预测c_predicts = correction_outputs == correction_targets# 计算错误token中被网络正确纠正的token数量c_recall_numerator += c_predicts[detection_targets == 1].sum().item()# 计算错误token的数量c_recall_denominator += (detection_targets == 1).sum().item()# 计算网络纠正token的数量correction_inputs[(correction_inputs == 101) | (correction_inputs == 102)] = 0c_precision_denominator += (correction_outputs != correction_inputs).sum().item()# 计算在网络纠正的这些token中,有多少是真正被纠正对的c_precision_numerator += c_predicts[correction_outputs != correction_inputs].sum().item()if total_step % log_after_step == 0:loss = total_loss / log_after_stepd_recall = d_recall_numerator / (d_recall_denominator + 1e-9)d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)print("Epoch {}, ""Step {}/{}, ""Total Step {}, ""loss {:.5f}, ""detection recall {:.4f}, ""detection precision {:.4f}, ""correction recall {:.4f}, ""correction precision {:.4f}".format(epoch, step, len(train_loader), total_step,loss,d_recall,d_precision,c_recall,c_precision))total_loss = 0.total_correct = 0total_num = 0d_recall_numerator = 0d_recall_denominator = 0d_precision_numerator = 0d_precision_denominator = 0c_recall_numerator = 0c_recall_denominator = 0c_precision_numerator = 0c_precision_denominator = 0torch.save({'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch + 1,'total_step': total_step,}, model_path)
。。。
Epoch 2, Step 15/8882, Total Step 8900, loss 0.02403, detection recall 0.4118, detection precision 0.8247, correction recall 0.8192, correction precision 0.9485
Epoch 2, Step 35/8882, Total Step 8920, loss 0.03479, detection recall 0.3658, detection precision 0.8055, correction recall 0.8029, correction precision 0.9125
。。。
模型评估
模型评估使用了SIGHAN 2013,2014,2015三个数据集对模型进行评估。对于Detection的Precision和Recall的评估,使用的是Correction Network的结果,这和训练阶段有所不同,这是因为Detection Network只是帮助Correction Network训练的,其结果在使用时不具备参考价值。
model = model.eval()
def evaluation(test_data):d_recall_numerator = 0 # Detection的Recall的分子d_recall_denominator = 0 # Detection的Recall的分母d_precision_numerator = 0 # Detection的precision的分子d_precision_denominator = 0 # Detection的precision的分母c_recall_numerator = 0 # Correction的Recall的分子c_recall_denominator = 0 # Correction的Recall的分母c_precision_numerator = 0 # Correction的precision的分子c_precision_denominator = 0 # Correction的precision的分母prograss = tqdm(range(len(test_data)))for i in prograss:src, tgt = test_data[i]['src'], test_data[i]['tgt']src_tokens = tokenizer(src, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]tgt_tokens = tokenizer(tgt, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]# 正常情况下,src和tgt的长度应该是一致的if len(src_tokens) != len(tgt_tokens):print("第%d条数据异常" % i)continuecorrection_outputs, _ = model(src)predict_tokens = correction_outputs[0][1:len(src_tokens) + 1].argmax(1).detach().cpu()# 计算错误token的数量d_recall_denominator += (src_tokens != tgt_tokens).sum().item()# 计算在这些错误token,有多少网络也认为它是错误的d_recall_numerator += (predict_tokens != src_tokens)[src_tokens != tgt_tokens].sum().item()# 计算网络找出的错误token的数量d_precision_denominator += (predict_tokens != src_tokens).sum().item()# 计算在网络找出的这些错误token中,有多少是真正错误的d_precision_numerator += (src_tokens != tgt_tokens)[predict_tokens != src_tokens].sum().item()# 计算Detection的recall、precision和f1-scored_recall = d_recall_numerator / (d_recall_denominator + 1e-9)d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)d_f1_score = 2 * (d_recall * d_precision) / (d_recall + d_precision + 1e-9)# 计算错误token的数量c_recall_denominator += (src_tokens != tgt_tokens).sum().item()# 计算在这些错误token中,有多少网络预测对了c_recall_numerator += (predict_tokens == tgt_tokens)[src_tokens != tgt_tokens].sum().item()# 计算网络找出的错误token的数量c_precision_denominator += (predict_tokens != src_tokens).sum().item()# 计算网络找出的错误token中,有多少是正确修正的c_precision_numerator += (predict_tokens == tgt_tokens)[predict_tokens != src_tokens].sum().item()# 计算Correction的recall、precision和f1-scorec_recall = c_recall_numerator / (c_recall_denominator + 1e-9)c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)c_f1_score = 2 * (c_recall * c_precision) / (c_recall + c_precision + 1e-9)prograss.set_postfix({'d_recall': d_recall,'d_precision': d_precision,'d_f1_score': d_f1_score,'c_recall': c_recall,'c_precision': c_precision,'c_f1_score': c_f1_score,})
with open("data/test.sighan13.pkl", mode='br') as f:sighan13 = pickle.load(f)
evaluation(sighan13)
100%|██████████| 1000/1000 [00:11<00:00, 90.12it/s, d_recall=0.799, d_precision=0.802, d_f1_score=0.8, c_recall=0.769, c_precision=0.772, c_f1_score=0.771]
with open("data/test.sighan14.pkl", mode='br') as f:sighan14 = pickle.load(f)
evaluation(sighan14)
100%|██████████| 1062/1062 [00:12<00:00, 85.48it/s, d_recall=0.666, d_precision=0.828, d_f1_score=0.738, c_recall=0.643, c_precision=0.799, c_f1_score=0.712]
with open("data/test.sighan15.pkl", mode='br') as f:sighan15 = pickle.load(f)
evaluation(sighan15)
100%|██████████| 1100/1100 [00:11<00:00, 92.04it/s, d_recall=0.761, d_precision=0.867, d_f1_score=0.811, c_recall=0.725, c_precision=0.827, c_f1_score=0.773]
模型使用
最后,我们来真正的使用一下该模型,看下效果:
def predict(text):sequences = [text]correction_outputs, _ = model(sequences)tokens = correction_outputs[0][1:len(text) + 1].argmax(1)return ''.join(tokenizer.convert_ids_to_tokens(tokens))
predict("今天早上我吃了以个火聋果")
'今天早上我吃了一个火聋果'
predict("我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳RAP蓝球")
'我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳ra##p蓝球[SEP]'
虽然在数据上模型表现还不错,但在真正使用场景上,效果还是不够好。中文文本纠错果然是一个比较难的任务 T_T !
参考文献
MDCSpell论文: https://aclanthology.org/2022.findings-acl.98/
MDCSpell论文笔记:https://blog.csdn.net/zhaohongfei_358/article/details/126973451
【论文复现】MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文复现相关推荐
- 【论文笔记】MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction
文章目录 论文内容 论文思路 模型架构 损失函数 训练细节 实验结果 个人总结 论文复现 : https://blog.csdn.net/zhaohongfei_358/article/details ...
- 论文解读:DCSpell:A Detector-Corrector Framework for Chinese Spelling Error Correction
论文解读:DCSpell:A Detector-Corrector Framework for Chinese Spelling Error Correction 简要信息: 序号 属性 值 1 模型 ...
- 论文解读:SpellBERT:A Lightweight Pretrained Model for Chinese Spelling Checking
论文解读:SpellBERT:A Lightweight Pretrained Model for Chinese Spelling Checking 简要信息: 序号 属性 值 1 模型名称 Spe ...
- 多智能体强化学习Multi agent,多任务强化学习Multi task以及多智能体多任务强化学习Multi agent Multi task概述
概述 在我之前的工作中,我自己总结了一些多智能体强化学习的算法和通俗的理解. 首先,关于题目中提到的这三个家伙,大家首先想到的就是强化学习的五件套: 状态:s 奖励:r 动作值:Q 状态值:V 策略: ...
- multi task训练torch_采用single task模型蒸馏到Multi-Task Networks
论文地址. 这篇论文主要研究利用各个single task model来分别作为teacher model,用knowledge distillation的方法指导一个multi task model ...
- Multi task learning多任务学习背景简介
2020-06-16 23:22:33 本篇文章将介绍在机器学习中效果比较好的一种模式,多任务学习(Multi task Learning,MTL).已经有一篇机器之心翻译的很好的博文介绍多任务学习了 ...
- 【论文相关】1.1 T 的 arXiv 数据集:170 万篇论文,可以看到下辈子
By 超神经 内容提要:近日,arXiv 将 170 万+ 篇的论文,打包成数据集,放在了 kaggle 平台,以后访问和下载论文,就更方便了.该数据集目前大小 1.1 TB 左右,而且之后还会随着每 ...
- 论文投稿新规则,不用跑出SOTA,还能“内定”发论文?!
文 | Sheryc_王苏 从5月初开始,CV圈似乎开始了一阵MLP"文艺复兴"的热潮:在短短4天时间里,来自谷歌.清华.牛津.Facebook四个顶级研究机构的研究者分别独立发布 ...
- 【学术相关】顶级论文创新点怎么找?中国高校首次获CVPR最佳学生论文奖有感...
几天前,同济大学公布了一条重磅消息:本校学生陈涵晟获得CVPR2022最佳学生论文奖,这也是CVPR自2001年设立最佳学生论文奖以来,获奖论文的第一作者首次来自中国高校. 华人在CV领域崛起 最近几 ...
最新文章
- iOS MMDrawerController源码解读(一)
- python编写篮球_Python编程2——Python实现计算篮球比赛是否领先安全的程序
- python unpack_ip地址处理每天10行python代码系列
- nginx+web.py+fastcgi(spawn-fcgi)的session失效問題
- Linux文件系统中的inode
- Wi-Fi 真的安全吗?一行代码就可让周边无线网络全部瘫痪!| 原力计划
- 自学android编程教程,安卓编程入门教程 安卓编程如何自学
- lg g2 android 5.0 rom,LG G2(D802)升级Flyme4.5图文教程
- Delphi中使用Imageen控件将图像文件转换成PDF
- 解决Android部分手机图片剪切返回崩溃问题
- 基于STM32单片机的FM调频TEA5767功放收音机方案原理图设计
- 职称英语职称计算机如何折算为学时,发表论文算继续教育多少学时
- 讯飞语音报错:未经授权的语音应用.(错误码:11210)
- PC常见故障及解决思路汇总(网络方面)
- 2020-2021前端面试题合集,面试题附答案
- 教你用 Python 快速批量转换 HEIC 文件
- deepin外置键盘无法打开键盘背光灯的解决方法
- 小米路由器的服务器无响应怎么回事,小米路由器常见问题与解决方法(高级功能) | 192路由网...
- 从51开始的单片机之旅(二)----LCD1602液晶、ADC0809、DAC0832
- `Computer-Algorithm` 最小生成树MST,Prim,Kruskal,次小生成树
热门文章
- 目前已确定转行开个淘宝店,想征集一个淘宝店名。
- 超链接做按钮 禁止跳转 submit 或 location 导致 return false 不起作用
- OleDbDataReader的一点属性和方法
- iOS精仿唱吧下载按钮、仿知乎日报、自定义提示视图、过渡动画、记录应用等源码...
- 【万字总结】以插排和分治为例来看如何分析与设计算法
- win10休眠按钮在“选择电源按钮功能”配置中找不到
- 服装创业未来的出路在哪里?
- mysql declare 语法_sql_declare等语法 | 学步园
- 快乐是福二级域名分发网美化版源码
- ldd命令 ubuntu_技术|简单介绍 ldd 命令