tf.estimator.Estimator是TF比较高级的接口。

最近在使用bert预训练模型的时候用到了tf.estimator.Estimator。使用该接口的时候需要开发者完成的工作比较少,一共3个步骤:

第一步,设置input_fun,第二步,设置model_fun,第三步,开始训练。

第一步的input_fun完成的功能是数据的输入准备工作,比如读取一个tfrecord文件,然后解析里面的内容,返回dataset;或者读取音频、图像等数据,返回相应的结果,目前来说返回的结果为dataset格式比较好。

第二步的model_fun完成的功能有:创建模型(输入feature,输出predict这种),设置loss,设置优化器,返回结果是tf.estimator.EstimatorSpec。(后续会说明tf.estimator.EstimatorSpec是什么,怎么设置)

第三步的开始训练是:参数准备(比如学习率什么的,就是上面的步骤1-2中需要用到的参数),设置config(用于训练模型是指定模型的保存路径,多长时间保存一次模型,使用GPU的一些情况),开始根据情况调用estimator.train 和 estimator.evaluate 或者 estimator.predict。

第一步:input_fun

def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):"""每次调用,从TFRecord文件中读取一个大小为batch_size的batchArgs:filenames: TFRecord文件batch_size: batch_size大小num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止perform_shuffle: 是否乱序Returns:tensor格式的,一个batch的数据"""def _parse_fn(record):features = {"label": tf.FixedLenFeature([], tf.int64),"image": tf.FixedLenFeature([], tf.string),}parsed = tf.parse_single_example(record, features)# imageimage = tf.decode_raw(parsed["image"], tf.uint8)image = tf.reshape(image, [28, 28])# labellabel = tf.cast(parsed["label"], tf.int64)return {"image": image}, label# Extract lines from input files using the Dataset API, can pass one filename or filename listdataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch# Randomizes input using a window of 256 elements (read into memory)if perform_shuffle:dataset = dataset.shuffle(buffer_size=256)# epochs from blending together.dataset = dataset.repeat(num_epochs)dataset = dataset.batch(batch_size)  # Batch size to useiterator = dataset.make_one_shot_iterator()batch_features, batch_labels = iterator.get_next()return batch_features, batch_labels

第二步:model_fun

def model_fn(features, labels, mode, params):""":param features::param labels::param mode: 指定训练、验证和测试三种模式tf.estimator.ModeKeys.TRAIN    tf.estimator.ModeKeys.EVAL  tf.estimator.ModeKeys.PREDICT:param params: 包含学习率等超参数的设计:return:"""# step1: 构建模型logits = create_model(features)predict = tf.nn.softmax(logits, axis=-1)# step2: 构建loss、optimization等loss = get_loss(logits, labels)train_op = tf.train.GradientDescentOptimizer(params['lr']).minimize(loss)# step3: 根据mode,构建不同情况下的tf.estimator.EstimatorSpec# For mode == ModeKeys.TRAIN: 需要的参数是 loss and train_op.# For mode == ModeKeys.EVAL:  需要的参数是  loss.# For mode == ModeKeys.PREDICT: 需要的参数是 predictions.if mode == tf.estimator.ModeKeys.TRAIN:# logging_hook是模型训练/测试的工具,主要执行特定的任务,如判断是否需要停止训练的EarlyStopping,# 改变学习速率的LearningRateScheduler,共性就是在每个step开始/结束或者每个epoch开始/结束时需要执行某个操作。output_spec = tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,training_hooks=[logging_hook])elif mode == tf.estimator.ModeKeys.EVAL:output_spec = tf.estimator.EstimatorSpec(mode=mode,loss=loss,eval_metric_ops=eval_metrics)else:output_spec = tf.estimator.EstimatorSpec(mode=mode,predictions={"probabilities": predict})return output_spec

第三步:main

def main_():# 1. 设置超参数params = {'lr', 0.0001}# 2. 设置config,用于控制模型保存的位置,多久保存一次等session_config = tf.ConfigProto(log_device_placement=False,inter_op_parallelism_threads=0,intra_op_parallelism_threads=0,allow_soft_placement=True)run_config = tf.estimator.RunConfig(model_dir=model_output_dir,save_checkpoints_steps=5000,keep_checkpoint_max=3,session_config=session_config)# 3. 开始训练estimator = tf.estimator.Estimator(model_fn=model_fn,config=run_config,params=params)if do_train:train_input_fn = input_fun(...)estimator.train(input_fn=train_input_fn)elif do_eval:eval_input_fn = input_fun(...)estimator.train(input_fn=eval_input_fn)else:predict_input_fn = input_fun(...)estimator.train(input_fn=predict_input_fn)

