作者 | Ocktavia Nurima Putri

来源 | Medium

编辑 | 代码医生团队

在这篇文章中,将使用Scikit-learn在Python中实现几种机器学习算法。将使用一个简单的数据集来训练分类器,以区分不同类型的水果。在数据科学中存在的每种方法中,希望比较不同的算法并选择最适合的算法。

数据

水果数据集由爱丁堡大学的Iain Murray博士创建,然后密歇根大学的教授稍微格式化了水果数据,可以从这里下载。

https://github.com/susanli2016/Machine-Learning-with-Python/blob/master/fruit_data_with_colors.txt

%matplotlib inline

import pandas as pd

import matplotlib.pyplot as plt

fruits = pd.read_table('fruit_data_with_colors.txt')

fruits.head()

图1

在数据的前几行中,数据集的每一行代表一个水果,由表格列中的几个要素表示,如标签,名称,子类型,质量,宽度,高度和颜色。

通过下面的脚本,知道数据集中有59个水果和7个特征。

print(fruits.shape)

(59, 7)

然后,有四种水果。

print(fruits['fruit_name'].unique())

[‘apple’ ‘mandarin’ ‘orange’ ‘lemon’]

通过下面的输出,可以假设数据非常平衡,除了mandarin。

print(fruits.groupby('fruit_name').size())

图2

import seaborn as sns

sns.countplot(fruits['fruit_name'],label="Count")

plt.show()

图3

可视化

每个数字变量的箱形图将更清楚地了解输入变量的分布:

fruits.drop('fruit_label', axis=1).plot(kind='box', subplots=True, layout=(2,2), sharex=False, sharey=False, figsize=(9,9),

title='Box Plot for each input variable')

plt.savefig('fruits_box')

plt.show()

图4

看起来颜色分数可能接近高斯分布。

import pylab as pl

fruits.drop('fruit_label' ,axis=1).hist(bins=30, figsize=(9,9))

pl.suptitle("Histogram for each numeric input variable")

plt.savefig('fruits_hist')

plt.show()

图5

一些属性对是相关的(质量和宽度)。这表明高度相关性和可预测的关系。

from pandas.tools.plotting import scatter_matrix

from matplotlib import cm

feature_names = ['mass', 'width', 'height', 'color_score']

X = fruits[feature_names]

y = fruits['fruit_label']

cmap = cm.get_cmap('gnuplot')

scatter = pd.scatter_matrix(X, c = y, marker = 'o', s=40, hist_kwds={'bins':15}, figsize=(9,9), cmap = cmap)

plt.suptitle('Scatter-matrix for each input variable')

plt.savefig('fruits_scatter_matrix')

图6

统计摘要

图7

可以看到数值不具有相同的比例。需要将缩放应用于为训练集计算的测试集。

创建训练和测试集并应用缩放

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()

X_train = scaler.fit_transform(X_train)

X_test = scaler.transform(X_test)

构建模型

Logistic回归

from sklearn.linear_model import LogisticRegression

logreg = LogisticRegression()

logreg.fit(X_train, y_train)

print('Accuracy of Logistic regression classifier on training set: {:.2f}'

.format(logreg.score(X_train, y_train)))

print('Accuracy of Logistic regression classifier on test set: {:.2f}'

.format(logreg.score(X_test, y_test)))

Logistic回归分类器在训练集上的准确度:0.70

Logistic回归分类器在测试集上的准确度:0.40

决策树

from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier().fit(X_train, y_train)

print('Accuracy of Decision Tree classifier on training set: {:.2f}'

.format(clf.score(X_train, y_train)))

print('Accuracy of Decision Tree classifier on test set: {:.2f}'

.format(clf.score(X_test, y_test)))

训练集

上决策树分类器的准确度:1.00

测试集上决策树分类器的准确度:0.73

K-Nearest Neighbors

from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier()

knn.fit(X_train, y_train)

print('Accuracy of K-NN classifier on training set: {:.2f}'

.format(knn.score(X_train, y_train)))

print('Accuracy of K-NN classifier on test set: {:.2f}'

.format(knn.score(X_test, y_test)))

训练集上K-NN分类器的准确度:0.95

测试集上K-NN分类器的准确度:1.00

线性判别分析

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

lda = LinearDiscriminantAnalysis()

lda.fit(X_train, y_train)

print('Accuracy of LDA classifier on training set: {:.2f}'

.format(lda.score(X_train, y_train)))

print('Accuracy of LDA classifier on test set: {:.2f}'

.format(lda.score(X_test, y_test)))

LDA分类器在训练集上的准确度:0.86

LDA分类器在测试集上的准确度:0.67

