前言

根据官方文档整理而来的,主要是对Iris数据集进行分类。使用tf.contrib.learn.tf.contrib.learn快速搭建一个深层网络分类器,

步骤

  1. 导入csv数据
  2. 搭建网络分类器
  3. 训练网络
  4. 计算测试集正确率
  5. 对新样本进行分类

数据

Iris数据集包含150行数据,有三种不同的Iris品种分类。每一行数据给出了四个特征信息和一个分类信息。
现在已经将数据分为训练集和测试集

  • A training set of 120 samples (iris_training.csv)
  • A test set of 30 samples (iris_test.csv)

网络搭建

1. 首先,导入tensorflow 和 numpy

  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import tensorflow as tf
  5. import numpy as np

2. 导入数据

  1. # 定义数据地址
  2. IRIS_TRAINING = "iris_training.csv"
  3. IRIS_TEST = "iris_test.csv"
  4. # 导入数据
  5. training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  6. filename=IRIS_TRAINING,
  7. target_dtype=np.int,
  8. features_dtype=np.float32)
  9. test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  10. filename=IRIS_TEST,
  11. target_dtype=np.int,
  12. features_dtype=np.float32)

load_csv_with_header() 有三个参数

  • filename, 数据地址
  • target_dtype, 目标值的numpy datatype(iris的目标值是0,1,2,所以是np.int)
  • features_dtype, 特征值的numpy datatype .

3. 搭建网络结构

  1. # 每行数据4个特征,都是real-value的
  2. feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
  3. # 3层DNN,3分类问题
  4. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
  5. hidden_units=[10, 20, 10],
  6. n_classes=3,
  7. model_dir="iris_model")

参数解释

  • feature_columns 特征值
  • hidden_units=[10, 20, 10]. 3个隐藏层,包含的隐藏神经元依次是10, 20, 10
  • n_classes 类别个数
  • model_dir 模型保存地址

4. 训练数据

  1. classifier.fit(x=training_set.data, y=training_set.target, steps=2000)

steps 为训练次数

5. 计算准确率

  1. accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]
  2. print('Accuracy: {0:f}'.format(accuracy_score))

运行结果是

  1. Accuracy: 0.966667

