前言: Hello大家好,我是Dream。 今天来学习一下如何基于mnist数据集取得最高的识别准确率,本文是从零开始的,如有需要可自行跳至所需内容~

本文目录:

    • 1.调用库函数
    • 2.调用数据集
    • 3.选择模型,构建网络
    • 4.编译
    • 5.数据增强
    • 6.训练
    • 7.画出图像
    • 8.输出
    • 9.结果
  • 源码获取

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下这个任务:

  • 基于mnist数据集,尽量取得更好的识别准确率。注意,要使用非训练集内容,通过evaluate方法得出准确率

1.调用库函数

import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D,MaxPooling2D,BatchNormalization,Flatten,Dense

指定当前程序使用的 GPU

# 指定当前程序使用的 GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)

2.调用数据集

# 调用数据集
(train_X, train_y),(test_X, test_y) = tf.keras.datasets.mnist.load_data()
train_X, test_X = train_X / 255.0, test_X / 255.0
train_X = train_X.reshape(-1, 28, 28, 1)
train_y = tf.keras.utils.to_categorical(train_y)
X_train, X_test, y_train, y_test = train_test_split(train_X, train_y, test_size=0.1, random_state=0)

3.选择模型,构建网络

在此我们使用的是CNN网络,以此搭建Conv2D层、MaxPooling2D层网络

# 选择模型,构建网络
model = tf.keras.Sequential()
model.add(Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1))),  #添加Conv2D层
model.add(Conv2D(64, (3, 3), activation='relu')),  #添加Conv2D层
model.add(MaxPooling2D((2, 2), strides=2)),  #添加MaxPooling2D层
model.add(BatchNormalization()),model.add(Conv2D(128, (3, 3), activation='relu')),  #添加Conv2D层
model.add(Conv2D(128, (3, 3), activation='relu')),  #添加Conv2D层
model.add(MaxPooling2D((2, 2), strides=2)),  #添加MaxPooling2D层
model.add(BatchNormalization()),model.add(Conv2D(256, (3, 3), activation='relu')),  #添加Conv2D层
model.add(MaxPooling2D((2, 2), strides=2)),  #添加MaxPooling2D层
model.add(BatchNormalization()),model.add(Flatten()),  #展平
model.add(Dense(512, activation='relu')),
model.add(Dense(10, activation='softmax'))

4.编译

使用交叉熵作为loss函数,指定优化器、损失函数和验证过程中的评估指标

