在本文中,我们将使用Python中最流行的机器学习工具Scikit-learn在Python中实现几种机器学习算法。使用简单的数据集来训练分类器以区分不同类型的水果。

本文的目的是确定最适合手头问题的机器学习算法; 因此,我们想要比较不同的算法,选择效果最好的算法。

数据

水果数据集由爱丁堡大学的Iain Murray博士创建。他买了几十个不同品种的桔子、柠檬和苹果,并记录了他们的尺寸。

让我们看一下数据的前几行。

%matplotlib inlineimport pandas as pdimport matplotlib.pyplot as pltfruits = pd.read_table('fruit_data_with_colors.txt')fruits.head()

数据集的每一行表示水果的一个部分,由列表中的几个特征表示。

我们的数据集中有59个水果和7个特征:print(fruits.shape)

(59,7)

我们的数据集中有四种类型的水果:

print(fruits['fruit_name'].unique())

['苹果''橘子(mandarin)''橙子''柠檬']

除橘子外,数据非常平衡。我们必须坚持下去。print(fruits.groupby('fruit_name').size())

import seaborn as snssns.countplot(fruits['fruit_name'],label='Count')plt.show()

可视化每个数字变量的方形图将使我们更清楚地了解输入变量的分布:fruits.drop('fruit_label', axis=1).plot(kind='box', subplots=True, layout=(2,2), sharex=False, sharey=False, figsize=(9,9), title='Box Plot for each input variable')plt.savefig('fruits_box')plt.show()

颜色分数近似于高斯分布。

import pylab as plfruits.drop('fruit_label' ,axis=1).hist(bins=30, figsize=(9,9))pl.suptitle('Histogram for each numeric input variable')plt.savefig('fruits_hist')plt.show()

一些属性对是相关的(质量和宽度)。这表明高度相关性和可预测的关系。from pandas.tools.plotting import scatter_matrixfrom matplotlib import cmfeature_names = ['mass', 'width', 'height', 'color_score']X = fruits[feature_names]y = fruits['fruit_label']cmap = cm.get_cmap('gnuplot')scatter = pd.scatter_matrix(X, c = y, marker = 'o', s=40, hist_kwds={'bins':15}, figsize=(9,9), cmap = cmap)plt.suptitle('Scatter-matrix for each input variable')plt.savefig('fruits_scatter_matrix')

统计摘要

我们可以看到数值没有相同的比例。我们需要对我们为训练集计算的测试集扩展应用。

创建训练和测试集扩展到应用。

from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)from sklearn.preprocessing import MinMaxScalerscaler = MinMaxScaler()X_train = scaler.fit_transform(X_train)X_test = scaler.transform(X_test)

构建模型

Logistic回归from sklearn.linear_model import LogisticRegressionlogreg = LogisticRegression()logreg.fit(X_train, y_train)print('Accuracy of Logistic regression classifier on training set: {:.2f}' .format(logreg.score(X_train, y_train)))print('Accuracy of Logistic regression classifier on test set: {:.2f}' .format(logreg.score(X_test, y_test)))

Logistic回归分类器在训练集上的准确率:0.70

Logistic回归分类器在测试集上的准确率:0.40

决策树

from sklearn.tree import DecisionTreeClassifierclf = DecisionTreeClassifier().fit(X_train, y_train)print('Accuracy of Decision Tree classifier on training set: {:.2f}' .format(clf.score(X_train, y_train)))print('Accuracy of Decision Tree classifier on test set: {:.2f}' .format(clf.score(X_test, y_test)))

决策树分类器在训练集上的准确率:1.00

决策树分类器在测试集上的准确率:0.73

K-Nearest Neighborsfrom sklearn.neighbors import KNeighborsClassifierknn = KNeighborsClassifier()knn.fit(X_train, y_train)print('Accuracy of K-NN classifier on training set: {:.2f}' .format(knn.score(X_train, y_train)))print('Accuracy of K-NN classifier on test set: {:.2f}' .format(knn.score(X_test, y_test)))

K-NN分类器在训练集上的准确率:0.95

K-NN分类器在测试集上的准确率:1.00

线性判别分析

from sklearn.discriminant_analysis import LinearDiscriminantAnalysislda = LinearDiscriminantAnalysis()lda.fit(X_train, y_train)print('Accuracy of LDA classifier on training set: {:.2f}' .format(lda.score(X_train, y_train)))print('Accuracy of LDA classifier on test set: {:.2f}' .format(lda.score(X_test, y_test)))

