提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 主要过程
    • 导入
    • 加载数据
    • 创建模型和训练
    • 模型应用
  • 总结

前言

非专业程序员,主业PLC单片机,2019年想扩充知识体系,紧跟潮流,带学生参加了人工智能大赛,才开始接触tensorflow以及深度学习的基本过程,非常艰难。后来比赛完了之后,学生也毕业了,因为感觉难度过大,而且自己将来也不准备转到这个行业,干脆就放弃了。最近疫情关在家里,想了一个晚上,对于本专业,自己干过项目,参加过大赛,虽然没有掌握PLC单片机所有的知识,但是掌握了方法论,能够快速的学习新的设备和目前尚未掌握的功能,没有应用的点,学习那些知识也就没有太大的必要。所以就决定利用疫情,继续开拓对自己来说仍然是新的领域,人工智能,从头开始。

现在再看tensorflow,已经改头换面了,1.x版本太过复杂,难以理解,2.x改善了很多,入门容易。


今天还是从手写数字开始,入门代码非常多,大多都是有关模型训练的,官网也有保存模型及加载模型的代码,这里就不再多写


预测部分,我在网上找到的都是使用mnist自己的测试数据来进行预测,有的是使用加载模型的方法测试准确率,有的是预测测试集中的数据,但是没有针对一个自己的图片(可以是摄像头拍的,可以是自己在画图里写的数字)的预测方法,这里主要解决这个问题。


这个想法产生的原因很简单,其实就是需要将我们的工作应用到现实中,整个过程我想是这几个步骤:准备训练数据、数据预处理来适配模型网络、搭建深度学习网络、训练模型、模型保存,到这实际上开发工作已经完成,下面的步骤就是要应用了,准备数据、加载模型、预测结果,预测的结果将用到后面的业务逻辑。

主要过程

导入

import tensorflow as tf
from tensorflow import keras
import cv2
from keras.preprocessing.image import img_to_array
import numpy as np

加载数据

def loadData():mnist = tf.keras.datasets.mnist(x_train, y_train),(x_test, y_test) = mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0

创建模型和训练

此部分大多来自官方文档和网络,就是一层全连接,理解也较为容易,最后将模型存为h5文件。

def create_model():model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10, activation='softmax')])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])return model
def train():model = create_model()model.summary()model.fit(x_train, y_train, epochs=5)model.evaluate(x_test,  y_test, verbose=2)model.save('my_model.h5')

模型应用

现在我们有了训练好的模型,正常逻辑就是考虑如何应用,官方文档有加载模型的方法,加载好之后预测就是一个predict函数,将预测的数据传进去就能得出结果,因为输入的尺寸是28*28,所以我考虑到图片大小不一,需要转换尺寸,这里我想到了用opencv,所以下面的函数就是使用cv处理图片。
为了简化操作过程探寻方法论,我使用黑底图片。

def imgTool():img = cv2.imread("D:/workspace/MNIST_data/1.jpg")img = cv2.resize(img,(28,28))img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)img = img_to_array(img)img = np.expand_dims(img, axis=0)

即使是黑底图片,也一定要转灰度,否则维度不对,在这浪费了不少时间。
步骤:读图片、塑形、灰度、图转矩阵、展开矩阵。
这个函数只用来测试图片处理结果。


最后的预测函数,在主程序里调用即可。

def predict():img = cv2.imread("D:/workspace/MNIST_data/1.jpg")img = cv2.resize(img,(28,28))img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)img = img_to_array(img)img = np.expand_dims(img, axis=0)img = img / 255.0new_model = tf.keras.models.load_model("my_model.h5")new_model.summary()pre= new_model.predict(img)print(np.argmax(pre))

总结

这里不讨论模型的准确度及网络的合理性,主要问题是训练好的模型如何应用。

mnist手写数字模型训练、保存、加载及图片预测相关推荐

  1. word2vec模型训练保存加载及简单使用

    目录 word2vec模型训练保存加载及简单使用 一 word2vec简介 二.模型训练和保存及加载 模型训练 模型保存和加载 模型的增量训练 三.模型常用API 四.文本相似度计算--文档级别 wo ...

  2. 手写数字识别--Android Studio 加载tensorflow模型

    思路: 在电脑端调用mnist数据集,构建深度卷积神经网络模型,使用TensorFlow进行训练,达到99%+的测试集数据准确率,继而把模型迁移到App端.具体迁移教程参考之前的文章: https:/ ...

  3. Pytorch模型训练保存/加载(搭建完整流程)

    文章目录 前言 模型训练完整步骤 模型保存与加载 GPU训练 "借鸡生蛋" 模型使用 本博文优先在掘金社区发布! 前言 我们这边还是以CIARF10这个模型为例子. 现在的话先说明 ...

  4. Caffe MNIST 手写数字识别(全面流程)

    目录 1.下载MNIST数据集 2.生成MNIST图片训练.验证.测试数据集 3.制作LMDB数据库文件 4.准备LeNet-5网络结构定义模型.prototxt文件 5.准备模型求解配置文件_sol ...

  5. MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测

    Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...

  6. 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...

  7. [tensorflow、神经网络] - 使用tf和mnist训练一个识别手写数字模型,并测试

    参考 包含: 1.层级的计算.2.训练的整体流程.3.tensorboard画图.4.保存/使用模型.5.总体代码(含详细注释) 1. 层级的计算 如上图,mnist手写数字识别的训练集提供的图片是 ...

  8. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  9. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

最新文章

  1. Spring事务--笔记
  2. 小强升职记梗概_解读《小强升职记》——一本关于时间管理的书
  3. oracle里有limit怎么用,[ORACLE]ORACLE 实现mysql中的limit 功能
  4. 线性代数四之动态DP(广义矩阵加速)——Can you answer these queries III,保卫王国
  5. Huawei eNSP 安装教程
  6. Docker系列之一:在线安装docker和下载镜像
  7. 哈啰单车JAVA面经
  8. Oracle 分组求和(特殊处理)
  9. java 数字游戏的方法_java实现猜数字游戏
  10. Java中submit的方法,线程池中 submit()和 execute()方法区别
  11. 递归专题---[2]开根号
  12. 小区人脸识别门禁系统云平台需求分析文档
  13. Neo4j ① <图论>图,节点,关系,属性<知识图谱和图库>图谱,图库,优势<基础>模块,应用场景,环境搭建,浏览器
  14. 如何使用petri网建模工具pipe4.3.0
  15. 教你小小JAVA爬虫爬到HDU首页(只为学习)
  16. 小型机和服务器有何区别
  17. “香约宁波”寻觅城市文化味
  18. c语言做bs架构_BS架构技术方案 Technology
  19. sip 时序图_时序图学习(一)
  20. DMPC-PEG-聚乙烯吡咯烷酮/聚乙烯醇/聚甲基丙烯酸甲酯/聚丙烯酰胺/聚醋酸乙烯酯

热门文章

  1. 枸杞的功效与食用方法
  2. kali下经典的ddos攻击软件_Kali-DDoS工具集合
  3. 3、微信小程序-通信
  4. 不要在学习启动管理器和元编程上浪费时间
  5. 引流复盘:从知乎引流20万粉,我只用了1个月
  6. 视频教程-从入门到精通学全套AI 轻松掌握illustrator基础加实战技能视频课程-Illustrator
  7. 最新版继续教育学习软件下载地址
  8. Mysql8.0设置允许远程连接
  9. (附源码)springboot高校社团管理系统的开发毕业设计231128
  10. 常见License错误代码