scikit-learn是一个开源的Python语言机器学习工具包。它涵盖了几乎所有主流机器学习算法的实现,并且提供了一致的调用接口。它基于Numpy和SciPy等Python数值计算库,提供了高效的算法实现。总结起来,scikit-learn工具包有以下几个优点:

  • 文档齐全:官方文档齐全,更新及时。
  • 接口易用:针对所有的算法提供了一致的接口调用规则,不管是KNN、K-Means还是PCA。
  • 算法全面:涵盖主流机器学习任务的算法,包括回归算法、分类算法、聚类分析、数据降维处理等。

当然,scikit-learn不支持分布式计算,不适合用来处理超大型数据,但这并不影响scikit-learn作为一个优秀的机器学习工具库这个事实。许多知名的公司,包括Evemote和Spotify都使用scikit-learn作为他们的机器学习应用。

1.scikit-learn示例

回顾前面章节介绍的机器学习应用开发的典型步骤,我们使用scikit-learn来完成一个手写数字识别的例子。这是一个有监督的学习,数据是标记过的手写数字的图片。即通过采集足够多的手写数字样本数据,选择合适的模型,并使用采集到的数据进行模型训练,最后验证手写识别程序的正确性。

(1)数据采集和标记

如果我们从头实现一个手写数字识别程序,需要先采集数据,即让尽量多不同手写习惯的用户,写出从0~9的所有数字,然后把用户写出来的数字进行标记,即用户每写出一个数字,就标记他写出的是哪个数字。
为什么要采集尽量多不同书写习惯的用户写的数字呢?因为只有这样,采集到的数据才有代表性,才能保证最终训练出来的模型的准确性。极端的例子,我们采集的都是习惯写出瘦高形数字的人,那么针对习惯写出矮胖形数字的人写出来的数字,模型的识别成功率就会很低。
所幸我们不需要从头开始这项工作。scikit-learn自带了一些数据集,其中一个是数字识别图片的数据。使用以下代码来加载数据。

from sklearn import dataset
digits = datasets.load_digits()

可以在ipython notebook环境下把数据所表示的图片用matplotlib显示出来:

from sklearn import datasets
from matplotlib import pyplot as plt
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images,digits.target))
plt.figure(figsize=(8,6),dpi=200)
for index,(image,label)in enumerate(images_and_labels[:8]):plt.subplot(2,4,index+1)plt.axis('off')plt.imshow(image,cmap=plt.cm.gray_r,interpolation='nearest')plt.title('Digit:%i'%label,fontsize=20)


从上图中可以看出,图片是一个个手写的数字。

(2)特征选择

针对一个手写的图片数据,应该怎么样来选择特征呢?一个直观的方法是,直接使用图片的每个像素点作为一个特征。比如一个图片是200x200的分辨率,那么我们就有40000个特征,即特征向量的长度是40000.
实际上,scikit-learn使用Numpy的array对象来表示数据,所有的图片数据保存在digits.images里,每个元素都是一个8x8尺寸的灰阶图片。我们在进行机器学习时,需要把数据保存为[样本个数]x[特征个数]格式的array对象,针对手写数字识别这个案例,scikit-learn已经为我们转换好了,它就保存在digits.data数据里,可以通过digits.data.shape来查看它的数据格式:

print("shape of raw image data: {0}".format(digits.images.shape))
print("shape if data: {0}".format(digits.data.shape))shape of raw image data: (1797, 8, 8)
shape if data: (1797, 64)

可以看到,总共有1797个训练样本,其中原始的数据是8x8的图片,而用来训练的数据是把图片的64个像素点都转换为特征。下面将直接使用digits.data作为训练数据。

(3)数据清洗