6. 对新样本进行预测

  1. # Classify two new flower samples.
  2. new_samples = np.array(
  3. [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
  4. y = list(classifier.predict(new_samples, as_iterable=True))
  5. print('Predictions: {}'.format(str(y)))

运行结果为:

  1. Prediction: [1 2]

完整代码

  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import tensorflow as tf
  5. import numpy as np
  6. IRIS_TRAINING = "iris_training.csv"
  7. IRIS_TEST = "iris_test.csv"
  8. training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  9. filename=IRIS_TRAINING,
  10. target_dtype=np.int,
  11. features_dtype=np.float32)
  12. test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
  13. filename=IRIS_TEST,
  14. target_dtype=np.int,
  15. features_dtype=np.float32)
  16. feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
  17. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
  18. hidden_units=[10, 20, 10],
  19. n_classes=3,
  20. model_dir="iris_model")
  21. classifier.fit(x=training_set.data,
  22. y=training_set.target,
  23. steps=2000)
  24. accuracy_score = classifier.evaluate(x=test_set.data,
  25. y=test_set.target)["accuracy"]
  26. print('Accuracy: {0:f}'.format(accuracy_score))
  27. new_samples = np.array(
  28. [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
  29. y = list(classifier.predict(new_samples, as_iterable=True))
  30. print('Predictions: {}'.format(str(y)))

参考

  • tf.contrib.learn Quickstart
  • tf.contrib.learn API
原文地址: http://www.datalearner.com/blog/1051488938031745

TFboys:使用Tensorflow搭建深层网络分类器相关推荐

  1. TensorFlow搭建VGG-Siamese网络

    TensorFlow搭建VGG-Siamese网络 Siamese原理 Siamese网络,中文称为孪生网络.大致结构如下图所示: Siamese网络有两个输入,一个输出.其中,两个输入经过相同的网络 ...

  2. 5.3 使用tensorflow搭建GoogLeNet网络 笔记

    B站资源 csdn本家 文章目录 model model_add_bn train train_add_bn trainGPU predict model from tensorflow.keras ...

  3. 利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结。

    利用TensorFlow搭建CNN,DNN网络实现图像手写识别,总结. 摘要 一.神经网络与卷积网络的对比 1.数据处理 2.对获取到的数据进行归一化和独热编码 二.开始我们的tensorflow神经 ...

  4. python训练手势分类器_机器学习零基础?手把手教你用TensorFlow搭建图像分类器|干货...

    编者按:Pete Warden是TensorFlow移动团队的技术负责人.曾在Jetpac担任首次技术官.Jetpac的深度学习技术经过优化,可在移动和嵌入式设备上运行.该公司已于2014年被谷歌收购 ...

  5. TensorFlow 使用 slim 模块搭建复杂网络

    原文链接: TensorFlow 使用 slim 模块搭建复杂网络 上一篇: scrapy 代理使用 下一篇: TensorFlow infogan 生成 mnist 数据集 参考 https://b ...

  6. #教计算机学画卡通人物#生成式对抗神经网络GAN原理、Tensorflow搭建网络生成卡通人脸

    生成式对抗神经网络GAN原理.Tensorflow搭建网络生成卡通人脸 下面这张图是我教计算机学画画,计算机学会之后画出来的,具体实现在下面. ▲以下是对GAN形象化地表述 ●赵某不务正业.游手好闲, ...

  7. tensorflow随笔——VGG网络

    这次用slim搭个稍微大一点的网络VGG16,VGG16和VGG19实际上差不多,所以本例程的代码以VGG16来做5类花的分类任务. VGG网络相比之前的LeNet,AlexNet引入如下几个特点: ...

  8. 从自我学习到深层网络

     从自我学习到深层网络 From Ufldl Jump to: navigation, search 在前一节中,我们利用自编码器来学习输入至 softmax 或 logistic 回归分类器的特 ...

  9. Stanford UFLDL教程 从自我学习到深层网络

    从自我学习到深层网络 在前一节中,我们利用自编码器来学习输入至 softmax 或 logistic 回归分类器的特征.这些特征仅利用未标注数据学习获得.在本节中,我们描述如何利用已标注数据进行微调, ...

最新文章

  1. 一个管理者的反思(太深刻了!)
  2. windows8.1 windows defender service无法启动解决方案
  3. Windows下MySQL 5.6.19 general_log的设置(亲测)
  4. SpringBoot2.0 基础案例(17):自定义启动页,项目打包和指定运行环境
  5. 2 HTTP和HTTPS
  6. 一种虚拟现实技术用计算机,虚拟现实技术有哪几大分类?
  7. 深入理解目标检测与YOLO(从v1到v3)
  8. HTML——HTML基础语法
  9. Ruby语言快速入门
  10. STVD+Cosmic搭建STM8开发环境
  11. Tarjan的缩点割点概述
  12. 数据库:order by排序语句的用法
  13. 自制hdmi线一头改vga图_杀鸡取卵 | 破拆电脑VGA电缆获取收音机天线零件:双目铁氧体磁芯...
  14. 有一个Map集合里面存储的是学生的姓名和年龄,内容如下{赵四=21,王二=17,张三=18,小丫=25,李四=26,王五=38}(15分) * a.将里面的元素用两种遍历方式打印到控制台上 *
  15. mysql 触发器很慢_mysql之视图、触发器、事物、存储过程、函数、流程控制、索引与慢查优化...
  16. 基于Python的阴阳师后台全平台辅助
  17. C# Code Review Checklist
  18. 提高免疫力的食物 十种提升免疫力食材
  19. matlab小作业答案,MATLAB编程作业答案.doc
  20. 遇到错误:python文件读写权限permission denied

热门文章

  1. 3月16日 winform
  2. Linux服务器与windows本地之间的数据同步
  3. 高性能计算中并行的概念理解
  4. KDE/QT vs GNOME/GTK
  5. 【云炬大学生创业基础笔记】第1章第1节 创新和创业有什么样的关系?
  6. 云炬创业政策学习笔记20210104
  7. [一维粒子模拟 version3.6]renormalization
  8. 吴恩达《Machine Learning》精炼笔记 7:支持向量机 SVM
  9. 独家干货 | 林轩田机器学习课程精炼笔记!
  10. 如何将ipynb转换为html,md,pdf等格式