预备知识

python语言基础

目标

导入图片数据集,分析图片的特点、定义变量,构建模型,训练模型并输出中间状态参数,测试、保存、读取模型

如何搞定它

1.1导入图片数据集

首先来看看数据集是什么样的。
MNIST是一个入门级的计算机视觉数据集。当我们开始学习编程时,第一件事往往是学习打印Hello World。在机器学习入门的领域里,我们会用MNIST数据集来实验各种模型。

1.1.1数据集介绍

MNIST里包含各种手写数字图片,如图所示。

它也包含每一张图片对应的标签,告诉我们这个是数字几。例如,上面这4张图片的标签分别是5、0、4、1。

1.1.2下载并安装MNIST数据集

介绍完MNIST数据集后,下面来演示一下如何通过代码来对其操作。

(1)利用TensorFlow代码下载MNIST

TensorFlow提供了一个库,可以直接用来自动下载与安装MNIST,见如下代码:

# MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True))
运行上面的代码,会自动下载数据集并将文件解压到当前代码所在同级目录下的MNIST_data文件夹下。
注意:代码中的one_hot=True,表示将样本标签转化为one_hot编码。

举例来解释one_hot编码:
假如一共10类。0的one_hot为1000000000,1的one_hot为0100000000,2的one_hot为0010000000,3的one_hot为0001000000……依此类推。只有一个位为1,1所在的位置就代表着第几类。

MNIST数据集中的图片是28×28像素,所以,每一幅图就是1行784(28×28)列的数据,括号中的每一个值代表一个像素。

  • 如果是黑白的图片,图片中黑色的地方数值为0;有图案的地方,数值为0~255之间的数字,代表其颜色的深度。
  • 如果是彩色的图片,一个像素会由3个值来表示RGB(红、黄、蓝)。在后面讲解其他数据集时会具体讲到。

接下来通过几行代码将MNIST里面的信息打印出来,看看它的具体内容。

# MNIST数据集(续)
print ('输入数据:',mnist.train.images)
print ('输入数据打印shape:',mnist.train.images.shape)
import pylab
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()

运行上面的代码,输出信息如下:

输出结果如图所示

刚开始的打印信息是解压数据集的意思。如果是第一次运行,还会显示下载数据的相关信息。
接着打印出来的是训练集的图片信息,是一个55000行、784列的矩阵。即,训练集里有55000张图片。

(2)MNIST数据集组成

在MNIST训练数据集中,mnist.train.images是一个形状为[55000,784]的张量。其中,第1个维度数字用来索引图片,第2个维度数字用来索引每张图片中的像素点。此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0~255之间。
MNIST里包含3个数据集:第一个是训练数据集,另外两个分别是测试数据集(mnist.test)和验证数据集(mnist.validation)。可使用如下命令查看里面的数据信息:

MNIST数据集(续)
print ('输入数据打印shape:',mnist.test.images.shape)
print ('输入数据打印shape:',mnist.validation.images.shape)

运行完上面的命令,可以发现在测试数据集里有10000条样本图片,验证数据集里有5000个图片。

在实际的机器学习模型设计时,样本一般分为3部分:

  • 一部分用于训练;
  • 一部分用于评估训练过程中的准确度(测试数据集);
  • 一部分用于评估最终模型的准确度(验证数据集)。

训练过程中,模型并没有遇到过验证数据集中的数据,所以利用验证数据集可以评估出模型的准确度。这个准确度越高,代表模型的泛化能力越强。

另外,这3个数据集还有分别对应的3个文件(标签文件),用来标注每个图片上的数字是几。把图片和标签放在一起,称为“样本”。通过样本来就可以实现一个有监督信号的深度学习模型。

相对应的,MNIST数据集的标签是介于0~9之间的数字,用来描述给定图片里表示的数字。标签数据是“one-hot vectors”:一个one-hot向量,除了某一位的数字是1外,其余各维度数字都是0。例如,标签0将表示为([1,0,0,0,0,0,0,0,0,0,0])。因此,mnist.train.labels是一个[55000,10]的数字矩阵。

1.2分析图片的特点,定义变量

由于输入图片是个55000×784的矩阵,所以先创建一个[None,784]的占位符x和一个[None,10]的占位符y,然后使用feed机制将图片和标签输入进去。具体代码如下。

# MNIST分类
import tensorflow as tf  # 导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
import pylab
tf.reset_default_graph()
# 定义占位符
x = tf.placeholder(tf.float32, [None, 784]) # MNIST数据集的维度是  28×28=784
y = tf.placeholder(tf.float32, [None, 10])  # 数字0~9 ,共10个类别
#代码中第8行的None,表示此张量的第一个维度可以是任何长度的。x就代表能够输入任意数量的MNIST图像,每一张图展平成784维的向量。

1.3构建模型

