参考   MNIST数据集手写数字分类 - 云+社区 - 腾讯云

目录

0.编程环境

1、下载并解压数据集

2、完整代码

3、数据准备

4、数据观察

4.1 查看变量mnist的方法和属性

4.2 对比三个集合

4.3 mnist.train.images观察

4.4 查看手写数字图

5、搭建神经网络

6、变量初始化

7、模型训练

9、模型测试


MNIST是Mixed National Institue of Standards and Technology database的简称,中文叫做美国国家标准与技术研究所数据库

0.编程环境

安装tensorflow命令:pip install tensorflow
操作系统:Win10
python版本:3.6
集成开发环境:jupyter notebook
tensorflow版本:1.6

1、下载并解压数据集

MNIST数据集下载链接: 百度网盘 请输入提取码 密码: wa9p
下载压缩文件MNIST_data.rar完成后,选择解压到当前文件夹不要选择解压到MNIST_data。
文件夹结构如下图所示:

2、完整代码

此章给读者能够直接运行的完整代码,使读者有编程结果的感性认识。
如果下面一段代码运行成功,则说明安装tensorflow环境成功。
想要了解代码的具体实现细节,请阅读后面的章节。

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)Weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([1,10]))
predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) + biases)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)session = tf.Session()
init = tf.global_variables_initializer()
session.run(init)for i in range(500):images, labels = mnist.train.next_batch(batch_size)session.run(train, feed_dict={X_holder:images, y_holder:labels})if i % 25 == 0:correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})print('step:%d accuracy:%.4f' %(i, accuracy_value))

上面一段代码的运行结果如下:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
step:0 accuracy:0.4747
step:25 accuracy:0.8553
step:50 accuracy:0.8719
step:75 accuracy:0.8868
step:100 accuracy:0.8911
step:125 accuracy:0.8998
step:150 accuracy:0.8942
step:175 accuracy:0.9050
step:200 accuracy:0.9026
step:225 accuracy:0.9076
step:250 accuracy:0.9071
step:275 accuracy:0.9049
step:300 accuracy:0.9055
step:325 accuracy:0.9101
step:350 accuracy:0.9097
step:375 accuracy:0.9116
step:400 accuracy:0.9102
step:425 accuracy:0.9113
step:450 accuracy:0.9155
step:475 accuracy:0.9151

从上面的运行结果可以看出,经过500步训练,模型准确率到达0.9151左右。

3、数据准备

import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)

第1行代码导入warnings库,第2行代码表示不打印警告信息;
第3行代码导入tensorflow库,取别名tf;
第4行代码人从tensorflow.examples.tutorials.mnist库中导入input_data文件;
本文作者使用anaconda集成开发环境,input_data文件所在路径:C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow\examples\tutorials\mnist,如下图所示:

第6行代码调用input_data文件的read_data_sets方法,需要2个参数,第1个参数的数据类型是字符串,是读取数据的文件夹名,第2个关键字参数ont_hot数据类型为布尔bool,设置为True,表示预测目标值是否经过One-Hot编码;
第7行代码定义变量batch_size的值为100;
第8、9行代码中placeholder中文叫做占位符,将每次训练的特征矩阵X和预测目标值y赋值给变量X_holder和y_holder。

4、数据观察

本章内容主要是了解变量mnist中的数据内容,并掌握变量mnist中的方法使用。

4.1 查看变量mnist的方法和属性

dir(mnist)[-10:]

上面一段代码的运行结果如下:

['_asdict',
'_fields',
'_make',
'_replace',
'_source',
'count',
'index',
'test',
'train',
'validation']

为了节省篇幅,只打印最后10个方法和属性。
我们会用到的是其中test、train、validation这3个方法。

4.2 对比三个集合

train对应训练集,validation对应验证集,test对应测试集。
查看3个集合中的样本数量,代码如下:

print(mnist.train.num_examples)
print(mnist.validation.num_examples)
print(mnist.test.num_examples)

上面一段代码的运行结果如下:

55000
5000
10000

对比3个集合的方法和属性

从上面的运行结果可以看出,3个集合的方法和属性基本相同。
我们会用到的是其中images、labels、next_batch这3个属性或方法。

4.3 mnist.train.images观察

查看mnist.train.images的数据类型和矩阵形状。

images = mnist.train.images
type(images), images.shape

上面一段代码的运行结果如下:

(numpy.ndarray, (55000, 784))

