TensorFlow实现鸢尾花分类任务

鸢尾花数据集介绍

Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。

过程

1、导入所需库和数据

import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt#加载数据
train_url = 'http://download.tensorflow.org/data/iris_training.csv'
train_path = tf.keras.utils.get_file(train_url.split('/')[-1],train_url)
TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1],TEST_URL)
df_iris_train = pd.read_csv(train_path,header=0)
df_iris_test = pd.read_csv(test_path,header=0)

2、处理数据并对数据进行可视化

iris_train=np.array(df_iris_train)#120行5列
iris_test=np.array(df_iris_test)
train_x=iris_train[:,0:2]#取前两个属性
train_y=iris_train[:,4]#取最后一列作为标签值
test_x=iris_test[:,0:2]
test_y=iris_test[:,4]
#提取山鸢尾和变色鸢尾
x_train=train_x[train_y<2]
y_train=train_y[train_y<2]
x_test=test_x[test_y<2]
y_test=test_y[test_y<2]num_train = len(x_train)
num_tst = len(x_test)
#绘制散点图,横坐标表示花萼长度,纵坐标表示宽度
plt.figure(figsize=(10,3))
x_train=x_train-np.mean(x_train,axis=0)#中心化处理
x_test=x_test-np.mean(x_test,axis=0)cm_cp = mpl.colors.ListedColormap(['blue','red'])plt.subplot(121)
plt.scatter(x_train[:,0],x_train[:,1],c=y_train,cmap=cm_cp)plt.subplot(122)
plt.scatter(x_test[:,0],x_test[:,1],c=y_test,cmap=cm_cp)
plt.show()


左图为训练集,右图为测试集;横坐标表示花萼长度,纵坐标表示宽度。

3、设置超参数及模型参数初始值

x0_test=np.ones(num_tst).reshape(-1,1)
X_test=tf.cast(tf.concat((x0_test,x_test),axis=1),dtype=tf.float32)
Y_test=tf.cast(y_test.reshape(-1,1),dtype=tf.float32)
#生成多元模型的属性矩阵和标签列向量
x0_train = np.ones(num_train).reshape(-1,1)
X_train=tf.cast(tf.concat((x0_train,x_train),axis=1),tf.float32)
Y_train=tf.cast(y_train.reshape(-1,1),tf.float32)x0_test=np.ones(num_tst).reshape(-1,1)
X_test=tf.cast(tf.concat((x0_test,x_test),axis=1),dtype=tf.float32)
Y_test=tf.cast(y_test.reshape(-1,1),dtype=tf.float32)
#设置超参数
learn_rate = 0.2
iters=120
display_step = 30
#设置模型参数初始值
np.random.seed(612)
W=tf.Variable(np.random.randn(3,1),dtype=tf.float32)

4、训练模型

#训练模型
ce_train = []
acc_train = []
ce_test = []
acc_test = []for i in range(iters+1):with tf.GradientTape() as tape:PRED_train = 1/(1+tf.exp(-tf.matmul(X_train,W)))Loss_train = -tf.reduce_mean(Y_train*tf.math.log(PRED_train)+(1-Y_train)*tf.math.log(1-PRED_train))PRED_test = 1/(1+tf.exp(-tf.matmul(X_test,W)))Loss_test = -tf.reduce_mean(Y_test*tf.math.log(PRED_test)+(1-Y_test)*tf.math.log(1-PRED_test))accuracy_train = tf.reduce_mean(tf.cast(tf.equal(tf.where(PRED_train.numpy()<0.5,0.,1.),Y_train),tf.float32))accuracy_test = tf.reduce_mean(tf.cast(tf.equal(tf.where(PRED_test.numpy()<0.5,0.,1.),Y_test),tf.float32))ce_train.append(Loss_train)acc_train.append(accuracy_train)ce_test.append(Loss_test)acc_test.append(accuracy_test)dL_dW = tape.gradient(Loss_train,W)W.assign_sub(learn_rate*dL_dW)if i%display_step==0:print("i: %i, Acc: %f, Loss: %f"%(i,accuracy_train,Loss_train))


这里我们从150组数据集选取了120组作为训练集,剩余30组作为测试集;对训练集每批次30组进行训练,输出准确率及损失函数。

5、对结果可视化

#可视化
#绘制损失和准确率变化曲线
plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(ce_train,color='blue',label='train')
plt.plot(ce_test,color='red',label="test")
plt.ylabel("Loss")
plt.legend()plt.subplot(122)
plt.plot(acc_train,color='blue',label='train')
plt.plot(acc_test,color='red',label='test')
plt.ylabel("Accuracy")plt.legend()
plt.show()#绘制边界曲线
plt.scatter(x_train[:,0],x_train[:,1],c=y_train,cmap=cm_cp)
x_=[-1.5,1.5]
y_=-(W[1]*x_+W[0])/W[2]
plt.plot(x_,y_,color='g')
plt.show()


可以看到不论是训练集还是测试集,损失函数都在减小;最后准确率都达到100%(测试集最后停留在80%是因为是120/150);并且线性函数可以将其很好的分类。

