tf.estimator的用法
tf.estimator的用法
利用 tf.estimator
训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn
),另一个用于模型创建的函数(model_fn
)。下面逐一来说明。(数据格式采用 TFRecord)。
首先我们从输入到输出调用顺序来介绍一下大概的训练过程(完整官方文档:tf.estimator):
定义输入函数 input_fn
,(就是先把图片、标签打包成一个dataset形式)返回如下两种格式之一:
tf.data.Dataset 对象:这个对象的输出必须是元组队 (features, labels),而且必须满足下一条返回格式的同等约束;
元组 (features, labels):features 以及 labels 都必须是一个张量或由张量组成的字典。
下面是一个例子:
def train(data_dir):#遍历出图片路径列表paths = walk_type(data_dir +r'\train\\*\\','*.bmp')#遍历出标签列表labels = []for path in paths:label = int(path[14:16])labels.append(label)# 图片路径列表转tensor常量filenames = tf.constant(paths)#<tf.Tensor 'args_0:0' shape=() dtype=string># 标签列表转tensor常量labels = tf.constant(labels)# tensor常量转datasetdataset = tf.data.Dataset.from_tensor_slices((filenames, labels))# <DatasetV1Adapter shapes: ((), ()), types: (tf.string, tf.int32)># 此时dataset中的一个元素是(image_resized, label)dataset = dataset.map(_parse_function)#<DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)>return dataset
def train_input_fn(data_dir, # dataparams #{'learning_rate': 0.001, 'batch_size': 1, 'num_epochs': 20, 'num_channels': 32, 'use_batch_norm': False, 'bn_momentum': 0.9, 'margin': 0.5, 'embedding_size': 64, 'triplet_strategy': 'batch_all', 'squared': False, 'image_size': 28, 'num_labels': 10, 'train_size': 50000, 'eval_size': 10000, 'num_parallel_calls': 4, 'save_summary_steps': 50}):# 把data_dir数据集中的image和label打包成元组张量dataset = img_label_to_dataset.train(data_dir # data) #<DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)># 打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。单位是以图片(张量)为单位,而不是byte;dataset = dataset.shuffle(params['train_size'] # 50000) # <DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)># 将整个数据集重复 params['num_epochs'] 次dataset = dataset.repeat(params['num_epochs'] # 20) # <DatasetV1Adapter shapes: ((64, 64, ?), ()), types: (tf.float32, tf.int32)># 将 params['batch_size'] 个元素组合成batchdataset = dataset.batch(params['batch_size'] # 1) # <DatasetV1Adapter shapes: ((?, 64, 64, ?), (?,)), types: (tf.float32, tf.int32)># 预先载入(并行运算,加快数据运算速度)dataset = dataset.prefetch(1) # <DatasetV1Adapter shapes: ((?, 64, 64, ?), (?,)), types: (tf.float32, tf.int32)>return dataset
类 tf.estimator.EstimatorSpec
的完整形式是:
tf.estimator.EstimatorSpec(mode, #指定当前是处于训练、验证还是预测状态predictions=None, #是预测的一个张量,或者是由张量组成的一个字典loss=None, #是损失张量train_op=None, #指定优化操作eval_metric_ops=None, #指定各种评估度量的字典export_outputs=None, #参数 export_outputs 只用于模型保存,描述了导出到 SavedModel 的输出格式training_chief_hooks=None,training_hooks=None,scaffold=None, #是一个 tf.train.Scaffold 对象,可以在训练阶段初始化、保存等时使用。evaluation_hooks=None,prediction_hooks=None)
定义模型函数 model_fn
,返回类 tf.estimator.EstimatorSpec 的一个实例。model_fn 的完整定义形式是(函数名任取):
def model_fn(
features, #从input_fn中传入
labels, #从input_fn中传入
mode, #指定训练模式,可以取 (TRAIN, EVAL, PREDICT)三者之一
params=None #是一个(可要可不要的)字典,指定其它超参数。):params = params or {}loss, train_op, ... = None, None, ...prediction_dict = ...if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):loss = ...#必填项(损失函数)if mode == tf.estimator.ModeKeys.TRAIN:train_op = ...#必填项(训练图)if mode == tf.estimator.ModeKeys.PREDICT: predictions = ...#必填项(预测结果)return tf.estimator.EstimatorSpec(mode=mode,predictions=prediction_dict, #预测结果loss=loss, #损失函数train_op=train_op, #训练图...)
使用 tf.estimator.TrainSpec
指定训练输入函数及相关参数。该类的完整形式是:
train_spec = tf.estimator.TrainSpec(
input_fn, #提供训练时的输入数据
max_steps, #指定总共训练多少步
hooks #一个 tf.train.SessionRunHook 对象,用来配置分布式训练等参数。)
使用 tf.estimator.EvalSpec
指定验证输入函数及相关参数。该类的完整形式是:
EvalSpec = tf.estimator.EvalSpec(input_fn, #用来提供验证时的输入数据steps=100, #指定总共验证多少步(一般设定为 None 即可)name=None, hooks=None, #用来配置分布式训练等参数 exporters=None, #Exporter 迭代器,会参与到每次的模型验证start_delay_secs=120, #指定多少秒之后开始模型验证throttle_secs=600 #指定多少秒之后重新开始新一轮模型验证 )
使用 tf.estimator.Estimator
定义 Estimator 实例 estimator。类 Estimator 的完整形式是:
estimator = tf.estimator.Estimator(
model_fn, #模型函数
model_dir=None, #训练时模型保存的路径
config=None, #tf.estimator.RunConfig 的配置对象
params=None, #传入 model_fn 的超参数字典
warm_start_from=None #或者是一个预训练文件的路径,或者是一个 tf.estimator.WarmStartSettings 对象,用于完整的配置热启动参数)
下面就有两种调用方法:
在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源
estimator.train()
# Train the Model. Estimator.train(input_fn = lambda: train_input_fn(args.data_dir,params ))
estimator.evaluate()
# Evaluate the model. eval_result = classifier.evaluate(input_fn = lambda: train_input_fn(args.data_dir,params ))
estimator.predict()
predictions = classifier.predict(input_fn=lambda:train_input_fn(args.data_dir,params ))
使用
tf.estimator.train_and_evaluate
启动训练和验证过程。该函数的完整形式是:tf.estimator.train_and_evaluate( estimator, #tf.estimator.Estimator 对象,用于指定模型函数以及其它相关参数 train_spec, #tf.estimator.TrainSpec 对象,用于指定训练的输入函数以及其它参数 eval_spec # tf.estimator.EvalSpec 对象,用于指定验证的输入函数以及其它参数)
tf.estimator的用法相关推荐
- tf.estimator用法
estimator:估算器 tf.estimator -----一种高级TensorFlow API.估算器封装以下操作: 训练(training) 评价(evaluation) 预测(predict ...
- 值域范围 tf.clip_by_value的用法
tf.clip_by_value的用法 tf.clip_by_value(A, min, max):输入一个张量A,把A中的每一个元素的值都压缩在min和max之间.小于min的让它等于min,大于m ...
- tf35:tf.estimator
MachineLP的Github(欢迎follow):https://github.com/MachineLP tf.estimator 是Tensorflow的高级API, 可快速训练和评估各种传统 ...
- 【TensorFlow基础函数】tf.concat的用法
tf.concat 的用法 TF官方的文档 tf.concat(values,axis,name='concat' ) 连接多个Tensor的操作 values 多个Tensor axis是哪个纬度 ...
- tf.estimator.train_and_evaluate 详解
TensorFlow 版本:1.11.0 在 TensorFlow 1.4 版本中,Google 新引入了一个新 API:tf.estimator.train_and_evaluate.提出这个 AP ...
- Tensorflow API 讲解——tf.estimator.Estimator
class Estimator(builtins.object) #介绍 Estimator 类,用来训练和验证 TensorFlow 模型. Estimator 对象包含了一个模型 model_fn ...
- tf.estimator.EstimatorSpec讲解
作用 是一个class(类),是定义在model_fn中的,并且model_fn返回的也是它的一个实例,这个实例是用来初始化Estimator类的 (Ops and objects returned ...
- tf.estimator.Estimator解析
Estimator类代表了一个模型,以及如何对这个模型进行训练和评估, class Estimator(builtins.object) 可以按照下面方式创建一个E def resnet_v1_10_ ...
- tf.estimator.Estimator的使用
tf.estimator.Estimator是TF比较高级的接口. 最近在使用bert预训练模型的时候用到了tf.estimator.Estimator.使用该接口的时候需要开发者完成的工作比较少,一 ...
最新文章
- 报告 | 2019年全球数字化转型现状研究报告
- .NET简谈组件程序设计之(详解NetRemoting结构)
- 为什么Tomcat的webapps目录下新建的目录不能访问html文件?
- QTP的那些事--调用外部的文件的方法
- hadoop简单介绍_Hadoop:简单介绍
- MySQL Mac安装教程
- 《Javascript权威指南》学习笔记之十二:数组、多维数组和符合数组(哈希映射)...
- oracle级联查询 level,ORACLE 数据库的级联查询 一句sql搞定(部门多级)
- php关联数组和哈希表,12、哈希表(关联数组) - RGSS 入门教程
- (html+css)静态小米闪购主页仿制
- 12月第1周网络安全报告:境内95.8万主机感染病毒
- 使用nssm管理tomcat服务操作步骤
- PowerScript--功能强大的智能卡,USB Key, POS脚本命令工具
- 2021年 全网最细大数据学习笔记(一):初识 Hadoop
- linux 卸载dnw命令,linux下面安装dnw
- 南京师范大学计算机技术研究生就业,重磅!2017年南京师范大学毕业研究生就业质量报告新鲜出炉...
- matlab中的中间值,matlab - 在MATLAB中获取中间值的索引 - 堆栈内存溢出
- 解决android sdk中找不到tools目录Android sdkmanager tool not found (D:\Android\SDK\tools\bin\sdkmanager).
- mac 安装homebrew 报错 curl: (7) Failed to connect to raw.githubusercontent.com port 443: Connection refu
- 萧乾升:4.14黄金,白银TD,纸白银,最新行情分析
热门文章
- Mask R-CNN:实例分割与检测算法
- LeetCode 52. N皇后 II
- ICML2021 | 自提升策略规划真实且可执行的分子逆合成路线
- 第二课.图卷积神经网络
- 第二课.PyTorch入门
- 边缘分布律_概率论笔记-Ch3随机向量及其分布
- 宏基因组实战3. MEGAHIT组装拼接及quast评估
- 价值4500元的微生物组培训资料
- java http 下载网页代码_Java下http下载文件客户端和上传文件客户端实例代码
- Python使用matplotlib可视化箱图、seaborn中的boxplot函数可视化分组箱图、在箱图中添加抖动数据点(Dot + Box Plot)