oneR即“一条规则”。oneR算法根据已有的数据中,具有相同特征值的个体最可能属于哪个类别来进行分类。
以鸢尾data为例,该算法实现过程可解读为以下六步:

文章目录

  • 一、 导包与获取数据
  • 二、划分为训练集和测试集
  • 三、定义函数:获取某特征值出现次数最多的类别及错误率
  • 四、定义函数:获取每个特征值下出现次数最多的类别、错误率
  • 五、调用函数,获取最佳特征值
  • 六、测试算法

一、 导包与获取数据

以均值为阈值,将大于或等于阈值的特征标记为1,低于阈值的特征标记为0。

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from collections import defaultdict
from operator import itemgetter
import warnings
from sklearn.metrics import classification_report# 加载内置iris数据,并保存
dataset = load_iris()
X = dataset.data
y = dataset.targetattribute_means = X.mean(axis=0)  # 得到一个列表,列表元素个数为特征值个数,列表值为每个特征的均值
X_d = np.array(X >= attribute_means, dtype='int')  # 转bool类型

数据到此已获取完毕,接下来将其划分为训练集和测试集。

二、划分为训练集和测试集

使用默认的0.25作为分割比例。即训练集:测试集=3:1。

X_train, X_test, y_train, y_test = train_test_split(X_d, y, random_state=random_state)

数据描述:
本例中共有四个特征,
原数据集有150个样本,分割后训练集有112个数据,测试集有38个数据。
标签一共分为三类,取值可以是0,1,2。

三、定义函数:获取某特征值出现次数最多的类别及错误率

首先遍历特征的每一个取值,对于每一个特征值,统计它在各个类别中出现的次数。
定义一个函数,有以下四个参数:

  • X, y_true即 训练集数据和标签
  • feature是特征的索引值,可以是0,1,2,3。
  • value是特征可以有的取值,这里为0,1。

该函数的意义在于,对于训练集数据,对于某个特征,依次遍历样本在该特征的真实取值,判断其是否等于特征的某个可以有的取值 (即value)(以0为例)。如果判定成功,则在字典class_counts中记录,以三个类别(0,1,2)中该样本对应的类别为键值,表示该类别出现的次数加一。

首先得到的字典(class_counts)形如:
{0: x1, 1.0: x2, 2.0:x3}
其中元素不一定是三个
x1:类别0中,某个特征feature的特征值为value(0或1)出现的次数
x2:类别0中,某个特征feature的特征值为value(0或1)出现的次数
x3:类别0中,某个特征feature的特征值为value(0或1)出现的次数

然后将class_counts按照值的大小排序,取出指定特征的特征值出现次数最多的类别:most_frequent_class。
该规则即为:该特征的该特征值出现在其出现次数最多的类别上是合理的,出现在其它类别上是错误的。

最后计算该规则的错误率:error
错误率具有该特征的个体在除出现次数最多的类别出现的次数,代表分类规则不适用的个体的数量

最后返回待预测的个体类别错误率

def train_feature_value(X, y_true, feature, value):class_counts = defaultdict(int)for sample, y_t in zip(X, y_true):if sample[feature] == value:class_counts[y_t] += 1sorted_class_counts = sorted(class_counts.items(), key=itemgetter(1), reverse=True) # 降序most_frequent_class = sorted_class_counts[0][0]error = sum([class_count for class_value, class_count in class_counts.items()if class_value != most_frequent_class])return most_frequent_class, error

返回值most_frequent_class是一个字典, error是一个数字

四、定义函数:获取每个特征值下出现次数最多的类别、错误率

def train(X, y_true, feature):n_samples, n_features = X.shapeassert 0 <= feature < n_features# 获取样本中某特征所有可能的取值values = set(X[:, feature])predictors = dict()errors = []for current_value in values:most_frequent_class, error = train_feature_value(X, y_true, feature, current_value)predictors[current_value] = most_frequent_classerrors.append(error)total_error = sum(errors)return predictors, total_error

因为most_frequent_class是一个字典,所以predictors是一个键为特征可以的取值(0和1),值为字典most_frequent_class的 字典。
total_error是一个数字,为每个特征值下的错误率的和。

五、调用函数,获取最佳特征值

all_predictors = {variable: train(X_train, y_train, variable) for variable in range(X_train.shape[1])}
Errors = {variable: error for variable, (mapping, error) in all_predictors.items()}
# 找到错误率最低的特征
best_variable, best_error = sorted(Errors.items(), key=itemgetter(1))[0]  # 升序
print("The best model is based on feature {0} and has error {1:.2f}".format(best_variable, best_error))
# 找到最佳特征值,创建model模型
model = {'variable': best_variable,'predictor': all_predictors[best_variable][0]}
print(model)


根据代码运行结果,最佳特征值是特征2(索引值为2的feature,即第三个特征)。

对于初学者这里的代码逻辑比较复杂,可以对变量进行逐个打印查看,阅读blog学习时要盯准字眼,细品其逻辑。

print(all_predictors)
print(all_predictors[best_variable])
print(all_predictors[best_variable][0])

六、测试算法

定义预测函数,对测试集数据进行预测

def predict(X_test, model):variable = model['variable']predictor = model['predictor']y_predicted = np.array([predictor[int(sample[variable])] for sample in X_test])return y_predicted# 对测试集数据进行预测
y_predicted = predict(X_test, model)
print(y_predicted)

预测结果:

# 统计预测准确率
accuracy = np.mean(y_predicted == y_test) * 100
print("The test accuracy is {:.1f}%".format(accuracy))


根据打印结果,该模型预测的准确率可达65.8%,对于只有一条规则的oneR算法而言,结果是比较良好的。到此便实现了oneR算法的一次完整应用。


最后,还可以使用classification_report()方法,传入测试集的真实值和预测值,打印出模型评估报告。

# 屏蔽警告
warnings.filterwarnings("ignore")
# 打印模型评估报告
print(classification_report(y_test, y_predicted))  # 参数为测试集的真实数据和预测数据

小啾祝您学习顺利!

python机器学习实现oneR算法 以鸢尾data为例相关推荐

  1. python机器学习手写算法系列——逻辑回归

    从机器学习到逻辑回归 今天,我们只关注机器学习到线性回归这条线上的概念.别的以后再说.为了让大家听懂,我这次也不查维基百科了,直接按照自己的理解用大白话说,可能不是很严谨. 机器学习就是机器可以自己学 ...

  2. 大数据基石python学习_资源 | 177G Python/机器学习/深度学习/算法/TensorFlow等视频,涵盖入门/中级/项目各阶段!...

    原标题:资源 | 177G Python/机器学习/深度学习/算法/TensorFlow等视频,涵盖入门/中级/项目各阶段! 这是一份比较全面的视频教程,基本上包括了市面上所有关于机器学习,统计学习, ...

  3. python机器学习手写算法系列——线性回归

    本系列另一篇文章<决策树> https://blog.csdn.net/juwikuang/article/details/89333344 本文源代码: https://github.c ...

  4. python机器学习手写算法系列——kmeans聚类

    从机器学习到kmeans 聚类是一种非监督学习,他和监督学习里的分类有相似之处,两者都是把样本分布到不同的组里去.区别在于,分类分析是有标签的,聚类是没有标签的.或者说,分类是有y的,聚类是没有y的, ...

  5. Python 机器学习/深度学习/算法专栏 - 导读目录

    目录 一.简介 二.机器学习 三.深度学习 四.数据结构与算法 五.日常工具 一.简介 Python 机器学习.深度学习.算法主要是博主从研究生到工作期间接触的一些机器学习.深度学习以及一些算法的实现 ...

  6. python 机器学习——从感知机算法到各种最优化方法的应用(python)

    一 准备 1 数据集 2 基本工具 21 pandasread in data 22 numpyprocess data 23 matplotlibvisualize data 二 基本概念与定义 三 ...

  7. [python] 机器学习 随机森林算法RandomForestRegressor

    前言 随机森林Python版本有很可以调用的库,使用随机森林非常方便,主要用到以下的库 sklearn Scikit learn 也简称 sklearn, 是机器学习领域当中最知名的 python 模 ...

  8. python机器学习 | K近邻算法学习(1)

    K近邻算法学习 1 K近邻算法介绍 1.1算法定义 1.2算法原理 1.3算法讨论 1.3.1 K值选择 1.3.2距离计算 1.3.3 KD树 2 K近邻算法实现 2.1scikit-learn工具 ...

  9. Python机器学习:KNN算法03训练数据集,测试数据集train test split

    示例代码 首先引入相关包 import numpy as np import matplotlib.pyplot as plt from sklearn import datasets import ...

最新文章

  1. 滴滴CEO程维:当初把产品拿给美团王兴看,他说了两个字“垃圾”!
  2. mysql_connect 废弃_解决Deprecated: mysql_connect():
  3. 通过VsPhere体验MAC OS X
  4. 一些通过SAP ABAP代码审查得出的ABAP编程代码优化建议
  5. 【最小环】【Floyed】观光旅游(ssl 1763)
  6. bbs.php168,PHP168 下载安装教程
  7. Android插件GsonFormat
  8. spring boot session超时设置
  9. Excel常用电子表格公式大全1-2
  10. SysML-Sec: A Model-Driven Environment for Developing Secure Embedded Systems
  11. 免费、好用、强大的开源笔记软件综合评测
  12. 【英语:基础进阶_核心词汇扩充】E4.常见词根拓词
  13. 织梦cms怎么上传html模板,织梦dedecms 本地模板安装图文方法
  14. 假设有一段英文,其中有单独的字母I误写为i,请编写程序进行纠正。
  15. 如何优雅地使用Sublime Text
  16. jquery向服务器发送ajax请求标准写法
  17. 没有shell63号单元_如何在Ansys/lsdyna中给Shell 163 赋值变厚度(注:不是Ansys下的Shell63号单元)...
  18. Win10删除右键菜单快捷键方法
  19. 通过SSH远程控制服务器
  20. 完美复刻小米路由器Misstar Tools(MT工具箱)BY:蜜罐版

热门文章

  1. 天梯赛题目练习——打印杨辉三角(附带PTA测试点)
  2. ROS自主驾驶割草机
  3. 无为而无不为和企业管理
  4. 一加5t ,安卓p系统卡在更新页面,安卓p降级教程(一加5t测试通过)
  5. 运营商开始悄悄火拼5G价格战,19元套餐开始涌现
  6. 【转载】UEBA架构设计之路
  7. 当复制Web浏览器的SVN地址到TorioseSVN上时显示错误,无法解析URL
  8. python证件照_python实现证件照换底功能
  9. 表单-微信小程序前端制作切片演示
  10. PDF如何修改,PDF怎么删除其中一页