MNIST手写数字数据库的训练集为60,000个示例,而测试集为10,000个示例。
一共4个文件,训练集、训练集标签、测试集、测试集标签,这些数据直接可以用mnist = tf.keras.datasets.mnist导入

1.调用神经网络API代码如下:

import tensorflow as tf
mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()# 对数据进行归一化
x_train, x_test = x_train/225.0, x_test/255.0# 调用API搭建神经网络
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.25),tf.keras.layers.Dense(10, activation='softmax')
])# 设置损失函数和梯度下降
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 开始训练模型
model.fit(x_train, y_train, epochs=5)# 查看准确率
model.evaluate(x_test, y_test, verbose=2)

2.手写简单神经网络代码如下:

import tensorflow as tf
import numpy as np# 获取数据
from tensorflow.keras import datasets
(x_train, y_train),(x_test, y_test) = datasets.mnist.load_data()# 查看数据的基本情况
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
print(x_train[0], type(x_train[0]), y_train[0], type(y_train[0]))# 对数据进行归一化
x_train, x_test = x_train/225.0, x_test/255.0# 将转化数据类型
x_train = tf.cast(x_train,dtype=tf.float32)
y_train = tf.cast(y_train,dtype=tf.int32)
x_test = tf.cast(x_test,dtype=tf.float32)
y_test = tf.cast(y_test,dtype=tf.int32)# 通过tf.data.Dataset.from_tensor_slices函数将特征值和目标值拼接起来,组成样本数据
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(128)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(128)
print(train_db)
print(test_db)# 根据公式创建权重和偏置
'''60000个样本,28 X 28 = 784个特征,所以有784个输入神经元,因为是数字识别,输出是0-9的数字图片,输出神经元是10个,所以权重个数是[784 X 10], [60000,784] dot [784,10] ==> [60000,10]'''
weight = tf.Variable(initial_value=tf.random.truncated_normal([784,10], stddev=0.1))
print(weight)
# 10个输出神经元,所以有10个偏置
bias = tf.Variable(initial_value=(tf.zeros([10])))
print(bias)# 批量抓取数据  x = (60000,28,28)
for epoch in range(10):# 60000图循环10次for step, (x,y) in enumerate(train_db):# 将数据从三维转化为二维x = tf.reshape(x, [-1, 28 * 28])with tf.GradientTape() as taps:# [128,784] dot [784,10] ==> [128,10]y_predict = tf.add(tf.matmul(x,weight), bias)# 把预测值转化为0-9的概率y_predict = tf.nn.softmax(y_predict)# 因为现在真是值是一个值,所以要进行one-hot编码转化为类别y = tf.one_hot(y, depth=10)# 交叉信息熵:信息熵越小则说明误差越小,正确率越高loss = tf.reduce_mean(tf.reduce_sum(- (y * tf.math.log(y_predict))))# 构建模型,并计算梯度下降grade = taps.gradient(loss,[weight,bias])# 定义学习率lr = 0.001# w = w-lr * grade_wweight.assign_sub(lr * grade[0])bias.assign_sub(lr * grade[1])if step % 100 == 0:print(f'第第{epoch}迭代的第{step}次都loss为:{loss}')# 上面循环完毕代表6w张图训练集已训练完成,接下来要通过测试集来查看正确率total_correct, total_num = 0, 0for setp, (x, y) in enumerate(test_db):x = tf.reshape(x, [-1, 28* 28])# 拿着训练好的权重和偏置进行验证y_test_predict = tf.add(tf.matmul(x,weight),bias)# 把预测值输出到0-1之间prob = tf.nn.softmax(y_test_predict,axis=1)# 获取概率最大值的索引位置pred = tf.argmax(prob,axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)# 把正确率的数量添加到total_correcttotal_correct += int(correct)total_num += x.shape[0]acc = total_correct / total_numprint('测试集的正确率为:', acc)

3.手写深度神经网络代码如下:

其实深度神经网络只是在简单神经网络的基础上添加了多个权重和偏置

