小白玩机器学习(6)--- 基于Tensorflow.js的在线手写数字识别
一、题目要求
1.三个js文件,分别完成:网络训练以及模型保存、模型加载及准确率测试、在线手写数字识别;
2.模型测试准确率要高于99.3%(尽量);
3.在线手写数字识别需要能够通过鼠标在画布中写入0~9数字,并进行实时识别,按空格键清除。测试需具有一定的准确性。
二、实验原理
利用卷积神经网络提高数字识别结果的精度。
假设图像的尺寸是28*28,那么如果我们在下一层有1000个单位,我们就需要学习28*28*1000个单位的权重。像素可能是相关的,因此构建了一个k*k核作为权重学习的过滤器。
池化层没有需要学习的变量。它的作用是对图像进行细分采样,使下一层可以查看更大的空间区域。进一步缩小网络范围,减少需要学习的参数量。
如何进一步提高准确性?添加noise或dropout;添加更多层;使用更多的epochs和更大的batch size;在模型中添加卷积层,使用卷积神经网络强化准确率。
保存并加载 tf.Model的方法:tf.Model和tf.Sequential同时提供了函数 model.save 允许您保存一个模型的拓扑结构和权重。IndexedDB (仅限浏览器):await model.save('indexeddb://my-model');这样会将模型保存到浏览器的IndexedDB存储中。与本地存储一样,它在刷新后仍然存在,同时它往往也对存储的对象的大小有较大的限制。(参考链接:https://blog.csdn.net/Aria_Miazzy/article/details/103793323)
三、设计思路
1. 准备工作
下载MNIST数据集(http://yann.lecun.com/exdb/mnist/)
数据读取需要下载并保存为mnist.js文件:
https://github.com/CodingTrain/Toy-Neural-Network-JS/blob/master/examples/mnist/mnist.js
添加加载数据集的代码:
loadMNIST(function (data) {mnist = data;console.log(mnist);})
三个页面三个js文件分别进行:
begin.html对应train.js完成网络模型的训练以及模型保存;
Load.html对应load.js完成模型的加载以及准确率测试;
Recognition.html对应recognition.js完成手写体在线实时识别
2. 添加网络
根据上图添加神经网络:(train.js)
(1)添加卷积层,大小为28*28,其中卷积核大小为5,使用的激活函数为relu;(2)添加池化层,尺寸为2*2;
(2)添加卷积层,卷积核个数为5,激活函数为relu;
(3)添加池化层;
(4)为了提高准确路,在此处添加dropout,并且rate=0.5;
(5)降维后添加全连接层,激活函数为relu;
(6)使用adam()优化器并设置rate=0.002,损失函数为softmaxCrossEntrop;至此完成了网络的配置。
// 初始化模型
const model = tf.sequential();
// Convolutional layer 二维卷积层
model.add(tf.layers.conv2d({inputShape: [28, 28, 1], // 1:颜色黑白kernelSize: 5, // 卷积核大小为5filters: 16, // 卷积核数量为16strides: 1, // 步长为1activation: 'relu', // 激活函数为relukernelInitializer: 'varianceScaling' // 初始化卷积核
}));
// 经过这层变化[28,28,1]-->[14,14,16]
// Pooling layer 二维池化层
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], // 尺寸strides: [2, 2] // 步长
}));
// Convolutional layer 二维卷积层
model.add(tf.layers.conv2d({kernelSize: 5, // 卷积核filters: 32,strides: 1,activation: 'relu',kernelInitializer: 'varianceScaling'
}));
// Pooling layer 池化层
model.add(tf.layers.maxPooling2d({poolSize: [2, 2],strides: [2, 2]
}));
// 添加dropout rate = 0.5随机去掉一半
model.add(tf.layers.dropout({rate: 0.5
}));
// Flatten layer 降维
model.add(tf.layers.flatten());
// Dense layer
model.add(tf.layers.dense({//全连接层units: 128,activation: 'relu'
}));
model.add(tf.layers.dense({units: 10, // 对应0-10数字}));
const OPT = tf.train.adam(0.002) // 优化器const config = {optimizer: OPT,loss: tf.losses.softmaxCrossEntropy, // 损失函数
}
model.compile(config); //模型设置好配置
3. 加载数据
由于训练集数量比较大,这里选取了前60000个数据进行训练(train.js)
console.log("载入数据")
inputs = tf.tensor2d(mnist.train_images.slice(0, 60000));
outputs_org = tf.tensor1d(mnist.train_labels.slice(0, 60000));// 标签Y
outputs = tf.oneHot((outputs_org), 10);//全部对应到0-9 [0,0,0,0,0,0,0,1]console.log("重组数据") // 归一化除以255 变成0-1
inputs = tf.div(inputs, tf.scalar(255.0));
inputs = inputs.reshape([60000, 28, 28, 1]); // 格式化28* 28* 1
4. 训练模型
这里使用15个epoh迭代,并且实时输出每一轮结果的loss.(train.js)
async function train() {for (let i = 1; i < 15; i++) {const h = await model.fit(inputs, outputs, {atchSize: 200,epochs: 1);console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);}const saveResults = await model.save('indexeddb://my-model-6');console.log("模型已经保存");select('#modelStatus').html('模型已经训练完成并保存');}
5. 其中需要对模型进行保存和重加载
// 保存训练模型到浏览器数据库my-model-5const saveResults = await model.save('indexeddb://my-model');
// 加载已经保存的my-model模型,不需要重新训练
const model = await tf.loadLayersModel('indexeddb://my-model');
6. 测试训练准确率
首先加载测试数据,这里选择前10000个,之后进行训练(load.js)
console.log("加载测试数据。。")inputs_test = tf.tensor2d(mnist.test_images.slice(0, 10000));inputs_test = tf.div(inputs_test,tf.scalar(255.0));inputs_test = inputs_test.reshape([10000, 28, 28, 1]);outputs_test = tf.tensor1d(mnist.test_labels.slice(0, 10000));print(outputs_test.shape);console.log("测试数据加载完成")async function test() {const model = await tf.loadLayersModel('indexeddb://my-model');console.log('加载已经保存的模型');output_tem = model.predict(inputs_test);label = tf.argMax(output_tem, 1);// 打印测试准确率tf.div(tf.sum(outputs_test.equal(label)), mnist.test_labels.length).print();result = tf.div(tf.sum(outputs_test.equal(label)), mnist.test_labels.length);select('#modelStatus').html('模型已经加载完成:' + result);}
7. 手写体识别可视化
实时鼠标在区域画数字,会进行预测,点击空格键删除。(recognition.js)
(参考链接:https://github.com/CodingTrain/Toy-Neural-Network-JS/blob/master/examples/mnist)
let img = user_digit.get();if(!user_has_drawing) {return img;}let inputs = [];img.resize(28, 28);img.loadPixels();for (let i = 0; i < 784; i++) {inputs[i] = img.pixels[i * 4];}inputs = tf.tensor2d([inputs]);inputs = inputs.reshape([1,28,28,1]);let prediction = model.predict(inputs);let guess = tf.argMax(prediction,1);user_guess_ele.html(guess.dataSync());return img;image(user_digit, 0, 0);// 鼠标控制画线,预测数字if (mouseIsPressed) {user_has_drawing = true;user_digit.stroke(255);user_digit.strokeWeight(16);user_digit.line(mouseX, mouseY, pmouseX, pmouseY);}
四、实验结果
1. 网络训练以及模型保存
运行页面加载模型并开始训练,显示每个epoch的loss值,迭代完成后模型保存,页面也显示‘模型已经训练完成并保存’:
2. 模型加载及准确率测试
模型保存完成之后,点击‘测试准确率’按钮,跳转到模型测试页面,加载测试数据并显示准确率。可见当前的准确率为99.39
3. 在线手写数字识别
数据测试完成之后,点击‘开始手写识别’按钮,跳转到手写识别页面,可以随机用鼠标在电脑上画0-9的数值测试结果,猜测的数字会显示在下面,点击空格键重画。首先会显示“正在加载模型”。当模型加载好后会出现“模型已经加载完成”,之后可以进行手写识别,如下图:
五、总结提升
(1)使用Tensorflow.js构建深度模型。使用卷积神经网络提高准确率
(2)把数组数据转换成张量,把标签转换成一种热类型。转换绘图成28*28图像(img。调整大小(28、28))并将其平铺以供测试。
(3)三个js文件,分别完成:网络训练以及模型保存、模型加载及准确率测试、在线手写数字识别。
(4)异步保存和加载模型(异步函数和等待)。
(5)如何进一步提高准确性?添加noise或dropout;添加更多层;使用更多的epochs和更大的batch size;在模型中添加卷积层,使用卷积神经网络强化准确率。
小白玩机器学习(6)--- 基于Tensorflow.js的在线手写数字识别相关推荐
- 深蓝学院第三章:基于卷积神经网络(CNN)的手写数字识别实践
参看之前篇章的用全连接神经网络去做手写识别:https://blog.csdn.net/m0_37957160/article/details/114105389?spm=1001.2014.3001 ...
- MATLAB实现数字识别系统,基于人工神经网络的MATLAB手写数字识别系统
<基于人工神经网络的MATLAB手写数字识别系统>由会员分享,可在线阅读,更多相关<基于人工神经网络的MATLAB手写数字识别系统(8页珍藏版)>请在人人文库网上搜索. 1.基 ...
- TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)
TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) 源代码/数据集已上传到 Github - tensorflow-tutorial-samples 大白话讲解卷积 ...
- 基于随机梯度下降法的手写数字识别、epoch是什么、python实现
基于随机梯度下降法的手写数字识别.epoch是什么.python实现 一.普通的随机梯度下降法的手写数字识别 1.1 学习流程 1.2 二层神经网络类 1.3 使用MNIST数据集进行学习 注:关于什 ...
- 基于朴素贝叶斯的手写数字识别
基于朴素贝叶斯的手写数字识别 关于数据集 关于SIMD 关于python 数据预处理 总结 关于数据集 MNIST数据库(http://www.cs.nyu.edu/~roweis/data.html ...
- 机器学习之KNN结合微信机器人实现手写数字识别终极API
机器学习之KNN结合微信机器人实现手写数字识别终极API 手写数字识别 功能概述 实现步骤 结果展示 改进之处和TIPS 手写数字识别 功能概述 微信机器人接收到的手写数字图片,传送给已经经过机器学习 ...
- 基于matlab BP神经网络的手写数字识别
摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...
- tensorflow入门之MINIST手写数字识别
最近在学tensorflow,看了很多资料以及相关视频,有没有大佬推荐一下比较好的教程之类的,谢谢.最后还是到了官方网站去,还好有官方文档中文版,今天就结合官方文档以及之前看的教程写一篇关于MINIS ...
- MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测
Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...
最新文章
- 生信和植物领域最新资讯合集
- 在vmware esx平台创建windows 2003 server群集时无法找到共享磁盘的解决方法
- ASP.NET知识点:母版页的路径问题
- undo系统参数详解
- 重温强化学习之OpenAI经典场景
- 随机数排列JAVA_随机数生成器,按排序顺序
- 03_TF2 Guide、文档清单(数据输入、估计器、保存模型、加速器、性能调优等)、TF2库和扩展库(TensorBoard、数据集、TensorFlow Hub、概率和统计分析库、图像处理库)
- Oracle在开源Mission Control后将其开发团队解散
- leetcode3 无重复字符最长子串
- 深化美国分布式光伏领域合作 苏美达辉伦向美企供应7兆瓦组件
- JAVA开发面试常问问题总结2
- zookeeper编程入门系列之zookeeper实现分布式进程监控和分布式共享锁(图文详解)...
- java static是单例_JAVA基础-static关键字及单例设计模式
- mybatis批量操作
- 面向对象组件开发一个弹窗
- Oracle sql中的正则表达式
- Testbed软件下载安装使用试用
- HTML+CSS静态页面网页设计作业 仿天猫购物商城(7页) 网页设计作业,网页制作作业, 学生网页作业, 网页作业成品, 网页作业模板
- 如何高效对接第三方支付
- 【Servlet入门】一篇文章让你从没听过到了熟于心
热门文章
- mysql下载好压缩包如何安装_Mysql下载压缩包安装及Navicat连接
- vue---H5--获取短信验证码及完整登陆流程
- 【附源码】计算机毕业设计JAVA疫苗药品批量扫码识别追溯系统
- wifi显示请求服务器超时,wifi服务器链接超时怎么回事啊
- 每周值得关注的人工智能头条:Google让我自动做AI,Julia让我学好数学
- flowable 候选人候选组同时使用
- 【JZOJ6379】小w与密码(password)
- 财经365热点:牛股票层出不穷 基金提前“埋伏”获丰收
- java视频转换语音,视频转换成音频方法,avi格式视频怎么转换为MP3格式
- Leetcode 保持城市天际线