MNIST数据集
MNIST数据集是一个大型的手写体数字数据库,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试。MNIST数据库中的图像集是NIST(National Institute of Standards and Technology)的两个数据库的组合:专用数据库1和特殊数据库3。数据集是有250人手写数字组成,50%是高中学生,50%是美国人口普查局。
MNIST数据集分为60,000张的训练数据集合10,000张的测试数据集,每张图像的大小为28x28(像素);每张图像都为灰度图像,位深度为8(灰度图像是0-255)

程序代码部分

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model# 加载并准备 MNIST 数据集。
mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 增加维数
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]# 切分数据集,以混淆数据集
# from_tensor_slices:它的作用是把给定的元组、列表和张量等数据进行特征切片。切片的范围是从最外层维度开始的。
# shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)# 构建tf.keras模型
class MyModel(Model):def __init__(self):super(MyModel, self).__init__()# ReLU(Rectified Linear Unit, 修正线性单元): max(0, x)self.conv1 = Conv2D(32, 3, activation='softmax')self.flatten = Flatten()self.d1 = Dense(128, activation='softmax')# softmax: si = e^i /(Σej)self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.conv1(x)x = self.flatten(x)x = self.d1(x)return self.d2(x)model = MyModel()# 损失函数和优化器
# SparseCategoricalCrossentropy分类交叉熵函数
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
# Adam随机梯度(SGD)下降与动量结合(Momentum)
optimizer = tf.keras.optimizers.Adam()# 设定损失函数和准确率的评估标准
# SparseTopKCategoricalAccuracy (稀疏多分类TopK准确率,要求y_true(label)为序号编码形式)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')# 使用tf.GradientTape来训练模型
# @ :装饰器本质上是一个Python函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外功能.
@tf.function
def train_step(images, labels):# 求解梯度with tf.GradientTape() as tape:# 得到预测结构,和进行损失值的计算pre = model(images)los = loss_obj(labels, pre)# gadient是对los中的variabs自动求梯度,返回结果gradients = tape.gradient(los, model.trainable_variables)# 自动更新模型参数optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_loss(los)train_accuracy(labels, pre)# 测试模型
@tf.function
def test_step(images, labels):pred = model(images)t_loss = loss_obj(labels, pred)test_loss(t_loss)test_accuracy(labels, pred)if __name__ == "__main__":EPOCH = 50for epoch in range(EPOCH):train_loss.reset_states()train_accuracy.reset_states()test_loss.reset_states()test_accuracy.reset_states()for images, labels in train_ds:train_step(images, labels)for test_images, test_labels in test_ds:test_step(test_images, test_labels)template = 'Epoch:{} , loss:{}, Accuracy:{}%, Test Loss:{}, Test Accuracy:{}%'print(template.format(epoch+1, train_loss.result(), train_accuracy.result()*100, test_loss.result(), test_accuracy.result()*100))

