目录
前言
源码解析
主函数
自定义模型
遮蔽词预测
下一句预测
规范化数据集
前言
本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨。BERT针对两个任务同时训练。1.下一句预测。2.遮蔽词识别
下面介绍BERT的预训练模型run_pretraining.py是怎么训练的。

源码解析
主函数
训练过程主要用了estimator调度器。这个调度器支持自定义训练过程,将训练集传入之后自动训练。详情见注释

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)

if not FLAGS.do_train and not FLAGS.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")

bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

tf.gfile.MakeDirs(FLAGS.output_dir)

input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))

tf.logging.info("*** Input Files ***")
for input_file in input_files:
tf.logging.info(" %s" % input_file)

tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig( #训练参数
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))

model_fn = model_fn_builder( #自定义模型,用于estimator训练
bert_config=bert_config,
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=FLAGS.num_train_steps,
num_warmup_steps=FLAGS.num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)

# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.contrib.tpu.TPUEstimator( #创建TPUEstimator
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)

if FLAGS.do_train: #训练过程
tf.logging.info("***** Running training *****")
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
train_input_fn = input_fn_builder( #创建输入训练集
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=True)
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)

if FLAGS.do_eval: #验证过程
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)

eval_input_fn = input_fn_builder( #创建验证集
input_files=input_files,
max_seq_length=FLAGS.max_seq_length,
max_predictions_per_seq=FLAGS.max_predictions_per_seq,
is_training=False)

result = estimator.evaluate(
input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)

output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
with tf.gfile.GFile(output_eval_file, "w") as writer:
tf.logging.info("***** Eval results *****")
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
自定义模型
首先获取数据内容,传入到上一篇定义的模型中。对下一句预测任务取出模型的[CLS]结果。对遮蔽词预测任务取出模型的最后结果。然后分别计算loss值,最后将loss值相加。详情见注释

def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
"""Returns `model_fn` closure for TPUEstimator."""

def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""

tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
#获取数据内容
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
masked_lm_positions = features["masked_lm_positions"]
masked_lm_ids = features["masked_lm_ids"]
masked_lm_weights = features["masked_lm_weights"]
next_sentence_labels = features["next_sentence_labels"]

is_training = (mode == tf.estimator.ModeKeys.TRAIN)
传入到Bert模型中。
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
#遮蔽预测的batch_loss,平均loss,预测概率矩阵
(masked_lm_loss,
masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
bert_config, model.get_sequence_output(), model.get_embedding_table(),
masked_lm_positions, masked_lm_ids, masked_lm_weights)
#下一句预测的batch_loss,平均loss,预测概率矩阵
(next_sentence_loss, next_sentence_example_loss,
next_sentence_log_probs) = get_next_sentence_output(
bert_config, model.get_pooled_output(), next_sentence_labels)
#loss相加
total_loss = masked_lm_loss + next_sentence_loss
#获取所有变量
tvars = tf.trainable_variables()

initialized_variable_names = {}
scaffold_fn = None
#如果有之前保存的模型
if init_checkpoint:
(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
if use_tpu:

def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()

scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

tf.logging.info("**** Trainable Variables ****")
#如果有之前保存的模型
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)

output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer( #自定义好的优化器
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)

output_spec = tf.contrib.tpu.TPUEstimatorSpec( #Estimator要求返回一个EstimatorSpec对象
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
#验证过程
elif mode == tf.estimator.ModeKeys.EVAL:

def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_log_probs, next_sentence_labels):
"""Computes the loss and accuracy of the model."""
masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
[-1, masked_lm_log_probs.shape[-1]]) #概率矩阵转成[batch_size*max_pred_pre_seq,vocab_size]
masked_lm_predictions = tf.argmax(
masked_lm_log_probs, axis=-1, output_type=tf.int32) #取最大值位置为输出
masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) #每句loss列表 [batch_size*max_pred_per_seq]
masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
masked_lm_accuracy = tf.metrics.accuracy( #计算准确率
labels=masked_lm_ids,
predictions=masked_lm_predictions,
weights=masked_lm_weights)
masked_lm_mean_loss = tf.metrics.mean( #计算平均loss
values=masked_lm_example_loss, weights=masked_lm_weights)

