目录

  • 一、LeNet5网络介绍
  • 二、环境搭建
  • 三、网络搭建以及训练
    • 3.1、加载数据集
    • 3.2、网络搭建
    • 3.3、模型训练
    • 3.4、模型固化
  • 四、c++ opencv加载模型

一、LeNet5网络介绍

LeNet5 这个网络包含了深度学习的基本模块:卷积层,池化层,全链接层。是其他深度学习模型的基础。LeNet-5共有7层,不包含输入,每层都包含可训练参数;每个层有多个Feature Map,每个FeatureMap通过一种卷积滤波器提取输入的一种特征,然后每个FeatureMap有多个神经元。

二、环境搭建

本人环境配置如下:

pycharm2021
vs2022
Anaconda3
tensorflow=2.3
opencv=4.5.5

前几个安装相对轻松,直接上官网安装即可,tensorflow使用pip命令安装,c++ opencv相对较为麻烦,可以参考本人以前的安装方法:c++ opencv 学习笔记(一) Visual Studio 2019 + OpenCV4.5.5 配置详解

三、网络搭建以及训练

3.1、加载数据集

tensorflow内置了MINST数据集,从tensorflow中导入即可

import tensorflow as tf
mnist = tf.keras.datasets.mnist
train, test = mnist.load_data()

将数据按照batch提供给网络模型

import numpy as npclass MNISTData:def __init__(self, data, need_shuffle, batch_size=128):""":param datas: 数据集,格式为 data,label:param shuffle: 是否随机打乱数据 True or False:param batch_size: 一批数据大小"""self._data = data[0]self._labels = data[1]self.num_examples = self._data.shape[0]self._need_shuffle = need_shuffleself._indicator = 0self._batch_size = batch_sizeif self._need_shuffle:self._shuffle_data()def __iter__(self):return selfdef _shuffle_data(self):p = np.random.permutation(self.num_examples)self._data = self._data[p]self._labels = self._labels[p]def next_batch(self):end_indicator = self._indicator + self._batch_sizeif end_indicator > self.num_examples:if self._need_shuffle:self._shuffle_data()self._indicator = 0end_indicator = self._batch_sizeelse:self._indicator = 0end_indicator = self._batch_sizeif end_indicator > self.num_examples:raise StopIterationbatch_data = self._data[self._indicator: end_indicator] / 255.0 # 归一化batch_labels = self._labels[self._indicator: end_indicator]self._indicator = end_indicatorreturn batch_data, batch_labelsdef __next__(self):return self.next_batch()train_dataset = dataset.MNISTData(train, True)
test_dateset = dataset.MNISTData(test, False)

查看数据集

def display(train_images, train_labels):plt.figure(figsize=(10,10))for i in range(25):plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(train_labels[i])plt.show()for data in train_dataset:display(*data)

3.2、网络搭建

使用tensorflow中的keras搭建网络结构,激活函数使用Mish

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.utils import get_custom_objectsclass Mish(Activation):def __init__(self, activate, **kwargs):super(Mish, self).__init__(activate, **kwargs)self.__name__ = "Mish"def mish(inputs):return inputs * tf.math.tanh(tf.math.softplus(inputs))def LeNet5(input_shape=[32, 32, 3]):get_custom_objects().update({'Mish': Mish(mish)})#输入层inputs = Input(shape=input_shape)#第一个卷积-池化层conv1 = Conv2D(6, 5, activation="relu", padding='same')(inputs)pool1 = MaxPooling2D((2, 2))(conv1)#第二个卷积-池化层conv2 = Conv2D(16, 5, activation="relu", padding='same')(pool1)pool2 = MaxPooling2D((2, 2))(conv2)#第三个卷积层conv2 = Conv2D(120, 5, activation="relu", padding='same')(pool2)fc = Flatten()(conv2)#全连接层fc1 = Dense(120, activation="relu")(fc)#输出层fc2 = Dense(10, activation="softmax")(fc1)model = Model(inputs, fc2)return model
model = LeNet5(input_shape=[28, 28, 1])

定义损失函数以及优化器

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']
)

保存模型