人们不可能在8x8这个小的分辨率的图片上写出数字,在采集数据的时候,是让用户在一个大图片上写出这些数字,如果图片是200x200分辨率,那么一个训练样例就有40000个特征,计算量将是巨大的。为了减少计算量,也为了模型的稳定性,我们需要把200x200的样本缩小为8x8的图片。这个过程就是数据清洗,即把采集到的、不适合用来做机器学习训练的数据进行预处理,从而转换为适合机器学习的数据。

(4)模型选择

不同的机器学习算法模型针对特定的机器学习应用有不同的效率,模型的选择和验证留到后面章节详细介绍。此处,我们使用支持向量机来作为手写识别算法的模型。关于支持向量机,后面章节也会详细介绍。

(5)模型训练

在开始训练我们的模型之前,需要先把数据集分成训练数据集和测试数据集。我们可以使用下面的代码吧数据集分出20%作为测试数据集,80%作为训练数据集。

from sklearn.cross_validation import train_test_split
Xtrain,Xtest,Ytrain,Ytest = train_test_split(digits.data,digits.target,test_size=0.20,random_state=2)

接着,使用训练数据集Xtrain和Ytrain来训练模型。

from sklearn import svm
clf = svm.SVC(gamma=0.001,C=100.0)
clf.fit(Xtrain,Ytrain)
from sklearn import svm
clf = svm.SVC(gamma=0.001,C=100.0)
clf.fit(Xtrain,Ytrain)

训练完成后,clf对象就会包含我们训练出来的模型参数,可以使用这个模型来进行预测。

SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',max_iter=-1, probability=False, random_state=None, shrinking=True,tol=0.001, verbose=False)
(6)模型测试

我们来测试一个训练出来的模型的准确度。一个直观的方法是,我们用训练出来的模型clf预测测试数据集,然后把预测结果Ypred和真正的结果Ytest比较,看有多少个是正确的,这样就能评估出模型的准确度了。所幸,scikit-learn提供了现成的方法来完成这项工作:

clf.score(Xtest,Ytest)

输出结果为:

0.9777777777777777

结果显示出模型有97.8%的准确率。
除此之外,还可以直接把测试数据集里的部分图片显示出来,并且在图片的左下角显示预测值,右下角显示真实值。

Ypred = clf.predict(Xtest)
fig,axes = plt.subplots(4,4,figsize=(8,8))
fig.subplots_adjust(hspace=0.1,wspace=0.1)
for i,ax in enumerate(axes.flat):ax.imshow(Xtest[i].reshape(8,8),cmap=plt.cm.gray_r,interpolation='nearest')ax.text(0.05,0.05,str(Ypred[i]),fontsize=32,transform=ax.transAxes,color='green' if Ypred[i] == Ytest[i] else 'red')ax.text(0.8,0.05,str(Ytest[i]),fontsize=32,transform=ax.transAxes,color='black')ax.set_xticks([])ax.set_yticks([])


从上图可以看出,第二行第一个图片预测出错了,真实数字是4,但是预测成了8。

(7)模型保存与加载

当我们对模型的准确度感到满意后,就可以把模型保存下来。这样下次需要预测时,可以直接加载模型来进行预测,而不是重新训练一遍模型。可以使用下面的代码来保存模型:

from sklearn.externals import joblib
joblib.dump(clf,'E:\digits_svm.pkl')

当我们需要这个模型来进行预测时,直接加载模型即可进行预测。

clf2 = joblib.load('E:\\digits_svm.pkl')
Ypred = cfl2.predict(Xtest)
clf2.score(Xtest,Ytest)

2.scikit-learn一般性原理和通用规则

scikit-learn包含大部分流行的有监督学习算法(分类和回归)和无监督学习算法(聚类和数据降维)的实现。

(1)评估模型对象