TensorFlow实现鸢尾花分类相关推荐

  1. 断点续训 Pytorch 和 Tensorflow 框架 VGG16 模型 猫狗大战 鸢尾花分类

    神经网络训练模型的过程中,如果程序突然中断,竹篮打水一场空? >>>断点续训来解决! 目录 (1)Pytorch框架的断点续训(猫狗大战) (2)Tensorflow框架的断点续训( ...

  2. 机器学习 | 使用TensorFlow搭建神经网络实现鸢尾花分类

    鸢尾花分类问题是机器学习领域一个非常经典的问题,本文将利用神经网络来实现鸢尾花分类 实验环境:Windows10.TensorFlow2.0.Spyder 参考资料:人工智能实践:TensorFlow ...

  3. 【神经网络学习】鸢尾花分类的实现

    目录 1.问题 2.问题解决思路 3.神经网络理论准备 4.Tensor Flow编程基础 5. 鸢尾花分类神经网络实现 1.问题 鸢尾花分为:狗尾草鸢尾.杂色鸢尾.弗吉尼亚鸢尾: 通过测量:花萼长. ...

  4. 神经网络实现鸢尾花分类(Tensorflow2.0)

    1鸢尾花数据 回顾鸢尾花数据集,其提供了150组鸢尾花数据,每组包括鸢尾花的花萼长.花萼宽.花瓣长.花瓣宽 4个输入特征,同时还给出了这一组特征对应的鸢尾花类别.类别包括狗尾鸢尾.杂色鸢尾.弗吉尼亚鸢 ...

  5. Tensorflow2.x框架-神经网络实现鸢尾花分类

    神经网络实现鸢尾花分类 一.数据准备     1.数据集读入 2.数据集乱序 3.生成训练集和测试集(即 x_train / y_train,x_test / y_test) 4.配成(输入特征,标签 ...

  6. 最简单的单层神经网络实现鸢尾花分类

    一,知识背景 鸢尾花的分类由四个数据定义,分别是花萼长.花萼宽.花瓣长.花瓣宽.我们把这样的一组数据称为是一组特征,根据特征可以分为三类鸢尾花. 二,神经元模型 神经元采用最简单的简化MP(麦卡洛克- ...

  7. 基于Keras实现鸢尾花分类

    神经网络原理与实现(以鸢尾花分类为例) 环境准备 实现步骤 第一步:导入Keras模型库,创建模型对象 Keras构建神经网络的两种模型 导入keras库 用顺序模型的构建和使用神经网络的基本步骤 第 ...

  8. 机器学习 鸢尾花分类的原理和实现(一)

    机器学习 鸢尾花分类的原理和实现(一) 前言: 鸢尾花数据集是机器学习中的经典小规模数据集.通过查阅资料和视频进行学习,将整个实验的学习心得和实验过程分享,希望对喜爱机器学习并入门的新手提供帮助,同时 ...

  9. 神经网络实现鸢尾花分类

    神经网络实现鸢尾花分类 一.数据集介绍 共有数据150组,每组包括花萼长.花萼宽.花瓣长.花瓣宽4个输入特征. 同时给出了,这一组特征对应的鸢尾花类别.类别包括Setosa Iris(狗尾草 鸢尾), ...

最新文章

  1. python人工智能-Python和人工智能的关系,看完你就明白了!
  2. Struts2里的Action返回Json数据
  3. python提高——类(私有化,封装、继承、多态)
  4. MLAPP————第十四章 核方法
  5. Java 数组的三种创建方法
  6. cocos2dx[2.x](9)--编辑框之一CCTextFieldTTF
  7. ue4 中动画控制,利用conduit节点
  8. Android开发之殇
  9. html 输入框加搜索框,如何实现一个input搜索框?
  10. linux系统外接硬盘_Mac如何在外置硬盘上安装Linux
  11. 无法访问此网站 localhost 拒绝了我们的连接请求。
  12. 理解js执行的过程:JS运行三部曲
  13. 拟牛顿法算法的设计与实现c语言,牛顿法与拟牛顿法的故事
  14. Android工程师面试准备知识点
  15. 基于JavaWeb的收银台系统
  16. opencv4中未定义标识符CV_CAP_PROP_FPS;CV_CAP_PROP_FRAME_COUNT;CV_CAP_PROP_POS_FRAMES问题
  17. matlab的imshow, image, imagesc区别
  18. 服务器租用对比托管的优势
  19. 同一个Maven项目移机出错解决办法
  20. 微信小程序开发(四)跳转页面、传递参数获得数据

热门文章

  1. Poi excel 导出 工具类参考
  2. 给小学生科普计算机知识,小学生必懂的15个科普知识
  3. 异步电机仿真为什么转速不是0
  4. 吃货程序猿怎么区分低热量食品
  5. 自然语言表达处理笔记01—— 1.正则表达式 2.文本标记化 3.词干提取和词形还原 4.中文分词
  6. Win10 下报错 WerFault.exe -解决方法亲测有效
  7. java8新特性总结——lambda表达式
  8. 【Python技能树共建】requests-html库初识
  9. Unity 基于PDFViewer制作读取横板PDF,改为横向滑动读取并做自适应(可网络同步)
  10. 十五天学会Autodesk Inventor,看完这一系列就够了(二),软件界面