环境: Ubuntu 18.04, tensorflow 2.4.1

mnist是Yann Lecun大神的手写数据,数据中的数字都是28X28的图像,每个像素点是[0-255]的值

其中训练数据为60000,测试数据为10000

打印出数字看一下

def show_mnist(images, labels):n = 5m = 5for i in range(n):for j in range(m):plt.subplot(n, m, i*n+j+1)index = i*n+jarray = images[index]plt.title(labels[index])plt.imshow(array, cmap='Greys')plt.show()

准备数据

因为数据是三维的, (60000,28,28),但是我们构建的网络是

# 200 个 epoch
EPOCH = 200
BATCH_SIZE = 128
VERBOSE = 1
# 需要把图像二位数据压缩为一维数据
RESHAPED = 784 # 28*28
# 数字分类类别  0~9 一共10个
NUM_CLASSES = 10# 载入数据,第一次使用时,会稍微花点时间下载
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()# 将图像二维数据压缩为一维,同时将数据转为浮点型
x_train = x_train.reshape(60000, RESHAPED).astype('float32')
x_test = x_test.reshape(10000, RESHAPED).astype('float32')# normalization
# 如果数据不做归一化,其实训练也可以进行,但是会影响最终的loss
x_train /= 255
x_test /= 255# 将label的数字 0~9 转为hothot
y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)

构建网络

使用最简单的一层网络

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(NUM_CLASSES, input_shape=(RESHAPED,), activation='softmax'))
model.summary()

训练

## compile network
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=VERBOSE)## validation  92.36
val_loss,val_acc = model.evaluate(x_test, y_test, verbose=VERBOSE)
print("Test loss: ", val_loss)
print("Test accuracy: ", val_acc)

完成代码

import tensorflow as tf
import matplotlib.pyplot as plt
# from PIL import ImageEPOCH = 200
BATCH_SIZE = 128
VERBOSE = 1
RESHAPED = 784 # 28*28
NUM_CLASSES = 10def show_mnist(images, labels):n = 5m = 5for i in range(n):for j in range(m):plt.subplot(n, m, i*n+j+1)index = i*n+jarray = images[index]plt.title(labels[index])plt.imshow(array, cmap='Greys')plt.show()(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()print('training images data shape:', x_train.shape)
print('training labels data shape:', y_train.shape)
print('testing images data shape:', x_test.shape)
print('testing labels data shape:', y_test.shape)#show_mnist(x_train, y_train)x_train = x_train.reshape(60000, RESHAPED).astype('float32')
x_test = x_test.reshape(10000, RESHAPED).astype('float32')# normalization
x_train /= 255
x_test /= 255#
y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)# train_labels = train_labels.reshape(60000, RESHAPED)## build network
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(NUM_CLASSES, input_shape=(RESHAPED,), activation='softmax'))
model.summary()## compile network
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCH, verbose=VERBOSE)## validation  92.36
val_loss,val_acc = model.evaluate(x_test, y_test, verbose=VERBOSE)
print("Test loss: ", val_loss)
print("Test accuracy: ", val_acc)

训练结果

mnist学习实例(1)相关推荐

  1. mnist学习实例(2)

    环境: Ubuntu 18.04, tensorflow 2.4.1 该版本是全连接网络的优化版本,采用了卷积神经网络,参考. 可以看到BATCH_SIZE没有改动,而EPOCH明显减少. 但是EPO ...

  2. 涵盖 14 大主题!最完整的 Python 学习实例集来了!

    机器学习.深度学习最简单的入门方式就是基于 Python 开始编程实战.最近闲逛 GitHub,发现了一个非常不错的 Python 学习实例集,完全是基于 Python 来实现包括 ML.DL 等领域 ...

  3. ajax请求返回json实例,Jquery Ajax 学习实例2 向页面发出请求 返回JSon格式数据

    一.AjaxJson.aspx 处理业务数据,产生JSon数据,供JqueryRequest.aspx调用,代码如下: protected void Page_Load(object sender, ...

  4. php实训总结00字,说明的比较细的php 正则学习实例

    说明的比较细的php 正则学习实例 "^The": 匹配以 "The"开头的字符串; "of despair$": 匹配以 "of ...

  5. 深度学习之生成对抗网络(1)博弈学习实例

    深度学习之生成对抗网络(1)博弈学习实例 博弈学习实例  在 生成对抗网络(Generative Adversarial Network,简称GAN)发明之前,变分自编码器被认为是理论完备,实现简单, ...

  6. 从入门到入土:Python爬虫学习|实例练手|爬取猫眼榜单|Xpath定位标签爬取|代码

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

  7. 从入门到入土:Python爬虫学习|实例练手|爬取百度翻译|Selenium出击|绕过反爬机制|

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

  8. 从入门到入土:Python爬虫学习|实例练手|爬取新浪新闻搜索指定内容|Xpath定位标签爬取|代码注释详解

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

  9. 从入门到入土:Python爬虫学习|实例练手|爬取百度产品列表|Xpath定位标签爬取|代码注释详解

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

最新文章

  1. 如何在MacOS上创建第一个iOS Flutter应用
  2. Spark HistoryServer日志解析清理异常
  3. 七点建议,帮助你编写出简洁、干练的Java代码
  4. git 回滚 add 操作_炫技!git 优雅回滚一次错误的合并操作!
  5. 参数 携带 跳转_微信小程序:页面跳转及参数传递
  6. 设单片机的晶振频率为6mhz c语言,单片机习题科学出版社.doc
  7. Nacos集群部署说明
  8. 剑指offer刷题感想
  9. struts2中的constant配置详解
  10. php mysql5.7.110安装教程_CentOS7安装配置Nginx1.10、PHP5.6、MySQL5.7教程
  11. Android Uri to Path
  12. 点击选中框 批量删除
  13. mac remix导入本地项目
  14. Git patch的使用方法和场景
  15. PySpark fold foldByKey用法
  16. - 模块“VPMC“启动失败,未能启动虚拟机?
  17. Ambari2.7.4配置HIVE_AUX_JARS_PATH
  18. 大数据背景下的智慧物流:物流行业解决方案
  19. Android Studio查看SQLite数据库方法大全
  20. 中冠百年|工薪族怎么投资理财

热门文章

  1. 第十八课.动态图模型
  2. 第六课.模型评估与模型选择
  3. python使用matplotlib, seaborn画图时候的数据加载
  4. python(matplotlib)画柱状图(1)
  5. 在线作图|小基因组——线粒体基因组圈图
  6. R语言数据格式转换函数、数据类型判断函数(numeric、character、vector、matrix、data.frame、factor、logical)、R语言数据格式类型转换
  7. pandas数据预处理(字段筛选、query函数进行数据筛选、缺失值删除)、seaborn可视化分面图(facet)、seaborn使用Catplot可视化分面箱图(Faceted Boxplot)
  8. R语言dplyr包pull函数抽取dataframe数据列实战
  9. ANTS医学影像配准+Li‘s 核磁共振影像数据处理
  10. 潜在狄利克雷分配(LDA,Latent Dirichlet Allocation)模型(三)