LDA分类器在训练集上的准确率:0.86

LDA分类器在测试集上的准确率:0.67

高斯朴素贝叶斯from sklearn.naive_bayes import GaussianNBgnb = GaussianNB()gnb.fit(X_train, y_train)print('Accuracy of GNB classifier on training set: {:.2f}' .format(gnb.score(X_train, y_train)))print('Accuracy of GNB classifier on test set: {:.2f}' .format(gnb.score(X_test, y_test)))

GNB分类器在训练集上的准确率:0.86

GNB分类器在测试集上的准确率:0.67

支持向量机

from sklearn.svm import SVCsvm = SVC()svm.fit(X_train, y_train)print('Accuracy of SVM classifier on training set: {:.2f}' .format(svm.score(X_train, y_train)))print('Accuracy of SVM classifier on test set: {:.2f}' .format(svm.score(X_test, y_test)))

SVM分类器在训练集上的准确率:0.61

SVM分类器在测试集上的准确率:0.33

KNN算法是我们尝试过的最准确的模型。混淆矩阵表示测试集没有发生错误。但是,测试集非常小。from sklearn.metrics import classification_reportfrom sklearn.metrics import confusion_matrixpred = knn.predict(X_test)print(confusion_matrix(y_test, pred))print(classification_report(y_test, pred))

绘制k-NN分类器的决策边界

import matplotlib.cm as cmfrom matplotlib.colors import ListedColormap, BoundaryNormimport matplotlib.patches as mpatchesimport matplotlib.patches as mpatchesX = fruits[['mass', 'width', 'height', 'color_score']]y = fruits['fruit_label']X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)def plot_fruit_knn(X, y, n_neighbors, weights): X_mat = X[['height', 'width']].as_matrix() y_mat = y.as_matrix()# Create color maps cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF','#AFAFAF']) cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF','#AFAFAF'])clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights) clf.fit(X_mat, y_mat)# Plot the decision boundary by assigning a color in the color map # to each mesh point. mesh_step_size = .01 # step size in the mesh plot_symbol_size = 50 x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1 y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size), np.arange(y_min, y_max, mesh_step_size)) Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])# Put the result into a color plot Z = Z.reshape(xx.shape) plt.figure() plt.pcolormesh(xx, yy, Z, cmap=cmap_light)# Plot training points plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold, edgecolor = 'black') plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max())patch0 = mpatches.Patch(color='#FF0000', label='apple') patch1 = mpatches.Patch(color='#00FF00', label='mandarin') patch2 = mpatches.Patch(color='#0000FF', label='orange') patch3 = mpatches.Patch(color='#AFAFAF', label='lemon') plt.legend(handles=[patch0, patch1, patch2, patch3])plt.xlabel('height (cm)')plt.ylabel('width (cm)')plt.title('4-Class classification (k = %i, weights = '%s')' % (n_neighbors, weights)) plt.show()plot_fruit_knn(X_train, y_train, 5, 'uniform')

k_range = range(1, 20)scores = []for k in k_range: knn = KNeighborsClassifier(n_neighbors = k) knn.fit(X_train, y_train) scores.append(knn.score(X_test, y_test))plt.figure()plt.xlabel('k')plt.ylabel('accuracy')plt.scatter(k_range, scores)plt.xticks([0,5,10,15,20])

对于这个特定的数据集,当k = 5时,我们获得最高的准确度

总结

在本文中,我们关注预测的准确性。我们的目标是学习具有良好泛化性能的模型。这种模型使预测精度最大化。我们确定了最适合手头问题的机器学习算法(即水果类型分类); 因此,我们比较了不同的算法并选择了性能最佳的算法。

