第三章学习分类任务——引入MNIST数据集,分类0-9十个数字。

1.载入数据集

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
mnist.keys()

将数据集分成训练集和测试集

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

2.使用SGD分类器训练5和非5

y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)from sklearn.linear_model import SGDClassifiersgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)
sgd_clf.fit(X_train, y_train_5)sgd_clf.predict([some_digit])

3.性能测量

交叉验证

from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")

计算混淆矩阵

from sklearn.model_selection import cross_val_predicty_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
from sklearn.metrics import confusion_matrixconfusion_matrix(y_train_5, y_train_pred)

计算精度:

from sklearn.metrics import precision_score, recall_scoreprecision_score(y_train_5, y_train_pred)

计算召回率:

recall_score(y_train_5, y_train_pred)

计算F1分数

from sklearn.metrics import f1_scoref1_score(y_train_5, y_train_pred)

绘制精度和召回率相对于阈值的函数图:

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)plt.legend(loc="center right", fontsize=16) plt.xlabel("Threshold", fontsize=16)        plt.grid(True)                              plt.axis([-50000, 50000, 0, 1])            recall_90_precision = recalls[np.argmax(precisions >= 0.90)]
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")
plt.plot([threshold_90_precision], [0.9], "ro")
plt.plot([threshold_90_precision], [recall_90_precision], "ro")
save_fig("precision_recall_vs_threshold_plot")
plt.show()

精度与召回率的关系函数:

def plot_precision_vs_recall(precisions, recalls):plt.plot(recalls, precisions, "b-", linewidth=2)plt.xlabel("Recall", fontsize=16)plt.ylabel("Precision", fontsize=16)plt.axis([0, 1, 0, 1])plt.grid(True)plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], "r:")
plt.plot([0.0, recall_90_precision], [0.9, 0.9], "r:")
plt.plot([recall_90_precision], [0.9], "ro")
save_fig("precision_vs_recall_plot")
plt.show()

绘制ROC曲线:

from sklearn.metrics import roc_curvefpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
def plot_roc_curve(fpr, tpr, label=None):plt.plot(fpr, tpr, linewidth=2, label=label)plt.plot([0, 1], [0, 1], 'k--') # dashed diagonal
plot_roc_curve(fpr, tpr)
plt.show()

计算AUC

from sklearn.metrics import roc_auc_scoreroc_auc_score(y_train_5, y_scores)

比较随机森林分类器和SGD分类器:

from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,method="predict_proba")y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)recall_for_forest = tpr_forest[np.argmax(fpr_forest >= fpr_90)]plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, "b:", linewidth=2, label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.plot([fpr_90, fpr_90], [0., recall_for_forest], "r:")
plt.plot([fpr_90], [recall_for_forest], "ro")
plt.grid(True)
plt.legend(loc="lower right", fontsize=16)
save_fig("roc_curve_comparison_plot")
plt.show()

4.多类分类器:

SVM分类器:

from sklearn.svm import SVCsvm_clf = SVC(gamma="auto", random_state=42)
svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5
svm_clf.predict([some_digit])

5分数最高识别成功,

书中还列举了其他方法。

5.误差分析:

查看混淆矩阵:

y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx

可视化:

plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

主对角线比较 明显清晰,分类效果良好。

接下来关注错误值:

第8列较亮,可见较多图片被误分类为8.那么需要另设计算法加以区分。

6.多标签分类

将一个实例输出多个类,例如是否大数,是否奇数:

from sklearn.neighbors import KNeighborsClassifiery_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
knn_clf.predict([some_digit])

结果正确。

7.多输出分类

输出多个标签,标签有多个值(去除图片噪声为例):

noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test
some_index = 0
plt.subplot(121); plot_digit(X_test_mod[some_index])
plt.subplot(122); plot_digit(y_test_mod[some_index])
save_fig("noisy_digit_example_plot")
plt.show()

