本文是我自己对官方文档的记录

基于TensorFlow预先定义的一些Estimator,编写程序,需要遵从以下步骤:

1. 创建一个或多个输入函数

2. 定义模型的特征列(feature_column)

3. 实例化Estimator,指定特征列和各种超参数

4. 在Estimator对象上调用一个或多个方法,传递适当的输入函数作为数据源。

(1)输入函数应该是一个返回 tf.data.Dataset 的函数,返回的dataset 输出two-element tuple:

features: (是python中的一个字典,key(键)是feature的名字,value(值)是包含所有该键对应的值的数组)

label:包含所有样本标签值的数组。

(2)feature_column用于描述model如何使用原始输入数据。当创建Estimator时,需要传入feature_column来告诉模型,会传入什么特征。在本例中,因为传入的特征为4个数值,因此我们创建feature_column告诉Estimator model,每个特征都使用32位浮点数来表示。

(3)实例化Estimator

鸢尾花是一个经典的分类问题,TensorFlow提供了现成的model,

(4)训练,评估和测试

调用model.train()方法

model.evaluate()

代码:

import tensorflow as tf

import pandas as pd

import argparse

TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"

TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',

'PetalLength', 'PetalWidth', 'Species']

SPECIES = ['Setosa', 'Versicolor', 'Virginica']

def maybeDownload():

pathTrain = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1],TRAIN_URL)

pathTest = tf.keras.utils.get_file(TEST_URL.split('/')[-1],TEST_URL)

return pathTrain,pathTest

def loadData(label_name='Species'):

pathTrain ,pathTest =maybeDownload()

Train = pd.read_csv(pathTrain,names=CSV_COLUMN_NAMES,header=0)

Test = pd.read_csv(pathTest,names=CSV_COLUMN_NAMES,header=0)

Train_X,Train_Y = Train,Train.pop(label_name)

Test_X,Test_Y = Test, Test.pop(label_name)

return (Train_X,Train_Y), (Test_X,Test_Y)

def trainInFunc (features ,labels ,batchsize):

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

dataset = dataset.shuffle(1000).repeat().batch(batchsize)

return dataset

def testInFunc (features, labels, batchsize):

features = dict(features)

if labels is None:

input = features

else:

input = (features, labels)

dataset = tf.data.Dataset.from_tensor_slices(input)

assert batchsize is not None, "batch_size must not be None"

dataset = dataset.batch(batchsize)

return dataset

parser = argparse.ArgumentParser()

parser.add_argument('--batch_size', default=100, type=int, help='batch size')

parser.add_argument('--train_steps', default=1000, type=int,

help='number of training steps')

def main(argv):

arg = parser.parse_args(argv[1:])

(Train_X, Train_Y) ,(Test_X, Test_Y) = loadData()

feature_column = []

for key in Test_X.keys():

feature_column.append(tf.feature_column.numeric_column(key=key))

classifier = tf.estimator.DNNClassifier(hidden_units=[10,10],feature_columns=feature_column,n_classes=3)

classifier.train(lambda :trainInFunc(Train_X,Train_Y,100),steps=1000)

accuracy = classifier.evaluate(lambda :testInFunc(Test_X,Test_Y,100))

print(accuracy)

if __name__ == '__main__':

tf.logging.set_verbosity(tf.logging.INFO)

tf.app.run(main)

