文章

  • 一、MNIST数据集及Softmax
    • 1.MNIST数据集
    • 2.Softmax
  • 二、MNIST数据集分类
    • 1.导入第三方库
    • 2.加载数据及数据预处理
    • 3.训练模型

一、MNIST数据集及Softmax

1.MNIST数据集

大多数示例使用手写数字的MNIST数据集。该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。

每一张图片包含28*28个像素,在MNIST 训练数据集 中是一个 形状为[60000,28,28] 的张量,我们首先需要把数据集转成[60000,784],然后才能放到网络中训练第一个维度数字用来索引图片第二个维度数字用来索引每张图片中的像素点。一般我们还需要把图片中的数据归一化0~1之间

MNIST数据集的标签介于0-9的数字,我们要把标签转化为"one-hotvectors"。一个one-hot向量除了一位数字是1外,其余维度数字都是0比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]),标签3将表示为([0,0,0,1,0,0,0,0,0,0])。
因此,MNIST数据集的标签是一个[60000,10]的数字矩阵。

28*28=784,每张图片784个像素点,对应着784个神经元。最后输出10个神经元对应着10个数字

2.Softmax

Softmax作用就是把神经网络的输出转化为概率值。
我们知道MNIST的结果是0-9,我们模型可能推测出一张图片的数字9的概率是80%,是数字8的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。


二、MNIST数据集分类

代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。

1.导入第三方库

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD

2.加载数据及数据预处理

# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
#  (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,)  还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784)  reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)

3.训练模型

# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([# 定义输出是10 输入是784,设置偏置为1,添加softmax激活函数Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
# 定义优化器
sgd = SGD(lr=0.2)# 定义优化器,loss_function,训练过程中计算准确率
model.compile(optimizer=sgd,loss="mse",metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)print("\ntest loss",loss)
print("accuracy:",accuracy)

最终运行结果:

注意

  • Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax")这里用到了softmax激活函数
  • 这里我们使用的fit方法进行的模型训练,之前的线性回归和非线性回归的模型训练方式和这不同。

代码:

model.compile(optimizer=sgd,loss="mse",metrics=['accuracy']
)

添加metrics=['accuracy'], 可以在训练过程中计算准确率。

3.MNIST数据集分类相关推荐

  1. 二隐层的神经网络实现MNIST数据集分类

    二隐层的神经网络实现MNIST数据集分类 传统的人工神经网络包含三部分,输入层.隐藏层和输出层.对于一个神经网络模型的确定需要考虑以下几个方面: 隐藏层的层数以及各层的神经元数量 各层激活函数的选择 ...

  2. 深度学习笔记(2)——pytorch实现MNIST数据集分类(FNN、CNN、RNN、LSTM、GRU)

    文章目录 0 前言 1 数据预处理 2 FNN(前馈神经网络) 3 CNN(卷积神经网络) 4 RNN(循环神经网络) 5 LSTM(长短期记忆网络) 6 GRU(门控循环单元) 7 完整代码 0 前 ...

  3. fashionmnist数据集_Keras实现Fashion MNIST数据集分类

    本篇用keras构建人工神经网路(ANN)和卷积神经网络(CNN)实现Fashion MNIST 数据集单个物品分类,并从模型预测的准确性方面对ANN和CNN进行简单比较. Fashion MNIST ...

  4. MNIST 数据集分类

    构建简单的CNN对 mnist 数据集进行分类.同时,还会在实验中学习池化与卷积操作的基本作用. 1. 引入库文件 mport torch import torch.nn as nn import t ...

  5. Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务

    关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!! 可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行 第一步:基本库的导入 import n ...

  6. Tensorflow— MNIST数据集分类简单版本

    代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data#载入数据集 #当前路径 m ...

  7. 构建单层单向RNN网络对MNIST数据集分类

    一.导入数据集 1 import tensorflow as tf 2 import numpy as np 3 #清除默认图形堆栈并重置全局默认图形,tf.reset_default_graph函数 ...

  8. Tensorflow—CNN应用于MNIST数据集分类

    代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_datamnist = input_ ...

  9. 基于Conv3D实现三维立体MNIST数据集分类

    前言 大家好,我是阿光. 本专栏整理了<PyTorch深度学习项目实战100例>,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集. 正在更新 ...

  10. MNIST数据集分类

    import numpy as np from keras.datasets import mnist from keras.utils import np_utils from keras.mode ...

最新文章

  1. (14)某工业生产部门根据国家计划的安排, 拟将某种高效率的5台机器,分配给所属的3个工厂A,B,C,各工厂在获得这种机器后,可以为国家盈利的情况如表4-10所示。
  2. InnoDB 存储引擎中的表锁和行锁详解
  3. 由浅入深CIL系列:6.For和Foreach的CIL结构组成以及运行效率
  4. 【风控体系】互联网反欺诈体系漫谈
  5. 最佳展示场景kit:顶部视角的展示场景(Mockups)
  6. react加水印_给网页增加水印的方法,react
  7. 在河北大学就读是怎样一种体验?
  8. 中国网络游戏上市突击大事记
  9. 2019测试指南-web应用程序安全测试(二)指纹Web应用程序
  10. Elasticsearch入门教程(六):Elasticsearch查询(二)
  11. MySQL查询函数---为表和字段取别名
  12. http://www.jb51.net/article/41274.htm
  13. Cocos2dx 之 cocosbuilder的使用
  14. 【Java】基础09
  15. windows系统安装配置mysql8,并设置远程访问
  16. TP,TN,FP,FN,F1,TPR,FPR (一图看懂)
  17. React本地化解决版本更迭出现的缓存问题
  18. codeforces 规则
  19. C语言运算符号优先级别
  20. 网络版杀毒软件部署……

热门文章

  1. Repast——参数栏实现下拉列表对应不同的功能实现
  2. Moodle 安装出现访问空白和open_basedir问题
  3. 欺骗的艺术.资料搜集自互联网.
  4. chartControl
  5. 学校后勤物资管理系统
  6. 意凡社:盘点那些令网赚者疯狂的时代!
  7. eXeScope之类的程序资源修改器的使用,很牛!
  8. 结合eXeScope打造个性flash发布后的应用程序exe文件
  9. flash mx拖拽实例_Flash MX 2004的注释添加器面板
  10. Delphi android 开发视频教程