next_sentence_log_probs = tf.reshape(
next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
next_sentence_predictions = tf.argmax( #获取最大位置为输出
next_sentence_log_probs, axis=-1, output_type=tf.int32)
next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
next_sentence_accuracy = tf.metrics.accuracy( #计算准确率
labels=next_sentence_labels, predictions=next_sentence_predictions)
next_sentence_mean_loss = tf.metrics.mean( 计算平均loss
values=next_sentence_example_loss)

return {
"masked_lm_accuracy": masked_lm_accuracy,
"masked_lm_loss": masked_lm_mean_loss,
"next_sentence_accuracy": next_sentence_accuracy,
"next_sentence_loss": next_sentence_mean_loss,
}

eval_metrics = (metric_fn, [
masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
masked_lm_weights, next_sentence_example_loss,
next_sentence_log_probs, next_sentence_labels
])
output_spec = tf.contrib.tpu.TPUEstimatorSpec( #Estimator要求返回一个EstimatorSpec对象
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
else:
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

return output_spec

return model_fn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
遮蔽词预测
输入BERT模型的最后一层encoder,输出遮蔽词预测任务的loss和概率矩阵。详情见注释

def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
label_ids, label_weights):
#这里的input_tensor是模型中传回的最后一层结果 [batch_size,seq_length,hidden_size]。
#output_weights是词向量表 [vocab_size,embedding_size]
"""Get loss and log probs for the masked LM."""
#获取positions位置的所有encoder(即要预测的那些位置的encoder)
input_tensor = gather_indexes(input_tensor, positions) #[batch_size*max_pred_pre_seq,hidden_size]

with tf.variable_scope("cls/predictions"):
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense( #传入一个全连接层 输出shape [batch_size*max_pred_pre_seq,hidden_size]
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size*max_pred_pre_seq,vocab_size]
logits = tf.nn.bias_add(logits, output_bias) #加bias
log_probs = tf.nn.log_softmax(logits, axis=-1) #[batch_size*max_pred_pre_seq,vocab_size]

label_ids = tf.reshape(label_ids, [-1]) #[batch_size*max_pred_per_seq]
label_weights = tf.reshape(label_weights, [-1])

one_hot_labels = tf.one_hot( #[batch_size*max_pred_per_seq,vocab_size]
label_ids, depth=bert_config.vocab_size, dtype=tf.float32) #label id转one hot

# The `positions` tensor might be zero-padded (if the sequence is too
# short to have the maximum number of predictions). The `label_weights`
# tensor has a value of 1.0 for every real prediction and 0.0 for the
# padding predictions.
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) #[batch_size*max_pred_per_seq]
numerator = tf.reduce_sum(label_weights * per_example_loss) #[1] 一个batch的loss
denominator = tf.reduce_sum(label_weights) + 1e-5
loss = numerator / denominator #平均loss

return (loss, per_example_loss, log_probs)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
下一句预测
输入BERT模型CLS的encoder,输出下一句预测任务的loss和概率矩阵,详情见注释

def get_next_sentence_output(bert_config, input_tensor, labels):
#input_tensor shape [batch_size,hidden_size]
"""Get loss and log probs for the next sentence prediction."""

# Simple binary classification. Note that 0 is "next sentence" and 1 is
# "random sentence". This weight matrix is not used after pre-training.
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, bert_config.hidden_size],
initializer=modeling.create_initializer(bert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer()) #[batch_size,hidden_size]

logits = tf.matmul(input_tensor, output_weights, transpose_b=True) #[batch_size,2]
logits = tf.nn.bias_add(logits, output_bias) #[batch_size,2]
log_probs = tf.nn.log_softmax(logits, axis=-1)
labels = tf.reshape(labels, [-1])
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) #[batch_size,2]
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) #[batch_size]
loss = tf.reduce_mean(per_example_loss) #[1]
return (loss, per_example_loss, log_probs)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
规范化数据集
Estimator要求模型的输入为特定格式(from_tensor_slices),所以要对数据进行类封装

def input_fn_builder(input_files,
max_seq_length,
max_predictions_per_seq,
is_training,
num_cpu_threads=4):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""

def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]

name_to_features = {
"input_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"input_mask":
tf.FixedLenFeature([max_seq_length], tf.int64),
"segment_ids":
tf.FixedLenFeature([max_seq_length], tf.int64),
"masked_lm_positions":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_ids":
tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
"masked_lm_weights":
tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
"next_sentence_labels":
tf.FixedLenFeature([1], tf.int64),
}

# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
d = d.repeat() #重复
d = d.shuffle(buffer_size=len(input_files)) #打乱

# `cycle_length` is the number of parallel files that get read.
cycle_length = min(num_cpu_threads, len(input_files))

# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
d = d.apply(
tf.contrib.data.parallel_interleave( #生成嵌套数据集,并且输出其元素隔行交错
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
d = d.shuffle(buffer_size=100)
else:
d = tf.data.TFRecordDataset(input_files)
# Since we evaluate for a fixed number of steps we don't want to encounter
# out-of-range exceptions.
d = d.repeat()

# We must `drop_remainder` on training because the TPU requires fixed
# size dimensions. For eval, we assume we are evaluating on the CPU or GPU
# and we *don't* want to drop the remainder, otherwise we wont cover
# every sample.
d = d.apply(
tf.contrib.data.map_and_batch( #结构转换
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_cpu_threads,
drop_remainder=True))
return d

return input_fn
---------------------
作者:保持一份率性
来源:CSDN
原文:https://blog.csdn.net/weixin_39470744/article/details/84619903
版权声明:本文为博主原创文章,转载请附上博文链接!

谷歌BERT预训练源码解析(三):训练过程相关推荐

  1. Disruptor源码解析三 RingBuffer解析

    目录 系列索引 前言 主要内容 RingBuffer的要点 源码解析 系列索引 Disruptor源码解析一 Disruptor高性能之道 Disruptor源码解析二 Sequence相关类解析 D ...

  2. 谷歌BERT预训练源码解析(一):训练数据生成

    目录 预训练源码结构简介 输入输出 源码解析 参数 主函数 创建训练实例 下一句预测&实例生成 随机遮蔽 输出 结果一览 预训练源码结构简介 关于BERT,简单来说,它是一个基于Transfo ...

  3. 谷歌BERT预训练源码解析(二):模型构建

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_39470744/arti ...

  4. OkHttp3源码解析(三)——连接池复用

    OKHttp3源码解析系列 OkHttp3源码解析(一)之请求流程 OkHttp3源码解析(二)--拦截器链和缓存策略 本文基于OkHttp3的3.11.0版本 implementation 'com ...

  5. ReactiveSwift源码解析(三) Signal代码的基本实现

    上篇博客我们详细的聊了ReactiveSwift源码中的Bag容器,详情请参见<ReactiveSwift源码解析之Bag容器>.本篇博客我们就来聊一下信号量,也就是Signal的的几种状 ...

  6. 并发编程与源码解析 (三)

    并发编程 (三) 1 Fork/Join分解合并框架 1.1 什么是fork/join ​ Fork/Join框架是JDK1.7提供的一个用于并行执行任务的框架,开发者可以在不去了解如Thread.R ...

  7. 前端入门之(vuex源码解析三)

    上两节前端入门之(vuex源码解析二)我们把vuex的源码大概的撸了一遍,还剩下(插件.getters跟module),我们继续哈~ 插件童鞋们可以去看看vuex在各个浏览器的状态显示插件,小伙伴可以 ...

  8. 拆轮子-RxDownload2源码解析(三)

    本文为博主原创文章,未经允许不得转载 造轮子者:Season_zlc 轮子用法请戳作者链接 ↑ 前言 本文主要讲述 RxDownload2 的多线程断点下载技术. 断点下载技术前提 服务器必须支持按 ...

  9. Spring 源码解析 - Bean创建过程 以及 解决循环依赖

    一.Spring Bean创建过程以及循环依赖 上篇文章对 Spring Bean资源的加载注册过程进行了源码梳理和解析,我们可以得到结论,资源文件中的 bean 定义信息,被组装成了 BeanDef ...

最新文章

  1. 工作笔记-2019.7.8
  2. WTMPlus 低代码平台来了
  3. 数据库设计三大范式详解
  4. 韩梦飞沙Android应用集合 想法
  5. Windows Server 2008 R2之三十八 Hyper-V的授权管理
  6. SDN(软件定义网络)详解
  7. App Store 审核指南 审核失败对照
  8. 抗击肺炎,我们能做到的,就是别让爱隔离——python分析B站三个视频弹幕内容,云图数据。
  9. ADAMoracle预言机将数据传至链上实现区块链落地应用
  10. VR基础——PicoVR SDK接入及使用整理
  11. python语言编写爬虫_自写Python小爬虫一个 - 『编程语言区』 - 吾爱破解 - LCG - LSG |安卓破解|病毒分析|www.52pojie.cn...
  12. 近期全球知识图谱相关行业动态、会议讲座、综述推荐
  13. Servlet 执行原理
  14. 计算机考研网课平台哪个好,考研网课哪家排名好
  15. linux read 少读末尾一行的问题
  16. 考研日语线上笔记(四):中级日语语法总结20课(1~10)
  17. 银凤湖公园项目-工业矿坑变公园 | 用科技与艺术让城市“绽放”
  18. doris insert数据时出现问题:Invalid floating-point literal
  19. 设计模式-单一职责原则-实践运用
  20. ABP官方文档翻译 1.3 模块系统

热门文章

  1. 2022-2028年中国醋酸行业投资分析及前景预测报告
  2. 微服务架构必备的几点知识
  3. intellij idea 常见遇到的问题整理
  4. 格式化_icecream_python
  5. 通俗易懂word2vec详解词嵌入-深度学习
  6. LeetCode简单题之有序数组中出现次数超过25%的元素
  7. 华为计算平台MDC810发布量产
  8. 英特尔® 至强® 平台集成 AI 加速构建数据中心智慧网络
  9. YOLOv3和YOLOv4长篇核心综述(下)
  10. VsCode 开发工具中英文切换