深度学习入门——03 MNIST手写数字图像集识别实验
- MNIST是机器学习中最有名的数据集之一,由0~9的手写数字图像构成,在下面实验中利用在上一篇初识神经网络中所学习到基本框架做一个简单实验,下面代码中许多基于《深度学习入门 基于Python的理论与实现》这本书提供的代码及资料,此文仅作文个人学习笔记,如有侵权,请联系删除。
- 下面实验主要是为了验证之前学习的神经网络的基本框架,用上述书中所提供的权重和偏置参数,对数据集中的测试图像做一个测试。
- 1、逐张图像进行学习:
- (1)load_mnist函数:
normalize:是否将图像正规化为0.0-1.0的值,如果设置为False,z=则图像输入保持0~255,这是像素的取值。
faltten:是否将输入图像展开为一维数组,否则图像为1x28x28,展开后为784。
one_hot_label:是否将标签保存为one_hot,指的是如原本图像的标签为[1,2,3,4,5],one_hot之后只有1和0,
经过计算后,最符合的标签为1,如识别出这个图像的数字是2,则one_hot表现为[0,1,0,0,0]这种模式。
import os
import pickle
import sys
sys.path.append(os.pardir) #为了导入父目录中的文件进行的设定。
from dataset.mnist import load_mnist
from PIL import Image
import numpy as np
from simple_neural import *"""
此代码基于《深度学习入门 基于pythom的理论与实践》一书所提供的源代码编写,其中使用了该书提供的
dataset中的资料,其中load_mnist是次书中写好的代码,这里是直接调用,如有侵权,请联系删除。
"""
def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize = True, flatten = True,\one_hot_label = False)return x_test, t_test
(2)、这里使用了pickle功能,此功能可以将程序运行中的对象保存为文件,第二次加载时可以快速复原此程序运行中的对象。下面函数调用的内容也基于上述提到的书,init_函数会快速调用之前保存simple_weight.pkl中的
内容,调取其已经学习之后所确定的比较好的权重与偏置参数。
def init_network():with open ("sample_weight.pkl", 'rb') as f:network = pickle.load(f)return networkdef net_predict(network, x):w1, w2, w3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, w1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, w2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, w3) + b3y = softmax(a3)return yx, t = get_data()
network = init_network()
# 精确度
accuracy = 0;
for i in range(len(x)): # [0,1,.....len(x)]y = net_predict(network, x[i])p = np.argmax(y) # 获取y中返回列表中最大值的所以,对应的就标签。如第三位对应值最大,# 那么对应的标签就是数字2。if p == t[i]:accuracy += 1print("Accuraacy: ", str(float(accuracy) / len(x)))
# Accuraacy: 0.9352 根据之前的权重参数,对测试图像进行分类,共有93.52%的图像正确归类。
2、批处理进行学习:
上述实验中我们每次处理一张图片,中间有两层隐藏层分别有50和100个神经元,我们每次读入一张图片,如果我们改成每次处理100张图片,基于数值计算的库都能够高效处理大型数组的运算,并且在神经网络的运算中当数据传送成为瓶颈时,批处理可以减小数据总线的负荷,将更多资源用于计算上。与前面相同部分:
import os
import pickle
import sys
sys.path.append(os.pardir) #为了导入父目录中的文件进行的设定。
from dataset.mnist import load_mnist
from PIL import Image
import numpy as np
from simple_neural import *def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize = True, flatten = True,\one_hot_label = False)return x_test, t_testdef init_network():with open ("sample_weight.pkl", 'rb') as f:network = pickle.load(f)return networkdef net_predict(network, x):w1, w2, w3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, w1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, w2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, w3) + b3y = softmax(a3)return y
不同之处:
(1)、range函数 range(start, end, step) 从start到end-1,以step为间隔的一组数据,如range(0, 9, 2)生成[0, 2, 4, 6, 8]的序列。
(2)、axis是对横着方向对数据进行处理,如argmax([0.05, 0.15, 0.8],[0.3, 0.5, 0.2], [0.3, 0.1, 0.6])得到的是[2, 1, 2]。
(3)、最后np.sum(p == t[i:i+batch_num])是比较分类结构与实际的相等情况,p ==t[…]返回的是bool型的数据用np.sum统计True的个数。如x = [1, 2, 3, 4], y = [1, 3, 3, 4],x == y 返回[T, F,T, T]。由np.sum统计T的数据。
x, t = get_data()
network = init_network()
# 精确度
accuracy = 0;
# 批处理一批的数量
batch_num = 100
for i in range(0, len(x), batch_num): # [0-batch_num-1, bctach_num~2*batach_num-1......]x_batch = x[i:i+batch_num] #取出i到i+batcch_num内的数据,不包括i+batch_num这个值。 y_batch = net_predict(network, x_batch)p = np.argmax(y_batch, axis=1) accuracy += np.sum(p == t[i:i+batch_num])print("Accuraacy: ", str(float(accuracy) / len(x)))
# Accuraacy: 0.9352 根据之前的权重参数,对测试图像进行分类,共有93.52%的图像正确归类。
- 此文中用到上述书所提供的函数,如Load_mnist函数,在下篇的博客中再具体分析是如何构造的。
深度学习入门——03 MNIST手写数字图像集识别实验相关推荐
- 深度学习——CNN实现MNIST手写数字的识别
活动地址:CSDN21天学习挑战赛 目录 知识点介绍 MNIST 介绍 下载 数据的简单处理 CNN神经网络 CNN的作用 CNN的主要特征 CNN的神经网络结构 CNN的相关参数 MNIST识别的 ...
- Tensorflow 学习笔记:Mnist 手写训练集调试,准确率变为0.1的解决办法及如何将准确率调高到98%以上
学习笔记:Mnist 手写训练集 加入隐藏层后准确率变为0.1的解决办法 提高神经网络准确率的尝试 提高准确率:调小每次训练的批次大小 提高准确率:使用交叉熵 更改优化器及学习率 小结 提高神经网络准 ...
- 深度学习笔记(MNIST手写识别)
先看了点花书,后来觉得有点枯燥去看了b站up主六二大人的pytorch深度学习实践的课,对深度学习的理解更深刻一点,顺便做点笔记,记录一些我认为重要的东西,便于以后查阅. 一. 机器学习基础 学习的定 ...
- 深度学习(4)手写数字识别实战
深度学习(4)手写数字识别实战 Step0. 数据及模型准备 1. X and Y(数据准备) 2. out=relu{relu{relu[X@W1+b1]@W2+b2}@W3+b3}out=relu ...
- 深度学习(3)手写数字识别问题
深度学习(3)手写数字识别问题 1. 问题归类 2. 数据集 3. Image 4. Input and Output 5. Regression VS Classification 6. Compu ...
- 深度学习 卷积神经网络-Pytorch手写数字识别
深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...
- DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别
DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别 目录 输出结果 实现代码 输出结果 实现代码 from _ ...
- Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集简介、下载、使用方法(包括数据增强)之详细攻略
Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集简介+数据增强(将已有MNIST数据集通过移动像素上下左右的方法来扩大数据集为初始数据集的5倍) 目录 MNIST ...
- Dataset之MNIST:MNIST(手写数字图片识别及其ubyte.gz文件)数据集简介、下载、使用方法(包括数据增强)之详细攻略
Dataset之MNIST:MNIST(手写数字图片识别及其ubyte.gz文件)数据集简介.下载.使用方法(包括数据增强,将已有MNIST数据集通过移动像素上下左右的方法来扩大数据集为初始数据集的5 ...
- 【AI参赛经验】深度学习入门指南:从零开始TinyMind汉字书法识别——by:Link
各位人工智能爱好者,大家好! 由TinyMind发起的#第一届汉字书法识别挑战赛#正在火热进行中,比赛才开始3周,已有数只黑马冲进榜单.目前TOP54全部为90分以上!可谓竞争激烈,高手如林.不是比赛 ...
最新文章
- git 提交的时候报错:error: 'flutter_app/' does not have a commit checked out
- C++ VARIANT 学习小记录
- React 中 $$typeof 的作用
- 006_FastDFS文件上传
- Linux中yum和apt-get
- Python实现快速排序(非递归实现)
- Solarized ----vim配色方案
- CSDN×易观算法大赛火热进行中~
- Python3.x爬虫教程:爬网页、爬图片、自己主动登录
- RedHat Linux安装Informix v10.x(图文详解)
- 【转】cron表达式详解
- wincc嵌入式excel报表 该报表系统能够读取WINCC中历史归档数据,产生出EXCEL报表文件,同时在画面中EXCEL控件实时显示
- 应聘总经理的答卷,供大家打分!(二)
- 升级Windows7到旗舰版
- 用户画像业务数据调研及ETL(二)持续更新中...
- 代码的坏味道之十七 :Inappropriate Intimacy(狎昵关系)
- python68个内置函数_Python中68个内置函数的总结
- 矩阵 的逆、 迹、 秩
- Apache Flink_JZZ_MBY
- Advanced IP Scanner教程 详细使用方法