上一篇对MNIST数据集有了一些了解,数据集包含着60000张训练图片与标签值和10000张测试图片与标签值的数据集,数据集有了,现在我们来构造神经网络,预测下对这测试的10000张图片的正确识别率,也就是看下手写数字的识别率的情况。

def getTestData():"""获取测试数据"""(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True,flatten=True,one_hot_label=False)return x_test,t_test

getTestData()[0].shape
(10000, 784)
getTestData()[1].shape
(10000,)

获取测试的数据,通过形状可以看到有10000张784(28x28像素)的图片的矩阵和10000张标签值的一维数组!

在构造神经网络过程中,不可或缺的就是权重和偏置,我们读取已预设好的权重和偏置文文件(sample_weight.pkl),在后期的学习中将如何选定参数。

import sys,os,numpy as np
sys.path.append('D:\Anaconda3\TONYTEST')
os.chdir('D:\Anaconda3\TONYTEST\dataset')
from dataset.mnist import load_mnist
import pickledef getweights():"""获取权重和偏置数据"""with open('sample_weight.pkl','rb') as f:wbs=pickle.load(f)return wbs

sample_weight.pkl是一个权重和偏置的参数的pkl文件,内容是字典类型保存,分别查看下权重和偏置的形状,我们看到形状对应维度的数量一致,知道可以做点积运算
wb=getweights()
wb['W1'].shape,wb['W2'].shape,wb['W3'].shape
((784, 50), (50, 100), (100, 10))
wb['b1'].shape,wb['b2'].shape,wb['b3'].shape
((50,), (100,), (10,))

def predict(wbs,x):"""以NumPy数组的形式输出各个标签对应的概率"""W1,W2,W3=wbs['W1'],wbs['W2'],wbs['W3']b1,b2,b3=wbs['b1'],wbs['b2'],wbs['b3']a1=np.dot(x,W1)+b1z1=sigmoid(a1)a2=np.dot(z1,W2)+b2z2=sigmoid(a2)a3=np.dot(z2,W3)+b3y=softmax(a3)return y

现在来看下使用测试数据来查看识别精度,测试识别精度,都使用测试数据,不能使用训练数据

x,t=getTestData()
wb=getweights()
accuracy_cnt=0
for i in range(len(x)):y=predict(wb,x[i])p=np.argmax(y) #获取概率最高的元素的索引值#概率最高的索引值如果和测试的标签值相等就加1if p==t[i]:accuracy_cnt+=1
print('识别精确率:{:.2%}'.format(accuracy_cnt/len(x)))

识别精确率:93.52%
也就是说10000张测试图片,正确识别了9352张

上面用到的一些公用函数,可以统一写在一个common目录里面的funtions.py

import numpy as np
def sigmoid(x):return 1 / (1 + np.exp(-x))    def softmax(x):if x.ndim == 2:x = x.T#矩阵转置x = x - np.max(x, axis=0)y = np.exp(x) / np.sum(np.exp(x), axis=0)return y.T x = x - np.max(x) # 溢出对策return np.exp(x) / np.sum(np.exp(x))

使用方法跟上一篇文章中的使用dataset一样
from common.functions import sigmoid,softmax

上面的识别精确率,是一张一张处理,现在改进一下做批处理,也就是一次性输入比如100张图片,那10000张图片就100批次搞定,批处理可以大大缩短处理时间。

x,t=getTestData()
wb=getweights()
batch_size=100#每次批处理的数量
accuracy_cnt=0
for i in range(0,len(x),batch_size):x_batch=x[i:i+batch_size]y_batch=predict(wb,x_batch)p=np.argmax(y_batch,axis=1)accuracy_cnt+=np.sum(p==t[i:i+batch_size])
print("精确率:{:0.2%}".format(accuracy_cnt/len(x)))

精确率:93.52%

其中axis=1表示沿着第一维方向(第一维为轴),或者我是这么理解它,axis=0代表列,axis=1代表行,如下

q=np.array([[1,2,3],[22,39,1],[10,9,12],[0.9,0.1,8]])
'''
array([[  1. ,   2. ,   3. ],[ 22. ,  39. ,   1. ],[ 10. ,   9. ,  12. ],[  0.9,   0.1,   8. ]])
'''
a0=np.argmax(q,axis=0)
array([1, 1, 2], dtype=int64)a1=np.argmax(q,axis=1)
array([2, 1, 2, 2], dtype=int64)

