前言

尝试一下用Mnist数据集训练一个简单的CNN网络,然后搭建一个静态页面,在浏览器端加载模型使用canvas区域的内容预测手写数字。模型使用Pytorch编写,用cpu训了10个epoch之后导出为onnx模型。之后在浏览器端通过onnxruntime-web进行加载,并进行预测。

模型

模型代码其实网络上已经有很多了,原理和细节也不再赘述;需要注意的是,输入是一个Batchsize x 1 x 28 x 28 的矩阵,输出为Batchsize x 10的矩阵也就是说第一维是动态的,这就决定了我们在导出为onnx模型时的写法:

def transformToOnnx(model, batch_size, name='mnist.onnx'):model.eval()x = torch.randn(batch_size, 1, 28, 28)torch.onnx.export(model, x, name, export_params=True, opset_version=11, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})

首先需要使用model.eval()将模型切换为预测模式,接下来我们随机生成一个输入的参数,也就是Batchsize x 1 x 28 x 28大小的一个随机矩阵。在导出时需要指定导出的路径,输入和输出的符号(上边的写法意思是在后续加载模型的时候,输入变量名为input,输出变量名为output)。同时由于输入和输出的第一维都是batchsize,因此把它们指定为动态轴。

前端实现

初始化工程

首先采用vite初始化一个react-ts项目,这一步没有太多注意事项。

模型的加载和预测

为了加载模型,我们需要使用onnxruntime-web。onnxruntime-web是一个可以在浏览器环境下和nodejs环境下加载onnx模型的库,可以在CPU和GPU上运行,CPU使用web assambly来加载模型,而GPU使用Webgl来加载,默认运行在CPU上。两种方案支持的符号集不同,wasm方式支持全部的符号集,而webgl方式仅仅支持一部分符号集(具体的说明参考文献[1]);除此之外,在ios的chrome、edge和safari浏览器中仅支持wasm。本次小实验导出的模型如果采用webgl加载,就会遇到上边提到的符号集的问题,因此采用wasm加载模型。我们只需要:

yarn add onnxruntime-web

便可以在工程中安装这个包了。

接下来我们需要对vite工程进行一些配置。由于vite在启动server时有一个pre-bundle的过程,使用esbuild将各种非标准的模块转化为es6模块。onnxruntime-web中使用到了export namespace xx的写法,这些会在pre-bundle的时候报错,因此我们可以选择通过pre-bundle过程;

同时,即便我们跳过了pre-bundle的过程,我们会发现在项目启动之后,onnxruntime-web会自动的去static/js路径下去找两个wasm文件,而在启动服务和打包的时候并不会自动的加入这两个文件。而如果我们引入cdn上的onnxruntime-web库,我们会发现它会自动地去cdn地址请求wasm文件,cdn上这两个文件自然是存在的。参考onnxruntime给出的demo[2],可以看到,官方在使用webpack打包的时候也是使用了CopyWebpackPlugin将对应的文件拷贝到打包之后的目录中。为了方便开发和打包,建议首先跳过pre-bundle过程,然后采用cdn加载onnxruntime-web包,并在vite.config.js中声明该包为external,即:

// vite.config.js
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
import { viteExternalsPlugin } from 'vite-plugin-externals'// https://vitejs.dev/config/
export default defineConfig({plugins: [react(),viteExternalsPlugin({ // 声明为external'onnxruntime-web': 'ort'})],optimizeDeps: {exclude: [ // 跳过pre-bundle'onnxruntime-web'] },base: '/mnist-demo/'
})

然后在index.html中加上库的cdn地址:

<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>

接下来的过程其实就很简单了,我们成功的引入了onnxruntime-web库,然后需要用它来加载模型,并进行预测:

...// 加载模型const session = await ort.InferenceSession.create(model);// 输入数据,第一个参数是数据类型,第二个参数inputArray是一个一维数组,第三个参数表示的是维度,注意需要和之前模型导出时定义的维度相一致,即:dynamic x 1 x 28 x 28const inputs = new ort.Tensor("float32", inputArray, [1, 1, 28, 28]);// 使用run进行预测,需要注意的是,输入和输出与之前导出时定义的输入输出变量名一致const outputs = await session.run({input: inputs,});// 预测结果console.log(outputs.output.data);

到此为止,在浏览器加载模型的部分就完成了,接下来只需要想个办法获取到用户输入的数据,并使用这些数据进行预测。

获取输入数据

模型输入是1x28x28的图片,而让人在屏幕上手动的去在一个28像素x28像素的区域内绘制肯定是个不现实的事情(太小了)。因此我们需要把输入的canvas放大(这里采用的是300x300),在预测时对画布的输入进行缩小,并转化为单通道。

为了获取到这个300x300区域内的像素数据,我们使用canvas.getImageData()获取到这个区域内的rgba数组。接下来,我们需要将它缩放为28x28的大小。这里引入了pica库,使用pica的resizeBuffer函数对像素区域进行缩放。

由于canvas的默认颜色是黑色透明,因此我们拿到的数组的非画笔区域的rgba值为(0,0,0,0)。同时注意到模型的输入中,灰度的取值范围为-1-1,因此为了保留单通道,我们保留a,并将其根据是否为0,简单地映射到-1和1就够了。

还需要注意的是,由于画布会从300x300缩放到28x28,因此canvas画笔的粗细也是一个影响效果的因素:如果画笔过细,缩放之后画布区域的像素值都是0,也就没有效果了;如果画笔过粗,可能缩放之后,原本隔着很远的两个区域变成了邻居,也会影响效果。

最后,我们只需要根据上述操作,根据缩放、处理过后的数组构建输入的Tensor,并传入模型进行预测就可以了。

总结

到这里,其实模型的加载、预测和如何获取输入数据都已经完成了。最后就是把以上的东西串起来。实际的效果就是最上边两张图的样式,我把它放在了gitee page上,实测网络请求的速度还可以接受:

同时我也把它部署在我的服务器中,模型丢到cdn上,速度也还可以接受(gzip对模型好像压不了多少呀。。):

也就是说,对于一些简单的模型,我们完全可以丢到gitee page上进行使用,还是蛮好玩的。

最后丢个页面地址和仓库地址,有人需要的话我再去补readme,球球点个关注和star吧:

页面地址:

Gitee Page版本

部署到nginx的版本

仓库地址:

Github仓库

Gitee仓库

个人博客

原文地址


参考文献

[1] onnxruntime web: https://www.npmjs.com/package/onnxruntime-web#Operators

[2] onnxruntime-web使用demo https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js/quick-start_onnxruntime-web-bundler

浏览器中加载CNN进行手写数字识别,并部署到Gitee Page相关推荐

  1. 基于CNN的手写数字识别

    基于CNN的手写数字识别 文章目录 基于CNN的手写数字识别 零. 写在之前 壹. 聊聊CNN 01. 什么是CNN 02. 为什么要有CNN 03. CNN模型 3.1 卷积层 3.2 池化层 3. ...

  2. CNN之手写数字识别(Handwriting Recognition)

    CNN之手写数字识别(Handwriting Recognition) 目录 CNN之手写数字识别(Handwriting Recognition) 1.常用的包 2.常见概念 3.手写数字识别器实现 ...

  3. Matlab卷积神经网络(CNN)手写数字识别(一)

    今天买的书到了,开始接触卷积神经网络,展示书中内容~ Matlab卷积神经网络手写数字识别(一) 机器学习的基本流程 加载Matlab自带数据集 机器学习的基本流程 在机器学习中,一般将数据集划分为两 ...

  4. 利用CNN进行手写数字识别

    资源下载地址:https://download.csdn.net/download/sheziqiong/85884967 资源下载地址:https://download.csdn.net/downl ...

  5. 卷积神经网络(cnn) 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里12,卷积运算有两个非常重要特 ...

  6. Keras搭建CNN(手写数字识别Mnist)

    MNIST数据集是手写数字识别通用的数据集,其中的数据是以二进制的形式保存的,每个数字是由28*28的矩阵表示的. 我们使用卷积神经网络对这些手写数字进行识别,步骤大致为: 导入库和模块 我们导入Se ...

  7. 简陋的CNN实现手写数字识别

    文章目录 前言 背景知识 Neural Network Backpropagation CNN pytorch 介绍 代码 CNN模型 训练&测试 前言 日常翘课,但是作业还是要写的. 数据集 ...

  8. Python仿真及应用结课大作业—基于CNN的手写数字识别与涂鸦识别

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.结课文档目录 二.涂鸦识别(篇幅问题只展示其一) 涂鸦识别 引入必要的库函数 导入数据 为各个数据文件添加标签 数 ...

  9. 用CNN实现手写数字识别

    一.模型结构 用户输入的图像是一个784维的向量x,我们按照以下步骤搭建网络: 1.把x整形为[28, 28, 1]的灰度图 2.用一次3x3的卷积操作从x中抽象出32个基本特征,图像形状变成[28, ...

  10. 深蓝学院第三章:基于卷积神经网络(CNN)的手写数字识别实践

    参看之前篇章的用全连接神经网络去做手写识别:https://blog.csdn.net/m0_37957160/article/details/114105389?spm=1001.2014.3001 ...

最新文章

  1. Android自定义旋钮效果,Android自定义悬浮按钮效果实现,带移动效果
  2. 毕业三年,快手总包 90W 值得去吗?
  3. AUTOSAR从入门到精通100讲(二十七)-DoIP远程诊断及与UdsOnCan的比较
  4. 爬格子呀9.17(图论)
  5. 给IT新人的15个建议:程序员的辛酸反省与总结!
  6. linux时间类型localtime_r
  7. linux路由内核实现分析(二)---FIB相关数据结构(1)
  8. PHP面向对象之继承和多态
  9. CSS 3D透视效果 星空穿越
  10. Git版本控制及Goland使用Git教程
  11. 免费下载 客道巴巴文档 教程
  12. SQL基础篇 (增 删 查 改)
  13. 软件测试-兼容性测试
  14. php采集 今日头条链接,火车头按作者采集今日头条全部文章的方法
  15. 深度学习的loss变小梯度是否变小
  16. celery使用post方法解决方案
  17. 大学生应该如何选择服务器
  18. 网页下载工具curl命令简介
  19. 学生专用计算机怎么没声音,详细教你解决电脑突然没声音
  20. 利用Octave做分形几何

热门文章

  1. 历年软件设计师 试卷 参考案例解析
  2. 宝藏机器学习资料分享(超高质量pdf直接下载)
  3. 2021美赛C题思路
  4. 《深入解析Windows操作系统第4版》随笔记录03
  5. 算法设计与分析重点总结
  6. Python语言程序设计----【第8周 程序设计方法学】之8.2 Python程序设计思维
  7. 计算机电子设计论文,计算机毕业设计电子信息毕业设计电子信息毕业论文
  8. tftp64工具使用
  9. Java编写一个小游戏
  10. ie11兼容性问题,jsp在IE11显示不全问题,ie11覆盖内容问题解决方法