scikit-learn里的所有算法都以一个评估模型对象来对外提供接口。上面的例子里的svm.SVC()函数返回的就是一个支持向量机评估模型对象。创建评估模型对象时,可以指定不同的参数,这个称为评估对象参数,评估对象参数直接影响评估模型训练时的效率和准确度。
我们可以试着修改上面例子里的clf.svm.SVC(gamma=0.001,C=100.0)语句的参数值,看对模型准确度有没有影响。我们暂时忽略这些评估对象参数的意思,后面讲解每个机器学习算法时再详细介绍。
需要特比说明的是,我们学习机器学习算法的原理,其中一项飞虫重要的任务就是了解不同的机器学习算法有哪些可调参数,这些参数代表什么意思,对机器学习算法的性能及准确性有没有什么影响。因为在工程应用上,要从头实现一个机器学习算法的可能性非常低,除非是数值计算科学家。更多的情况下,是分析采集到的数据,根据数据特征选择合适的算法,并且调整算法的参数,从而实现算法效率和准确度之间的平衡。

(2)模型接口

scikit-learn所有的评估模型对象都有fit()这个接口,这是用来训练模型的接口。针对有监督的机器学习(如上面的例子),使用fit(X,y)来进行训练,其中y是标记数据。针对无监督的机器学习算法,使用fit(X)来进行训练,因为无监督机器学习算法的数据集是没有标记的,不需要传入y。
针对所有的有监督机器学习算法,scikit-learn的模型对象提供了predict()接口,进过训练的模型,可以用这个接口进行预测。针对分类问题,有些模型还提供了predict_proba()的接口,用来输出一个待预测的数据属于各种类型的可能性,而predict()接口直接返回了可能性最高的那个类别。
几乎所有的模型都提供了score()接口来评价一个模型的好坏,得分越高越好。需要说明的是,不是所有的问题都有准确度这个评价标准,比如针对异常检测系统,一些产品不良率可以控制到10%以下,这个时候一个最简单的模型是无条件地全部预测为合格,即无条件返回1,其准确率将达到99.999%以上,但实际上这不是一个好的模型。评价这种模型,就需要使用 查准率 和 召回率 来衡量。相关概念我们后面会详细介绍。
针对无监督的机器学习算法,scikit-learn的模型对象也提供了predict()接口,它是用来对数据进行聚类分析,即把新数据归入某个聚类里。除此之外,无监督学习算法还有transform()接口,这个接口用来进行转换,比如使用PCA(主成分分析)算法时即可把一个三维数据转换为对应的二维数据。
模型接口也是scikit-learn工具包的最大优势之一,即把不同的算法抽象出来,对外提供一致的接口调用。

(3)模型检测

机器学习应用开发的一个非常重要的方面就是模型检测,即需要检测我们训练出来的模型,针对“没见过的”陌生数据其预测准确性如何。除了模型提供的score()接口外,在sklearn.metrics包的下面有一系列用来检测模型性能的方法。

(4)模型选择

模型选择是个非常重要的课题,根据要处理的问题性质,数据是否经过标记,数据规模多大等等这些问题,可以对模型有个初步的选择。scikit-learn的官方网站上提供了一个模型速查表,只要回答几个简单的问题就可以选择一个相对合适的模型。感兴趣的读者可以搜索scikit-learn algorithm cheat sheet来查看这个图片(见下图),现在先大概有个印象,等阅读完本书再回头看这张图片,感受一下自己对其理解的变化和收获。