高斯朴素贝叶斯

from sklearn.naive_bayes import GaussianNB

gnb = GaussianNB()

gnb.fit(X_train, y_train)

print('Accuracy of GNB classifier on training set: {:.2f}'

.format(gnb.score(X_train, y_train)))

print('Accuracy of GNB classifier on test set: {:.2f}'

.format(gnb.score(X_test, y_test)))

GNB分类器对训练集的准确度:0.86

GNB分类器在测试集上的准确度:0.67

支持向量机

from sklearn.svm import SVC

svm = SVC()

svm.fit(X_train, y_train)

print('Accuracy of SVM classifier on training set: {:.2f}'

.format(svm.score(X_train, y_train)))

print('Accuracy of SVM classifier on test set: {:.2f}'

.format(svm.score(X_test, y_test)))

SVM分类器在训练集上的准确度:0.61

测试集上SVM分类器的准确度:0.33

KNN算法是尝试过的最准确的模型。混淆矩阵提供了对测试集没有错误的指示。但是,测试集非常小。

from sklearn.metrics import classification_report

from sklearn.metrics import confusion_matrix

pred = knn.predict(X_test)

print(confusion_matrix(y_test, pred))

print(classification_report(y_test, pred))

图7

绘制k-NN分类器的决策边界

import matplotlib.cm as cm

from matplotlib.colors import ListedColormap, BoundaryNorm

import matplotlib.patches as mpatches

import matplotlib.patches as mpatches

X = fruits[['mass', 'width', 'height', 'color_score']]

y = fruits['fruit_label']

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

def plot_fruit_knn(X, y, n_neighbors, weights):

X_mat = X[['height', 'width']].as_matrix()

y_mat = y.as_matrix()

# Create color maps

cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF','#AFAFAF'])

cmap_bold  = ListedColormap(['#FF0000', '#00FF00', '#0000FF','#AFAFAF'])

clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights)

clf.fit(X_mat, y_mat)

# Plot the decision boundary by assigning a color in the color map

# to each mesh point.

mesh_step_size = .01  # step size in the mesh

plot_symbol_size = 50

x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1

y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1

xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),

np.arange(y_min, y_max, mesh_step_size))

Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

# Put the result into a color plot

Z = Z.reshape(xx.shape)

plt.figure()

plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

# Plot training points

plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold, edgecolor = 'black')

plt.xlim(xx.min(), xx.max())

plt.ylim(yy.min(), yy.max())

patch0 = mpatches.Patch(color='#FF0000', label='apple')

patch1 = mpatches.Patch(color='#00FF00', label='mandarin')

patch2 = mpatches.Patch(color='#0000FF', label='orange')

patch3 = mpatches.Patch(color='#AFAFAF', label='lemon')

plt.legend(handles=[patch0, patch1, patch2, patch3])

plt.xlabel('height (cm)')

plt.ylabel('width (cm)')

plt.title("4-Class classification (k = %i, weights = '%s')"

% (n_neighbors, weights))

plt.show()

plot_fruit_knn(X_train, y_train, 5, 'uniform')

图8

k_range = range(1, 20)

scores = []

for k in k_range:

knn = KNeighborsClassifier(n_neighbors = k)

knn.fit(X_train, y_train)

scores.append(knn.score(X_test, y_test))

plt.figure()

plt.xlabel('k')

plt.ylabel('accuracy')

plt.scatter(k_range, scores)

plt.xticks([0,5,10,15,20])

图9

对于这个特定的日期集,当k = 5时,获得最高的准确度。

结论

在这篇文章中,关注预测的准确性。目标是确定最适合手头问题的机器学习算法。因此比较了不同的算法并选择了性能最佳的算法。

源代码:

https://github.com/susanli2016/Machine-Learning-with-Python/blob/master/Solving%20A%20Simple%20Classification%20Problem%20with%20Python.ipynb

关于图书

《深度学习之TensorFlow:入门、原理与进阶实战》和《Python带我起飞——入门、进阶、商业实战》两本图书是代码医生团队精心编著的 AI入门与提高的精品图书。配套资源丰富:配套视频、QQ读者群、实例源码、 配套论坛:http://bbs.aianaconda.com。更多请见:aianaconda.com

点击“阅读原文”配套图书资源