# 编译(使用交叉熵作为loss函数)
model.compile(optimizer='adam',  #指定优化器loss="categorical_crossentropy",   #指定损失函数metrics=['accuracy'])   #指定验证过程中的评估指标
# 展示训练的过程
display(model.summary())

这里是输出的结果:

5.数据增强

在这里我们使用数据增强方法,更好的提高准确率

# 数据增强
datagen = ImageDataGenerator(rotation_range=15,zoom_range = 0.01,width_shift_range=0.1,height_shift_range=0.1)
train_gen = datagen.flow(X_train, y_train, batch_size=128)
test_gen = datagen.flow(X_test, y_test, batch_size=128)

6.训练

首先我们批量输入的样本个数,然后经过我们测试分析,此模型训练到30轮之前变化趋于静止,我们可以只进行30个epoch。

# 批量输入的样本个数
batch_size = 128
train_steps = X_train.shape[0] // batch_size
valid_steps = X_test.shape[0] // batch_size# 经过我们测试分析,此模型训练到30轮之前变化趋于静止,我们可以只进行30个epoch
es = tf.keras.callbacks.EarlyStopping(monitor="val_accuracy",patience=10,verbose=1,mode="max",restore_best_weights=True)rp = tf.keras.callbacks.ReduceLROnPlateau(monitor="val_accuracy",factor=0.2,patience=5,verbose=1,mode="max",min_lr=0.00001)# 训练(训练30个epoch)
history = model.fit(train_gen,epochs = 30,steps_per_epoch = train_steps,validation_data = test_gen,validation_steps = valid_steps,callbacks=[es, rp])

这里是输出的结果:

7.画出图像

使用plt模块进行数据可视化处理

# 画出图像
fig, ax = plt.subplots(2,1, figsize=(14, 10))
ax[0].plot(history.history['loss'], color='red', label="Loss")
ax[0].legend(loc='best', shadow=False)
ax[1].plot(history.history['accuracy'], color='red', label="Accuracy")
ax[1].legend(loc='best', shadow=False)
plt.show()

这里是输出的结果:

8.输出

最后在测试集上进行模型评估,输出测试集上的预测准确率

score = model.evaluate(X_test, y_test) # 在测试集上进行模型评估
print('测试集预测准确率:', score[1]) # 打印测试集上的预测准确率

这里是输出的结果:

9.结果

最后的结果:mnist数据集最终的准确率是: 0.996833

源码获取

关注此公众号:人生苦短我用Pythons,回复 神经网络实验获取源码,快点击我吧

神经网络--基于mnist数据集取得最高的识别准确率相关推荐

  1. 神经网络——实现MNIST数据集的手写数字识别

    由于官网下载手写数字的数据集较慢,因此提供便捷下载地址如下 手写数字的数据集MNIST下载:https://download.csdn.net/download/gaoyu1253401563/108 ...

  2. 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

    机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)

  3. [Pytorch系列-41]:卷积神经网络 - 模型参数的恢复/加载 - 搭建LeNet-5网络与MNIST数据集手写数字识别

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  4. [转载] 卷积神经网络做mnist数据集识别

    参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...

  5. 基于Python实现的卷积神经网络分类MNIST数据集

    卷积神经网络分类MNIST数据集 目录 人工智能第七次实验报告 1 卷积神经网络分类MNIST数据集 1 一 .问题背景 1 1.1 卷积和卷积核 1 1.2 卷积神经网络简介 2 1.3 卷积神经网 ...

  6. 基于Python实现的神经网络分类MNIST数据集

    神经网络分类MNIST数据集 目录 神经网络分类MNIST数据集 1 一 .问题背景 1 1.1 神经网络简介 1 前馈神经网络模型: 1 1.2 MINST 数据说明 4 1.3 TensorFlo ...

  7. GAN生成对抗网络基本概念及基于mnist数据集的代码实现

    本文主要总结了GAN(Generative Adversarial Networks) 生成对抗网络的基本原理并通过mnist数据集展示GAN网络的应用. GAN网络是由两个目标相对立的网络构成的,在 ...

  8. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  9. 基于MNIST数据集的最优参数的方法的比较

    前面章节我们知道神经网络的目的是寻找最优参数,介绍了四种以及两种改进的方法来寻找最优参数,并画图进行了比较,具体可参阅 神经网络技巧篇之寻找最优参数的方法https://blog.csdn.net/w ...

最新文章

  1. Cisco2620路由器的密码恢复和灾难性恢复
  2. JZOJ 5443. 【NOIP2017提高A组冲刺11.2】字典序
  3. SQLite 3 一些基本的使用
  4. 8.Eclipse中创建Maven Web项目
  5. saltstack之基础入门系列文章简介
  6. 赚钱真的要抓住风口,抓住风口猪都能飞
  7. 解决jsp页面乱码问题
  8. 【图像分割】基于形态学重建和过滤改进FCM算法(FRFCM)的图像分割【含Matlab源码 085期】
  9. 股票交易数据下载 | 下载股票历史交易数据到本地Excel
  10. java怎么打不开vos_JAVA如何调用VOS2009接口
  11. Gson将json转Map的那些坑
  12. python format是什么意思_python的format什么意思
  13. 介绍PS工具“仿制图章工具”和“图案图章工具”
  14. 用python程序计算勾股数,用Python程序计算勾股数
  15. python读取lst文件
  16. Ubuntu16.04系统+GTX1050TI显卡的tensorflow1.6(GPU版)安装-详细图文
  17. PCIe系列专题之二:2.3 TLP结构解析
  18. 四足机器人champ项目和高程图构建elevation_mapping联合使用(Ubuntu18.04)
  19. Tailwind 初识
  20. SourceTree解决冲突的三种情形

热门文章

  1. 区块链与物联网技术结合为传统行业发展带来全新机遇
  2. 关于 URLLC场景下的 the Short Blocklength Regime
  3. 四舍五入保留一位小数
  4. yubikey复制_将YubiKey与ISAM一起使用
  5. 群雄逐鹿 浏览器之战将进入HTML 5时代
  6. PCIE总线理解笔记
  7. Element UI 之table表格表头过长使用点点…显示,并添加鼠标移入悬浮显示
  8. Java 里 NonNull 和 NotNull 区别
  9. 使用C语言实现汉诺塔问题——递归
  10. web导入excel(利用POI解析)