3.MNIST数据集分类
文章
- 一、MNIST数据集及Softmax
- 1.MNIST数据集
- 2.Softmax
- 二、MNIST数据集分类
- 1.导入第三方库
- 2.加载数据及数据预处理
- 3.训练模型
一、MNIST数据集及Softmax
1.MNIST数据集
大多数示例使用手写数字的MNIST数据集。该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。
每一张图片包含28*28个像素,在MNIST 训练数据集 中是一个 形状为[60000,28,28] 的张量,我们首先需要把数据集转成[60000,784],然后才能放到网络中训练。第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。一般我们还需要把图片中的数据归一化0~1之间。
MNIST数据集的标签是介于0-9的数字,我们要把标签转化为"one-hotvectors"。一个one-hot向量除了一位数字是1外,其余维度数字都是0,比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]),标签3将表示为([0,0,0,1,0,0,0,0,0,0])。
因此,MNIST数据集的标签是一个[60000,10]的数字矩阵。
28*28=784,每张图片有784个像素点,对应着784个神经元。最后输出10个神经元对应着10个数字。
2.Softmax
Softmax作用就是把神经网络的输出转化为概率值。
我们知道MNIST的结果是0-9,我们模型可能推测出一张图片的数字9的概率是80%,是数字8的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。
二、MNIST数据集分类
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([# 定义输出是10 输入是784,设置偏置为1,添加softmax激活函数Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
# 定义优化器
sgd = SGD(lr=0.2)# 定义优化器,loss_function,训练过程中计算准确率
model.compile(optimizer=sgd,loss="mse",metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print("\ntest loss",loss)
print("accuracy:",accuracy)
最终运行结果:
注意
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax")
这里用到了softmax激活函数。- 这里我们使用的
fit
方法进行的模型训练,之前的线性回归和非线性回归的模型训练方式和这不同。
代码:
model.compile(optimizer=sgd,loss="mse",metrics=['accuracy']
)
中添加metrics=['accuracy']
, 可以在训练过程中计算准确率。
3.MNIST数据集分类相关推荐
- 二隐层的神经网络实现MNIST数据集分类
二隐层的神经网络实现MNIST数据集分类 传统的人工神经网络包含三部分,输入层.隐藏层和输出层.对于一个神经网络模型的确定需要考虑以下几个方面: 隐藏层的层数以及各层的神经元数量 各层激活函数的选择 ...
- 深度学习笔记(2)——pytorch实现MNIST数据集分类(FNN、CNN、RNN、LSTM、GRU)
文章目录 0 前言 1 数据预处理 2 FNN(前馈神经网络) 3 CNN(卷积神经网络) 4 RNN(循环神经网络) 5 LSTM(长短期记忆网络) 6 GRU(门控循环单元) 7 完整代码 0 前 ...
- fashionmnist数据集_Keras实现Fashion MNIST数据集分类
本篇用keras构建人工神经网路(ANN)和卷积神经网络(CNN)实现Fashion MNIST 数据集单个物品分类,并从模型预测的准确性方面对ANN和CNN进行简单比较. Fashion MNIST ...
- MNIST 数据集分类
构建简单的CNN对 mnist 数据集进行分类.同时,还会在实验中学习池化与卷积操作的基本作用. 1. 引入库文件 mport torch import torch.nn as nn import t ...
- Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务
关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!! 可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行 第一步:基本库的导入 import n ...
- Tensorflow— MNIST数据集分类简单版本
代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data#载入数据集 #当前路径 m ...
- 构建单层单向RNN网络对MNIST数据集分类
一.导入数据集 1 import tensorflow as tf 2 import numpy as np 3 #清除默认图形堆栈并重置全局默认图形,tf.reset_default_graph函数 ...
- Tensorflow—CNN应用于MNIST数据集分类
代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_datamnist = input_ ...
- 基于Conv3D实现三维立体MNIST数据集分类
前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...
- MNIST数据集分类
import numpy as np from keras.datasets import mnist from keras.utils import np_utils from keras.mode ...
最新文章
- (14)某工业生产部门根据国家计划的安排, 拟将某种高效率的5台机器,分配给所属的3个工厂A,B,C,各工厂在获得这种机器后,可以为国家盈利的情况如表4-10所示。
- InnoDB 存储引擎中的表锁和行锁详解
- 由浅入深CIL系列:6.For和Foreach的CIL结构组成以及运行效率
- 【风控体系】互联网反欺诈体系漫谈
- 最佳展示场景kit:顶部视角的展示场景(Mockups)
- react加水印_给网页增加水印的方法,react
- 在河北大学就读是怎样一种体验?
- 中国网络游戏上市突击大事记
- 2019测试指南-web应用程序安全测试(二)指纹Web应用程序
- Elasticsearch入门教程(六):Elasticsearch查询(二)
- MySQL查询函数---为表和字段取别名
- http://www.jb51.net/article/41274.htm
- Cocos2dx 之 cocosbuilder的使用
- 【Java】基础09
- windows系统安装配置mysql8,并设置远程访问
- TP,TN,FP,FN,F1,TPR,FPR (一图看懂)
- React本地化解决版本更迭出现的缓存问题
- codeforces 规则
- C语言运算符号优先级别
- 网络版杀毒软件部署……