tf.estimator.Estimator的使用
tf.estimator.Estimator是TF比较高级的接口。
最近在使用bert预训练模型的时候用到了tf.estimator.Estimator。使用该接口的时候需要开发者完成的工作比较少,一共3个步骤:
第一步,设置input_fun,第二步,设置model_fun,第三步,开始训练。
第一步的input_fun完成的功能是数据的输入准备工作,比如读取一个tfrecord文件,然后解析里面的内容,返回dataset;或者读取音频、图像等数据,返回相应的结果,目前来说返回的结果为dataset格式比较好。
第二步的model_fun完成的功能有:创建模型(输入feature,输出predict这种),设置loss,设置优化器,返回结果是tf.estimator.EstimatorSpec。(后续会说明tf.estimator.EstimatorSpec是什么,怎么设置)
第三步的开始训练是:参数准备(比如学习率什么的,就是上面的步骤1-2中需要用到的参数),设置config(用于训练模型是指定模型的保存路径,多长时间保存一次模型,使用GPU的一些情况),开始根据情况调用estimator.train 和 estimator.evaluate 或者 estimator.predict。
第一步:input_fun
def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):"""每次调用,从TFRecord文件中读取一个大小为batch_size的batchArgs:filenames: TFRecord文件batch_size: batch_size大小num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止perform_shuffle: 是否乱序Returns:tensor格式的,一个batch的数据"""def _parse_fn(record):features = {"label": tf.FixedLenFeature([], tf.int64),"image": tf.FixedLenFeature([], tf.string),}parsed = tf.parse_single_example(record, features)# imageimage = tf.decode_raw(parsed["image"], tf.uint8)image = tf.reshape(image, [28, 28])# labellabel = tf.cast(parsed["label"], tf.int64)return {"image": image}, label# Extract lines from input files using the Dataset API, can pass one filename or filename listdataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch# Randomizes input using a window of 256 elements (read into memory)if perform_shuffle:dataset = dataset.shuffle(buffer_size=256)# epochs from blending together.dataset = dataset.repeat(num_epochs)dataset = dataset.batch(batch_size) # Batch size to useiterator = dataset.make_one_shot_iterator()batch_features, batch_labels = iterator.get_next()return batch_features, batch_labels
第二步:model_fun
def model_fn(features, labels, mode, params):""":param features::param labels::param mode: 指定训练、验证和测试三种模式tf.estimator.ModeKeys.TRAIN tf.estimator.ModeKeys.EVAL tf.estimator.ModeKeys.PREDICT:param params: 包含学习率等超参数的设计:return:"""# step1: 构建模型logits = create_model(features)predict = tf.nn.softmax(logits, axis=-1)# step2: 构建loss、optimization等loss = get_loss(logits, labels)train_op = tf.train.GradientDescentOptimizer(params['lr']).minimize(loss)# step3: 根据mode,构建不同情况下的tf.estimator.EstimatorSpec# For mode == ModeKeys.TRAIN: 需要的参数是 loss and train_op.# For mode == ModeKeys.EVAL: 需要的参数是 loss.# For mode == ModeKeys.PREDICT: 需要的参数是 predictions.if mode == tf.estimator.ModeKeys.TRAIN:# logging_hook是模型训练/测试的工具,主要执行特定的任务,如判断是否需要停止训练的EarlyStopping,# 改变学习速率的LearningRateScheduler,共性就是在每个step开始/结束或者每个epoch开始/结束时需要执行某个操作。output_spec = tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,training_hooks=[logging_hook])elif mode == tf.estimator.ModeKeys.EVAL:output_spec = tf.estimator.EstimatorSpec(mode=mode,loss=loss,eval_metric_ops=eval_metrics)else:output_spec = tf.estimator.EstimatorSpec(mode=mode,predictions={"probabilities": predict})return output_spec
第三步:main
def main_():# 1. 设置超参数params = {'lr', 0.0001}# 2. 设置config,用于控制模型保存的位置,多久保存一次等session_config = tf.ConfigProto(log_device_placement=False,inter_op_parallelism_threads=0,intra_op_parallelism_threads=0,allow_soft_placement=True)run_config = tf.estimator.RunConfig(model_dir=model_output_dir,save_checkpoints_steps=5000,keep_checkpoint_max=3,session_config=session_config)# 3. 开始训练estimator = tf.estimator.Estimator(model_fn=model_fn,config=run_config,params=params)if do_train:train_input_fn = input_fun(...)estimator.train(input_fn=train_input_fn)elif do_eval:eval_input_fn = input_fun(...)estimator.train(input_fn=eval_input_fn)else:predict_input_fn = input_fun(...)estimator.train(input_fn=predict_input_fn)
===未完待续===
之后会更新关于hook等如何设置
参考文献:
https://zhuanlan.zhihu.com/p/129018863
https://zhuanlan.zhihu.com/p/106400162
https://www.jianshu.com/p/5495f87107e7
tf.estimator.Estimator的使用相关推荐
- tf.estimator.Estimator解析
Estimator类代表了一个模型,以及如何对这个模型进行训练和评估, class Estimator(builtins.object) 可以按照下面方式创建一个E def resnet_v1_10_ ...
- tf.estimator.Estimator讲解
tf.estimator.Estimator 简单介绍 是一个class 所以需要初始化,作用是用来 训练和评价 tensorflow 模型的 Estimator对象包装由一个名为model_fn函数 ...
- [tensorflow]tf.estimator.Estimator构建tensorflow模型
目录 一.Estimator简介 二.数据集 三.定义特征列 四.estimator创建模型 五.模型训练.评估和预测 六.模型保存和恢复 一.Estimator简介 Estimator是Tensor ...
- Tensorflow API 讲解——tf.estimator.Estimator
class Estimator(builtins.object) #介绍 Estimator 类,用来训练和验证 TensorFlow 模型. Estimator 对象包含了一个模型 model_fn ...
- tf.estimator的用法
tf.estimator的用法 利用 tf.estimator 训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn),另一个用于模型创建的函数(model_fn).下面逐一来说明 ...
- 机器学习笔记5-Tensorflow高级API之tf.estimator
前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...
- tf.estimator用法
estimator:估算器 tf.estimator -----一种高级TensorFlow API.估算器封装以下操作: 训练(training) 评价(evaluation) 预测(predict ...
- tf.estimator API技术手册(16)——自定义Estimator
tf.estimator API技术手册(16)--自定义Estimator (一)前 言 (二)自定义estimator的一般步骤 (三)准备训练数据 (四)自定义estimator实践 (1)创建 ...
- tf.estimator使用入门
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 6.6 tf.estimator使用入门 学习目标 目标 知道 ...
最新文章
- 扩增子图表解读8网络图:节点OTU或类Venn比较
- Intel汇编语言程序设计学习-第六章 条件处理-中
- ESP32的FLASH、NVS、SPIFFS、OTA等存储分布以及启动过程
- quartz可以指定方法名吗_大理石可以自己抛光吗?大理石自己抛光方法解答
- java comparator相等_详解Java中Comparable和Comparator接口的区别
- 怎么样才算是精通 JavaScript?
- mysql5.7 非gtid同步
- Jenkins 使用slave管理进行持续集成测试说明
- Android Binder 分析——匿名共享内存(好文)
- Java 常用正则表达式搜集ing
- 软件测试日志怎么写,为什么要进行日志测试和如何进行日志测试?
- 计算机动画推导,AE表达式实现逼真弹性动画
- 【锐捷交换机】清除密码
- php laravel mix,Laravel前端工程化之mix
- 一文带你了解redux的工作流程——action/reducer/store
- GYM CERC 16 K Key Knocking 构造
- java系统_Java 系统
- 湖南大学计算机学硕推免率,好几个帖子都在讨论清北华五的推免生源我来发一下b类大学湖大今...
- C# 进行 Starlink 仿真03:72轨道面 * 22颗卫星 F相位因子==11 的Walker星座,创建3168条星间链路,并与 icarus 论文的Python结果相对比。
- 【虚幻引擎】UE4/UE5 后期处理盒子(PostProcessVolume)