import tensorflow as tf
import numpy as np# 获取数据
from tensorflow.keras import datasets
(x_train, y_train),(x_test, y_test) = datasets.mnist.load_data()# 对数据进行归一化
x_train, x_test = x_train/225.0, x_test/255.0# 将转化数据类型
x_train = tf.cast(x_train,dtype=tf.float32)
y_train = tf.cast(y_train,dtype=tf.int32)
x_test = tf.cast(x_test,dtype=tf.float32)
y_test = tf.cast(y_test,dtype=tf.int32)# 将特征值和目标值拼接起来,组成样本数据
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(128)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(128)
print(train_db)
print(test_db)# 根据公式创建权重和偏置
'''输入的特征是784个,我们希望通过输入层后的第一个隐藏层是500个神经元,500个神经元,所以有500个偏置'''
# 第一层隐藏层 [128,784] dot [784,500] ==> [128,500]
w1 = tf.Variable(initial_value=tf.random.truncated_normal([784,500], stddev=0.1))
b1 = tf.Variable(initial_value=(tf.zeros([500])))
# 第二个隐藏层 [128,500] dot [500,200] ==> [128,200]
w2 = tf.Variable(initial_value=tf.random.truncated_normal([500,200], stddev=0.1))
b2 = tf.Variable(initial_value=(tf.zeros([200])))
# 输出层 [128,200] dot [200,10] ==> [128,10]
w3 = tf.Variable(initial_value=tf.random.truncated_normal([200,10], stddev=0.1))
b3 = tf.Variable(initial_value=(tf.zeros([10])))# 批量抓取数据  x = (60000,28,28)
for epoch in range(10):# 60000图循环10次for step, (x,y) in enumerate(train_db):# 将数据从三维转化为二维x = tf.reshape(x, [-1, 28 * 28])with tf.GradientTape() as taps:# [128,784] dot [784,500] ==> [128,500], 输入层通过w1和b1时,输出到第一隐藏层(500个神经元)的结果r1r1 = tf.add(tf.matmul(x,w1), b1)# 添加激活函数r1 = tf.nn.relu(r1)# 第二隐藏层 [128,500] dot [500,200] ==> [128,200]r2 = tf.nn.relu(r1 @ w2 + b2)# 输出层 [128,200] dot [200,10] ==> [128,10]y_predict = r2 @ w3 + b3# 把预测值转化为0-9的概率y_predict = tf.nn.softmax(y_predict)# 因为现在真是值是一个值,所以要进行one-hot编码转化为类别y = tf.one_hot(y, depth=10)# 交叉信息熵:信息熵越小则说明误差越小,正确率越高loss = tf.reduce_mean(tf.reduce_sum(- (y * tf.math.log(y_predict))))# 构建模型,并计算梯度下降grade = taps.gradient(loss,[w1, b1, w2, b2, w3, b3])# 定义学习率lr = 0.001# w = w-lr * grade_ww1.assign_sub(lr * grade[0])b1.assign_sub(lr * grade[1])w2.assign_sub(lr * grade[2])b2.assign_sub(lr * grade[3])w3.assign_sub(lr * grade[4])b3.assign_sub(lr * grade[5])if step % 100 == 0:print(f'第第{epoch}迭代的第{step}次都loss为:{loss}')# 上面循环完毕代表6w张图训练集已训练完成,接下来要通过测试集来查看正确率total_correct, total_num = 0, 0for setp, (x, y) in enumerate(test_db):x = tf.reshape(x, [-1, 28* 28])# 拿着训练好的权重和偏置进行验证r1 = tf.nn.relu(tf.add(tf.matmul(x,w1), b1))r2 = tf.nn.relu(r1 @ w2 + b2)y_test_predict = r2 @ w3 + b3# 把预测值输出到0-1之间prob = tf.nn.softmax(y_test_predict,axis=1)# 获取概率最大值的索引位置pred = tf.argmax(prob,axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)# 把正确率的数量添加到total_correcttotal_correct += int(correct)total_num += x.shape[0]acc = total_correct / total_numprint('测试集的正确率为:', acc)