第一个TensorFlow程序相关推荐

  1. 构建第一个tensorflow神经网络

    朋友们,如需转载请标明出处:https://blog.csdn.net/jiangjunshow 本篇文章我会教大家构建第一个tensorflow程序. PS:因为本篇文章是我在人工智能教学中给我的二 ...

  2. 【转载】使用tf.py_func函数增加Tensorflow程序的灵活性

    转自:https://blog.csdn.net/jiongnima/article/details/80555387 目录 tf.py_func函数接口 tf.py_func在Faster R-CN ...

  3. 一个资深程序猿对Python的总结

    1.环境管理:管理 Python 版本和环境的工具. p – 非常简单的交互式 python 版本管理工具. pyenv – 简单的 Python 版本管理工具. Vex – 可以在虚拟环境中执行命令 ...

  4. 经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

    不知不觉,笔者接触Tensorflow也满一年了.在这一年当中,笔者对Tensorflow的了解程度也逐渐加深.相比笔者接触的第一个深度学习框架Caffe而言,笔者认为Tensorflow更适合科研一 ...

  5. 第一个spring程序

    第一个spring程序: 第一步:导入jar包. 新建maven项目: <dependency><groupId>org.springframework</groupId ...

  6. 快速搭建第一个Mybatis程序

    一.简介 1.初识Mybatis 2.持久化 3.持久层 4.为什么需要Mybatis 二.第一个Mybatis程序 1.搭建环境 2.创建一个新模块 3.编写代码 4.Junit测试 三.可能遇到的 ...

  7. Linux上运行一个c程序

    b站的视频链接:Linux虚拟机运行c程序_哔哩哔哩_bilibili希望对大家有所帮助,不对的地方还请多多指教!https://www.bilibili.com/video/BV18Q4y1r7st ...

  8. Go:分布式学习利器(1) -- 开发环境搭建 + 运行第一个go程序

    文章目录 为什么要学习 go 开发环境搭建 -- MAC 运行第一个go程序 go 函数的返回值设置 go 函数的命令行参数 为什么要学习 go 在如下几个应用场景的需求下产生了go: 超大规模分布式 ...

  9. 用java实现一个计算器程序_1.2第一个java程序——hello world

    第一个java程序--hello world 实现一个java程序,主要有三个步骤:1.编写源代码,2.编译源代码,3.运行.java的源代码必须先编译,然后才能由JVM解析执行.所以我们程序员第一步 ...

  10. C(第一个C程序) 和 C++ (第一个C++程序)对比碰撞

    个人博客首页(点击查看详情) -- https://blog.51cto.com/11495268 1.简介 C++ 是对 C 的继承.扩展,但从语言角度来说,这是 两种变成语言,就一定存在不同,本文 ...

最新文章

  1. POJ-2746:约瑟夫问题(Java版)
  2. linux之cat命令详解
  3. 第八十四期: Java、Web 和移动程序员学习的 12 个框架
  4. android设置图片自适应控件大小
  5. 数字证书转换cer---pem
  6. pb利用datawindow查询符合条件的数据并且过滤掉其他数据_牟宇航:百度OLAP数据库——Palo...
  7. 内存池的设计和实现总结(一)
  8. css字体图标的使用
  9. FastStone Capture无法录制系统声音解决方法(win10)
  10. Butterfly-蝴蝶-主题优化、美化-Lete乐特
  11. 新浪微博分享出现libc++abi.dylib: terminating with uncaught exception of type NSException微博微信SDK运行编译报错
  12. WLAN按钮不见了或者网络适配器不见了导致上不了网
  13. 继续写usb gadget驱动(解决枚举失败问题)
  14. a113 智能音箱芯片方案_智能音箱九大芯片方案商及其生产厂商和代表作品介绍-控制器/处理器-与非网...
  15. 一起打造自己的自动驾驶小车mycar - 4.手柄控制小车移动
  16. 送书 | 哈佛大学单细胞课程:笔记汇总前篇
  17. 3D种类游戏系统开发
  18. 【航拍中国】广东笔记
  19. JS前端取得并解析后台服务器返回的JSON数据的方法
  20. 【180928】魔兽连连看游戏源代码

热门文章

  1. 官网USB读卡器移植(TF卡SDIO模式 + SPI-flash)
  2. airdrop搜不到对方_如何将AirDrop图标添加到您的macOS Dock
  3. Linux配置访问服务器图片路径(防止踩坑)
  4. cubeIDE开发,在LCD显示摄像头抓取的图片数据
  5. VML 魅力初现--美少女图(vml可以这样强大?)
  6. 内存管理基本原理及非ARC环境使用小心得
  7. LOJ#6198. 谢特 SAM+启发式合并+01trie
  8. contents()与children()的用法和区别
  9. wifi mouse linux,WiFi Mouse Pro
  10. AI代替30%班主任工作 尚德机构督学机器人上线