弱者用泪水安慰自己,强者用汗水磨练自己。

这段时间因为项目中有一块需要用到图像识别,最近就一直在炼丹,宝宝心里苦,但是宝宝不说。。。

能点开这篇文章的朋友估计也已经对TensorFlow有了一定了解,至少知道这是个什么东西,我也就不过多介绍了。

没安装TensorFlow的建议去下一个Anaconda,可以很方便的下载配置好各种科学计算的常用库,对于Anaconda的配置和更新问题可以去搜一些文章去看,这里就不多说了。

实现手写数字识别几乎是所有入手图像识别的入门程序了,TensorFlow库里面也有手写数字识别的示例程序,在这个路径下,你可以对应自己的电脑去找一下C:\ProgramData\Anaconda3\pkgs\tensorflow-base-1.12.0-gpu_py36h6e53903_0\Lib\site-packages\tensorflow\examples\tutorials\mnist。

首先载入MNIST数据集,并创建默认的Interactive Session

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tfmnist=input_data.read_data_sets("MNIST/",one_hot=True)
sess=tf.InteractiveSession()

接下来要实现的这个卷积神经网络会有很多的权重和偏置要创建,因此我们先定义好初始化函数以便重复使用。我们需要给权重制造一些随机的噪声来打破完全对称,比如截断的正太分布噪声,标准差为0.1。同时因为我们使用的ReLU激活函数,也给偏置增加一些大于零的值来避免死亡节点。

def weight_variable(shape):initial=tf.truncated_normal(shape,stddev=0.1)return tf.Variable(initial)def bias_variable(shape):inital=tf.constant(0.1,shape=shape)return tf.Variable(inital)

卷积层、池化层也是接下来要重复使用的,因此也要为他们定义方法。这里的tf.nn.conv2d是TensorFlow中的2维卷积函数,参数中X是输入,W是卷积的参数,比如[5,5,1,32],前面两个数字代表卷积核的尺寸,第三个数字代表通道数量,因为我们只是灰度图,所以是1,如果是RGB彩图那么就应该为3,最后一个数字代表卷积核的数量,也就是这个卷积层要提取多少类的特征。Strides代表卷积模板移动的步长,都是1代表会一个不落的划过每个点,Padding代表边界处理方式,这里的SAME代表给边界加上Padding让卷积的输出和输入保持同样的尺寸。

tf.nn.max_pool是TensorFLow中的最大池化函数,我们这里使用2X2的最大池化,就是说把一个2X2的像素块降到1X1的像素。最大池化会保留原始像素块中灰度值最高的哪一个像素,即保留最显著的特征。因为希望整体上缩小图片尺寸,因此池化层的strides也设为横竖两个方向以2为步长。如果步长还是1,那么我们会得到一个尺寸不变的图片。

def conv2d(x,W):return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')def max_pool_2x2(x):return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

在设计卷积神经网络之前,我们还需要定义输入的placeholder,x是特征,y_是真实的类别,因为卷积神经网咯会利用到哦控件结构信息,因此需要将1D的输入向量转为2D的图片结构,就是说从1X784变成28X28,,又因为只有一个颜色通道,所以最终的尺寸应该是[-1,28,28,1],前面的-1代表样本数量不固定,最后的1代表颜色通道数量。

x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder(tf.float32,[None,10])
x_image=tf.reshape(x,[-1,28,28,1])

定义第一个卷积层,先初始化参数,[5,5,1,32]代表卷积核的尺寸为5X5,1个颜色通道,32个不同的卷积核。然后用conv2d进行卷积操作,加上偏置,用ReLU激活函数进行非线性处理。最后,使用最大池化函数max_pool_2x2对卷积的输出结果进行池化操作。

w_conv1=weight_variable([5,5,1,32])
b_conv1=bias_variable([32])
h_conv1=tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)
h_pool1=max_pool_2x2(h_conv1)

定义第二层卷积,大体上和第一层一样,把卷积核数量改为64就OK了

w_conv2=weight_variable([5,5,32,64])
b_conv2=bias_variable([64])
h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
h_pool2=max_pool_2x2(h_conv2)

因为前面经历了两次步长为2X2的最大池化,所以边长已经只有1/4了,图片尺寸从28X28变到了7X7,而第二个卷积层的卷积核数量为64,其输出的tensor尺寸也就是7X7X64。用tf.reshape对第二个卷积层的输出tensor进行变形,转为1D的向量,然后连接一个全连接层,隐藏节点为1024,使用ReLU激活函数。

w_fc1=weight_variable([7*7*64,1024])
b_fc1=bias_variable([1024])
h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)

机器学习中经常出现过拟合问题,为了减轻过拟合,我们可以添加一个Dropout层,该层会随机丢弃一些节点之间的连接,这样可减轻过拟合,提高模型泛化性。

keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)

将Dropout层连接一个Softmax层,得到最后的概率输出。

w_fc2=weight_variable([1024,10])
b_fc2=bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)

定义损失函数cross_entropy,优化器用Adam,学习效率尽量往小里设置。

cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv),reduction_indices=[1]))
train_step=tf.train.AdagradOptimizer(1e-4).minimize(cross_entropy)