python输入水果求个数问题_水果爱好者:用Python解决一个简单的分类问题相关推荐

  1. python自学篇十[ 面向对象 (四) :王者荣耀小游戏+模拟一个简单的银行进行业务办理的类]

    python基础系列: python自学篇一[ Anaconda3安装 ] python自学篇二[ pycharm安装及使用 ] python自学篇三[ 判断语句if的使用 ] python自学篇四[ ...

  2. python输入成绩求总分和平均分_python脚本如何输入成绩求平均分?

    python脚本如何输入成绩求平均分? python脚本输入成绩求平均分的方法: 脚本要实现功能: 1.输入学生学号: 2.依次输入学生的三门科目成绩: 3.计算该学生的平均成绩,并打印: 4.平均成 ...

  3. python的输入函数是什么意思_「小白学Python」像风一样自由的输入:input( )函数详解...

    从使用Python写出第一行代码:print("Hello Python")时,我就怀揣着一个梦想,有一天,我一定要输入自己想要的内容.今天这个梦想终于实现了,多亏了input( ...

  4. python 类的内置方法_【转】[python] 类常用的内置方法

    原文:http://xukaizijian.blog.163.com/blog/static/170433119201111894228877/ 内置方法 说明 __init__(self,...) ...

  5. python提取发票信息发票识别_(附完整python源码)基于tensorflow、opencv的入门案例_发票识别二:字符分割...

    (附完整python源码)基于tensorflow.opencv的入门案例_发票识别二:字符分割 发布时间:2018-05-14 20:16, 浏览次数:1201 , 标签: python tenso ...

  6. python虚拟环境的安装和配置_基于virtualenv的Python虚拟环境的安装配置(Mac环境)...

    一.安装前提 明确自己的开发所需的python版本, Python 2.7.x 或者Python 3.6.x . 安装 Python 2.7.x 或Python 3.6.x 版的virtualenv. ...

  7. python高效编程15个利器_你不知道的18个Python高效编程技巧

    来源 | Python编程时光 初识Python语言,觉得python满足了我上学时候对编程语言的所有要求.python语言的高效编程技巧让我们这些大学曾经苦逼学了四年c或者c++的人,兴奋的不行不行 ...

  8. python的简单程序代码_小白学编程?从一个简单的程序开始学习Python编程

    笔者思虑再三还是决定选择图文(因为百家的视频发布画质真不怎么样[囧]). 笔者学习编程的时间也挺长的,因为业余,因为时间不多,各种原因,自学编程的路特别难走.然后笔者发现,自己能为小白贡献一些力量,然 ...

  9. python在建筑施工方面的应用_有哪些关于 Python 在建筑中的应用和教程?

    2018.02.09更新 (發現距離上一次更新馬上就要兩年了--) 嗯,兩年間發生了很多事.我也莫名其妙跑到ETH來了. 做起了Fab的優化,python已經完全不能滿足效率和複雜度的要求,走上了C+ ...

最新文章

  1. HDU 5384 Danganronpa (2015年多校比赛第8场)
  2. django时差8个小时问题
  3. 6.3-4 zip、unzip
  4. Python3文件操作详解 Python3文件操作大全
  5. ArcGIS客户端开发学习笔记(二)——XML
  6. jdbc 链接不了mysql_JDBC链接Mysql失败
  7. 开课吧python小课学了有用吗-(内推实习)年薪30万,大量缺人,这个技能在金融圈到底有多吃香?...
  8. iOS中转义后的html标签如何还原
  9. php 文件上传框架,Laravel框架实现文件上传的方法分析
  10. 前端优化,包括css,jss,img,cookie
  11. 搜索引擎为什么不收录原创文章
  12. web浏览器显示网站小图标
  13. 如何在手机上压缩图片?两种免费方法了解一下
  14. 安卓开发-接收系统广播
  15. 跨平台移动开发工具:PhoneGap与Titanium全方位比拼
  16. excel应用(1)
  17. 深度干货:史上最全的市场推广渠道大全(附攻略和技巧)
  18. A股中的level1跟Level2有什么区别
  19. 培训机构炒出来的Unity就业没问题吗
  20. iPhone 11忘记了密码怎么办?

热门文章

  1. html input 传值 request接到值为null,解决jsp向servlet传值为null的问题
  2. vue可视化拖拽生成工具_vdesjs: 基于vue的可视化拖拽,代码生成工具。提升前端开发效率,或者集成至项目作为在线拖拽工具。(持续迭代升级中)...
  3. python去除中文停用词_删除停止词Python
  4. 用python倒序输出一个字符串_Python字符串逆序输出的实例讲解
  5. linux查看证书位数,查看Linux系统是32位还是64位(getconf WORD_BIT误区)
  6. Linux 操作命令记录
  7. python读取视频分辨率_Python实现以不同分辨率分类视频
  8. vue 部门tree样式_vue+Element实现tree树形数据展示
  9. Kafka 基本原理
  10. andengine游戏引擎总结基础篇