model_filepath = 'model/'
checkpoint_filepath = model_filepath + 'tmp/'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filepath,save_best_only=True,save_weights_only=True,monitor='accuracy',mode='max'
)

3.3、模型训练

开始训练(俗称炼丹)

# 是否使用GPU
use_gpu = True
tf.debugging.set_log_device_placement(True)
if use_gpu:gpus = tf.config.experimental.list_physical_devices(device_type='GPU')if gpus:for gpu in gpus:tf.config.experimental.set_memory_growth(device=gpu, enable=True)tf.print(gpu)else:os.environ["CUDA_VISIBLE_DEVICE"] = "-1"else:os.environ["CUDA_VISIBLE_DEVICE"] = "-1"# TensorBoard可视化工具
log_path = 'logging/'
logging = tf.keras.callbacks.TensorBoard(log_dir=log_path)
model_filepath = 'model/'
checkpoint_filepath = model_filepath + 'tmp/'
history = model.fit(train_dataset,epochs=10,steps_per_epoch=train_dataset.num_examples // BATCH_SIZE + 1,validation_data=test_dateset,validation_steps=test_dateset.num_examples // BATCH_SIZE + 1,callbacks=[cp_callback, logging ]
)model.load_weights(checkpoint_filepath)
model.save(model_filepath + 'model')

可视化训练过程
TensorBoard是一个可视化工具,它可以用来展示网络图、张量的指标变化、张量的分布情况等。进入logging文件夹的上一层文件夹,在DOS窗口运行命令:

tensorboard --logdir=./logging

在浏览器输入网址:http://localhost:6006,或者输入上图提示的网址,即可查看生成图。


3.4、模型固化

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2def export_frozen_graph(model, name, input_size) :f = tf.function(lambda x: model(x))f = f.get_concrete_function(x=tf.TensorSpec(shape=[None, input_size[0], input_size[1], input_size[2]], dtype=tf.float32))f2 = convert_variables_to_constants_v2(f)graph_def = f2.graph.as_graph_def()# Export frozen graphwith tf.io.gfile.GFile(name, 'wb') as f:f.write(graph_def.SerializeToString())export_frozen_graph(model, model_filepath + 'frozen_graph.pb', (input_size, input_size, 1))

四、c++ opencv加载模型

