代码实现及说明

# python 3.6
# TensorFlow实现简单的鸢尾花分类器
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn import datasetssess = tf.Session()#导入数据
iris = datasets.load_iris()
# 是否是山鸢尾 0/1
binary_target = np.array([1. if x == 0 else 0. forx in iris.target])
# 选择两个特征:花瓣长度和宽度
iris_2d = np.array([[x[2],x[3]] for x in iris.data])# 声明批训练大小、占位符和变量
# tf.float32降低float字节数 可以提高算法性能
batch_size = 20
x1_data = tf.placeholder(shape=[None,1],dtype=tf.float32)
x2_data = tf.placeholder(shape=[None,1],dtype=tf.float32)
y_target = tf.placeholder(shape=[None,1],dtype=tf.float32)
# 声明变量 A 和 b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))# 定义线性模型
# 线性模型的表达式为:x1=x2*A+b。
# 如果找到的数据点在直线以上,则将数据点代入x1-x2*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x1-x2*A-b计算出的结果小于0。
# 将公式x1-x2*A-b传入sigmoid函数,然后预测结果1或者0
# TensorFlow有内建的sigmoid损失函数,所以这里仅仅需要定义模型输出
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data,my_add)# 增加分类损失函数 这里用两类交叉熵损失函数 cross entropy
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output,labels=y_target)# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)# 循环
for i in range(1000):rand_index = np.random.choice(len(iris_2d),size=batch_size)rand_x = iris_2d[rand_index]rand_x1 = np.array([[x[0]] for x in rand_x])rand_x2 = np.array([[x[1]] for x in rand_x])rand_y = np.array([[y] for y in binary_target[rand_index]])sess.run(train_step, feed_dict={x1_data:rand_x1,x2_data:rand_x2,y_target:rand_y})if (i+1)%200 == 0:print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))# 结果可视化
[[slope]] = sess.run(A) # 斜率
# 因为A的shape是(1,1)所以要写成一行一列的形式
[[intercept]] = sess.run(b) # 截距# 创建拟合线
x = np.linspace(0, 3, num=50) # 0~3 50个均匀间隔的数字
ablineValues = []
for i in x:ablineValues.append(slope*i+intercept)# 绘图
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

绘图结果

总结

这里利用花瓣长度和花瓣宽度的特征在山鸢尾与其他物种间拟合一条直线,然后通过该直线来分割两类目标(山鸢尾和非山鸢尾),直线是迭代1000次得到的线性分割,通过直线分割两个目标并不是最好的模型。