keras模型 鸾尾花数据集_TensorFlow 入门(鸢尾花数据集)(一)相关推荐

  1. python鸢尾花数据集_Python实现鸢尾花数据集分类问题——使用LogisticRegression分类器...

    . 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题,常见的是二分类或二项分布问题,也可以处理多分类问题,它实际上是属于一种分类方法. 概率p与因变量往 ...

  2. python导入鸢尾花数据集_Python实现鸢尾花数据集分类问题——基于skearn的SVM

    1 #!/usr/bin/env python 2 #encoding: utf-8 3 __author__ = 'Xiaolin Shen' 4 from sklearn importsvm5 i ...

  3. knn鸢尾花数据集java_机器学习——鸢尾花数据集(Knn分类)

    Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理.Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集.数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个 ...

  4. knn鸢尾花数据集java_1.从鸢尾花数据集与KNN说起

    序 这是一个全新的系列--机器学习系列.鉴于我的研究方向和兴趣,我打算在这个系列中,从简到难系统地回顾各类机器学习方法,包括最优化,监督学习与无监督学习,神经网络,强化学习,迁移学习等内容.对于一些重 ...

  5. 机器学习(1)机器学习基础 鸢尾花数据集

    目录 一.机器学习基础理论 1.机器学习过程 2.机器学习分类 3.数据集返回值介绍 二.鸢尾花数据集(实战) 1.首先是获取数据集 2.显示数据集信息(可以不要) 三.数据集划分 1.数据集划分AP ...

  6. TensorFlow基础1(波士顿房价/鸢尾花数据集可视化)

    记录TensorFlow听课笔记 文章目录 记录TensorFlow听课笔记 一,波士顿房价数据集可视化 1.1介绍波士顿房价数据集 1.2波士顿房价数据集加载 1.3将平均房间数与房价之间的关系可视 ...

  7. 实验一:鸢尾花数据集分类

    实验一:鸢尾花数据集分类 一.问题描述 利用机器学习算法构建模型,根据鸢尾花的花萼和花瓣大小,区分鸢尾花的品种.实现一个基础的三分类问题. 二.数据集分析 Iris 鸢尾花数据集内包含 3 种类别,分 ...

  8. KNN算法实现鸢尾花数据集分类

    KNN算法实现鸢尾花数据集分类 作者介绍 数据集介绍 KNN算法介绍 用KNN实现鸢尾花分类 作者介绍 乔冠华,女,西安工程大学电子信息学院,2020级硕士研究生,张宏伟人工智能课题组. 研究方向:机 ...

  9. 机器学习入门案例:鸢尾花数据集分类 绘制PR曲线

    案例使用鸢尾花数据集进行分类预测,并绘制评价分类性能的PR曲线图 认识分类任务和数据集 Iris(鸢尾花)数据集 案例演示中使用的是有监督的机器学习算法:SVM 支持向量机 建立模型的流程如下: 训练 ...

最新文章

  1. Android base64 上传图片
  2. 回顾我学过的编程语言
  3. 3.1.11 段页式管理方式
  4. 计算机网络和机器视觉,一文读懂计算机视觉和机器人视觉
  5. NOIP2016全国信息学分区普级组 买铅笔(c++版)
  6. centos7公司内网环境搭建集群性能测试环境(ip+域名部署)
  7. BZOJ 2442: [Usaco2011 Open]修剪草坪 单调队列
  8. SAP Fiori Elements 框架里 Smart Table 控件的工作原理介绍
  9. oModel.create will also send to backend directly
  10. 2015蓝桥杯省赛---java---B---3(三羊献瑞)
  11. oracle更换rac节点,Oracle-rac 更改VIP地址—2节点的
  12. 计算机组成原理试卷分析,《计算机组成原理与汇编语言》试卷分析报告.doc.docx...
  13. 求职软件测试工程师英文简历,软件测试工程师英文简历范文
  14. Clang vs Other Open Source Compilers
  15. 强烈推荐这些值得下载的神仙工具,每一个都让人惊喜
  16. [CF1603D] Artistic Partition——欧拉函数,线段树优化DP
  17. 什么是内部类?内部类的作用
  18. 双机热备、双机互备、双机双工之间的区别
  19. 用U盘安装XP操作系统
  20. 二进制基带信号的时域特性

热门文章

  1. 计算机伦理的发展,人工智能技术发展的伦理困境研究
  2. H3C 胖AP设置(非VLAN模式)
  3. 2022-10-17 环境映射
  4. 用相关法辨识系统的脉冲响应 matlab,利用相关分析法辨识脉冲响应
  5. 逆天了!全地形、四舵轮、八连杆、独立悬挂的机器人运动结构方案,来了!
  6. mysql vtype_ExtJs6学习笔记 -- 自定义 vtype
  7. 用HTML做窗体程序界面
  8. Windows10系统如何多开微信程序(上班划水必备)
  9. 2001-2019年300多个城市进口额、出口额、进出口额汇总
  10. 机器学习六步曲——“小马医生”养成记