MNIST数据集手写数字识别(二)
上一篇对MNIST数据集有了一些了解,数据集包含着60000张训练图片与标签值和10000张测试图片与标签值的数据集,数据集有了,现在我们来构造神经网络,预测下对这测试的10000张图片的正确识别率,也就是看下手写数字的识别率的情况。
def getTestData():"""获取测试数据"""(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True,flatten=True,one_hot_label=False)return x_test,t_test
getTestData()[0].shape
(10000, 784)
getTestData()[1].shape
(10000,)
获取测试的数据,通过形状可以看到有10000张784(28x28像素)的图片的矩阵和10000张标签值的一维数组!
在构造神经网络过程中,不可或缺的就是权重和偏置,我们读取已预设好的权重和偏置文文件(sample_weight.pkl),在后期的学习中将如何选定参数。
import sys,os,numpy as np
sys.path.append('D:\Anaconda3\TONYTEST')
os.chdir('D:\Anaconda3\TONYTEST\dataset')
from dataset.mnist import load_mnist
import pickledef getweights():"""获取权重和偏置数据"""with open('sample_weight.pkl','rb') as f:wbs=pickle.load(f)return wbs
sample_weight.pkl是一个权重和偏置的参数的pkl文件,内容是字典类型保存,分别查看下权重和偏置的形状,我们看到形状对应维度的数量一致,知道可以做点积运算
wb=getweights()
wb['W1'].shape,wb['W2'].shape,wb['W3'].shape
((784, 50), (50, 100), (100, 10))
wb['b1'].shape,wb['b2'].shape,wb['b3'].shape
((50,), (100,), (10,))
def predict(wbs,x):"""以NumPy数组的形式输出各个标签对应的概率"""W1,W2,W3=wbs['W1'],wbs['W2'],wbs['W3']b1,b2,b3=wbs['b1'],wbs['b2'],wbs['b3']a1=np.dot(x,W1)+b1z1=sigmoid(a1)a2=np.dot(z1,W2)+b2z2=sigmoid(a2)a3=np.dot(z2,W3)+b3y=softmax(a3)return y
现在来看下使用测试数据来查看识别精度,测试识别精度,都使用测试数据,不能使用训练数据
x,t=getTestData()
wb=getweights()
accuracy_cnt=0
for i in range(len(x)):y=predict(wb,x[i])p=np.argmax(y) #获取概率最高的元素的索引值#概率最高的索引值如果和测试的标签值相等就加1if p==t[i]:accuracy_cnt+=1
print('识别精确率:{:.2%}'.format(accuracy_cnt/len(x)))
识别精确率:93.52%
也就是说10000张测试图片,正确识别了9352张
上面用到的一些公用函数,可以统一写在一个common目录里面的funtions.py
import numpy as np
def sigmoid(x):return 1 / (1 + np.exp(-x)) def softmax(x):if x.ndim == 2:x = x.T#矩阵转置x = x - np.max(x, axis=0)y = np.exp(x) / np.sum(np.exp(x), axis=0)return y.T x = x - np.max(x) # 溢出对策return np.exp(x) / np.sum(np.exp(x))
使用方法跟上一篇文章中的使用dataset一样
from common.functions import sigmoid,softmax
上面的识别精确率,是一张一张处理,现在改进一下做批处理,也就是一次性输入比如100张图片,那10000张图片就100批次搞定,批处理可以大大缩短处理时间。
x,t=getTestData()
wb=getweights()
batch_size=100#每次批处理的数量
accuracy_cnt=0
for i in range(0,len(x),batch_size):x_batch=x[i:i+batch_size]y_batch=predict(wb,x_batch)p=np.argmax(y_batch,axis=1)accuracy_cnt+=np.sum(p==t[i:i+batch_size])
print("精确率:{:0.2%}".format(accuracy_cnt/len(x)))
精确率:93.52%
其中axis=1表示沿着第一维方向(第一维为轴),或者我是这么理解它,axis=0代表列,axis=1代表行,如下
q=np.array([[1,2,3],[22,39,1],[10,9,12],[0.9,0.1,8]])
'''
array([[ 1. , 2. , 3. ],[ 22. , 39. , 1. ],[ 10. , 9. , 12. ],[ 0.9, 0.1, 8. ]])
'''
a0=np.argmax(q,axis=0)
array([1, 1, 2], dtype=int64)a1=np.argmax(q,axis=1)
array([2, 1, 2, 2], dtype=int64)
一般来说都是axis=1,因为q.shape是(4,3),4行3列,那么我们取最大值的索引值,也是每行取一个最大值,回到前面的(10000,784)也是这样,10000张图片,每张图片是784列,那么我们肯定也是取10000张标签值,所以每行取一个!
np.max(q,axis=1)array([ 3., 39., 12., 8.])
MNIST数据集手写数字识别(二)相关推荐
- [Pytorch系列-41]:卷积神经网络 - 模型参数的恢复/加载 - 搭建LeNet-5网络与MNIST数据集手写数字识别
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...
- MNIST数据集手写数字识别
1 数据集介绍 MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它"下手"几乎成为一个 "典范",可以说它就是计算机视觉里面的Hell ...
- MNIST数据集手写数字识别(一)
MNIST数据集是初步学习神经网络的很好的数据集,也是拿来教学,不可多得的好教材,有很多知识点在里面.官网下载地址,可以自己手动下载,当然也可以通过下面的代码自动下载[urllib.request(3 ...
- Python深度学习之分类模型示例,MNIST数据集手写数字识别
MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片. 我们把60000个训练样本分成两部分,前 ...
- PyTorch:MNIST数据集手写数字识别
MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它"下手"几乎成为一个 "典范",可以说它就是计算机视觉里面的Hello World. ...
- tensorflow笔记(曹健老师):mnist数据集手写数字识别
分为三部分:前向传播,反向传播,数据测试.适应一下简单的结构化编程 第一部分:前向传播(mnist_forward.py) #前向传播,两层神经网络 import tensorflow as tf i ...
- MNIST数据集手写数字识别(CNN)
- 基于tensorflow+RNN的MNIST数据集手写数字分类
2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...
- MNIST数据集手写数字分类
参考 MNIST数据集手写数字分类 - 云+社区 - 腾讯云 目录 0.编程环境 1.下载并解压数据集 2.完整代码 3.数据准备 4.数据观察 4.1 查看变量mnist的方法和属性 4.2 对 ...
最新文章
- Day 05 名人能树立好榜样吗
- 洛谷 P1036 选数
- Eclipse实现hibernate反向工程:从数据库逆向生成实体类和hbm文件
- python绘制繁花曲线代码_使用python和pygame绘制繁花曲线的方法
- 【优雅代码】深入浅出 妙用Javascript中apply、call、bind
- PPT 如何做好关卡设计
- swagger2-接口文档
- 微信小程序+.NET(十八) ffmpeg音频转码/拼接/混合
- moodle php代码解读_Moodle插件开发笔记
- 整理出的安卓国家码,简称,语言的Json文件,可以一一对应国旗
- 希捷服务器硬盘格式化不了,希捷硬盘专用分区格式化Seagate DiscWizard16.0 官方版...
- 基于html的项目的选题报告,团队项目-选题报告
- 企业微信好友无上限,私域流量即将迎来春天?
- startActivitystartActivities有什么不同?
- agent常见处理问题的处理
- Robot Framework Selenium UI自动化测试 --- 实战篇
- 剑指offer中使用辅助栈方法的题目的整理(待更)
- 安装依赖报错:An unexpected error occurred: “E:\\ReactProject\\umi-project\\package.json:
- 人称代词I/my/mine/me 用法
- 蓝牙耳机+大鼠标垫+笔记本电脑支架
热门文章
- Qt on Android 蓝牙通信开发
- 只有在配置文件或 Page 指令中将 enableSessionState”的异常解决办法
- AndroidStudio_使用gradle添加依赖jar包_依赖模块---Android原生开发工作笔记78
- Configuration property name ‘fdfs.thumbImage‘ is not valid---springcloud工作笔记163
- 数据库工作笔记002---Linux下开启,重启,关闭mysql
- el 表达式 可以解析的数据类型
- 发现个特别合胃口的仓鼠、小鱼和计数器代码
- BCD与ASCII码互转-C语言实现
- 添加中文菜单项出现乱码的解决办法
- 给定一个N位数,得到一个N-k位的数中最小的数