看文章标题,mask可以定位到mask language model,代表模型是bert一系列的成果。

mask可以分为token mask和whole Word mask,怎么实现?
两者的区别是什么?
整个实现过程可以借鉴transformer的源码。

whole Word mask

添加链接描述

@dataclass
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):"""Data collator used for language modeling.- collates batches of tensors, honoring their tokenizer's pad_token- preprocesses batches for masked language modeling"""def __call__(self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:if isinstance(examples[0], (dict, BatchEncoding)):input_ids = [e["input_ids"] for e in examples]else:input_ids = examplesexamples = [{"input_ids": e} for e in examples]batch_input = _collate_batch(input_ids, self.tokenizer)mask_labels = []for e in examples:ref_tokens = []for id in tolist(e["input_ids"]):token = self.tokenizer._convert_id_to_token(id)ref_tokens.append(token)# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]if "chinese_ref" in e:ref_pos = tolist(e["chinese_ref"])len_seq = len(e["input_ids"])for i in range(len_seq):if i in ref_pos:ref_tokens[i] = "##" + ref_tokens[i]mask_labels.append(self._whole_word_mask(ref_tokens))batch_mask = _collate_batch(mask_labels, self.tokenizer)inputs, labels = self.mask_tokens(batch_input, batch_mask)return {"input_ids": inputs, "labels": labels}def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):"""Get 0/1 labels for masked tokens with whole word mask proxy"""cand_indexes = []for (i, token) in enumerate(input_tokens):if token == "[CLS]" or token == "[SEP]":continueif len(cand_indexes) >= 1 and token.startswith("##"):cand_indexes[-1].append(i)else:cand_indexes.append([i])random.shuffle(cand_indexes)num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))masked_lms = []covered_indexes = set()for index_set in cand_indexes:if len(masked_lms) >= num_to_predict:break# If adding a whole-word mask would exceed the maximum number of# predictions, then just skip this candidate.if len(masked_lms) + len(index_set) > num_to_predict:continueis_any_index_covered = Falsefor index in index_set:if index in covered_indexes:is_any_index_covered = Truebreakif is_any_index_covered:continuefor index in index_set:covered_indexes.add(index)masked_lms.append(index)assert len(covered_indexes) == len(masked_lms)mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]return mask_labels

中文实现whole Word mask

参考原文

# coding=utf-8
'''Whole word mask for bert
'''
import pkuseg
from transformers import BertConfig, BertForMaskedLM, DataCollatorForWholeWordMask,\BertTokenizer, TrainingArguments, Trainer
from torch.utils.data import Dataset
from tqdm import tqdm
import torchclass My_wwm_pretrain_dataset(Dataset):def __init__(self, path, tokenizer, dup_factor=5,max_length=512): # dup_factor : dynamic mask for 5 timesself.examples = []with open(path,'r',encoding='utf-8') as f:total_data = f.readlines()with tqdm(total_data * dup_factor) as loader:for data in loader:# clean datadata = data.replace('\n', '').replace('\r', '').replace('\t','').replace(' ','').replace(' ', '')chinese_ref = self.get_new_segment(data)input_ids = tokenizer.encode_plus(data,truncation=True,max_length=max_length).input_idsdict_data = {'input_ids' : input_ids, 'chinese_ref' : chinese_ref}self.examples.append(dict_data)loader.set_description(f'loading data')def get_new_segment(self,segment):"""使用分词工具获取 whole word maske.g [喜,欢]-> [喜,##欢]"""seq_cws = seg.cut("".join(segment))  # 利用 pkuseg 进行医学领域分词chinese_ref = []index = 1for seq in seq_cws:for i, word in enumerate(seq):if i>0:chinese_ref.append(index)index +=1return chinese_refdef __getitem__(self, index):return self.examples[index]def __len__(self):return len(self.examples)if __name__ == '__main__':# configurationepoch = 100batch_size = 1pretrian_model = 'mc-bert-base'train_file = 'data/train.txt'save_epoch = 10 # every 10 epoch save checkpointbert_file = '../../pretrained_models/' + pretrian_modeltokenizer_model_path='../../pretrained_models/pkuseg_medical'#device = 'cuda' if torch.cuda.is_available() else 'cpu'seg = pkuseg.pkuseg(model_name=tokenizer_model_path)config = BertConfig.from_pretrained(bert_file)tokenizer = BertTokenizer.from_pretrained(bert_file)train_dataset = My_wwm_pretrain_dataset(train_file,tokenizer)model = BertForMaskedLM.from_pretrained(bert_file).to(device)print('No of parameters: ', model.num_parameters())data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)print('No. of lines: ', len(train_dataset))save_step = len(train_dataset) * save_epochtot_step = int(len(train_dataset)/batch_size *  epoch)print(f'\n\t***** Running training *****\n'f'\tNum examples = {len(train_dataset)}\n'f'\tNum Epochs = {epoch}\n'f'\tBatch size = {batch_size}\n'f'\tTotal optimization steps = {tot_step}\n')# official trainingtraining_args = TrainingArguments(output_dir='./outputs/',overwrite_output_dir=True,num_train_epochs=epoch,per_device_train_batch_size=batch_size,save_steps=save_step,)trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset,)trainer.train()trainer.save_model('pretrain_outputs/wwm/')

预训练语言模型复现-2 whole word mask相关推荐

  1. 预训练语言模型复现CPT-1Restructure_pretrain

    (1)CPT -pretrain CPT参数初始化,不是random initilized 是inference Robert的参数. roberta_zh/: Place the checkpoin ...

  2. 预训练语言模型(三):RNN和LSTM

    目录 RNN LSTM 参考一个很全的总结: 预训练语言模型的前世今生 - 从Word Embedding到BERT RNN部分参考了这个: 循环神经网络 LSTM部分参考了这两个: LSTM以及三重 ...

  3. 预训练语言模型(四):ELMo模型

    目录 ELMo模型 模型结构 公式 参考一个很全的总结: 预训练语言模型的前世今生 - 从Word Embedding到BERT ELMo也参考了这个: [NLP-13]ELMo模型(Embeddin ...

  4. 大模型系统和应用——Transformer预训练语言模型

    引言 最近在公众号中了解到了刘知远团队退出的视频课程<大模型交叉研讨课>,看了目录觉得不错,因此拜读一下. 观看地址: https://www.bilibili.com/video/BV1 ...

  5. yolov4 火灾检测,烟雾检测、 古文预训练语言模型等AI开源项目分享

    ~ 文末免费送书 ~ 项目一:FinBERT基于 BERT 架构的金融领域预训练语言模型 项目地址: https://github.com/valuesimplex/FinBERT 为了促进自然语言处 ...

  6. 赠书 | 一文了解预训练语言模型

    来源 | 博文视点 头图 | 下载于视觉中国 近年来,在深度学习和大数据的支撑下,自然语言处理技术迅猛发展.而预训练语言模型把自然语言处理带入了一个新的阶段,也得到了工业界的广泛关注. 通过大数据预训 ...

  7. 西湖大学蓝振忠:预训练语言模型的前沿发展趋势

    蓝振忠,西湖大学助理教授 报告 | 蓝振忠 撰文 | 沈磊贤 我的报告主题为<预训练语言模型的前沿发展趋势>,主要从以下三个方面展开: ‍1.为什么全网络预训练模型如此重要? 2.为什么语 ...

  8. 微软统一预训练语言模型UniLM 2.0解读

    微软研究院在2月底发布的UniLM 2.0(Unified Language Model \ 统一语言模型)文章,相比于19年上半年发布的UniLM 1.0,更加有效地整合了自然语言理解(NLU)与自 ...

  9. 首个金融领域的开源中文预训练语言模型FinBERT了解下

    背景及下载地址 为了促进自然语言处理技术在金融科技领域的应用和发展,熵简科技 AI Lab 近期开源了基于 BERT 架构的金融领域预训练语言模型 FinBERT 1.0.据我们所知,这是国内首个在金 ...

最新文章

  1. java 环境配置 maven 环境配置
  2. Go 高性能编程技法
  3. 图书漂流系统的设计和研究_研究在设计系统中的作用
  4. ROS机器人导航仿真(kinetic版本)
  5. 浅谈RSocket与响应式编程
  6. Mysql外键约束foreign key
  7. HDOJ(HDU) 2123 An easy problem(简单题...)
  8. 在浏览的地址栏中,直接调用js「javascript:alert(hello wrold);」。
  9. atitit.attilax.com产品 软件项目通用框架类库总结
  10. LightGBM 二元分类、多类分类、 Python的回归和分类器应用
  11. 计算机系统操作权限,已过GSP认证文件计算机系统操作权限确认表.docx
  12. Python爬取淘宝商品附加cookie修改
  13. java rf14bug_让云平台发生重大宕机事故的15个方法
  14. 计算机硬盘模式,详细教你bios设置硬盘模式
  15. html合并边框线,css中border-collapse属性设置表格边框线的方法
  16. 炫酷的ViewPager翻页动画
  17. 8月第4周基金排行榜 | TokenInsight
  18. 防止U盘中毒的一个小技巧
  19. pythongui界面实现爬取b站弹幕_Python爬虫自动化爬取b站实时弹幕实例方法
  20. matlab求市场清算价格不停,MATLAB数学实验(201516年第2学期)试题题目及答案,课程2020最新期末考试题库,章节测验答案...

热门文章

  1. 使用ggplot2画 点图、箱线图、小提琴图、蜂窝图、云雨图
  2. 踔厉奋发,笃行不怠 润和软件 HiHope 2021 OpenHarmony大事记
  3. 麦克马斯特大学计算机系褚令洋招收硕士生、博士生啦!
  4. Linux驱动笔试知识
  5. kotlin “${ }”是什么意思?
  6. Problem H. 小凡与英雄救美
  7. 实验作品集:百度随心听air桌面版
  8. 应用商店打开服务器错误,应用商店出错的修复方法
  9. 折叠屏手机阵亡,三星的手机梦会不会被彻底折叠?
  10. Matlab中switch, case, otherwise语句