【机器学习】sklearn机器学习入门案例——使用k近邻算法进行鸢尾花分类
1 背景
这个案例恐怕已经被说的很烂了,机器学习方面不同程度的人对该案例还是有着不同的感觉。有的人追求先理解机器学习背后的理论甚至自己推导一遍相关数学公式,再用代码实现;有的人则满足于能够实现相关功能即可。凡是都有两面性,理解算法背后原理,再去实现相关算法,这个对算法理解深刻,更能融会贯通,拓展性强,但是需要有一定的数学基础以及要花费一段时间;若能够实现相关算法,知道各个参数的意义,也是能够尽快处理相关的任务,但是可拓展性就不那么好了。
当前,对于传统的机器学习算法进行很好的实现的Python包也当属sklearn了,本文更注重使用sklearn提供的算法包去完成鸢尾花分类任务,也不用去把相关算法去逐一实现(就不要去造轮子了)。对于相关算法理论只作简要介绍。
2 任务背景
- 假设你有一份数据集:有很多不同类别的鸢尾花,每一条数据有多个特征:花瓣(petal)长、宽、花萼(sepal)长、宽,以及这条数据对应的类别,也就是说有4个特征,1个标签。
- 任务是:使用这些数据采用监督学习的机器学习相关算法训练一个模型,然后对没有类别,只有花瓣长宽、花萼长宽的数据预测其所属类别
- 数据背景:这里使用的sklearn中自带的数据集,其中鸢尾花的类别有三种setosa、versicolor、virginica,相关数据将在后面的内容详细介绍。
- 实验环境:Python3.7, sklearn 0.22.1
3 了解一下数据
from sklearn.datasets import load_irisiris_dataset = load_iris()
print("keys of iris_dataset:{}\n".format(iris_dataset.keys()))
# 结果
"""
keys of iris_dataset:dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
"""
可通过以下方式查看数据形式:
import numpy as np
for key in iris_dataset.keys():print('current key:"{}", key type:{}'.format(key, type(iris_dataset[key])))# 如果为np.ndarray 可以说明是训练数据以及对应的标签if isinstance(iris_dataset[key], np.ndarray):print(iris_dataset[key][0])elif isinstance(iris_dataset[key], str):print(iris_dataset[key][:150])else:print(iris_dataset[key]))
结果如如下:
current key:"data", key type:<class 'numpy.ndarray'>
[5.1 3.5 1.4 0.2]
current key:"target", key type:<class 'numpy.ndarray'>
0
current key:"target_names", key type:<class 'numpy.ndarray'>
setosa
current key:"DESCR", key type:<class 'str'>
.. _iris_dataset:Iris plants dataset
--------------------**Data Set Characteristics:**:Number of Instances: 150 (50 in each of three classes
current key:"feature_names", key type:<class 'list'>
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
current key:"filename", key type:<class 'str'>
C:\software\Development\Python\Anaconda3\envs\commEnv\lib\site-packages\sklearn\datasets\data\iris.csv
数据的具体介绍可以通过iris_dataset[“DESCR”]查看,其他数据也比较清楚了,如:每一条数据如:[5.1 3.5 1.4 0.2]对应的特征名称为:花萼长、宽,花瓣长、宽,数据样本一共有150条;iris_dataset[“target”]是使用0,1,2分别代表setosa,versicolor,virginica这三个类别的。
4 数据预处理
数据只有150,需要将一部分数据拿出来训练得到一个模型,剩余的数据留出来验证模型的泛化能力,也就是看这个模型对不在训练集中的数据的识别能力。当然有时候将还会将数据集划再划分一个验证集,这是因为在训练集中可以构建多个模型,通过训练集选出一个泛化能力更好的模型,然后在使用这个更好的模型在测试集数据上进行测试,这些复杂的过程(根据哪些指标在备选的模型中挑选最好等)暂且不谈。
这里使用scikit-learn中的train_test_split函数打乱数据集(原始数据是每一类都在一起,需要将这个顺序打乱)并拆分为两个部分,默认将75%的数据划分为训练集train_data,25%的数据划分为测试集test_data。操作如下
from sklearn.model_selection import train_test_split
# 划分数据时设置随机数种子,random_state 便于实验的复现
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
# 查看划分后的数据集情况
print("X_train shape:{}".format(X_train.shape))
print("y_train shape:{}".format(y_train.shape))
print("X_test shape:{}".format(X_test.shape))
print("y_test shape:{}".format(y_test.shape))
"""
X_train shape:(112, 4)
y_train shape:(112,)
X_test shape:(38, 4)
y_test shape:(38,)
"""
数据划分好了,也不要急着去构建机器学习模型,我们先来观察一下数据,可视化一下。由于数据集有4个特征,我们绘制散点图矩阵,也就是查看数据的两两特征之间的情况。绘制代码如下:
import pandas as pd
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
iris_dataframe.head(5)
# 绘图
grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins':20}, s=60, alpha=0.8)
从上面的图可以看出,使用数据集中的几个特征是可以将各个类型的数据区别开的,也就是说使用机器学习算法也就能够区分开了。
5 模型构建
在机器学习中,分类的算法有很多,这里使用一个简单比较容易理解的KNN算法来分类。这里的K的含义是,待预测的数据与训练集中最近的任意k个邻居,根据这k个邻居的类别确定这个预测数据的类别。下面是构建模型的代码:
from sklearn.neighbors import KNeighborsClassifier
# 1 构建模型对象,设置相关参数
knn = KNeighborsClassifier(n_neighbors=1) # 根据最近的一个训练数据来确定类别
# 2 训练模型
knn.fit(X_train, y_train)
# knn会被返回,如下:
"""
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',metric_params=None, n_jobs=None, n_neighbors=1, p=2,weights='uniform')
"""
当然训练后的一个knn模型包含很多参数,很多参数用于速度优化或非常特殊的用途。
6 模型评估
这里我们使用一个简单的指标精度,
y_pred = knn.predict(X_test)
print('Test set predictions:{}\n'.format(y_pred))
print("Test set accuracy:{:.2f}%".format(100*np.mean(y_pred==y_test)))
结果如下:
Test set predictions:[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 02]
Test set accuracy:97.37%
总体来看效果还是可以的。
7 预测数据
我们根据训练模型的数据构建预测数据的格式,例如:X_new=[[5, 2.9, 1, 0.2]]
,预测代码如下:
X_new=[[5, 2.9, 1, 0.2]]
# 开始预测
prediction=knn.predict(X_new)
# 预测结果
print("Prediction:{}".format(prediction))
print("Predicted target name:{}".format(iris_dataset['target_names'][prediction]))
"""
Prediction:[0]
Predicted target name:['setosa']
"""
8 总结
其实上面是一个使用sklearn的一个简单入门,不管是理论还是模型调参所涉及的内容都还没介绍太多,不过上面也涉及使用sklearn进行机器学习的大部分步骤。当然也不要被当前深度学习“遍地飞”所震慑,传统的机器学习算法依然在一些领域中发挥着很重要的作用。当然,需要在这个领域中还要进行更深一步的学习和深入。当然本部分内容是参考《Python机器学习基础教程》内容并结合自己的理解写出,所以我还是推荐一下这本书,或者可以在订阅号“AIAS编程有道”中回复“Python机器学习基础教程”获取电子档后决定是否要购买,建议购买正版书籍。
课程推荐
图解Python数据结构与算法-实战篇
【机器学习】sklearn机器学习入门案例——使用k近邻算法进行鸢尾花分类相关推荐
- 【机器学习入门】(1) K近邻算法:原理、实例应用(红酒分类预测)附python完整代码及数据集
各位同学好,今天我向大家介绍一下python机器学习中的K近邻算法.内容有:K近邻算法的原理解析:实战案例--红酒分类预测.红酒数据集.完整代码在文章最下面. 案例简介:有178个红酒样本,每一款红酒 ...
- 基于K-最近邻算法构建鸢尾花分类模型
基于K-最近邻算法构建鸢尾花分类模型 一 任务描述 鸢尾花(Iris)数据集是机器学习中一个经典的数据集.假设有一名植物学爱好者收集了150朵鸢尾花的测量数据:花瓣的长度和宽度以及花萼的长度和宽度,这 ...
- chapter2 机器学习之KNN(k-nearest neighbor algorithm)--K近邻算法从原理到实现
一.引入 K近邻算法作为数据挖掘十大经典算法之一,其算法思想可谓是intuitive,就是从训练集里找离预测点最近的K个样本来预测分类 因为算法思想简单,你可以用很多方法实现它,这时效率就是我们需要慎 ...
- 【计算机视觉之三】运用k近邻算法进行图片分类
这篇文章主要给不知道计算机视觉是啥的人介绍一下图像分类问题以及最近的最近邻算法. 目录 图像分类 1.1 图像分类的原理 1.2 面临的问题 1.3 图像分类任务 最近邻算法 代码实现 L2距离 用k ...
- 一文搞懂K近邻算法(KNN),附带多个实现案例
简介:本文作者为 CSDN 博客作者董安勇,江苏泰州人,现就读于昆明理工大学电子与通信工程专业硕士,目前主要学习机器学习,深度学习以及大数据,主要使用python.Java编程语言.平时喜欢看书,打篮 ...
- 09_分类算法--k近邻算法(KNN)、案例、欧氏距离、k-近邻算法API、KNeighborsClassifier、及其里面的案例(网络资料+学习资料整理笔记)
1 分类算法–k近邻算法(KNN) 定义:如果一个样本在特征空间中**k个最相似(即特征空间中最邻近)**的样本中的大多数属于某一个类别,则该样本也属于这个类别,则该样本也属于这个类别. k-近邻算法 ...
- K近邻算法的Python实现
作为『十大机器学习算法』之一的K-近邻(K-Nearest Neighbors)算法是思想简单.易于理解的一种分类和回归算法.今天,我们来一起学习KNN算法的基本原理,并用Python实现该算法,最后 ...
- 统计学习方法笔记(一)-k近邻算法原理及python实现
k近邻法 k近邻算法 算法原理 距离度量 距离度量python实现 k近邻算法实现 案例地址 k近邻算法 kkk近邻法(kkk-NN)是一种基本分类和回归方法. 算法原理 输入:训练集 T={(x1, ...
- C++实现的简单k近邻算法(K-Nearest-Neighbour,K-NN)
C++实现的简单的K近邻算法(K-Nearest Neighbor,K-NN) 前一段时间学习了K近邻算法,对K近邻算法有了一个初步的了解,也存在一定的问题,下面我来简单介绍一下K近邻算法.本博客将从 ...
最新文章
- python 批量下载 代码_Python + Selenium +Chrome 批量下载网页代码修改
- Vue 组件间的通讯
- Python基础(list和tuple)可变集合和‘不可变’集合
- leetcode -eleven:Container With Most Water
- 手把手教你做产品经理,视频课教程已经发布,欢迎观看
- android 导航动画,安利一个Android导航库
- java多线程之Executor框架
- Android中动态初始化布局参数以及ConstraintLayout使用中遇到的坑
- 量子计算机 程序,量子计算机程序 会早于量子计算机出现
- HTML和XHTML解析(HTMLParser、BeautifulSoup)
- 分计算iv值_【美股期權】多高的IV才算高?理解IV percentile
- java tolist_Java Collectors toList()用法及代码示例
- IERS EOP 文件的解读
- WEB/HTTP服务器搭建
- 微信公众号网页授权--前端获取code及用户信息(vue)
- Java设计模式 Design Pattern:包装模式 Decorator Pattern
- python爬取酷狗音乐_Python爬取酷狗音乐
- Outlook Business Contact Manager 2010入门
- 中国没有乔布斯,美国没有史玉柱
- pt-archive使用
热门文章
- 直积与张量积的数学定义与物理定义异同
- 【3d建模】全网最全3dmax快捷键【附软件安装包和角色基础教程下载】
- c语言输入名字判断姓是否缩写,C语言复习笔记
- 计算机无法连接蓝牙键盘,电脑如何连接无线键盘_电脑上怎么连接蓝牙键盘-win7之家...
- Codeforces 821B Okabe and Banana Trees 题解
- 空间直线的最小二乘拟合
- 细说Java性能测试第三课 性能测试详解2
- 无线网服务器端口是什么,无线路由器上的lan端口是什么意思?
- MFC OnFileNew OnFileOpen过程分析代码(一)
- 一张专家推荐的最健康的作息时间表,你能做到吗?