数字识别实例两种实现方式(tensorflow2.x):1.调用高级API 2.手写简单神经网络 3.手写深度神经网络(DNN)相关推荐

  1. 手写体数字识别的两种方法

    基于贝叶斯模型和KNN模型分别对手写体数字进行识别 首先,我们准备了0~9的训练集和测试集,这些手写体全部经过像素转换,用0,1表示,有颜色的区域为0,没有颜色的区域为1.实现代码如下: # 图片处理 ...

  2. 浅谈POE供电系统中PSE两种供电方式——终端跨度、中间跨度

    标准的五类网线有四对双绞线但是在10M BASE-T和100M BASE-T中只用到其中的两对. IEEE80 2.3af允许两种用法: 1. 中间跨度法,信号线(1,2,3,6).电源线(4,5,7 ...

  3. halcon颜色识别的两种简单方式

    颜色识别的两种简单方式: 1.单通道方式: 原理:通过不同颜色在灰度图中的阈值范围不同来区分颜色(理论上这种方式不推荐,但在一定情况下适用) 材料: halcon代码: dev_close_windo ...

  4. MongoDB——MongoDB分片集群(Sharded Cluster)两种搭建方式

    MongoDB分片集群(Sharded Cluster)两种搭建方式 MongoDB分片的概念 分片集群包含的组件 分片集群架构目标 MongoDB分片集群搭建 第一套副本集 第二套副本集 配置节点副 ...

  5. android对象序列化的方法,Android 进阶-两种序列化方式 Serializable 和 Parcelable

    [导读]序列化 (Serialization)将对象的状态信息转换为可以存储或传输的形式的过程.在序列化期间,对象将其当前状态写入到临时或持久性存储区.以后,可以通过从存储区中读取或反序列化对象的状态 ...

  6. JavaSE学习总结(八)常用类(上)Object类==与equals方法的区别浅克隆的特点Scanner类String类String两种创建对象方式的区别String类的各种功能

    JavaSE学习总结(八)常用类(上)/Object类/==与equals方法的区别/浅克隆的特点/Scanner类/String类/String两种创建对象方式的区别/String类的各种功能 常用 ...

  7. Vivado使用心得(一)Vivado IP的两种综合方式:Global 和 Out-Of-Context

    ​在最新的Vivado的版本中,定制IP的时候,会有一个综合方式的选择,如下图所示.可以看到一种叫做"Global",一种叫"Out-Of-Context (OOC)&q ...

  8. Java两种排序方式快慢比较

    2019独角兽企业重金招聘Python工程师标准>>> Java中List的排序方式有两种,现在我们测试下这两种排序方式的快慢吧,我们需要用到两个类, 一个是运行程序的Main类,另 ...

  9. Vivado IP的两种综合方式:Global 和 Out-Of-Context

    在最新的Vivado的版本中,定制IP的时候,会有一个综合方式的选择,如下图所示.可以看到一种叫做"Global",一种叫"Out-Of-Context (OOC)&qu ...

最新文章

  1. 移动端导航页面html,swiper4实现移动端导航切换
  2. jQuery判断当前元素显示状态并控制元素的显示与隐藏
  3. 微信小程序----日期时间选择器(自定义精确到分秒或时段)
  4. mysql校验字符集
  5. h3c防火墙u200配置命令_网络设备配置——H3C命令行基本操作【分级】
  6. MySQL索引的创建、删除和查看
  7. ps、grep和kill联合使用杀掉进程(转)
  8. 一步一步写算法(之二叉树深度遍历)
  9. json解析库go-simplejson使用
  10. IDEA访问不到SpringBoot项目webapp下的内容
  11. 具有system权限的进程无法访问sdcard
  12. 淘宝客高手必备的14大WordPress插件
  13. c语言yuv图片cb,YUV格式图像基础
  14. 面试分享|机械行业面试常见问题有哪些
  15. XMLConstants.FEATURE_SECURE_PROCESSING错误
  16. Win10右下角提示“已禁用IME”的多种解决方法汇总
  17. ubuntu恢复被rm误删的数据及原理
  18. 学计算机网络技术遇到问题,维护计算机网络教室的常见问题及解决方案
  19. java 超卖_Java生鲜电商平台-秒杀系统如何防止超买与超卖?(小程序/APP)
  20. 以Listener和Talker为例ROS1和ROS2代码对比

热门文章

  1. HttpWatch及文件上传
  2. AI代写计划总结怎么做?分享一个代写计划总结小工具
  3. cxfreeze打包工程文件生成.exe,运行exe出现闪退问题,相关解决办法总结
  4. 软件工程复习 第一章 概述 软件定义 软件危机 软件三要素
  5. Uipath 百度OCR发票识别
  6. 如何缓解上台演讲的紧张
  7. pandas中的pct_change的用法简介
  8. go分析和kegg分析_GO分析和KEGG分析都是啥?
  9. 如何提高用户对短信消息的打开率?
  10. 大数据让人看到更真实的历史