#include <opencv2/opencv.hpp>
#include <iostream>
#include <vector>using namespace std;//多分类问题用这个函数判断类别,二分类的话不用也行
std::vector<int> Argmax(cv::Mat x)
{std::vector<int> res;for (int i = 0; i < x.rows; i++){int maxIdx = 0;float maxNum = 0.0;for (int j = 0; j < x.cols; j++){float tmp = x.at<float>(i, j);if (tmp > maxNum){maxIdx = j; //更新最优值序号maxNum = tmp; //更新最优值}}res.push_back(maxIdx); //最优预测值的序号}return res;
}int main()
{//cv加载模型cv::dnn::Net net = cv::dnn::readNetFromTensorflow("frozen_graph.pb");//加载图片cv::Mat src = cv::imread("8.jpg", cv::IMREAD_COLOR);cv::Mat img = src;cv::cvtColor(img, img, cv::COLOR_BGR2GRAY);//调整图片大小cv::resize(img, img, cv::Size(28, 28));//归一化 0-1之间img.convertTo(img, CV_32FC1, 1.f / 255.f, -1.f);//格式转化cv::dnn::blobFromImage(img, img, 1.0, cv::Size(), cv::Scalar(), false, false, CV_32F);//将数据喂给网络net.setInput(img);//前向传播,得到传播结果cv::Mat pred = net.forward();//输出结果vector<int> res = Argmax(pred);//输出标签stringstream ss;string str;ss << "label:" << res[0];ss >> str;//放大图片便于观察cv::resize(src, src, cv::Size(280, 280));cv::putText(src, str, cv::Size(0, 40), cv::FONT_HERSHEY_COMPLEX, 1, cv::Scalar(0, 255, 0), 1);cv::imshow("", src);cv::waitKey();
}

结果如下:

有需要的可以下载完整项目链接进行测试:

GitHub:https://github.com/small-guang/LeNet5
CSDN:https://download.csdn.net/download/qq_45723275/77992089
其他项目链接:tensorflow2.3 搭建 vgg16训练cifar10数据集

tensorflow2 搭建LeNet5训练MINST手写数字数据集并用c++ opencv4.5.5 DNN加载模型预测结果相关推荐

  1. tensorflow2实现yolov3并使用opencv4.5.5 DNN加载模型预测

    目录 综述 一.什么是YOLO 二.YOLOv3 网络 1.网络结构 2.网络输出解读(前向过程) 2.1.输出特征图尺寸 2.2.锚框和预测 3.训练策略与损失函数(反向过程) 三.tensorfl ...

  2. BP算法实现--minst手写数字数据集识别

    实验步骤 初始化网络架构 网络层数,每层神经元数,连接神经元的突触权重,每个神经元的偏置 构造bp算法函数 对于一个输入数据,前向计算每层的输出值,保存未激活的输出和激活过的输出值,这里用的激活函数是 ...

  3. minst手写数字识别(带界面)

    minst手写数字识别(带界面) 目录 minst手写数字识别(带界面) 一.项目简介 二.项目结构及环境 三.网络结构介绍 四.程序文件介绍 五.使用介绍 六.源代码获取 一.项目简介 1)概述:手 ...

  4. Educoder 机器学习 神经网络 第四关:使用pytorch搭建卷积神经网络识别手写数字

    任务描述 相关知识 卷积神经网络 为什么使用卷积神经网络 卷积 池化 全连接网络 卷积神经网络大致结构 pytorch构建卷积神经网络项目流程 数据集介绍与加载数据 构建模型 训练模型 保存模型 加载 ...

  5. Python学习记录 搭建BP神经网络实现手写数字识别

    搭建BP神经网络实现手写数字识别 通过之前的文章我们知道了,构建一个简单的神经网络需要以下步骤 准备数据 初始化假设 输入神经网络进行计算 输出运行结果 这次,我们来通过sklearn的手写数字数据集 ...

  6. tkinter+socket&MySQL+keras识别minst手写数字

    tkinter + socket + keras + MySQL识别Minst手写数字 环境配置 代码 服务端 客户端 主函数main.py 类Window.py 实验报告部分 一.总体功能说明 1. ...

  7. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

  8. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...

  9. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  10. svm对未知数据的分类_SVM对sklearn自带手写数字数据集进行分类

    sklearn自带一些数据集,其中手写数字数据集可通过load_digits加载,我找到load_digits里头是这样 def load_linnerud(): """ ...

最新文章

  1. 高频面试题:如何保证缓存与数据库的双写一致性?
  2. Maven工程的分类
  3. 海信新机F30S即将发布:搭载紫光展锐虎贲T310处理器
  4. Solidity safesub防止溢出
  5. 2018/11/22工作日志
  6. selenium 如何处理table
  7. 什么是芯片加速器 Accelerator
  8. 局域网,手机与电脑文件共享
  9. 证书服务器,及申请证书。
  10. java中的Cipher类
  11. 第八章:善于利用指针
  12. 由pytorch中的super().__init__到python中的测试
  13. LeetCode——缺失数字(C语言)
  14. PAT 甲级1003 Emergency 题解
  15. 网页数据抓取之当当网
  16. 谁是女人一生中最重要的人
  17. 获取CARLA插件SCENARIO RUNNER
  18. java毕业设计颜如玉图书销售网站的设计与实现Mybatis+系统+数据库+调试部署
  19. Verilog语言快速入门
  20. 众昂矿业:萤石货源紧张,价格可能上涨

热门文章

  1. Unity3D开发资料
  2. c/c++成长之捷径
  3. UOJ#211. 【UER #6】逃跑 (Dynamic Programming)
  4. 25_多易教育之《yiee数据运营系统》OLAP平台-画像分析篇
  5. 三桥君:如何把SQL Server的数据库导为sql文件
  6. 儿童学计算机编程好处,十个理由告诉你孩子为什么要学习编程?
  7. oppo手机快速截屏的方法
  8. 射频通信接收机设计的主要结构
  9. 《信号与系统学习笔记》—通信系统(一)
  10. php连接mysql MariaDB_PHP+MariaDB数据库操作基本技巧