《机器学习实战》第三章
记
第三章学习分类任务——引入MNIST数据集,分类0-9十个数字。
1.载入数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
mnist.keys()
将数据集分成训练集和测试集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
2.使用SGD分类器训练5和非5
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)from sklearn.linear_model import SGDClassifiersgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)
sgd_clf.fit(X_train, y_train_5)sgd_clf.predict([some_digit])
3.性能测量
交叉验证
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
计算混淆矩阵
from sklearn.model_selection import cross_val_predicty_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
from sklearn.metrics import confusion_matrixconfusion_matrix(y_train_5, y_train_pred)
计算精度:
from sklearn.metrics import precision_score, recall_scoreprecision_score(y_train_5, y_train_pred)
计算召回率:
recall_score(y_train_5, y_train_pred)
计算F1分数
from sklearn.metrics import f1_scoref1_score(y_train_5, y_train_pred)
绘制精度和召回率相对于阈值的函数图:
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)plt.legend(loc="center right", fontsize=16) plt.xlabel("Threshold", fontsize=16) plt.grid(True) plt.axis([-50000, 50000, 0, 1]) recall_90_precision = recalls[np.argmax(precisions >= 0.90)]
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]plt.figure(figsize=(8, 4))
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")
plt.plot([threshold_90_precision], [0.9], "ro")
plt.plot([threshold_90_precision], [recall_90_precision], "ro")
save_fig("precision_recall_vs_threshold_plot")
plt.show()
精度与召回率的关系函数:
def plot_precision_vs_recall(precisions, recalls):plt.plot(recalls, precisions, "b-", linewidth=2)plt.xlabel("Recall", fontsize=16)plt.ylabel("Precision", fontsize=16)plt.axis([0, 1, 0, 1])plt.grid(True)plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], "r:")
plt.plot([0.0, recall_90_precision], [0.9, 0.9], "r:")
plt.plot([recall_90_precision], [0.9], "ro")
save_fig("precision_vs_recall_plot")
plt.show()
绘制ROC曲线:
from sklearn.metrics import roc_curvefpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
def plot_roc_curve(fpr, tpr, label=None):plt.plot(fpr, tpr, linewidth=2, label=label)plt.plot([0, 1], [0, 1], 'k--') # dashed diagonal
plot_roc_curve(fpr, tpr)
plt.show()
计算AUC
from sklearn.metrics import roc_auc_scoreroc_auc_score(y_train_5, y_scores)
比较随机森林分类器和SGD分类器:
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,method="predict_proba")y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)recall_for_forest = tpr_forest[np.argmax(fpr_forest >= fpr_90)]plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, "b:", linewidth=2, label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.plot([fpr_90, fpr_90], [0., recall_for_forest], "r:")
plt.plot([fpr_90], [recall_for_forest], "ro")
plt.grid(True)
plt.legend(loc="lower right", fontsize=16)
save_fig("roc_curve_comparison_plot")
plt.show()
4.多类分类器:
SVM分类器:
from sklearn.svm import SVCsvm_clf = SVC(gamma="auto", random_state=42)
svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5
svm_clf.predict([some_digit])
5分数最高识别成功,
书中还列举了其他方法。
5.误差分析:
查看混淆矩阵:
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
可视化:
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
主对角线比较 明显清晰,分类效果良好。
接下来关注错误值:
第8列较亮,可见较多图片被误分类为8.那么需要另设计算法加以区分。
6.多标签分类
将一个实例输出多个类,例如是否大数,是否奇数:
from sklearn.neighbors import KNeighborsClassifiery_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
knn_clf.predict([some_digit])
结果正确。
7.多输出分类
输出多个标签,标签有多个值(去除图片噪声为例):
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test
some_index = 0
plt.subplot(121); plot_digit(X_test_mod[some_index])
plt.subplot(122); plot_digit(y_test_mod[some_index])
save_fig("noisy_digit_example_plot")
plt.show()
《机器学习实战》第三章相关推荐
- 第三章 UT单元测试——CPU与内存使用率限制
系列文章目录 第一章 UT单元测试--GoogleTest通用构建说明 第二章 UT单元测试--GTest框架实例 第三章 UT单元测试--CPU与内存使用率限制 文章目录 系列文章目录 前言 一.环 ...
- 慕课软件质量保证与测试(第三章.单元测试)
慕课金陵科技学院.软件质量保证与测试.第三章.黑盒测试.单元测试 0 目录 3 黑盒测试 3.9 单元测试 3.9.1课堂重点 3.9.2测试与作业 4 下一章 0 目录 3 黑盒测试 3.9 单元测 ...
- 《构建之法》前三章读后感
通过第一章讲述的概论,理解到软件工程到底是什么,又为何要叫软件工程,他对我们的生活又有什么影响. 通过一些实例我也认识到客户需求分析的重要,就阿超那样的四则运算一样,渐渐的功能和需求就多了. 在第二章 ...
- 走向.NET架构设计—第三章—分层设计,初涉架构(后篇)
走向.NET架构设计-第三章-分层设计,初涉架构(后篇) 前言:本篇主要是接着前两篇文章继续讲述! 本篇的议题如下: 4. 数据访问层设计 5. 显示层设计 6. UI层设计 4. 数据访问层设 ...
- 软考中项第三章 信息系统集成专业知识
第三章 信息系统集成专业知识 信息系统的生命周期可以分为立项.开发.运维及消亡四个阶段 立项阶段:概念阶段或需求阶段,这一阶段根据用户业务发展和经营管理的需要,提出建设信息系统的初步构想,然后对企业信 ...
- 构建之法前三章读后感—软件工程
本教材不同于其他教材一贯的理知识直接灌溉,而是以对话形式向我们传授知识的,以使我们更好地理解知识点,更加清晰明确. 第一章 第一章的概述中,书本以多种方式,形象生动地向我们阐述了软件工程的内容,也让我 ...
- 关于对《Spring Security3》翻译 (第一章 - 第三章)
原文:http://lengyun3566.iteye.com/category/153689?page=2 翻译说明 最近阅读了<Spring Security3>一书,颇有收获(封面见 ...
- C++ API 设计 08 第三章 模式
第三章 模式 前一章所讨论的品质是用来区分设计良好和糟糕的API.在接下来的几个章节将重点关注构建高品质的API的技术和原则.这个特殊的章节将涵盖一些有用的设计模式和C++ API设计的相关习惯用法. ...
- 敏捷整洁之道 -- 第三章 业务实践
敏捷整洁之道 -- 第三章 业务实践 0. 引子 1. 计划游戏 1.1 三元分析 1.2 故事和点数 1.3 故事 1.4 故事估算 1.5 对迭代进行管理 1.6 速率 2. 小步发布 3. 验收 ...
- 第三章 信息系统集成专业技术知识
第三章 信息系统集成专业技术知识 知识点 1.信息系统的生命周期有哪几个过程 2.信息系统开发的方法有几种:各种用于什么情况的项目. 3.软件需求的定义及分类: 4.软件设计的基本原则是什么: 5.软 ...
最新文章
- APL开发日志--2012-11-08
- 文件格式用Latex排版论文(1)如何将Visio画图文件转换成Latex支持的.eps文件
- 文巾解题 1480. 一维数组的动态和
- 与优秀的人在一起进步:我发起的“乐学”分享活动
- Spring5参考指南:SpringAOP简介
- 通过shell访问hive_【HIVE】SHELL调用Hive查询
- python怎么让py里面逐行运行_怎样在安卓上运行python
- mysql5.7改了配置文件怎么生效_如何找到并修改MySQL57的配置文件m
- JAVA小白启蒙篇:第一个SSM框架搭建示例(附源码下载)
- 爬取世界各国历年的GDP数据
- P1048 采药 洛谷Oj
- 在电脑上如何剪辑音乐?
- Linux-千兆网卡驱动实现机制浅析
- Spyder中出现IndentationError:unindent does not match any outer indentation level错误
- 2020暨南大学计算机专硕考研经验分享
- 王者战力查询接口(免费)
- wxpy 建群 err_code: 1 err_msg:
- JAVA打印300以内的质数
- ASEMI整流模块MSAD165-16参数,MSAD165-16规格
- 如何配置NAT Server?