import tensorflow as tf
from tensorflow import keras
from    tensorflow.keras import datasets, layers, optimizers, Sequential, metricsimport  osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'def preprocess(x, y):#数值归一化,转换数据类型x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,y#导入数据集
(x,y),(x_test,y_test) = datasets.fashion_mnist.load_data()
print(x.shape,y.shape)db = tf.data.Dataset.from_tensor_slices((x,y))
batchsz = 128
#打乱数据
db = db.map(preprocess).shuffle(10000).batch(batchsz)db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test = db_test.map(preprocess).batch(batchsz)db_iter = iter(db)
sample = next(db_iter)
print('batch:',sample[0].shape,sample[1].shape)#取一个五层的网络
model = Sequential([layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]layers.Dense(10) # [b, 32] => [b, 10], 330 = 32*10 + 10])model.build(input_shape=[None, 28*28])
model.summary()#打印网络结构
# w = w - lr*grad
optimizer = optimizers.Adam(lr=1e-3)def main():for epoch in range(30):for step,(x,y) in enumerate(db):#x:[b,28,28] => [b, 784]# y: [b]x = tf.reshape(x, [-1, 28*28])with tf.GradientTape() as tape:# [b, 784] => [b, 10]logits = model(x)y_onehot = tf.one_hot(y, depth=10)#做10位的独热编码# [b]loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))loss_ce = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss_ce = tf.reduce_mean(loss_ce)grads = tape.gradient(loss_ce, model.trainable_variables)#https://blog.csdn.net/Cerisier/article/details/86523446optimizer.apply_gradients(zip(grads, model.trainable_variables))#一一对应去打包#利用取余方式,每训练100次打印一次if step % 100 == 0:print(epoch, step, 'loss:', float(loss_ce), float(loss_mse))#test#构建计数器,便于计算正确率total_correct = 0total_num = 0for x,y in db_test:# x: [b, 28, 28] => [b, 784]# y: [b]x = tf.reshape(x, [-1, 28*28])# [b, 10]logits = model(x)# logits => prob, [b, 10]prob = tf.nn.softmax(logits, axis=1)# [b, 10] => [b], int64pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)# pred:[b]# y: [b]# correct: [b], True: equal, False: not equalcorrect = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#在注意数据结构的基础上进行计算total_correct += int(correct)total_num += x.shape[0]acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()

tensorflow2.0基础操作-手写数字识别实战相关推荐

  1. 深度学习(4)手写数字识别实战

    深度学习(4)手写数字识别实战 Step0. 数据及模型准备 1. X and Y(数据准备) 2. out=relu{relu{relu[X@W1+b1]@W2+b2}@W3+b3}out=relu ...

  2. 深度学习数字仪表盘识别_【深度学习系列】手写数字识别实战

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  3. 【零基础】从零开始学神经网络《python神经网络编程》——手写数字识别实战

    文章目录 前言 一.机器学习是什么,深度学习是什么? 二.对NN,CNN,RNN,GNN,GAN的名词解释 三.详细介绍神经网络(NN) 1.认识神经网络 2.神经元 3.激活函数 4.权重--连接的 ...

  4. 4.1 keras基础实例 手写数字识别

    1)手写数据集 手写数据集是深度学习中,最基础应用最广泛的数据集. 手写数据集内置在keras中 import keras from keras import layers import matplo ...

  5. Pytorch+CNN+MNIST手写数字识别实战

    文章目录 1.MNIST 2.数据预处理 2.1相关包 2.2数据载入和预处理 3.网络结构 4.优化器.损失函数.网络训练以及可视化分析 4.1定义优化器 4.2网络训练 4.3可视化分析 5.测试 ...

  6. TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)

    TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) 源代码/数据集已上传到 Github - tensorflow-tutorial-samples 大白话讲解卷积 ...

  7. 利用Tensorflow实现手写数字识别(附python代码)

    手写识别的应用场景有很多,智能手机.掌上电脑的信息工具的普及,手写文字输入,机器识别感应输出:还可以用来识别银行支票,如果准确率不够高,可能会引起严重的后果.当然,手写识别也是机器学习领域的一个Hel ...

  8. Tensorflow 学习入门(二) 初级图像识别——手写数字识别

    初级图像识别--手写数字识别 背景知识储备 Softmax Regression MNIST 矩阵相乘 One Hot 编码 Cross Entropy(交叉熵) 代码实现 引入数据 设计数据结构 完 ...

  9. 手写数字识别案例、手写数字图片处理

    python_手写数字识别案例.手写数字图片处理 1.手写数字识别案例 步骤: 收集数据 带有标签的训练数据集来源于trainingDigits文件夹里面所有的文件,接近2000个文件,每个文件中有3 ...

最新文章

  1. IDEA Maven项目引入本地外部jar包
  2. c语言数据结构线性表LA和LB,数据结构(C语言版)设有线性表LA(3,5,8,110)和LB(2,6,8,9,11,15,20)求新集合?...
  3. laravel中的多对多关系详解
  4. JavaGC(1)—深入浅出Java垃圾回收机制
  5. mysql update语句卡死_oracle执行update语句时卡住问题分析及解决办法
  6. 百倍加速!Python量化策略的算法性能提升指南
  7. 宽带——选择中国电信
  8. 7年,我从功能测试到测试开发,写给即将进入或者正在做测试的你...
  9. python基本数据类型 整数、小数、字符串、布尔、空值、列表、元组、字典、集合、bytes
  10. web项目接入指纹识别+识别过程信息推送
  11. 微信朋友圈抓取 附近人自动加 附近人朋友圈抓取 最近一直在研究(有兴趣的看网址)...
  12. 集成Fbreader显示空白页
  13. 计算机网络之网络层-网络层拥塞控制
  14. Linux配置定时任务
  15. sencha 安装教程
  16. Win7下使用VirtualBox虚拟机安装OS X 10.9 Mavericks
  17. 公派访问学者申请签证的五点建议
  18. 百度与虚假广告的博弈
  19. CPNTools入门
  20. 重装系统如何保留正版Win10和Office

热门文章

  1. [Vue warn]: Property or method “throttle“ is not defined on the instance but referenced during rende
  2. Activity A 调用Activity B 里的方法探索
  3. 游戏史上30位最有影响力的人
  4. n皇后问题-回溯法求解
  5. PHP接入快递鸟查询快递
  6. 去哪儿网抢票成功率怎么样?
  7. python基础教程Day06
  8. 汽车操作系统攻防综述
  9. java计算机毕业设计青岛地区常见昆虫图鉴与论坛源程序+mysql+系统+lw文档+远程调试
  10. 默孚龙导电滑环的内部结构和使用范围