python调用数据集mnist_使用MNIST数据集进行分类
本文是对书《机器学习实战:基于Scikit-Learn和Tensorflow》第三章的知识学习以及代码复现,欢迎大家一起学习一起进步。
获取数据集
提前将MNIST数据集下载好,并放在’\scikit_learn_data’目录之下from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original', data_home=r'C:\Users\12637\scikit_learn_data')
mnist
001.pngX, y = mnist["data"], mnist["target"]
X.shape
002.png
一共有70000张图片,每张图片有784个特征。因为图片是28×28像素,每个特征代表了一个像素点的强度,从0(白色)到255(黑色)。随手抓取一个实例的特征向量,将其重新形成一个28X28数组,然后用Matplotlib的imshow()函数将其显示出来%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
some_digit = X[36000]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation="nearest")
plt.axis("off")
plt.show()
003.png# MNIST数据集中的部分数字图像
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# 给训练集数据洗牌
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
训练一个二元分类器# 训练一个二元分类器
y_train_5 = (y_train == 5) # True for all 5s, False for all other digits
y_test_5 = (y_test == 5)
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)
sgd_clf.predict([some_digit])
004.png
实施交叉验证# 实施交叉验证
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
skfolds = StratifiedKFold(n_splits=3, random_state=42)
for train_index, test_index in skfolds.split(X_train, y_train_5):
clone_clf = clone(sgd_clf)
X_train_folds = X_train[train_index]
y_train_folds = (y_train_5[train_index])
X_test_fold = X_train[test_index]
y_test_fold = (y_train_5[test_index])
clone_clf.fit(X_train_folds, y_train_folds)
y_pred = clone_clf.predict(X_test_fold)
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
005.png
006.png
混淆矩阵# 混淆矩阵
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
007.png
精度和召回率以及精度/召回率平衡
008.pngy_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method="decision_function")
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.xlabel("Threshold")
plt.legend(loc="upper left")
plt.ylim([0, 1])
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()
009.png
ROC曲线绘制# ROC曲线
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label=label)
plt.plot([0, 1], [0, 1], 'k--')
plt.axis([0, 1, 0, 1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plot_roc_curve(fpr, tpr)
plt.show()
010.pngfrom sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")
y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
plt.plot(fpr, tpr, "b:", label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "Random Forest")
plt.legend(loc="bottom right")
plt.show()
011.png
多类别分类器
012.png
013.png
错误分析# 错误分析
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx
014.png
015.png
016.png
多标签分类# 多标签分类
from sklearn.neighbors import KNeighborsClassifier
y_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
knn_clf.predict([some_digit])
多输出分类# 多输出分类
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test
def plot_digit(data):
image = data.reshape(28, 28)
plt.imshow(image, cmap = matplotlib.cm.binary,
interpolation="nearest")
plt.axis("off")
some_index = 5500
plt.subplot(121); plot_digit(X_test_mod[some_index])
plt.subplot(122); plot_digit(y_test_mod[some_index])
plt.show()
017.pngknn_clf.fit(X_train_mod, y_train_mod)
clean_digit = knn_clf.predict([X_test_mod[some_index]])
plot_digit(clean_digit)
018.png
https://www.jianshu.com/p/b6cf853975cc
python调用数据集mnist_使用MNIST数据集进行分类相关推荐
- Python实现bp神经网络识别MNIST数据集
title: "Python实现bp神经网络识别MNIST数据集" date: 2018-06-18T14:01:49+08:00 tags: [""] cat ...
- [Pytorch系列-33]:数据集 - torchvision与MNIST数据集
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...
- 基于Python实现的神经网络分类MNIST数据集
神经网络分类MNIST数据集 目录 神经网络分类MNIST数据集 1 一 .问题背景 1 1.1 神经网络简介 1 前馈神经网络模型: 1 1.2 MINST 数据说明 4 1.3 TensorFlo ...
- Python 手写数字识别 MNIST数据集下载失败
目录 一.MNIST数据集下载失败 1 失败的解决办法(经验教训): 2 亲测有效的解决方法: 一.MNIST数据集下载失败 场景复现:想要pytorch+MINIST数据集来实现手写数字识别,首先就 ...
- 使用mnist数据集_使用MNIST数据集上的t分布随机邻居嵌入(t-SNE)进行降维
使用mnist数据集 It is easy for us to visualize two or three dimensional data, but once it goes beyond thr ...
- python调用数据集mnist_Python读取MNIST数据集
importnumpy as npimportmatplotlib.pyplot as plt'''试验transpose() def back (a,b): return a,b if __name ...
- 机器学习之sklearn使用下载MNIST数据集进行分类识别
机器学习之sklearn使用下载MNIST数据集进行分类识别 一.MNIST数据集 1.MNIST数据集简介 2.获取MNIST数据集 二.训练一个二分类器 1.随机梯度下降(SGD)分类器 2.分类 ...
- 全面理解主成分分析(PCA)和MNIST数据集的Python降维实现
注:本博文为原创博文,如需转载请注明原创链接!!! 这篇博文主要讲述主成分分析的原理并用该方法来实现MNIST数据集的降维. 一.引言 主成分分析是一种降维和主成分解释的方法.举一个比较容易理 ...
- python处理MNIST数据集
1. MNIST数据集 1.1 MNIST数据集获取 MNIST数据集是入门机器学习/模式识别的最经典数据集之一.最早于1998年Yan Lecun在论文: Gradient-based learni ...
- 读取mnist数据集方法大全(train-images-idx3-ubyte.gz,train-labels.idx1-ubyte等)(python读取gzip文件)
文章目录 gzip包 keras读取mnist数据集 本地读取mnist数据集 下载数据集 解压读取 方法一 方法二 gzip包读取 读取bytes数据 注:import导入的包如果未安装使用pip安 ...
最新文章
- [Math][Algebra]--线性代数中的各种空间
- boost::type_index::type_id相关的测试程序
- 史上最简单的SpringCloud教程 | 第九篇: 服务链路追踪(Spring Cloud Sleuth)
- MFCC梅尔频率倒谱系数
- c++中RTTI的观念和使用
- OSG仿真案例(5)——创建火光、爆炸(碎片)
- 我的CSDN资源下载怎么被自动设置了积分
- pydobc连接sql server_python – PyOdbc无法连接到SQL Server实例
- Spring定时器corn表达式
- PROE 安装提示注册号丢失
- 软件企业变更管理流程
- 网络速率与TCP窗口大小的关系
- OSError: exception: access violation writing 0x000000003F800000
- 【计蒜客 - 蓝桥训练】蒜厂年会(循环数列的最大子段和)
- 软件开发自学靠谱吗?
- php快捷方式 图标ie,pubwin删除IE快捷方式并自行创建IE快捷方式的问题解决方法...
- 台式电脑键盘错乱会出现计算机模式怎么办,键盘错乱怎么修复错位(台式电脑键盘按键错乱)...
- html长图转换成pdf,iOS将HTML页面转换成PDF文件保存到本地并分享传输文件
- Scrum敏捷开发之我的总结
- python 利用cip.cc查询IP归属地