样本完成后就可以构建模型了。下面列出了构建模型的相关步骤。

1.3.1 定义学习参数

模型也需要权重值和偏置量,它们被统一叫做学习参数。在TensorFlow里,使用Variable来定义学习参数。
一个Variable代表一个可修改的张量,定义在TensorFlow的图(一个执行任务)中,其本身也是一种变量。使用Variable定义的学习参数可以用于计算输入值,也可以在计算中被修改。

# MNIST分类(续)
W = tf.Variable(tf.random_normal(([784,10]))
b = tf.Variable(tf.zeros([10]))

在这里赋予tf.Variable不同的初值来创建不同的参数。一般将W设为一个随机值,将b设为0。
注意:W的维度是[784,10],因为想要用784维的图片向量乘以它,以得到一个10维的证据值向量,每一位对应不同数字类。b的形状是[10],所以可以直接把它加到输出上面。

1.3.2 定义输出节点

有了输入和模型参数,接着便可以将它们串起来构建成真正的模型。

# MNIST分类(续)
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类

首先,用tf.matmul(x,W)表示x乘以W,这里x是一个二维张量,拥有多个输入。然后再加上b,把它们的和输入到tf.nn.softmax函数里。
至此就构建好了正向传播的结构。也就是表明,只要模型中的参数合适,通过具体的数据输入,就能得到我们想要的分类。

1.3.3 定义反向传播的结构

下面定义一个反向传播的结构,编译训练模型,以得到合适的参数。
这里涉及一个“学习率”的概念。学习率,是指每次改变学习参数的大小。在这里读者只要先有个概念即可,后面章节还会详细介绍。
先看下面代码。

代码1-2 MNIST分类(续)
# 损失函数
cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))# 定义参数
learning_rate = 0.01
# 使用梯度下降优化器
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

上面的代码可以这样来理解:
(1)将生成的pred与样本标签y进行一次交叉熵的运算,然后取平均值。
(2)将这个结果作为一次正向传播的误差,通过梯度下降的优化方法找到能够使这个误差最小化的b和W的偏移量。
(3)更新b和W,使其调整为合适的参数。
整个过程就是不断地让损失值(误差值cost)变小。因为损失值越小,才能表明输出的结果跟标签数据越相近。当cost小到我们的需求时,这时的b和W就是训练出来的合适值。

1.4 训练模型并输出中间状态参数

现在开始真正地训练模型了,先定义训练相关的参数。
下面代码中

  • 第1行中,training_epochs代表要把整个训练样本集迭代25次;
  • 第2行中,batch_size代表在训练过程中一次取100条数据进行训练
  • 第3行中,display_step代表每训练一次就把具体的中间状态显示出来。

注意:batch_size参数代表的意义很关键,在深度学习中,都是将数据按批次地向里面放的。在后面章节中还会详细介绍这么做的目的。
参数定义好后,启动一个session就可以开始训练过程了。session中有两个run,第一个run是运行初始化,第二个run是运行具体的运算模型。模型运算之后便将里面的状态打印出来。

training_epochs = 25
batch_size = 100
display_step = 1saver = tf.train.Saver()
model_path = "log/521model.ckpt"# 启动session
with tf.Session() as sess:sess.run(tf.global_variables_initializer())# Initializing OP# 启动循环开始训练for epoch in range(training_epochs):avg_cost = 0.total_batch = int(mnist.train.num_examples/batch_size)# 循环所有数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)# 运行优化器_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,y: batch_ys})# 计算平均loss值avg_cost += c / total_batch# 显示训练中的详细信息if (epoch+1) % display_step == 0:print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".            format(avg_cost))print( " Finished!")

执行上面的代码,会输出如下信息:

这里输出的中间状态是cost损失值。读者也可以把自己关心的内容打印出来。可以看到,从第1次迭代到第25次迭代的损失值在逐渐减小,最终的误差只有0.8。

1.5 测试模型

还记得MNIST里面有测试数据吗?现在我们使用测试数据来测试一下训练完的模型吧。
与前面的过程类似,也是先将计算测试的网络结构建立起来,然后通过最终节点的eval将测试值运算出来。
注意:这个过程仍然是在session里进行的。
测试错误率的算法是:直接判断预测的结果与真实的标签是否相同,如是相同的就表明是正确的,如是不相同的就表示是错误的。然后将正确的个数除以总个数,得到的值即为正确率。由于是onehot编码,这里使用了tf.argmax函数返回onehot编码中数值为1的那个元素的下标。下面是具体代码。

#MNIST分类(续)correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 计算准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

上面代码执行后,显示信息如下:

测试正确率的算法与损失值的算法略有差别,但代表的意义却很类似。当然,也可以直接拿计算损失值的交叉熵结果来代表模型测试的错误率。
注意:
(1)并不是所有模型的测试错误率和训练时的最后一次损失值都很接近,这取决于训练样本和测试样本的分布情况,也取决于模型本身的拟合质量。关于拟合质量问题,将在后面章节详细介绍。
(2)读者自己运行时,得到的值可能和本书中的值不一样。甚至每次运行时,得到的值也不一样。原因是每次初始的权重w都是随机的。由于初始权重不同,而且每次训练的批次数据也不同,所以最终生成的模型也不会完全相同。但如果核心算法保持一致,则会保证最终的结果不会有太大的偏差。

1.6 保存模型

下面开始讲解如何保存模型。
首先要建立一个saver和一个路径,然后通过调用save,自动将session中的参数保存起来,见如下代码。

# MNIST分类(续)   # 保存模型save_path = saver.save(sess, model_path)print("Model saved in file: %s" % save_path)

上面代码的作用是保存模型,并将模型保存的路径打印出来。当然,在这段代码运行之前,需要添加saver和model_path的定义。来到前面session创建之前添加如下代码:

# MNIST分类(续)
saver = tf.train.Saver()
model_path = "log/521model.ckpt"

执行上述的全部代码后,会打印出存储位置

1.7 读取模型

将模型存储好后,下面来做一个实验:读取模型并将两张图片放进去让模型预测结果,然后将两张图片极其对应的标签一并显示出来。
在整个代码执行过程中,对于网络模型的定义不变,只是重新建立一个session而已,所有的操作都在这个新的session中完成。具体细节见代码。

# MNIST分类(续)
print("Starting 2nd session...")
with tf.Session() as sess:# 初始化变量sess.run(tf.global_variables_initializer())# 恢复模型变量saver.restore(sess, model_path)# 测试 modelcorrect_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 计算准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.      test.labels}))output = tf.argmax(pred, 1)batch_xs, batch_ys = mnist.train.next_batch(2)outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})print(outputval,predv,batch_ys)im = batch_xs[0]im = im.reshape(-1,28)pylab.imshow(im)pylab.show()im = batch_xs[1]im = im.reshape(-1,28)pylab.imshow(im)pylab.show()

以上代码可以替代原来的session,也可以直接放到代码后面,将前面的session注释掉。
输出结果

  • 第一行是模型的准确率,接下来是3个数组。
  • 第一个数组是输出的预测结果[3,6]
  • 第二个大的数组比较大,是预测出来的真实输出值,哪一项数值越大,代表对应的概率越大.
  • 第三个大的数组元素都是0和1,是图片实际的标签值onehot编码表示的数字

完整代码:

import tensorflow.compat.v1 as tf  # 导入tensorflow库#
tf.disable_v2_behavior()from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)# print ('输入数据:',mnist.train.images)
# print ('输入数据打印shape:',mnist.train.images.shape)import pylab
im = mnist.train.images[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
# print ('输入数据打印shape:',mnist.test.images.shape)
# print ('输入数据打印shape:',mnist.validation.images.shape)tf.reset_default_graph()
# 定义占位符
x = tf.placeholder(tf.float32, [None, 784]) # MNIST数据集的维度是  28×28=784
y = tf.placeholder(tf.float32, [None, 10])  # 数字0~9 ,共10个类别
W = tf.Variable(tf.random_normal([784,10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类# 损失函数
cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))# 定义参数
learning_rate = 0.01
# 使用梯度下降优化器
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1saver = tf.train.Saver()
model_path = "log/521model.ckpt"# 启动session
with tf.Session() as sess:sess.run(tf.global_variables_initializer())# Initializing OP# 启动循环开始训练for epoch in range(training_epochs):avg_cost = 0.total_batch = int(mnist.train.num_examples/batch_size)# 循环所有数据集for i in range(total_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)# 运行优化器_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,y: batch_ys})# 计算平均loss值avg_cost += c / total_batch# 显示训练中的详细信息if (epoch+1) % display_step == 0:print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".            format(avg_cost))print( " Finished!")# 测试 modelcorrect_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 计算准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))#     # 保存模型save_path = saver.save(sess, model_path)print("Model saved in file: %s" % save_path)print("Starting 2nd session...")
with tf.Session() as sess:# 初始化变量sess.run(tf.global_variables_initializer())# 恢复模型变量saver.restore(sess, model_path)# 测试 modelcorrect_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 计算准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.      test.labels}))output = tf.argmax(pred, 1)batch_xs, batch_ys = mnist.train.next_batch(2)outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})print(outputval,predv,batch_ys)im = batch_xs[0]im = im.reshape(-1,28)pylab.imshow(im)pylab.show()im = batch_xs[1]im = im.reshape(-1,28)pylab.imshow(im)pylab.show()

识别图中模糊的手写数字(菜鸟做法)相关推荐

  1. 基于TensorFlow深度学习框架,运用python搭建LeNet-5卷积神经网络模型和mnist手写数字识别数据集,设计一个手写数字识别软件。

    本软件是基于TensorFlow深度学习框架,运用LeNet-5卷积神经网络模型和mnist手写数字识别数据集所设计的手写数字识别软件. 具体实现如下: 1.读入数据:运用TensorFlow深度学习 ...

  2. 基于python的手写数字识别knn_KNN分类算法实现手写数字识别

    需求: 利用一个手写数字"先验数据"集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ...

  3. 深度学习数字仪表盘识别_深度学习之手写数字识别项目(Sequential方法amp;Class方法进阶版)...

    此项目使用LeNet模型针对手写数字进行分类.项目中我们分别采用了顺序式API和子类方法两种方式构建了LeNet模型训练mnist数据集,并编写了给图识物应用程序用于手写数字识别. 一.LeNet模型 ...

  4. pythonsklearn做手写识别_Python scikit-learn 学习笔记—手写数字识别

    这是一个手写数字的识别实验,是一个sklearn在现实中使用的案例.原例网址里有相应的说明和代码. 首先实验的数据量为1797,保存在sklearn的dataset里.我们可以直接从中获取.每一个数据 ...

  5. linux手写数字识别opencv,opencv实现KNN手写数字的识别

    人工智能是当下很热门的话题,手写识别是一个典型的应用.为了进一步了解这个领域,我阅读了大量的论文,并借助opencv完成了对28x28的数字图片(预处理后的二值图像)的识别任务. 预处理一张图片: 首 ...

  6. 基于python的手写数字识别实验报告_联机手写数字识别实验报告

    1 联机手写数字识别设计 一.设计论述 模式识别是六十年代初迅速发展起来的一门学科. 由于它研究的是如何用机 器来实现人 ( 及某些动物 ) 对事物的学习. 识别和判断能力, 因而受到了很多科技 领域 ...

  7. python手写汉字识别_用python实现手写数字识别

    前言 在之前的学习中,已经对神经网络的算法具体进行了学习和了解.现在,我们可以用python通过两种方法来实现手写数字的识别.这两种方法分别是多元逻辑回归和神经网络方法. 用多元逻辑回归手写数字识别 ...

  8. mnist手写数字识别python_Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】...

    本文实例讲述了Python tensorflow实现mnist手写数字识别.分享给大家供大家参考,具体如下: 非卷积实现 import tensorflow as tf from tensorflow ...

  9. tensorflow+python flask进行手写识别_使用tensorflow进行手写数字识别

    首先要在对应的目录下安装好手写数字识别数据集. 编写代码如下所示: import tensorflow as tf from tensorflow.examples.tutorials.mnist i ...

最新文章

  1. lamp不解析php,LAMP环境下不能解析php原因及排查步骤
  2. flutter怎么手动刷新_如何手动刷新或重新加载Flutter Firestore StreamBuilder?
  3. Masonry('couldn't find a common superview for)
  4. [na]华为acl(traffic-filter)和dhcp管理
  5. html5游戏制作入门系列教程(三)
  6. 在使用DelphiXE3和SQLite3进行程序开发时,解决最后一个字符乱码的问题
  7. iPhone不送充电器?工信部发话了
  8. Sql Server系列:数据类型转换函数
  9. 基于STM32的自由度云台运动姿态控制系统
  10. Ubuntu下读取Xbox360手柄输出
  11. PDF文档如何用关键字精确查找?
  12. c语言课程设计图像处理,摄影与图像处理课程设计
  13. AutomationAnywhere(AA)实现读取Excel文件
  14. 【修真院pm小课堂】登录注册的触发场景
  15. 企业申请3C认证,需要提交哪些资料?
  16. 微信小程序 - 下载文件到本地、打开文档
  17. ip地址转换数字函数 iton_esp8266 inet_ntoa函数实现 ip地址转换为字符串 MAC地址转字符串...
  18. 红旗桌面版本最新行使流动和结果解答100例-6
  19. n.m8yun.com/list/sy.php?p=1,开胃菜:冰蝎2.0流量分析
  20. Linux下为知笔记和蚂蚁笔记测评,推荐蚂蚁笔记!(非广告)

热门文章

  1. 互联网通信基础与Ajax篇
  2. 实用的twitter客户端:Twitterrific for Mac
  3. 全国计算机一级课件,2017全国计算机一级考试习题及答案课件.doc
  4. 邮箱“邮件备份”功能详解【申请企业邮箱】
  5. Axure中怎么画斜线和波浪曲线?
  6. 终年57岁!中国科学院院士因病逝世
  7. Java对List中的中文属性按照拼音排序
  8. vue : 无法将“vue”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。请检查名称的拼写,如果包括路径,请确保路径正确, 然后再试一次 必解决技巧
  9. 穿过任意防火墙NAT的远程控制软件TeamViewer
  10. majicMIX realistic 模型