跟我一起学scikit-learn07:scikit-learn简介相关推荐

  1. 头条 上传图片大小_2021今日头条!龙华《学仕名府》规划《学仕名府》简介《学仕名府》全国上市!-龙华新区名盘导购...

    2021今日头条!龙华<学仕名府>规划<学仕名府>简介<学仕名府>全国上市! 项目营销中心:400-763-1618转74066 深圳龙华<学仕名府> ...

  2. 跟我一起学.NetCore之中间件(Middleware)简介和解析请求管道构建

    前言 中间件(Middleware)对于Asp.NetCore项目来说,不能说重要,而是不能缺少,因为Asp.NetCore的请求管道就是通过一系列的中间件组成的:在服务器接收到请求之后,请求会经过请 ...

  3. 新手学html 第一节:html简介

    什么是 HTML? HTML(Hypertext Markup Language)文本标记语言,是用于描述网页文档的一种标记语言. HTML 是用来描述网页的一种语言. HTML 指的是超文本标记语言 ...

  4. 北京学易星科技有限公司·简介

    北京学易星科技有限公司  成立于2005年,是一家专门从事教育科研.提供教育资源.教育服务与网络技术的实体性高新科技企业,以变革传统的学习方式和教育方式为己任,依托国内最大教学门户网站--中学学科网, ...

  5. 【打CF,学算法】CodeForces网站简介

    转自豆瓣:https://www.douban.com/review/5800694/ 你应当知道的关于Codeforces的事情 关于codeforces的文字 Codeforces 简称: cf( ...

  6. python scikit_Python SciKit学习教程

    python scikit Scikit学习 (Scikit Learn) Scikit-learn is a machine learning library for Python. It feat ...

  7. python 数据挖掘_Python数据挖掘框架scikit数据集之iris

    一.iris数据集简介 iris数据集的中文名是安德森鸢尾花卉数据集,英文全称是Anderson's Iris data set.iris包含150个样本,对应数据集的每行数据.每行数据包含每个样本的 ...

  8. python零基础能学吗-Python真的零基础可以学会吗?

    Python语言简单,对新手极其友好,但想要零基础学习,还需要解决一些基本问题,不能蛮干. 首先要弄明白自己的需求.为什么要学Python? 我学Python是想要给自己带来什么? 能够解决我的什么问 ...

  9. TensorFlow-4: tf.contrib.learn 快速入门

    学习资料: https://www.tensorflow.org/get_started/tflearn 相应的中文翻译: http://studyai.site/2017/03/05/%E3%80% ...

  10. 跟着数百万人编程导师学C语言!

    点击关注 异步图书,置顶公众号 每天与你分享 IT好书 技术干货 职场知识​ ​ ​​参与文末话题讨论,每日赠送异步图书 --异步小编 为什么说这不是一本完全C语言的书?因为<"笨办法 ...

最新文章

  1. 图解 SQL,这也太形象了吧!
  2. IT民工创业之殇---续1
  3. AcWing 199. 余数之和 (除法分块)打卡
  4. php解释命令行的参数
  5. ML之Clustering之K-means:K-means算法简介、应用、经典案例之详细攻略
  6. PHP连接达梦数据库
  7. poj 3678 Katu Puzzle(2-sat)
  8. 由于在客户端检测到一个协议错误_TLS协议的分析
  9. day20 派生属性和方法,钻石继承
  10. Java-打印三角形
  11. php 高德地图计算距离,距离、长度、面积
  12. 接口与抽象类的区别和联系
  13. java接口参数类型为枚举_Spring MVC处理参数中的枚举类型通用实现方法
  14. 精通JavaScript攻击框架:AttackAPI(上)
  15. Unity Asset Store——独立游戏开发者的素材插件商店
  16. 极速版RPS选股,一秒出结果的方案是如何实现的!股票量化分析工具QTYX-V2.5.3...
  17. 公网IP、私网IP、动态IP、静态IP
  18. hbuildX使用夜神模拟器配置
  19. lighttpd http响应报文(Response)增加安全头Referrer-Policy和X-Permitted-Cross-Domain-Policies方法
  20. MIMICIV数据库下载导入

热门文章

  1. 设计模式2之策略模式(整理笔记)
  2. 软件测试--基础知识1--测试简介、软件质量等
  3. windows2012R2更新KB2919355
  4. Portal技术介绍
  5. Linux中修改终端登录欢迎界面
  6. 汽车液压制动系统设计
  7. 深圳高新技术企业申请条件以及流程简单说明
  8. GIS讲堂第四课-大量POI点的展示
  9. LuaJit Trace Compiler剖析
  10. ijkplayer笔记