从上面的运行结果可以看出,在变量mnist.train中总共有55000个样本,每个样本有784个特征。
原图片形状为28*28,28*28=784,每个图片样本展平后则有784维特征。
选取1个样本,用3种作图方式查看其图片内容,代码如下:

import matplotlib.pyplot as pltimage = mnist.train.images[1].reshape(-1, 28)
plt.subplot(131)
plt.imshow(image)
plt.axis('off')
plt.subplot(132)
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.subplot(133)
plt.imshow(image, cmap='gray_r')
plt.axis('off')
plt.show()

上面一段代码的运行结果如下图所示:

从上面的运行结果可以看出,调用plt.show方法时,参数cmap指定值为graygray_r符合正常的观看效果。

4.4 查看手写数字图

从训练集mnist.train中选取一部分样本查看图片内容,即调用mnist.train的next_batch方法随机获得一部分样本,代码如下:

import matplotlib.pyplot as plt
import math
import numpy as npdef drawDigit(position, image, title):plt.subplot(*position)plt.imshow(image.reshape(-1, 28), cmap='gray_r')plt.axis('off')plt.title(title)def batchDraw(batch_size):images,labels = mnist.train.next_batch(batch_size)image_number = images.shape[0]row_number = math.ceil(image_number ** 0.5)column_number = row_numberplt.figure(figsize=(row_number, column_number))for i in range(row_number):for j in range(column_number):index = i * column_number + jif index < image_number:position = (row_number, column_number, index+1)image = images[index]title = 'actual:%d' %(np.argmax(labels[index]))drawDigit(position, image, title)batchDraw(196)
plt.show()

上面一段代码的运行结果如下图所示,本文作者对难以辨认的数字做了红色方框标注:

5、搭建神经网络

Weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([1,10]))
predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) + biases)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

该神经网络只有输入层和输出层,没有隐藏层。
第1行代码定义形状为784*10的权重矩阵Weights;
第2行代码定义形状为1*10的偏置矩阵biases;
第3行代码定义先通过矩阵计算,再使用激活函数softmax得出的每个分类的预测概率predict_y;
第4行代码定义损失函数loss,多分类问题使用交叉熵作为损失函数。
交叉熵的函数如下图所示,其中p(x)是实际值,q(x)是预测值


第5行代码定义优化器optimizer,使用梯度下降优化器;
第6行代码定义训练步骤train,即最小化损失。

6、变量初始化

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

对于神经网络模型,重要是其中的W、b这两个参数。
开始神经网络模型训练之前,这两个变量需要初始化。
第1行代码调用tf.global_variables_initializer实例化tensorflow中的Operation对象。

第2行代码调用tf.Session方法实例化会话对象;
第3行代码调用tf.Session对象的run方法做变量初始化。

7、模型训练

for i in range(500):images, labels = mnist.train.next_batch(batch_size)session.run(train, feed_dict={X_holder:images, y_holder:labels})if i % 25 == 0:correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})print('step:%d accuracy:%.4f' %(i, accuracy_value))

第1行代码表示模型迭代训练500次;
第2行代码调用mnist.train对象的next_batch方法,选出数量为batch_size的样本;
第3行代码是模型训练,每运行1次此行代码,即模型训练1次;
第4-8行代码是每隔25次训练打印模型准确率。
上面一段代码的运行结果如下:

step:0 accuracy:0.3161
step:25 accuracy:0.8452
step:50 accuracy:0.8668
step:75 accuracy:0.8860
step:100 accuracy:0.8906
step:125 accuracy:0.8948
step:150 accuracy:0.9008
step:175 accuracy:0.9027
step:200 accuracy:0.8956
step:225 accuracy:0.9102
step:250 accuracy:0.9022
step:275 accuracy:0.9097
step:300 accuracy:0.9039
step:325 accuracy:0.9076
step:350 accuracy:0.9137
step:375 accuracy:0.9111
step:400 accuracy:0.9069
step:425 accuracy:0.9097
step:450 accuracy:0.9150
step:475 accuracy:0.9105

9、模型测试

