
  • 上代码: 英文 Bert-WWM数据创建方法
def create_masked_lm_predictions(tokens, masked_lm_prob,max_predictions_per_seq, vocab_words, rng):"""Creates the predictions for the masked LM objective.params:tokens: input ids;masked_lm_prob: masked prob(how many tokens will be masked, default 0.15)max_predictions_per_seq: a numbervocab_words: vocabrng: random module"""cand_indexes = []for (i, token) in enumerate(tokens):if token == "[CLS]" or token == "[SEP]":continue# Whole Word Masking means that if we mask all of the wordpieces# corresponding to an original word. When a word has been split into# WordPieces, the first token does not have any marker and any subsequence# tokens are prefixed with ##. So whenever we see the ## token, we# append it to the previous set of word indexes.## Note that Whole Word Masking does *not* change the training code# at all -- we still predict each WordPiece independently, softmaxed# over the entire vocabulary.if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 andtoken.startswith("##")):# 如果WWM,则记录每个单词的所有sub token,如此达到对整个单词所有sub token进行广义maskcand_indexes[-1].append(i)else:cand_indexes.append([i])rng.shuffle(cand_indexes)  # shuffle再取前15%的词进行8-1-1处理,表示随机取15%的词进行maskoutput_tokens = list(tokens)num_to_predict = min(max_predictions_per_seq,max(1, int(round(len(tokens) * masked_lm_prob))))masked_lms = []covered_indexes = set()for index_set in cand_indexes:  # 对整个词或者单个字符if len(masked_lms) >= num_to_predict:break# If adding a whole-word mask would exceed the maximum number of# predictions, then just skip this candidate.if len(masked_lms) + len(index_set) > num_to_predict:continueis_any_index_covered = Falsefor index in index_set:if index in covered_indexes:is_any_index_covered = Truebreakif is_any_index_covered:continuefor index in index_set:covered_indexes.add(index)masked_token = None# 80% of the time, replace with [MASK] 8-1-1mask策略if rng.random() < 0.8:masked_token = "[MASK]"else:# 10% of the time, keep originalif rng.random() < 0.5:masked_token = tokens[index]# 10% of the time, replace with random wordelse:masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]output_tokens[index] = masked_token# 记录labelmasked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))assert len(masked_lms) <= num_to_predictmasked_lms = sorted(masked_lms, key=lambda x: x.index)masked_lm_positions = []masked_lm_labels = []for p in masked_lms:masked_lm_positions.append(p.index)masked_lm_labels.append(p.label)return (output_tokens, masked_lm_positions, masked_lm_labels)
  • 中文Bert-WWM预训练数据创建方法(from ymcui)