【TensorFlow】实现简单的鸢尾花分类器相关推荐

  1. 如何使用TensorFlow构建简单的图像识别系统(第2部分)

    by Wolfgang Beyer 沃尔夫冈·拜尔(Wolfgang Beyer) 如何使用TensorFlow构建简单的图像识别系统(第2部分) (How to Build a Simple Ima ...

  2. 梯度下降的线性回归用python_运用TensorFlow进行简单实现线性回归、梯度下降示例...

    线性回归属于监督学习,因此方法和监督学习应该是一样的,先给定一个训练集,根据这个训练集学习出一个线性函数,然后测试这个函数训练的好不好(即此函数是否足够拟合训练集数据),挑选出最好的函数(cost f ...

  3. 【入门】Pytorch实现简单的图片分类器

    系列文章目录 [入门]Pytorch实现简单的图片分类器 [入门]GPU训练图片分类器 文章目录 系列文章目录 前言 导入库 数据归一化 查看训练集 构造网络 定义损失函数和优化器 开始训练 查看分类 ...

  4. TensorFlow实现简单线性回归

    文章目录 实验主题-TensorFlow实现简单线性回归 案例实现 实验效果 逻辑回归或线性回归是用于对离散类别进行分类的监督机器学习方法.在本章中的目标是构建一个模型,用户可以通过该模型预测预测变量 ...

  5. python遥感影像分类代码_【博客翻译】使用 Python Tensorflow 实现简单的神经网络卫星遥感影像分类...

    Landsat 5 多光谱数据分类指导手册原作者:Pratyush Tripathy 翻译:荆雪涵 姐妹篇雪涵:[博客翻译]CNN 与中分辨率遥感影像分类​zhuanlan.zhihu.com 深度学 ...

  6. 使用tensorflow构建简单卷积神经网络

    一 概要 CIFAR-10分类问题是机器学习领域的一个通用基准,其问题是将32X32像素的RGB图像分类成10种类别:飞机,手机,鸟,猫,鹿,狗,青蛙,马,船和卡车.  更多信息请移步CIFAR-10 ...

  7. TensorFlow图像分类:如何构建分类器

    导言 图像分类对于我们来说是一件非常容易的事情,但是对于一台机器来说,在人工智能和深度学习广泛使用之前,这是一项艰巨的任务.自动驾驶汽车能够实时检测物体并采取相应必要的行动,并且由于TensorFlo ...

  8. 基于TensorFlow的简单验证码识别

    TensorFlow 可以用来实现验证码识别的过程,这里识别的验证码是图形验证码,首先用标注好的数据来训练一个模型,然后再用模型来实现这个验证码的识别. 生成验证码 首先生成验证码,这里使用 Pyth ...

  9. RNN循环神经网络的直观理解:基于TensorFlow的简单RNN例子

    RNN 直观理解 一个非常棒的RNN入门Anyone Can learn To Code LSTM-RNN in Python(Part 1: RNN) 基于此文章,本文给出我自己的一些愚见 基于此文 ...

最新文章

  1. Android - HttpURLConnection 抛出异常
  2. 简单的http服务器示例
  3. 【黑客免杀攻防】读书笔记14 - 面向对象逆向-虚函数、MFC逆向
  4. 弱密码校验_TomCat8 弱密码上传getshell
  5. 在Linux系统中应用su和sudo
  6. VS 2010 复制代码到word出现乱码解决办法
  7. matlab批量将csv转换成xls,如何批量将CSV格式的文件转化成excel格式 |
  8. html模拟鼠标点击图标,易语言模拟鼠标点击实现方法
  9. mongo数据库索引原理
  10. 02时态(2):一般现在时、疑问句主语相同的句子
  11. 两个冲击函数相乘的傅里叶变换_通信第三章常见函数的傅里叶变换.ppt
  12. vue-router 基本使用
  13. DEFCON 26 | 利用传真功能漏洞渗透进入企业内网(Faxploit)
  14. Matplotlib显示灰度图
  15. Detect-and-Track: Efficient Pose Estimation in Videos(检测和追踪:视频中有效的姿态评估)论文解读
  16. 【Java攻城狮宝典】04-for循环(答案)
  17. 关于python的文献综述_关于毕业论文文献综述,史上最全总结.doc
  18. 003云数据中心基础原理笔记
  19. 计算机考试重点题目与答案
  20. MySQL - SQL语句增加字段/修改字段/修改类型/修改默认值

热门文章

  1. 你应该知道这些有意思的代码
  2. GPG96244QS1屏驱动难题
  3. 二叉树小球下落问题c语言,#C++初学记录(树和二叉树)
  4. android sdk 封装html5,Android平台以WebView方式集成HTML5+SDK方法
  5. mp4 拍摄时间如何看_时间不多了,如何备考期末最有效?这些复习技巧,看了你就会了...
  6. 智慧交通day02-车流量检测实现03:辅助功能(交并比and候选框的表现形式)
  7. 1+X web中级 Laravel学习笔记——查询构造器简介及新增、更新、删除、查询数据
  8. 《数据结构与算法之美》学习汇总
  9. LeetCode 2018. 判断单词是否能放入填字游戏内(模拟)
  10. LeetCode 1903. 字符串中的最大奇数