《机器学习实战》第三章相关推荐

  1. 第三章 UT单元测试——CPU与内存使用率限制

    系列文章目录 第一章 UT单元测试--GoogleTest通用构建说明 第二章 UT单元测试--GTest框架实例 第三章 UT单元测试--CPU与内存使用率限制 文章目录 系列文章目录 前言 一.环 ...

  2. 慕课软件质量保证与测试(第三章.单元测试)

    慕课金陵科技学院.软件质量保证与测试.第三章.黑盒测试.单元测试 0 目录 3 黑盒测试 3.9 单元测试 3.9.1课堂重点 3.9.2测试与作业 4 下一章 0 目录 3 黑盒测试 3.9 单元测 ...

  3. 《构建之法》前三章读后感

    通过第一章讲述的概论,理解到软件工程到底是什么,又为何要叫软件工程,他对我们的生活又有什么影响. 通过一些实例我也认识到客户需求分析的重要,就阿超那样的四则运算一样,渐渐的功能和需求就多了. 在第二章 ...

  4. 走向.NET架构设计—第三章—分层设计,初涉架构(后篇)

    走向.NET架构设计-第三章-分层设计,初涉架构(后篇) 前言:本篇主要是接着前两篇文章继续讲述! 本篇的议题如下: 4. 数据访问层设计 5. 显示层设计 6. UI层设计   4.  数据访问层设 ...

  5. 软考中项第三章 信息系统集成专业知识

    第三章 信息系统集成专业知识 信息系统的生命周期可以分为立项.开发.运维及消亡四个阶段 立项阶段:概念阶段或需求阶段,这一阶段根据用户业务发展和经营管理的需要,提出建设信息系统的初步构想,然后对企业信 ...

  6. 构建之法前三章读后感—软件工程

    本教材不同于其他教材一贯的理知识直接灌溉,而是以对话形式向我们传授知识的,以使我们更好地理解知识点,更加清晰明确. 第一章 第一章的概述中,书本以多种方式,形象生动地向我们阐述了软件工程的内容,也让我 ...

  7. 关于对《Spring Security3》翻译 (第一章 - 第三章)

    原文:http://lengyun3566.iteye.com/category/153689?page=2 翻译说明 最近阅读了<Spring Security3>一书,颇有收获(封面见 ...

  8. C++ API 设计 08 第三章 模式

    第三章 模式 前一章所讨论的品质是用来区分设计良好和糟糕的API.在接下来的几个章节将重点关注构建高品质的API的技术和原则.这个特殊的章节将涵盖一些有用的设计模式和C++ API设计的相关习惯用法. ...

  9. 敏捷整洁之道 -- 第三章 业务实践

    敏捷整洁之道 -- 第三章 业务实践 0. 引子 1. 计划游戏 1.1 三元分析 1.2 故事和点数 1.3 故事 1.4 故事估算 1.5 对迭代进行管理 1.6 速率 2. 小步发布 3. 验收 ...

  10. 第三章 信息系统集成专业技术知识

    第三章 信息系统集成专业技术知识 知识点 1.信息系统的生命周期有哪几个过程 2.信息系统开发的方法有几种:各种用于什么情况的项目. 3.软件需求的定义及分类: 4.软件设计的基本原则是什么: 5.软 ...

最新文章

  1. APL开发日志--2012-11-08
  2. 文件格式用Latex排版论文(1)如何将Visio画图文件转换成Latex支持的.eps文件
  3. 文巾解题 1480. 一维数组的动态和
  4. 与优秀的人在一起进步:我发起的“乐学”分享活动
  5. Spring5参考指南:SpringAOP简介
  6. 通过shell访问hive_【HIVE】SHELL调用Hive查询
  7. python怎么让py里面逐行运行_怎样在安卓上运行python
  8. mysql5.7改了配置文件怎么生效_如何找到并修改MySQL57的配置文件m
  9. JAVA小白启蒙篇:第一个SSM框架搭建示例(附源码下载)
  10. 爬取世界各国历年的GDP数据
  11. P1048 采药 洛谷Oj
  12. 在电脑上如何剪辑音乐?
  13. Linux-千兆网卡驱动实现机制浅析
  14. Spyder中出现IndentationError:unindent does not match any outer indentation level错误
  15. 2020暨南大学计算机专硕考研经验分享
  16. 王者战力查询接口(免费)
  17. wxpy 建群 err_code: 1 err_msg:
  18. JAVA打印300以内的质数
  19. ASEMI整流模块MSAD165-16参数,MSAD165-16规格
  20. 如何配置NAT Server?

热门文章

  1. solidworks有限元分析怎么做?
  2. 个人站长常用的4款网站统计工具
  3. 【C#】控制台应用程序闪退解决方法
  4. 在线进销存软件免费版,哪个可以用?
  5. 肠道健康从核心菌属开始:肠道菌群的关键
  6. DC-DC转换器参数-电源技术基础---和讯康讲堂
  7. FFmpeg从入门到精通-云享读书会
  8. HTTP报文(message)是什么?请求报文、响应报文、报文首部(header)、报文主体(body)
  9. 学习PS都需要准备什么?
  10. Laravel 数据库去重计数