定义评测准确率

correct_prediction=tf.equal(tf.argmax(y_conv,1),tf.arg_max(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

接下里就是训练过程了。开始初始化所有参数,设置训练时Dropout层的丢弃率为0.5.使用大小为50的mini-batch训练20000次,每100次训练输出一次准确率。

tf.global_variables_initializer().run()
for i in range(20000):batch=mnist.train.next_batch(50)if i%100==0:train_accuracy=accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})print("step %d,training accuracy %g"%(i,train_accuracy))train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})

最后训练完成后要输出一次最后的测试结果

print("test accury %g"%accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

程序到这就完事了,就开始跑吧,你大概会跑上20分钟,你的窗口就会一直输出准确率

step 4300,training accuracy 0.88
step 4400,training accuracy 0.92
step 4500,training accuracy 0.82
step 4600,training accuracy 0.84
step 4700,training accuracy 0.94

教你用TensorFlow实现手写数字识别相关推荐

  1. 实战六:手把手教你用TensorFlow进行手写数字识别

    手把手教你用TensorFlow进行手写数字识别 github下载地址 目录 手写体数字MNIST数据集介绍 MNIST Softmax网络介绍 实战MNIST Softmax网络 MNIST CNN ...

  2. 基于tensorflow的手写数字识别

    基于tensorflow的手写数字识别 数据准备 引入包 加载数据 查看数据信息 查看一张图片 数据预处理 搭建网络模型 模型的预测与评价 模型的展示 对一张图片进行预测 准确率 数据准备 引入包 i ...

  3. 利用Tensorflow实现手写数字识别(附python代码)

    手写识别的应用场景有很多,智能手机.掌上电脑的信息工具的普及,手写文字输入,机器识别感应输出:还可以用来识别银行支票,如果准确率不够高,可能会引起严重的后果.当然,手写识别也是机器学习领域的一个Hel ...

  4. OpenCV+TensorFlow图片手写数字识别(附源码)

    初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...

  5. tensorflow+python flask进行手写识别_使用tensorflow进行手写数字识别

    首先要在对应的目录下安装好手写数字识别数据集. 编写代码如下所示: import tensorflow as tf from tensorflow.examples.tutorials.mnist i ...

  6. tensorflow实现手写数字识别(MNIST)

    手写数字图片数字集       机器学习需要从数据中间学习,因此首先需要采集大量的真实样本数据.以手写的数字图片识别为例,我们需要收集大量的由真人书写的0-9的数字图片,为了便于存储和计算,一般把收集 ...

  7. TensorFlow 教程——手写数字识别

    运行环境 TensorFlow2.0 解决方案 from tensorflow import keras import tensorflow as tf import mnist_reader imp ...

  8. 【机器学习】机器学习从零到掌握之七 -- 教你使用KNN进行手写数字识别

    本文是<机器学习从零到掌握>系列之第7篇 机器学习从零到掌握之一 -- 教你理解K近邻算法 机器学习从零到掌握之二 -- 教你实现K近邻算法 机器学习从零到掌握之三 -- 教你使用K近邻算 ...

  9. Python(TensorFlow框架)实现手写数字识别系统

    手写数字识别算法的设计与实现 本文使用python基于TensorFlow设计手写数字识别算法,并编程实现GUI界面,构建手写数字识别系统.这是本人的本科毕业论文课题,当然,这个也是机器学习的基本问题 ...

最新文章

  1. matlab中 bsxfun函数
  2. PIC中的#pragma idata 和#pragma udata
  3. extjs combobox分页查询
  4. layoutSubviews 详解
  5. javascript练习----复选框全选,全不选,反选
  6. [转载] 【数学问题】利用python求解表达式
  7. 2020年产业互联网发展报告
  8. 编码器 stm32_STM32榨干编码旋钮(第一期)
  9. 日志的打印 —— Java 支持
  10. 【CCCC】L3-009 长城 (30分),计算几何+凸包,极角排序
  11. 我的电脑属性被隐藏 咋能显示
  12. activitymq 登录界面地址
  13. centos安装phpstudy(小皮)
  14. 最有经验的域名注册邮箱运营商:TOM企邮
  15. 结合知识蒸馏的增量学习方法总结
  16. bootstrapr表格父子框_JS组件系列之Bootstrap table表格组件神器【二、父子表和行列调序】...
  17. 专业修图工具:Affinity Photo for mac
  18. ssh @ ssh: Could not resolve hostname : Name or service not known
  19. Linux 路由实现原理
  20. Java实验项目二——打印某年某月日历

热门文章

  1. 期货十三篇 第六篇 加仓篇
  2. Linux驱动开发之蜂鸣器驱动实验
  3. allegro 16.6 导出gerber文件---art文件
  4. 巧妙下载花椒直播视频
  5. 项目管理之周报的好处
  6. 艾美捷NCTC-135培养基化学性质和基本配方
  7. CSS---Chrome 102:新增两个 HTML 属性(-^-)
  8. tomcat 假死现象(转)
  9. 剑指 Offer II 002. 二进制加法
  10. mysql系统变量详解