TensorFlow实现鸢尾花分类
TensorFlow实现鸢尾花分类任务
鸢尾花数据集介绍
Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
过程
1、导入所需库和数据
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt#加载数据
train_url = 'http://download.tensorflow.org/data/iris_training.csv'
train_path = tf.keras.utils.get_file(train_url.split('/')[-1],train_url)
TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1],TEST_URL)
df_iris_train = pd.read_csv(train_path,header=0)
df_iris_test = pd.read_csv(test_path,header=0)
2、处理数据并对数据进行可视化
iris_train=np.array(df_iris_train)#120行5列
iris_test=np.array(df_iris_test)
train_x=iris_train[:,0:2]#取前两个属性
train_y=iris_train[:,4]#取最后一列作为标签值
test_x=iris_test[:,0:2]
test_y=iris_test[:,4]
#提取山鸢尾和变色鸢尾
x_train=train_x[train_y<2]
y_train=train_y[train_y<2]
x_test=test_x[test_y<2]
y_test=test_y[test_y<2]num_train = len(x_train)
num_tst = len(x_test)
#绘制散点图,横坐标表示花萼长度,纵坐标表示宽度
plt.figure(figsize=(10,3))
x_train=x_train-np.mean(x_train,axis=0)#中心化处理
x_test=x_test-np.mean(x_test,axis=0)cm_cp = mpl.colors.ListedColormap(['blue','red'])plt.subplot(121)
plt.scatter(x_train[:,0],x_train[:,1],c=y_train,cmap=cm_cp)plt.subplot(122)
plt.scatter(x_test[:,0],x_test[:,1],c=y_test,cmap=cm_cp)
plt.show()
左图为训练集,右图为测试集;横坐标表示花萼长度,纵坐标表示宽度。
3、设置超参数及模型参数初始值
x0_test=np.ones(num_tst).reshape(-1,1)
X_test=tf.cast(tf.concat((x0_test,x_test),axis=1),dtype=tf.float32)
Y_test=tf.cast(y_test.reshape(-1,1),dtype=tf.float32)
#生成多元模型的属性矩阵和标签列向量
x0_train = np.ones(num_train).reshape(-1,1)
X_train=tf.cast(tf.concat((x0_train,x_train),axis=1),tf.float32)
Y_train=tf.cast(y_train.reshape(-1,1),tf.float32)x0_test=np.ones(num_tst).reshape(-1,1)
X_test=tf.cast(tf.concat((x0_test,x_test),axis=1),dtype=tf.float32)
Y_test=tf.cast(y_test.reshape(-1,1),dtype=tf.float32)
#设置超参数
learn_rate = 0.2
iters=120
display_step = 30
#设置模型参数初始值
np.random.seed(612)
W=tf.Variable(np.random.randn(3,1),dtype=tf.float32)
4、训练模型
#训练模型
ce_train = []
acc_train = []
ce_test = []
acc_test = []for i in range(iters+1):with tf.GradientTape() as tape:PRED_train = 1/(1+tf.exp(-tf.matmul(X_train,W)))Loss_train = -tf.reduce_mean(Y_train*tf.math.log(PRED_train)+(1-Y_train)*tf.math.log(1-PRED_train))PRED_test = 1/(1+tf.exp(-tf.matmul(X_test,W)))Loss_test = -tf.reduce_mean(Y_test*tf.math.log(PRED_test)+(1-Y_test)*tf.math.log(1-PRED_test))accuracy_train = tf.reduce_mean(tf.cast(tf.equal(tf.where(PRED_train.numpy()<0.5,0.,1.),Y_train),tf.float32))accuracy_test = tf.reduce_mean(tf.cast(tf.equal(tf.where(PRED_test.numpy()<0.5,0.,1.),Y_test),tf.float32))ce_train.append(Loss_train)acc_train.append(accuracy_train)ce_test.append(Loss_test)acc_test.append(accuracy_test)dL_dW = tape.gradient(Loss_train,W)W.assign_sub(learn_rate*dL_dW)if i%display_step==0:print("i: %i, Acc: %f, Loss: %f"%(i,accuracy_train,Loss_train))
这里我们从150组数据集选取了120组作为训练集,剩余30组作为测试集;对训练集每批次30组进行训练,输出准确率及损失函数。
5、对结果可视化
#可视化
#绘制损失和准确率变化曲线
plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(ce_train,color='blue',label='train')
plt.plot(ce_test,color='red',label="test")
plt.ylabel("Loss")
plt.legend()plt.subplot(122)
plt.plot(acc_train,color='blue',label='train')
plt.plot(acc_test,color='red',label='test')
plt.ylabel("Accuracy")plt.legend()
plt.show()#绘制边界曲线
plt.scatter(x_train[:,0],x_train[:,1],c=y_train,cmap=cm_cp)
x_=[-1.5,1.5]
y_=-(W[1]*x_+W[0])/W[2]
plt.plot(x_,y_,color='g')
plt.show()
可以看到不论是训练集还是测试集,损失函数都在减小;最后准确率都达到100%(测试集最后停留在80%是因为是120/150);并且线性函数可以将其很好的分类。
TensorFlow实现鸢尾花分类相关推荐
- 断点续训 Pytorch 和 Tensorflow 框架 VGG16 模型 猫狗大战 鸢尾花分类
神经网络训练模型的过程中,如果程序突然中断,竹篮打水一场空? >>>断点续训来解决! 目录 (1)Pytorch框架的断点续训(猫狗大战) (2)Tensorflow框架的断点续训( ...
- 机器学习 | 使用TensorFlow搭建神经网络实现鸢尾花分类
鸢尾花分类问题是机器学习领域一个非常经典的问题,本文将利用神经网络来实现鸢尾花分类 实验环境:Windows10.TensorFlow2.0.Spyder 参考资料:人工智能实践:TensorFlow ...
- 【神经网络学习】鸢尾花分类的实现
目录 1.问题 2.问题解决思路 3.神经网络理论准备 4.Tensor Flow编程基础 5. 鸢尾花分类神经网络实现 1.问题 鸢尾花分为:狗尾草鸢尾.杂色鸢尾.弗吉尼亚鸢尾: 通过测量:花萼长. ...
- 神经网络实现鸢尾花分类(Tensorflow2.0)
1鸢尾花数据 回顾鸢尾花数据集,其提供了150组鸢尾花数据,每组包括鸢尾花的花萼长.花萼宽.花瓣长.花瓣宽 4个输入特征,同时还给出了这一组特征对应的鸢尾花类别.类别包括狗尾鸢尾.杂色鸢尾.弗吉尼亚鸢 ...
- Tensorflow2.x框架-神经网络实现鸢尾花分类
神经网络实现鸢尾花分类 一.数据准备 1.数据集读入 2.数据集乱序 3.生成训练集和测试集(即 x_train / y_train,x_test / y_test) 4.配成(输入特征,标签 ...
- 最简单的单层神经网络实现鸢尾花分类
一,知识背景 鸢尾花的分类由四个数据定义,分别是花萼长.花萼宽.花瓣长.花瓣宽.我们把这样的一组数据称为是一组特征,根据特征可以分为三类鸢尾花. 二,神经元模型 神经元采用最简单的简化MP(麦卡洛克- ...
- 基于Keras实现鸢尾花分类
神经网络原理与实现(以鸢尾花分类为例) 环境准备 实现步骤 第一步:导入Keras模型库,创建模型对象 Keras构建神经网络的两种模型 导入keras库 用顺序模型的构建和使用神经网络的基本步骤 第 ...
- 机器学习 鸢尾花分类的原理和实现(一)
机器学习 鸢尾花分类的原理和实现(一) 前言: 鸢尾花数据集是机器学习中的经典小规模数据集.通过查阅资料和视频进行学习,将整个实验的学习心得和实验过程分享,希望对喜爱机器学习并入门的新手提供帮助,同时 ...
- 神经网络实现鸢尾花分类
神经网络实现鸢尾花分类 一.数据集介绍 共有数据150组,每组包括花萼长.花萼宽.花瓣长.花瓣宽4个输入特征. 同时给出了,这一组特征对应的鸢尾花类别.类别包括Setosa Iris(狗尾草 鸢尾), ...
最新文章
- python人工智能-Python和人工智能的关系,看完你就明白了!
- Struts2里的Action返回Json数据
- python提高——类(私有化,封装、继承、多态)
- MLAPP————第十四章 核方法
- Java 数组的三种创建方法
- cocos2dx[2.x](9)--编辑框之一CCTextFieldTTF
- ue4 中动画控制,利用conduit节点
- Android开发之殇
- html 输入框加搜索框,如何实现一个input搜索框?
- linux系统外接硬盘_Mac如何在外置硬盘上安装Linux
- 无法访问此网站 localhost 拒绝了我们的连接请求。
- 理解js执行的过程:JS运行三部曲
- 拟牛顿法算法的设计与实现c语言,牛顿法与拟牛顿法的故事
- Android工程师面试准备知识点
- 基于JavaWeb的收银台系统
- opencv4中未定义标识符CV_CAP_PROP_FPS;CV_CAP_PROP_FRAME_COUNT;CV_CAP_PROP_POS_FRAMES问题
- matlab的imshow, image, imagesc区别
- 服务器租用对比托管的优势
- 同一个Maven项目移机出错解决办法
- 微信小程序开发(四)跳转页面、传递参数获得数据
热门文章
- Poi excel 导出 工具类参考
- 给小学生科普计算机知识,小学生必懂的15个科普知识
- 异步电机仿真为什么转速不是0
- 吃货程序猿怎么区分低热量食品
- 自然语言表达处理笔记01—— 1.正则表达式 2.文本标记化 3.词干提取和词形还原 4.中文分词
- Win10 下报错 WerFault.exe -解决方法亲测有效
- java8新特性总结——lambda表达式
- 【Python技能树共建】requests-html库初识
- Unity 基于PDFViewer制作读取横板PDF,改为横向滑动读取并做自适应(可网络同步)
- 十五天学会Autodesk Inventor,看完这一系列就够了(二),软件界面