tf.estimator的用法

利用 tf.estimator 训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn),另一个用于模型创建的函数(model_fn)。下面逐一来说明。(数据格式采用 TFRecord)。

首先我们从输入到输出调用顺序来介绍一下大概的训练过程(完整官方文档:tf.estimator):

定义输入函数 input_fn,(就是先把图片、标签打包成一个dataset形式)返回如下两种格式之一:

tf.data.Dataset 对象:这个对象的输出必须是元组队 (features, labels),而且必须满足下一条返回格式的同等约束;
元组 (features, labels):features 以及 labels 都必须是一个张量或由张量组成的字典。
下面是一个例子:

def train(data_dir):#遍历出图片路径列表paths = walk_type(data_dir +r'\train\\*\\','*.bmp')#遍历出标签列表labels = []for path in paths:label = int(path[14:16])labels.append(label)# 图片路径列表转tensor常量filenames = tf.constant(paths)#<tf.Tensor 'args_0:0' shape=() dtype=string># 标签列表转tensor常量labels = tf.constant(labels)# tensor常量转datasetdataset = tf.data.Dataset.from_tensor_slices((filenames, labels))# <DatasetV1Adapter shapes: ((), ()), types: (tf.string, tf.int32)># 此时dataset中的一个元素是(image_resized, label)dataset = dataset.map(_parse_function)#<DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)>return dataset
def train_input_fn(data_dir, # dataparams #{'learning_rate': 0.001, 'batch_size': 1, 'num_epochs': 20, 'num_channels': 32, 'use_batch_norm': False, 'bn_momentum': 0.9, 'margin': 0.5, 'embedding_size': 64, 'triplet_strategy': 'batch_all', 'squared': False, 'image_size': 28, 'num_labels': 10, 'train_size': 50000, 'eval_size': 10000, 'num_parallel_calls': 4, 'save_summary_steps': 50}):# 把data_dir数据集中的image和label打包成元组张量dataset = img_label_to_dataset.train(data_dir # data) #<DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)># 打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。单位是以图片(张量)为单位,而不是byte;dataset = dataset.shuffle(params['train_size'] # 50000)  # <DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)># 将整个数据集重复 params['num_epochs'] 次dataset = dataset.repeat(params['num_epochs'] # 20)  # <DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)># 将 params['batch_size'] 个元素组合成batchdataset = dataset.batch(params['batch_size'] # 1) # <DatasetV1Adapter shapes: ((?, 64, 64, ?), (?,)), types: (tf.float32, tf.int32)># 预先载入(并行运算,加快数据运算速度)dataset = dataset.prefetch(1)  # <DatasetV1Adapter shapes: ((?, 64, 64, ?), (?,)), types: (tf.float32, tf.int32)>return dataset

tf.estimator.EstimatorSpec 的完整形式是:

tf.estimator.EstimatorSpec(mode,                       #指定当前是处于训练、验证还是预测状态predictions=None,           #是预测的一个张量,或者是由张量组成的一个字典loss=None,                  #是损失张量train_op=None,              #指定优化操作eval_metric_ops=None,       #指定各种评估度量的字典export_outputs=None,        #参数 export_outputs 只用于模型保存,描述了导出到 SavedModel 的输出格式training_chief_hooks=None,training_hooks=None,scaffold=None,              #是一个 tf.train.Scaffold 对象,可以在训练阶段初始化、保存等时使用。evaluation_hooks=None,prediction_hooks=None)

定义模型函数 model_fn,返回类 tf.estimator.EstimatorSpec 的一个实例。model_fn 的完整定义形式是(函数名任取):

def model_fn(
features,    #从input_fn中传入
labels,      #从input_fn中传入
mode,        #指定训练模式,可以取 (TRAIN, EVAL, PREDICT)三者之一
params=None  #是一个(可要可不要的)字典,指定其它超参数。):params = params or {}loss, train_op, ... = None, None, ...prediction_dict = ...if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):loss = ...#必填项(损失函数)if mode == tf.estimator.ModeKeys.TRAIN:train_op = ...#必填项(训练图)if mode == tf.estimator.ModeKeys.PREDICT:   predictions = ...#必填项(预测结果)return tf.estimator.EstimatorSpec(mode=mode,predictions=prediction_dict,  #预测结果loss=loss,                    #损失函数train_op=train_op,            #训练图...)

使用 tf.estimator.TrainSpec 指定训练输入函数及相关参数。该类的完整形式是:

train_spec = tf.estimator.TrainSpec(
input_fn,    #提供训练时的输入数据
max_steps,   #指定总共训练多少步
hooks        #一个 tf.train.SessionRunHook 对象,用来配置分布式训练等参数。)

使用 tf.estimator.EvalSpec 指定验证输入函数及相关参数。该类的完整形式是:

EvalSpec = tf.estimator.EvalSpec(input_fn,              #用来提供验证时的输入数据steps=100,             #指定总共验证多少步(一般设定为 None 即可)name=None,             hooks=None,            #用来配置分布式训练等参数           exporters=None,        #Exporter 迭代器,会参与到每次的模型验证start_delay_secs=120,  #指定多少秒之后开始模型验证throttle_secs=600      #指定多少秒之后重新开始新一轮模型验证 )

使用 tf.estimator.Estimator 定义 Estimator 实例 estimator。类 Estimator 的完整形式是:

estimator = tf.estimator.Estimator(
model_fn,              #模型函数
model_dir=None,        #训练时模型保存的路径
config=None,           #tf.estimator.RunConfig 的配置对象
params=None,           #传入 model_fn 的超参数字典
warm_start_from=None   #或者是一个预训练文件的路径,或者是一个 tf.estimator.WarmStartSettings 对象,用于完整的配置热启动参数)

下面就有两种调用方法:

  1. 在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源

    estimator.train()

    # Train the Model.
    Estimator.train(input_fn = lambda: train_input_fn(args.data_dir,params ))
    

    estimator.evaluate()

    # Evaluate the model.
    eval_result = classifier.evaluate(input_fn = lambda: train_input_fn(args.data_dir,params ))
    

    estimator.predict()

    predictions = classifier.predict(input_fn=lambda:train_input_fn(args.data_dir,params ))
    
  2. 使用 tf.estimator.train_and_evaluate 启动训练和验证过程。该函数的完整形式是:

    tf.estimator.train_and_evaluate(
    estimator,  #tf.estimator.Estimator 对象,用于指定模型函数以及其它相关参数
    train_spec, #tf.estimator.TrainSpec 对象,用于指定训练的输入函数以及其它参数
    eval_spec   # tf.estimator.EvalSpec 对象,用于指定验证的输入函数以及其它参数)
    

tf.estimator的用法相关推荐

  1. tf.estimator用法

    estimator:估算器 tf.estimator -----一种高级TensorFlow API.估算器封装以下操作: 训练(training) 评价(evaluation) 预测(predict ...

  2. 值域范围 tf.clip_by_value的用法

    tf.clip_by_value的用法 tf.clip_by_value(A, min, max):输入一个张量A,把A中的每一个元素的值都压缩在min和max之间.小于min的让它等于min,大于m ...

  3. tf35:tf.estimator

    MachineLP的Github(欢迎follow):https://github.com/MachineLP tf.estimator 是Tensorflow的高级API, 可快速训练和评估各种传统 ...

  4. 【TensorFlow基础函数】tf.concat的用法

    tf.concat 的用法 TF官方的文档 tf.concat(values,axis,name='concat' ) 连接多个Tensor的操作 values 多个Tensor axis是哪个纬度 ...

  5. tf.estimator.train_and_evaluate 详解

    TensorFlow 版本:1.11.0 在 TensorFlow 1.4 版本中,Google 新引入了一个新 API:tf.estimator.train_and_evaluate.提出这个 AP ...

  6. Tensorflow API 讲解——tf.estimator.Estimator

    class Estimator(builtins.object) #介绍 Estimator 类,用来训练和验证 TensorFlow 模型. Estimator 对象包含了一个模型 model_fn ...

  7. tf.estimator.EstimatorSpec讲解

    作用 是一个class(类),是定义在model_fn中的,并且model_fn返回的也是它的一个实例,这个实例是用来初始化Estimator类的 (Ops and objects returned ...

  8. tf.estimator.Estimator解析

    Estimator类代表了一个模型,以及如何对这个模型进行训练和评估, class Estimator(builtins.object) 可以按照下面方式创建一个E def resnet_v1_10_ ...

  9. tf.estimator.Estimator的使用

    tf.estimator.Estimator是TF比较高级的接口. 最近在使用bert预训练模型的时候用到了tf.estimator.Estimator.使用该接口的时候需要开发者完成的工作比较少,一 ...

最新文章

  1. 报告 | 2019年全球数字化转型现状研究报告
  2. .NET简谈组件程序设计之(详解NetRemoting结构)
  3. 为什么Tomcat的webapps目录下新建的目录不能访问html文件?
  4. QTP的那些事--调用外部的文件的方法
  5. hadoop简单介绍_Hadoop:简单介绍
  6. MySQL Mac安装教程
  7. 《Javascript权威指南》学习笔记之十二:数组、多维数组和符合数组(哈希映射)...
  8. oracle级联查询 level,ORACLE 数据库的级联查询 一句sql搞定(部门多级)
  9. php关联数组和哈希表,12、哈希表(关联数组) - RGSS 入门教程
  10. (html+css)静态小米闪购主页仿制
  11. 12月第1周网络安全报告:境内95.8万主机感染病毒
  12. 使用nssm管理tomcat服务操作步骤
  13. PowerScript--功能强大的智能卡,USB Key, POS脚本命令工具
  14. 2021年 全网最细大数据学习笔记(一):初识 Hadoop
  15. linux 卸载dnw命令,linux下面安装dnw
  16. 南京师范大学计算机技术研究生就业,重磅!2017年南京师范大学毕业研究生就业质量报告新鲜出炉...
  17. matlab中的中间值,matlab - 在MATLAB中获取中间值的索引 - 堆栈内存溢出
  18. 解决android sdk中找不到tools目录Android sdkmanager tool not found (D:\Android\SDK\tools\bin\sdkmanager).
  19. mac 安装homebrew 报错 curl: (7) Failed to connect to raw.githubusercontent.com port 443: Connection refu
  20. 萧乾升:4.14黄金,白银TD,纸白银,最新行情分析

热门文章

  1. Mask R-CNN:实例分割与检测算法
  2. LeetCode 52. N皇后 II
  3. ICML2021 | 自提升策略规划真实且可执行的分子逆合成路线
  4. 第二课.图卷积神经网络
  5. 第二课.PyTorch入门
  6. 边缘分布律_概率论笔记-Ch3随机向量及其分布
  7. 宏基因组实战3. MEGAHIT组装拼接及quast评估
  8. 价值4500元的微生物组培训资料
  9. java http 下载网页代码_Java下http下载文件客户端和上传文件客户端实例代码
  10. Python使用matplotlib可视化箱图、seaborn中的boxplot函数可视化分组箱图、在箱图中添加抖动数据点(Dot + Box Plot)