数据集加载,本来想使用sklearn中的 fetch_openml函数直接从网站下载数据集,然而现在这条命令不行(似乎是网站问题),因此,尝试用使用本地加载首先在

链接:https://pan.baidu.com/s/163MTS_89EKpJZsO6da5J3w
提取码:it3v
复制这段内容后打开百度网盘手机App,操作更方便哦

下载MNIST文件,里面一共有7w个手写数字样本数据,每个数据有28*28=784维。

import numpy as np
from sklearn.datasets import fetch_openml#使用此命令失败

文件目录结构

下载下来的数据:mnist-original.mat
使用scipy.io 读取.mat文件

import scipy.io as sio
mnist = sio.loadmat('datasets/mnist-original.mat')
print(mnist)
{'__header__': b'MATLAB 5.0 MAT-file Platform: posix, Created on: Sun Mar 30 03:19:02 2014', '__version__': '1.0', '__globals__': [], 'mldata_descr_ordering': array([[array(['label'], dtype='<U5'), array(['data'], dtype='<U4')]],dtype=object), 'data': array([[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],...,[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0]], dtype=uint8), 'label': array([[0., 0., 0., ..., 9., 9., 9.]])}

字典结构
我们用的到的data,label,读取到data和label之后对其进行转置。shape[0]为样本个数70000,shape[1]为样本维度个数784

X,y = mnist['data'],mnist['label']
X = X.T
y = y.T

看看shape

print(X.shape)
print(y.shape)

ok,没什么问题

(70000, 784)
(70000, 1)

train_test_split

X_train = np.array(X[:60000],dtype=float)
y_train = np.array(y[:60000],dtype=float)
X_test = np.array(X[60000:],dtype=float)
y_test = np.array(y[60000:],dtype=float)
print(X_train.shape)
print(y_train.shape)
(60000, 784)
(60000, 1)
print(X_test.shape)
print(y_test.shape)
(10000, 784)
(10000, 1)

使用KNN方法预测
首先不进行降维,训练时主要关注模型训练时间和精度

#使用KNN
from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier()
%time knn_clf.fit(X_train,y_train)

27.2S

Wall time: 27.2 s

预测时间

%time knn_clf.score(X_test,y_test)

预测时间

Wall time: 11min 5s

预测精度

0.9688

然后对28*28数据进行降维
希望保持0.9的信息量

#使用PCA进行降维
from sklearn.decomposition import PCA
pca = PCA(0.9)
pca.fit(X_train)
X_train_reduction = pca.transform(X_train)

看看降维后的样本维度,87维,却保留了0.9的信息量

print(X_train_reduction.shape)
(60000, 87)

此时,使用降维后的数据训练分类器

knn_clf = KNeighborsClassifier()
%time knn_clf.fit(X_train_reduction,y_train)

用时1.42S

Wall time: 1.45 s

降维后的测试数据

X_test_reduction = pca.transform(X_test)

看看预测时花费的时间,和预测精度

%time knn_clf.score(X_test_reduction,y_test)#准确率还提高了。。(降噪)

时间减少是在意料之中,精度从0.9688上升到0.9728,信息量不是损失了吗?为什么预测精度却上升了?
PCA还有另一种作用:降噪!…

Wall time: 1min 2s
0.9728

