深度学习——手写数字识别底层实现
内容再要
手写数字识别,早在20世纪前,杨立昆(Yann LeCun)就完成这项工作,并在1980年左右利用卷积神经网络完善了手写数字识别
代码实现
import tensorflow as tf
import random
import matplotlib.pyplot as plt
# 例子 教程 手写数字 输入数据
from tensorflow.examples.tutorials.mnist import input_data# 设置随机种子
tf.set_random_seed(1)
# 读取数据并进行独热编码
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 占位符
X = tf.placeholder(tf.float32, shape=[None, 784])
Y = tf.placeholder(tf.float32, shape=[None, 10])# 初始化参数Wb(两层神经网络)
w1 = tf.Variable(tf.random_normal([784, 512]))
b1 = tf.Variable(tf.random_normal([512]))
w2 = tf.Variable(tf.random_normal([512, 10]))
b2 = tf.Variable(tf.random_normal([10]))# 前向传播
z1 = tf.matmul(X, w1) + b1
a1 = tf.sigmoid(z1)
z2 = tf.matmul(a1, w2) + b2
# 多分类用softmax
a2 = tf.nn.softmax(z2)# 计算代价
cost = -tf.reduce_mean(tf.reduce_sum(Y * tf.log(a2), axis=1), axis=0)# 反向传播
dz2 = a2 - Y
dw2 = tf.matmul(tf.transpose(a1), dz2) / tf.cast(tf.shape(X)[0], tf.float32)
db2 = tf.reduce_mean(dz2, axis=0)da1 = tf.matmul(dz2, tf.transpose(w2))
dz1 = da1 * a1 * (1 - a1)
dw1 = tf.matmul(tf.transpose(X), dz1) / tf.cast(tf.shape(X)[0], tf.float32)
db1 = tf.reduce_mean(dz1, axis=0)# 参数更新
learning_rate = 0.1
updata = [tf.assign(w2, w2 - learning_rate * dw2),tf.assign(b2, b2 - learning_rate * db2),tf.assign(w1, w1 - learning_rate * dw1),tf.assign(b1, b1 - learning_rate * db1),
]# 准确率
accurary = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(a2, axis=1),tf.argmax(Y, axis=1)), dtype=tf.float32), axis=0)# 大批次
train_times = 15
# 小批量 6万个数据每次训练100个
batch_size = 100# 开启会话
with tf.Session() as sess:sess.run(tf.global_variables_initializer())# 先开始大批次的循环for times in range(train_times):# 求代价,代价先归0avg_cost = 0# 大批量下的小批量次数n_batch = int(mnist.train.num_examples / batch_size)# 小批量训练for i in range(n_batch):# 每次取出训练集和测试集100个train_X, train_Y = mnist.train.next_batch(batch_size)c, _ = sess.run([cost, updata], feed_dict={X: train_X, Y: train_Y})avg_cost += c / n_batchprint('批次', times+1)print('代价', avg_cost)print('训练结束')# 准确率print(sess.run(accurary, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))# 随机预测r = random.randint(0, mnist.test.num_examples - 1)print('labels', sess.run(tf.argmax(mnist.test.labels[r: r + 1], axis=1)))print('prediction:', sess.run(tf.argmax(a2, 1), feed_dict={X: mnist.test.images[r: r+1]}))# 画图
plt.imshow(mnist.test.images[r: r + 1].reshape(28, 28),cmap='Greys',interpolation='nearest'
)
plt.show()
效果实现
批次 1
代价 2.557806055112321
批次 2
代价 1.0456371155652138
批次 3
代价 0.8032650183276691
批次 4
代价 0.6728521455417991
批次 5
代价 0.5812857172976839
批次 6
代价 0.5152520787715907
批次 7
代价 0.46274486312811985
批次 8
代价 0.4205591650916775
批次 9
代价 0.38447663389823633
批次 10
代价 0.35472463608465415
批次 11
代价 0.32847735724327237
批次 12
代价 0.30493005472150747
批次 13
代价 0.2853568002022804
批次 14
代价 0.266866419179873
批次 15
代价 0.2506876078756018
训练结束
0.9179
labels [7]
prediction: [7]
深度学习——手写数字识别底层实现相关推荐
- 深度学习——手写数字识别
深度学习--手写数字问题 前不久入门学习了Tensorflow深度学习框架,了解一下什么是神经网络和Tensorflow的简单使用.下面通过Tensorflow框架来建造神经网络模型来对手写数字进行训 ...
- 百度深度学习--手写数字识别之数据处理
文章目录 概述 前提条件 读入数据并划分数据集 扩展阅读:为什么学术界的模型总在不断精进呢? 训练样本乱序.生成批次数据 校验数据有效性 机器校验 人工校验 封装数据读取与处理函数 异步数据读取 概述 ...
- python-机器学习-手写数字识别
机器学习简单的来说,分为监督式学习和无监督式学习: 对于监督式学习就是需要人为的来告诉计算机这是什么,需要我们给他一个标签(答案). 无监督式学习就是不需要我们给出标签(答案). 图像识别(Image ...
- python手写字体程序_深度学习---手写字体识别程序分析(python)
我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...
- 深度学习 手写字体识别
数据集介绍: mnist数据集使用tensorflow封装好的数据(包含6000张训练数据,1000张测试数据),图片大小为28x28. 在神经网络的结构上,一方面需要使用激活函数去线性化.另一方面需 ...
- 基于深度学习的手写数字识别Matlab实现
基于深度学习的手写数字识别Matlab实现 1.网络设计 2. 训练方法 3.实验结果 4.实验结果分析 5.结论 1.网络设计 1.1 CNN(特征提取网络+分类网络) 随着深度学习的迅猛发展,其应 ...
- 【卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10)】
卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10) 在上一章已经完成了卷积神经网络的结构分析,并通过各个模块理解 ...
- 深度学习100例 | 第25天-卷积神经网络(CNN):中文手写数字识别
大家好,我是『K同学啊』! 接着上一篇文章 深度学习100例 | 第24天-卷积神经网络(Xception):动物识别,我用Xception模型实现了对狗.猫.鸡.马等四种动物的识别,带大家了解了Xc ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 (zz)
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 我想写一系列深度学习的简单实战教程,用mxnet做实现平台的实例代码简单讲解深度学习常用的一些技术方向和实战样例.这 ...
最新文章
- DataGrid在分页状态下删除纪录的问题
- Angular2.0 基础: User Input
- 数据结构基础 - 链表的遍历
- linux下exec系列(一)
- 丈夫博士毕业想离婚,妻子要求家务补偿!法院判了
- Solr7.2.1环境搭建和配置ik中文分词器
- 逆向工程-ARM程序
- 多线程java_敞开心扉,一起聊聊Java多线程
- 【联想拯救者R7000】安装nvidia驱动Perform MOK management 界面键盘失灵现象(已解决)
- 智慧医院信息化建设(整体解决方案)
- nexus3的目录介绍
- Python爬取虎扑NBA球员信息
- 如何使用git上传项目至GitHub repository
- Zookeeper之基础知识
- 联想万全r520服务器安装系统,联想(lenovo)万全R520服务器图解
- 认识黑客常用的入侵方法
- 人工智能专家细数AI安全隐患
- 华为电子邮件显示未读邮件1_电子邮件简介已经过去
- 十分钟教你用 svg 做出精美的动画!
- 【Colab】1.Colab基本使用方法及配置