python卖水果_用Python解决一个简单的水果分类问题相关推荐

  1. python七彩同心圆_用pygame做一个简单的python小游戏---七彩同心圆

    用pygame做一个简单的python小游戏---七彩同心圆 用pygame做一个简单的python小游戏-七彩同心圆 这个小游戏原是我同学python课的课后作业,并不是很难,就简单实现了一下,顺便 ...

  2. python对象引用计数器_在Python中借助计数器对象对项目进行计数

    python对象引用计数器 前提 (The Premise) When we deal with data containers, such as tuples and lists, in Pytho ...

  3. python 概率分布模型_使用python的概率模型进行公司估值

    python 概率分布模型 Note from Towards Data Science's editors: While we allow independent authors to publis ...

  4. python 时间序列预测_使用Python进行动手时间序列预测

    python 时间序列预测 Time series analysis is the endeavor of extracting meaningful summary and statistical ...

  5. python小项目实例流程-Python小项目:快速开发出一个简单的学生管理系统

    原标题:Python小项目:快速开发出一个简单的学生管理系统 本文根据实际项目中的一部分api 设计抽象出来,实例化成一个简单小例子,暂且叫作「学生管理系统」. 这个系统主要完成下面增删改查的功能: ...

  6. python小项目案例-Python小项目:快速开发出一个简单的学生管理系统

    本文根据实际项目中的一部分api 设计抽象出来,实例化成一个简单小例子,暂且叫作「学生管理系统」. 这个系统主要完成下面增删改查的功能: 包括: 学校信息的管理 教师信息的管理 学生信息的管理 根据A ...

  7. python项目开发实例-Python小项目:快速开发出一个简单的学生管理系统

    本文根据实际项目中的一部分api 设计抽象出来,实例化成一个简单小例子,暂且叫作「学生管理系统」. 这个系统主要完成下面增删改查的功能: 包括: 学校信息的管理 教师信息的管理 学生信息的管理 根据A ...

  8. python股票接口_股市python接口

    广告关闭 腾讯云11.11云上盛惠 ,精选热门产品助力上云,云服务器首年88元起,买的越多返的越多,最高返5000元! 最近发现一个很有趣的事情,受到全球经济大环境的影响,周围越来越多的人开始关注股市 ...

  9. 用Python实现音频卷积,并制作一个简单的HRTF效果

    用Python实现音频卷积,并制作一个简单的HRTF效果 作为一个刚刚入门Python的小白用户,写出这篇文章还是废了我很大的力气,不过幸运的是,在网上到处东拼西凑,我还是把它给做出来了. 废话不多说 ...

  10. Python开发第一步:如何制作一个简单的桌面应用

    Python开发第一步:如何制作一个简单的桌面应用 前言 大家好,我是baifagg, 一个热爱Python的编程爱好者. 今天我们来学习一下, 如何用Python制作一个简单的桌面应用程序. 虽然桌 ...

最新文章

  1. pandas使用bdate_range函数获取起始时间(start)和结束时间(end)范围内的所有工作日日期(business day)
  2. Android selector 使用
  3. loadrunner—参数化
  4. centos7 3行命令安装powershell
  5. Win7 64有点找不到MSVCP71.DLL和MSVCR71.dll
  6. Linux学习总结(8)——VMware v12.1.1 专业版以及永久密钥
  7. day 45 SQLAlchemy,和增删查改
  8. vxworks驱动开发基础
  9. Java学习系列(十七)Java面向对象之开发聊天工具
  10. 错别字检测的软件有哪些?自动检查错别字的工具 文字校对 文本纠错 查错别字 校对软件 错别字检查 论文格式 在线校对
  11. Python3.GrADS的二进制码数据
  12. Python开发之路(1)— 使用Pyaudio进行录音和播音
  13. 0712CF解题报告
  14. 2019前端面试常问
  15. 位地址和字节地址换算_IP地址详解
  16. 北斗导航公共服务平台首次落户四川
  17. win10全屏之后任务栏不消失的问题
  18. 计算机课数据排序与筛选ppt,《计算机应用基础》PPT课件
  19. 室内定位方案之化工厂访客定位监测系统,一种室内定位管理方案
  20. 专访智链ChainNova CTO谢文杰:区块链容器化与水平扩展实践

热门文章

  1. AUTOCAD--实时缩放
  2. ANC降噪耳机声学参数合成与校准的2种方式
  3. 精彩Linux 篇章
  4. python图像锐化_(python 图像锐化教程)C 实现bmp图像锐化后,锐化的效果很差,求大神帮忙啊...
  5. 博士毕业论文英文参考文献换行_Endnote教程丨本科研究生毕业论文参考文献格式模板,一键搞定...
  6. 2019年java全栈工程师学习大全
  7. 【MySQL】数据库的函数使用
  8. SAP数据接口技术类型
  9. 【转载】FPGA配置方式
  10. mysql 以空间换时间专研