Python机器学习:PCA与梯度上升:007试手MNIST数据集相关推荐

  1. Python机器学习日记4:监督学习算法的一些样本数据集(持续更新)

    Python机器学习日记4:监督学习算法的一些样本数据集 一.书目与章节 二.forge数据集(二分类) 三.blobs数据集(三/多分类) 四.moons数据集 五.wave数据集(回归) 六.威斯 ...

  2. DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练.预测 导读 利用python的numpy计算库,进行自定义搭建2层神经网络TwoLayerN ...

  3. python机器学习——决策树(分类)及“泰坦尼克号沉船事故”数据集案例操作

    决策树(分类)及具体案例操作 一.决策树(分类)算法 (1)算法原理(类似于"分段函数") (2)决策树的变量类型 (3)量化纯度 (4)基本步骤 (5)决策树的优缺点 二.决策树 ...

  4. 机器学习实战10-Artificial Neural Networks人工神经网络简介(mnist数据集)

    目录 一.感知器 1.1.单层感知器 1.2.多层感知器MLP与反向传播 二.用 TensorFlow 高级 API 训练 MLP DNNClassifier(深度神经网络分类器) 2.1.初始化: ...

  5. 【python机器学习】线性回归--梯度下降实现(基于波士顿房价数据集)

    波士顿房价数据集字段说明 crim 房屋所在镇的犯罪率 zn 面积大于25000平凡英尺住宅所占比例 indus 房屋所在镇非零售区域所占比例 chas 房屋是否位于河边 如果在河边,值1 nox 一 ...

  6. Python机器学习:线型回归法007多元线性回归和正规方程的解

  7. python cnn程序_python cnn训练(针对Fashion MNIST数据集)

    本文将和大家一起一步步尝试对Fashion MNIST数据集进行调参,看看每一步对模型精度的影响.(调参过程中,基础模型架构大致保持不变) 废话不多说,先上任务: 模型的主体框架如下(此为拿到的原始代 ...

  8. PCA主成分分析算法专题【Python机器学习系列(十五)】

    PCA主成分分析算法专题[Python机器学习系列(十五)] 文章目录 1. PCA简介 2. python 实现 鸢尾花数据集PCA降维 3. sklearn库实现 鸢尾花数据集PCA降维案例    ...

  9. Python机器学习笔记 使用scikit-learn工具进行PCA降维...

    Python机器学习笔记 使用scikit-learn工具进行PCA降维 之前总结过关于PCA的知识:深入学习主成分分析(PCA)算法原理.这里打算再写一篇笔记,总结一下如何使用scikit-lear ...

最新文章

  1. c语言括号匹配的检验,检验括号匹配的算法
  2. 8、mybatis之增删改查
  3. zcmu1540(二分)
  4. oracle数据库中分析函数大全,Oracle数据库的分析函数
  5. 聚焦2020云栖大会 边缘计算专场畅谈技术应用创新
  6. rgb红色范围_【论文阅读18】RGB-D Object-Oriented Semantic Mapping
  7. 比雷蛇0day更严重:通过虚拟赛睿外设即获取 Windows 管理员权限
  8. drools视频教程(drool实战实例+数据库+视频讲解)
  9. linux+加载迅雷插件,linux下使用aria2c + chrome插件取代迅雷
  10. ev3编码软件linux,乐高ev3编程软件下载
  11. android wifi检测呼吸,WiFi已经逆天了 现在能检测到你的呼吸
  12. xposed框架安全模式_Android 系统上的 Xposed 框架中都有哪些值得推荐的模块?
  13. 遗传算法详解(GA)(个人觉得很形象,很适合初学者)
  14. 微信模板消息html,微信推送模板消息,偶发出现报错errcode
  15. youtube python 中文_GitHub - dousirui001/youtube-streaming-translator-python: 实时翻译油管直播,开发中...
  16. ThinkPHP 提示验证码输入错误
  17. 【STM32】几款常用产品(F1、F4、F7)的区别
  18. 准确查询表空间使用情况
  19. 智能语音技术的深度解析
  20. Arduino WIFI智能小车 无线视频遥控小车 课程设计

热门文章

  1. 【BZOJ1901】Dynamic Rankings,树状数组套主席树
  2. 【codevs1021】玛丽卡,以前屯着的最短路
  3. english 2012020602
  4. 5.过滤器作为模板——寻找沃尔多、不相同的模板匹配_3
  5. Pentium 4处理器架构/微架构/流水线 (11) - NetBurst执行核详解 - Load/Store操作/存储转发
  6. gp3688 uhf2扩频_摩托罗拉GP3688对讲机(VHF、UHF)对讲机维修
  7. linux(centos8 ) 下安装anaconda3
  8. salesforce 架构设计_关于Salesforce证书维护重要通知
  9. 无向图的深度优先遍历非递归_LeetCode133-克隆图(附详细测试用例构建方法)
  10. 图形学理论知识 BRDF 双向反射分布函数