===未完待续===

之后会更新关于hook等如何设置

参考文献:

https://zhuanlan.zhihu.com/p/129018863

https://zhuanlan.zhihu.com/p/106400162

https://www.jianshu.com/p/5495f87107e7

tf.estimator.Estimator的使用相关推荐

  1. tf.estimator.Estimator解析

    Estimator类代表了一个模型,以及如何对这个模型进行训练和评估, class Estimator(builtins.object) 可以按照下面方式创建一个E def resnet_v1_10_ ...

  2. tf.estimator.Estimator讲解

    tf.estimator.Estimator 简单介绍 是一个class 所以需要初始化,作用是用来 训练和评价 tensorflow 模型的 Estimator对象包装由一个名为model_fn函数 ...

  3. [tensorflow]tf.estimator.Estimator构建tensorflow模型

    目录 一.Estimator简介 二.数据集 三.定义特征列 四.estimator创建模型 五.模型训练.评估和预测 六.模型保存和恢复 一.Estimator简介 Estimator是Tensor ...

  4. Tensorflow API 讲解——tf.estimator.Estimator

    class Estimator(builtins.object) #介绍 Estimator 类,用来训练和验证 TensorFlow 模型. Estimator 对象包含了一个模型 model_fn ...

  5. tf.estimator的用法

    tf.estimator的用法 利用 tf.estimator 训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn),另一个用于模型创建的函数(model_fn).下面逐一来说明 ...

  6. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  7. tf.estimator用法

    estimator:估算器 tf.estimator -----一种高级TensorFlow API.估算器封装以下操作: 训练(training) 评价(evaluation) 预测(predict ...

  8. tf.estimator API技术手册(16)——自定义Estimator

    tf.estimator API技术手册(16)--自定义Estimator (一)前 言 (二)自定义estimator的一般步骤 (三)准备训练数据 (四)自定义estimator实践 (1)创建 ...

  9. tf.estimator使用入门

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 6.6 tf.estimator使用入门 学习目标 目标 知道 ...

最新文章

  1. 扩增子图表解读8网络图:节点OTU或类Venn比较
  2. Intel汇编语言程序设计学习-第六章 条件处理-中
  3. ESP32的FLASH、NVS、SPIFFS、OTA等存储分布以及启动过程
  4. quartz可以指定方法名吗_大理石可以自己抛光吗?大理石自己抛光方法解答
  5. java comparator相等_详解Java中Comparable和Comparator接口的区别
  6. 怎么样才算是精通 JavaScript?
  7. mysql5.7 非gtid同步
  8. Jenkins 使用slave管理进行持续集成测试说明
  9. Android Binder 分析——匿名共享内存(好文)
  10. Java 常用正则表达式搜集ing
  11. 软件测试日志怎么写,为什么要进行日志测试和如何进行日志测试?
  12. 计算机动画推导,AE表达式实现逼真弹性动画
  13. 【锐捷交换机】清除密码
  14. php laravel mix,Laravel前端工程化之mix
  15. 一文带你了解redux的工作流程——action/reducer/store
  16. GYM CERC 16 K Key Knocking 构造
  17. java系统_Java 系统
  18. 湖南大学计算机学硕推免率,好几个帖子都在讨论清北华五的推免生源我来发一下b类大学湖大今...
  19. C# 进行 Starlink 仿真03:72轨道面 * 22颗卫星 F相位因子==11 的Walker星座,创建3168条星间链路,并与 icarus 论文的Python结果相对比。
  20. 【虚幻引擎】UE4/UE5 后期处理盒子(PostProcessVolume)

热门文章

  1. div布局改进treeview导航
  2. 直流电机系统模型识别
  3. 与贸易有关的知识产权协议 (转)
  4. 手动清除网卡IP的Linux命令
  5. 网站反爬虫策略VS反反爬虫策略
  6. 指针学习中二维数组解引用问题
  7. Lattice Mico8在LMS创建一个工程和创建LED程序
  8. 组播地址MAC的计算
  9. 【微信小程序】判断手机号是否合法
  10. oracle之动态sql