1.框架

Estimator是属于High level的API

Mid-level API分别是 —— Layers:用来构建网络结构、Datasets: 用来构建数据读取pipeline、Metrics:用来评估网络性能

2.使用

创建一个或多个输入函数,即input_fn

定义模型的特征列,即feature_columns

实例化 Estimator,指定特征列和各种超参数

在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源。(train, evaluate, predict)

伪代码形式介绍如何使用Estimator

创建一个或多个输入函数,即input_fn

注意, features需要是字典 (另外此处的feature与我们常说的提取特征的feature还不太一样,也可以指原图数据(raw image),或者其他未作处理的数据)。下面定义的my_feature_column会传给Estimator用于解析features。

def train_input_fn(features, labels, batch_size):"""An input function for training"""# Convert the inputs to a Dataset.dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))# Shuffle, repeat, and batch the examples.return dataset.shuffle(1000).repeat().batch(batch_size)

定义模型的特征列,即feature_columns

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():          my_feature_columns.append(tf.feature_column.numeric_column(key=key))

实例化 Estimator,指定特征列和各种超参数

注意在实例化Estimator的时候不用把数据传进来,你只需要把feature_columns传进来即可,告诉Estimator需要解析哪些特征值,而数据集需要在训练和评估模型的时候才传

# Build a DNN with 2 hidden layers and 10 nodes in each hidden layer.
classifier = tf.estimator.DNNClassifier(feature_columns=my_feature_columns,# Two hidden layers of 10 nodes each.hidden_units=[10, 10],# The model must choose between 3 classes.n_classes=3)

在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源

train(训练)
# Train the Model.
classifier.train(input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),steps=args.train_steps)evaluate(评估)
# Evaluate the model.
eval_result = classifier.evaluate(input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))predict(预测)
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {'SepalLength': [5.1, 5.9, 6.9],'SepalWidth': [3.3, 3.0, 3.1],'PetalLength': [1.7, 4.2, 5.4],'PetalWidth': [0.5, 1.5, 2.1],
}predictions = classifier.predict(input_fn=lambda:iris_data.eval_input_fn(predict_x,batch_size=args.batch_size))

3.深入理解

上面的示例中简单地介绍了Estimator,网络使用的是预创建好的DNNClassifier,其他预创建网络结构有如下

3.1 源码理解

class Estimator(object):def __init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None):...model_dir: 指定checkpoints和其他日志存放的路径
model_fn: 这个是需要我们自定义的网络模型函数,后面详细介绍
config: 用于控制内部和checkpoints等,如果model_fn函数也定义config这个变量,则会将config传给model_fn
params: 该参数的值会传递给model_fn
warm_start_from: 指定checkpoint路径,会导入该checkpoint开始训练

3.2 构建model_fn