一般来说都是axis=1,因为q.shape是(4,3),4行3列,那么我们取最大值的索引值,也是每行取一个最大值,回到前面的(10000,784)也是这样,10000张图片,每张图片是784列,那么我们肯定也是取10000张标签值,所以每行取一个!

np.max(q,axis=1)array([  3.,  39.,  12.,   8.])

MNIST数据集手写数字识别(二)相关推荐

  1. [Pytorch系列-41]:卷积神经网络 - 模型参数的恢复/加载 - 搭建LeNet-5网络与MNIST数据集手写数字识别

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  2. MNIST数据集手写数字识别

    1 数据集介绍 MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它"下手"几乎成为一个 "典范",可以说它就是计算机视觉里面的Hell ...

  3. MNIST数据集手写数字识别(一)

    MNIST数据集是初步学习神经网络的很好的数据集,也是拿来教学,不可多得的好教材,有很多知识点在里面.官网下载地址,可以自己手动下载,当然也可以通过下面的代码自动下载[urllib.request(3 ...

  4. Python深度学习之分类模型示例,MNIST数据集手写数字识别

    MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片. 我们把60000个训练样本分成两部分,前 ...

  5. PyTorch:MNIST数据集手写数字识别

    MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它"下手"几乎成为一个 "典范",可以说它就是计算机视觉里面的Hello World. ...

  6. tensorflow笔记(曹健老师):mnist数据集手写数字识别

    分为三部分:前向传播,反向传播,数据测试.适应一下简单的结构化编程 第一部分:前向传播(mnist_forward.py) #前向传播,两层神经网络 import tensorflow as tf i ...

  7. MNIST数据集手写数字识别(CNN)

  8. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  9. MNIST数据集手写数字分类

    参考   MNIST数据集手写数字分类 - 云+社区 - 腾讯云 目录 0.编程环境 1.下载并解压数据集 2.完整代码 3.数据准备 4.数据观察 4.1 查看变量mnist的方法和属性 4.2 对 ...

最新文章

  1. Day 05 名人能树立好榜样吗
  2. 洛谷 P1036 选数
  3. Eclipse实现hibernate反向工程:从数据库逆向生成实体类和hbm文件
  4. python绘制繁花曲线代码_使用python和pygame绘制繁花曲线的方法
  5. 【优雅代码】深入浅出 妙用Javascript中apply、call、bind
  6. PPT 如何做好关卡设计
  7. swagger2-接口文档
  8. 微信小程序+.NET(十八) ffmpeg音频转码/拼接/混合
  9. moodle php代码解读_Moodle插件开发笔记
  10. 整理出的安卓国家码,简称,语言的Json文件,可以一一对应国旗
  11. 希捷服务器硬盘格式化不了,希捷硬盘专用分区格式化Seagate DiscWizard16.0 官方版...
  12. 基于html的项目的选题报告,团队项目-选题报告
  13. 企业微信好友无上限,私域流量即将迎来春天?
  14. startActivitystartActivities有什么不同?
  15. agent常见处理问题的处理
  16. Robot Framework Selenium UI自动化测试 --- 实战篇
  17. 剑指offer中使用辅助栈方法的题目的整理(待更)
  18. 安装依赖报错:An unexpected error occurred: “E:\\ReactProject\\umi-project\\package.json:
  19. 人称代词I/my/mine/me 用法
  20. 蓝牙耳机+大鼠标垫+笔记本电脑支架

热门文章

  1. Qt on Android 蓝牙通信开发
  2. 只有在配置文件或 Page 指令中将 enableSessionState”的异常解决办法
  3. AndroidStudio_使用gradle添加依赖jar包_依赖模块---Android原生开发工作笔记78
  4. Configuration property name ‘fdfs.thumbImage‘ is not valid---springcloud工作笔记163
  5. 数据库工作笔记002---Linux下开启,重启,关闭mysql
  6. el 表达式 可以解析的数据类型
  7. 发现个特别合胃口的仓鼠、小鱼和计数器代码
  8. BCD与ASCII码互转-C语言实现
  9. 添加中文菜单项出现乱码的解决办法
  10. 给定一个N位数,得到一个N-k位的数中最小的数