MachineLP的Github(欢迎follow):https://github.com/MachineLP

tf.estimator 是Tensorflow的高级API, 可快速训练和评估各种传统机器学习模型。

看下面一段代码, 使用神经网络应用到Iris数据集上。

import os
from six.moves.urllib.request import urlopenimport numpy as np
import tensorflow as tf# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"def main():# If the training and test sets aren't stored locally, download them.if not os.path.exists(IRIS_TRAINING):raw = urlopen(IRIS_TRAINING_URL).read()with open(IRIS_TRAINING, "wb") as f:f.write(raw)if not os.path.exists(IRIS_TEST):raw = urlopen(IRIS_TEST_URL).read()with open(IRIS_TEST, "wb") as f:f.write(raw)# Load datasets.training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TRAINING,target_dtype=np.int,features_dtype=np.float32)test_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=IRIS_TEST,target_dtype=np.int,features_dtype=np.float32)# Specify that all features have real-value datafeature_columns = [tf.feature_column.numeric_column("x", shape=[4])]# Build 3 layer DNN with 10, 20, 10 units respectively.classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,hidden_units=[10, 20, 10],n_classes=3,model_dir="/tmp/iris_model")# Define the training inputstrain_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": np.array(training_set.data)},y=np.array(training_set.target),num_epochs=None,shuffle=True)# Train model.classifier.train(input_fn=train_input_fn, steps=2000)# Define the test inputstest_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": np.array(test_set.data)},y=np.array(test_set.target),num_epochs=1,shuffle=False)# Evaluate accuracy.accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]print("\nTest Accuracy: {0:f}\n".format(accuracy_score))# Classify two new flower samples.new_samples = np.array([[6.4, 3.2, 4.5, 1.5],[5.8, 3.1, 5.0, 1.7]], dtype=np.float32)predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": new_samples},num_epochs=1,shuffle=False)predictions = list(classifier.predict(input_fn=predict_input_fn))predicted_classes = [p["classes"] for p in predictions]print("New Samples, Class Predictions:    {}\n".format(predicted_classes))if __name__ == "__main__":main()

tf35:tf.estimator相关推荐

  1. tf.estimator的用法

    tf.estimator的用法 利用 tf.estimator 训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn),另一个用于模型创建的函数(model_fn).下面逐一来说明 ...

  2. tf.estimator.train_and_evaluate 详解

    TensorFlow 版本:1.11.0 在 TensorFlow 1.4 版本中,Google 新引入了一个新 API:tf.estimator.train_and_evaluate.提出这个 AP ...

  3. tensorflow综合示例4:逻辑回归:使用Estimator

    文章目录 1.加载csv格式的数据集并生成Dataset 1.1 pandas读取csv数据生成Dataframe 1.2 将Dataframe生成Dataset 2.将数据封装成Feature co ...

  4. Tensorflow API 讲解——tf.estimator.Estimator

    class Estimator(builtins.object) #介绍 Estimator 类,用来训练和验证 TensorFlow 模型. Estimator 对象包含了一个模型 model_fn ...

  5. tf.estimator.EstimatorSpec讲解

    作用 是一个class(类),是定义在model_fn中的,并且model_fn返回的也是它的一个实例,这个实例是用来初始化Estimator类的 (Ops and objects returned ...

  6. tf.estimator.Estimator解析

    Estimator类代表了一个模型,以及如何对这个模型进行训练和评估, class Estimator(builtins.object) 可以按照下面方式创建一个E def resnet_v1_10_ ...

  7. tf.estimator.Estimator的使用

    tf.estimator.Estimator是TF比较高级的接口. 最近在使用bert预训练模型的时候用到了tf.estimator.Estimator.使用该接口的时候需要开发者完成的工作比较少,一 ...

  8. 机器学习笔记5-Tensorflow高级API之tf.estimator

    前言 本文接着上一篇继续来聊Tensorflow的接口,上一篇中用较低层的接口实现了线性模型,本篇中将用更高级的API--tf.estimator来改写线性模型. 还记得之前的文章<机器学习笔记 ...

  9. tf.estimator用法

    estimator:估算器 tf.estimator -----一种高级TensorFlow API.估算器封装以下操作: 训练(training) 评价(evaluation) 预测(predict ...

最新文章

  1. 4K P60 444 相关的事
  2. 使用MySql保存session
  3. ABAP:为Table Control创建Context Menu
  4. StringBuilder的构造方法和append方法
  5. Java数组的十大方法
  6. ASP.NET Core 开源论坛项目 NETCoreBBS
  7. 华为P50渲染图再曝光:居中打孔直屏+徕卡四摄
  8. 解决Eclipse,MyEclipse出现An error has occurred,See error log for more details的错误
  9. LeetCode周赛
  10. vue-json-editor高度调整
  11. Java中取多个集合的交集——retainAll()
  12. python中--snip--是什么意思
  13. 华为计算机网络基础知识,华为HCNE专题一:网络基础知识
  14. 开发中mock什么意思_开发中
  15. 在OTFS学习中的一些总结
  16. python数据分析与应用第五章实训 2_第五章实训(二)
  17. 【分享】凡是不以风控为核心的创新都是在耍流氓
  18. IO回忆录之怎样过目不忘(BIO/NIO/AIO/Netty)2017版
  19. Kubernetes(三):k8s集群部署之kubeadm
  20. 轻松上手CSS Grid网格布局

热门文章

  1. 炸机不可怕,可怕的是你不知道为什么炸
  2. 盘点2015年英特尔旧金山IDF峰会上的黑科技
  3. 什么是EC, EC与多副本的对比分析
  4. YOLOv3使用笔记
  5. windows服务器文件上传与下载(不需要下载软件)
  6. 证明婚内出轨的几种证据
  7. c++算法基础必刷题目——前缀和与差分
  8. 明解C语言(入门篇)第二章
  9. 多模态 跨模态|人机交互新突破!
  10. 11 Daemonset:忠实可靠的看门狗