def my_model_fn(features,    # This is batch_features from input_fn,`Tensor` or dict of `Tensor` (depends on data passed to `fit`).labels,     # This is batch_labels from input_fnmode,      # An instance of tf.estimator.ModeKeysparams,    # Additional configurationconfig=None):

前两个参数是从输入函数中返回的特征和标签批次;也就是说,features 和 labels 是模型将使用的数据

params 是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params['n_classes']来定义最终输出节点的个数等。

config 通常用来控制checkpoint或者分布式什么,这里不深入研究。

mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN / EVAL / PREDICT 来定义。另外通过观察DNNClassifier的源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。例如当你调用estimator.train(...)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN

model_fn需要对于不同的模式提供不同的处理方式,并且都需要返回一个tf.estimator.EstimatorSpec的实例。

咋听起来可能有点不知所云,大白话版本就是:模型有训练,验证和测试三种阶段,而且对于不同模式,对数据有不同的处理方式。例如在训练阶段,我们需要将数据喂给模型,模型基于输入数据给出预测值,然后我们在通过预测值和真实值计算出loss,最后用loss更新网络参数,而在评估阶段,我们则不需要反向传播更新网络参数,换句话说,mdoel_fn需要对三种模式设置三套代码。

另外model_fn需要返回什么东西呢?Estimator规定model_fn需要返回tf.estimator.EstimatorSpec,这样它才好更具一般化的进行处理。

3.3 Config

此处的config需要传入tf.estimator.RunConfig,其源代码如下

class RunConfig(object):"""This class specifies the configurations for an `Estimator` run."""def __init__(self,model_dir=None,tf_random_seed=None,save_summary_steps=100,save_checkpoints_steps=_USE_DEFAULT,save_checkpoints_secs=_USE_DEFAULT,session_config=None,keep_checkpoint_max=5,keep_checkpoint_every_n_hours=10000,log_step_count_steps=100,train_distribute=None,device_fn=None,protocol=None,eval_distribute=None,experimental_distribute=None,experimental_max_worker_delay_secs=None,session_creation_timeout_secs=7200):model_dir: 指定存储模型参数,graph等的路径save_summary_steps: 每隔多少step就存一次Summaries,不知道summary是啥save_checkpoints_steps:每隔多少个step就存一次checkpointsave_checkpoints_secs: 每隔多少秒就存一次checkpoint,不可以和save_checkpoints_steps同时指定。如果二者都不指定,则使用默认值,即每600秒存一次。如果二者都设置为None,则不存checkpoints。注意上面三个**save-**参数会控制保存checkpoints(模型结构和参数)和event文件(用于tensorboard),如果你都不想保存,那么你需要将这三个参数都置为FALSEkeep_checkpoint_max:指定最多保留多少个checkpoints,也就是说当超出指定数量后会将旧的checkpoint删除。当设置为None或0时,则保留所有checkpoints。keep_checkpoint_every_n_hours:log_step_count_steps:该参数的作用是,(相对于总的step数而言)指定每隔多少step就记录一次训练过程中loss的值,同时也会记录global steps/s,通过这个也可以得到模型训练的速度快慢。(天啦,终于找到这个参数了。。。。之前用TPU测模型速度,每次都得等好久才输出一次global steps/s的数据。。。蓝瘦香菇)后面这些参数与分布式有关,以后有时间再慢慢了解。train_distributedevice_fnprotocoleval_distributeexperimental_distributeexperimental_max_worker_delay_secs

3.4 tf.estimator.EstimatorSpec

它是一个class(类),是定义在model_fn中的,并且model_fn返回的也是它的一个实例,这个实例是用来初始化Estimator类的

class EstimatorSpec():def __new__(cls,mode,predictions=None,loss=None,train_op=None,eval_metric_ops=None,export_outputs=None,training_chief_hooks=None,training_hooks=None,scaffold=None,evaluation_hooks=None,prediction_hooks=None):

mode:一个ModeKeys,指定是training(训练)、evaluation(计算)还是prediction(预测).
predictions:Predictions Tensor or dict of Tensor.
loss:Training loss Tensor. Must be either scalar, or with shape [1].
train_op:适用于训练的步骤.
eval_metric_ops: Dict of metric results keyed by name.
The values of the dict can be one of the following:
    (1) instance of Metric class.
    (2) Results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

其他参数的作用可参见源代码说明, 不同模式需要传入不同参数

根据mode的值的不同,需要不同的参数,即:

对于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
    对于mode == ModeKeys.EVAL:必填字段是loss.
    对于mode == ModeKeys.PREDICT:必填字段是predictions.

上面的参数说明看起来还是一头雾水,下面给出例子帮助理解:

最简单的情况: predict

只需要传入mode和predictions

# Compute predictions.
predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT:predictions = {'class_ids': predicted_classes[:, tf.newaxis],'probabilities': tf.nn.softmax(logits),'logits': logits,}return tf.estimator.EstimatorSpec(mode, predictions=predictions)

评估模式:eval

需要传入mode,loss,eval_metric_ops

如果调用 Estimator 的 evaluate 方法,则 model_fn 会收到 mode = ModeKeys.EVAL。在这种情况下,模型函数必须返回一个包含模型损失和一个或多个指标(可选)的 tf.estimator.EstimatorSpec。

loss示例如下:

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

TensorFlow提供了一个指标模块tf.metrics来计算常用的指标,这里以accuracy为例

# Compute evaluation metrics.
accuracy = tf.metrics.accuracy(labels=labels, predictions=predicted_classes, name='acc_op')

返回方式

metrics = {'accuracy': accuracy}if mode == tf.estimator.ModeKeys.EVAL:return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)

训练模式:train

需要传入mode,loss,train_op

loss同eval模式:

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

train_op示例

optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step())

返回值

return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

通用模式

model_fn可以填充独立于模式的所有参数.在这种情况下,Estimator将忽略某些参数.在eval和infer模式中,train_op将被忽略.例子如下:

def my_model_fn(mode, features, labels):predictions = ...loss = ...train_op = ...return tf.estimator.EstimatorSpec(mode=mode,predictions=predictions,loss=loss,train_op=train_op)

Reference

https://www.cnblogs.com/marsggbo/p/11232897.html