import math
import matplotlib.pyplot as plt
import numpy as npdef drawDigit2(position, image, title, isTrue):plt.subplot(*position)plt.imshow(image.reshape(-1, 28), cmap='gray_r')plt.axis('off')if not isTrue:plt.title(title, color='red')else:plt.title(title)def batchDraw2(batch_size):images,labels = mnist.test.next_batch(batch_size)predict_labels = session.run(predict_y, feed_dict={X_holder:images, y_holder:labels})image_number = images.shape[0]row_number = math.ceil(image_number ** 0.5)column_number = row_numberplt.figure(figsize=(row_number+8, column_number+8))for i in range(row_number):for j in range(column_number):index = i * column_number + jif index < image_number:position = (row_number, column_number, index+1)image = images[index]actual = np.argmax(labels[index])predict = np.argmax(predict_labels[index])isTrue = actual==predicttitle = 'actual:%d\npredict:%d' %(actual,predict)drawDigit2(position, image, title, isTrue)batchDraw2(100)
plt.show()

上面一段代码的运行结果如下图所示:

MNIST数据集手写数字分类相关推荐

  1. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  2. [Pytorch系列-41]:卷积神经网络 - 模型参数的恢复/加载 - 搭建LeNet-5网络与MNIST数据集手写数字识别

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  3. Python深度学习之分类模型示例,MNIST数据集手写数字识别

    MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片. 我们把60000个训练样本分成两部分,前 ...

  4. MNIST数据集手写数字识别

    1 数据集介绍 MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它"下手"几乎成为一个 "典范",可以说它就是计算机视觉里面的Hell ...

  5. MNIST数据集手写数字识别(一)

    MNIST数据集是初步学习神经网络的很好的数据集,也是拿来教学,不可多得的好教材,有很多知识点在里面.官网下载地址,可以自己手动下载,当然也可以通过下面的代码自动下载[urllib.request(3 ...

  6. 机器学习入门(07)— MNIST 数据集手写数字的识别

    和求解机器学习问题的步骤(分成学习和推理两个阶段进行)一样,使用神经网络解决问题时,也需要首先使用训练数据(学习数据)进行权重参数的学习:进行推理时,使用刚才学习到的参数,对输入数据进行分类. 1. ...

  7. MNIST数据集手写数字识别(二)

    上一篇对MNIST数据集有了一些了解,数据集包含着60000张训练图片与标签值和10000张测试图片与标签值的数据集,数据集有了,现在我们来构造神经网络,预测下对这测试的10000张图片的正确识别率, ...

  8. PyTorch:MNIST数据集手写数字识别

    MNIST 包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它"下手"几乎成为一个 "典范",可以说它就是计算机视觉里面的Hello World. ...

  9. 基于tensorflow的mnist数据集手写字体分类level-1

    本文属于学些tensorflow框架系列的文章,不是注重于算法- 基于之前博文中的工作,已经安装好tensorflow等等的配置工作,开始学习tensorflow框架的使用,本文参考了以下链接,致以敬 ...

最新文章

  1. [Spring mvc 深度解析(二)] Tomcat分析
  2. LiveVideoStackCon讲师热身分享第一季
  3. 干掉搜狗输入法云代理SogouCloud.exe
  4. java代码 计算器_java代码---------计算器实现
  5. Zookeeper架构及FastLeaderElection机制
  6. Android音频开发(六)音频编解码之初识MediaCodec上
  7. 用AXIS2发布WebService的方法 使用eclipse插件生成服务端和客户端
  8. 【zk开发】让eclipse识别×.zul文件为xml格式
  9. inteli211网卡linux驱动,Windows Server 2019安装Intel I211网卡驱动
  10. mysql类exadata功能_一些有用的Exadata诊断命令
  11. 年货节买什么东西好?2022新年好物推荐
  12. 使用ItextPdf给PDF文件加文字水印和图片水印
  13. VOC2007 2012数据集有多少张图片
  14. KMP,LCA(XJT Love Strings,玲珑杯 Round#8 A lonlife 1079)
  15. 大一新生的第一篇博客
  16. HTML信件-一种奇特的实现方式
  17. 关于MODIS数据说明及简单处理
  18. b站前端老猫总结面试题
  19. 《数据结构与算法》(十一)- 树、森林与二叉树的转换及哈夫曼树详解
  20. 微信小程序-访问豆瓣电影api400错误

热门文章

  1. Java回炉之File
  2. android集成twitter登录
  3. RIM Hong Kong地址和地图
  4. 2023届 计算机毕业设计 选题 计算机专业 毕业设计题目 推荐
  5. 韩国研发人工智能武器 遭30国专家联名抵制:和你绝交!
  6. 今天女朋友问我多线程是什么?送命题?
  7. python批量改名
  8. 对设计模式的总结之工厂方法模式和抽象工厂模式
  9. java fuoco车架_破风硬汉——JAVA FUOCO公路车 评测
  10. Linux服务器上设置全局代理访问外网并验证