TensorFlow estimator详解相关推荐

  1. Estimator详解

    Estimator Estimator是tensorflow推出的一个High level的API,用于简化机器学习 Estimator的优点 开发方便 方便整合其它tensorflow高阶api 单 ...

  2. TensorFlow之estimator详解

    Estimator初识 框架结构 在介绍Estimator之前需要对它在TensorFlow这个大框架的定位有个大致的认识,如下图示: 可以看到Estimator是属于High level的API,而 ...

  3. TensorFlow分布式详解

    每次 TensorFlow 运算都被描述成计算图的形式,允许结构和运算操作配置所具备的自由度能够被分配到各个分布式节点上.计算图可以分成多个子图,分配给服务器集群中的不同节点. 强烈推荐读者阅读论文& ...

  4. Tensorflow BatchNormalization详解:4_使用tf.nn.batch_normalization函数实现Batch Normalization操作...

    使用tf.nn.batch_normalization函数实现Batch Normalization操作 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 吴恩达deeplearnin ...

  5. cnn 预测过程代码_FPN的Tensorflow代码详解——特征提取

    @TOC   特征金字塔网络最早于2017年发表于CVPR,与Faster RCNN相比其在多池度特征预测的方式使得其在小目标预测上取得了较好的效果.FPN也作为mmdeteciton的Neck模块, ...

  6. fasterrcnn tensorflow代码详解_pytorch目标检测代码的一些bug调试

    这几天一直在做调包侠,是时候来总结总结了.记录一些我所遇到的不常见的问题. faster rcnn: 参考代码: jwyang/faster-rcnn.pytorch​github.com pytor ...

  7. Tensorflow ExponentialMovingAverage 详解

    tensorflow 中的 ExponentialMovingAverage 这时,再看官方文档中的公式: shadowVariable=decay∗shadowVariable+(1−decay)∗ ...

  8. FM(Factorization Machine)因式分解机 与 TensorFlow实现 详解

    1,线性回归(Linear Regression) 线性回归,即使用多维空间中的一条直线拟合样本数据,如果样本特征为: \[x = ({x_1},{x_2},...,{x_n})\] 模型假设函数如下 ...

  9. faster rcnn接口_Faster R-CNN tensorflow代码详解

    研究背景 根据Faster-RCNN算法的运行和调试情况,对代码进行深入分析. 参考资料 各部分代码分析 1 编译Cython模块 cd tf-faster-rcnn/lib # 首先进入目录Fast ...

最新文章

  1. Servlet 获取IllegelStateException
  2. 流氓网站5599.net修改ie主页分析
  3. Hadoop 在关机重启后,namenode启动报错
  4. C语言链表的来源分析
  5. Spring开启方法异步执行
  6. 前端学习(1387):多人管理项目7登录 数据库连接
  7. [你必须知道的.NET] 第八回:品味类型---值类型与引用类型(上)-内存有理
  8. mysql sql组合_详解mysql 组合查询
  9. b+树时间复杂度_数据结构:线性表,栈,队列,数组,字符串,树和二叉树,哈希表...
  10. [翻译]Writing Custom Wizards 编写自定义的向导
  11. [Elasticsearch] es 6.8 编译成功
  12. catch(…) vs catch(CException *)?
  13. tomcat绿色版及安装版修改内存大小的方法
  14. 脑电波连接计算机游戏,脑电波也能“玩游戏”?这个“挑战杯”全国一等奖告诉你这都不是事儿...
  15. php学生管理系统整理
  16. 在微型计算机中ega,在微机系统中,常有VGA、EGA等说法,它们的含义是什么
  17. 这名程序猿吐了一管口水,便迎来了人生的四大暴击…
  18. excel公式编辑器_只要2步,Excel就能拥有聚光灯效果,让你看清数据
  19. docker pull xxx 失败 超时 timeout
  20. python format函数 日期_Python-日期格式化

热门文章

  1. Shell脚本获得核酸反向互补序列
  2. 自然语言处理—初始自然语言处理技术—自然语言处理的前置技术
  3. 【gdb配置】打印stl容器,.gdbinit文件
  4. 小小数学家(python)
  5. 文字转语音朗读怎么操作?
  6. 把Excel转换成CSV/CSV UTF-8
  7. 通过API接口实现图片上传
  8. 论文精读:Asynchronous, Photometric Feature Tracking using Events and Frames(IJCV 2019)
  9. 机遇挑战药食同源健康产业论坛 万祥军:黑龙江工商联主导
  10. 精通CSS高级Web标准解决方案(第三版)读书笔记