tensorflow.js基本使用 图标识别(八)
示例
import $ from 'jquery';
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs, img2x, file2img } from './utils.js';$(async () => {const { inputs, labels } = await getInputs();console.log(inputs, labels);const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 255 } });const NUM_CLASSES = 3;inputs.forEach(imgEl => {surface.drawArea.appendChild(imgEl);});const MOBILENET_MODEL_PATH = "http://127.0.0.1:8080/mobilenet/web_model/model.json";//加载外部模型const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);mobilenet.summary();const layer = mobilenet.getLayer('conv_pw_13_relu');//截断模型const truncatedMobilenet = tf.model({inputs: mobilenet.inputs,outputs: layer.output});//将处理结果返回给第二个模型const model = tf.sequential();model.add(tf.layers.flatten({//去掉首位图片数目,现在没数据,值为nullinputShape: layer.outputShape.slice(1)}));model.add(tf.layers.dense({units: 10,activation: 'relu'}));model.add(tf.layers.dense({units: NUM_CLASSES,activation: 'softmax'}));//设置损失函数,优化器model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });//把输入数据输入给截断模型const { xs, ys } = tf.tidy(() => {const xs = tf.concat(inputs.map((item) => truncatedMobilenet.predict(img2x(item))));const ys = tf.tensor(labels);return { xs, ys };});model.fit(xs, ys, {epochs: 20,callbacks: tfvis.show.fitCallbacks({ name: '训练效果' },['loss'],{ callbacks: ['onEpochEnd'] })});const BRAND_CLASSES = ['android', 'apple', 'windows'];window.predict = async (file) => {const img = await file2img(file);document.body.appendChild(img);const pred = tf.tidy(() => {const x = img2x(img);const input = truncatedMobilenet.predict(x);return model.predict(input);});const index = pred.argMax(1).dataSync()[0];setTimeout(() => {alert(`预测结果:${BRAND_CLASSES[index]}`);}, 0);};window.download = async () => {await model.save('downloads://model');}
});
html 部分
<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>Document</title>
</head>
<body><div>图标识别</div><input type="file" onchange="predict(this.files[0])"><button onclick="download()">保存模型</button>
</body>
<script src="./t7.js"></script>
</html>
util.js
import * as tf from '@tensorflow/tfjs';//载入测试图片的方法↓↓↓↓↓↓↓↓↓↓↓
const IMAGE_SIZE = 224;const loadImg = (src) => {return new Promise(resolve => {const img = new Image();img.crossOrigin = "anonymous";img.src = src;img.width = IMAGE_SIZE;img.height = IMAGE_SIZE;img.onload = () => resolve(img);});
};export const getInputs = async () => {const loadImgs = [];const labels = [];for (let i = 0; i < 30; i += 1) {['android', 'apple', 'windows'].forEach(label => {const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;const img = loadImg(src);loadImgs.push(img);labels.push([label === 'android' ? 1 : 0,label === 'apple' ? 1 : 0,label === 'windows' ? 1 : 0,]);});}const inputs = await Promise.all(loadImgs);return {inputs,labels,};
}
//载入测试图片的方法↑↑↑↑↑↑↑↑↑↑↑//图片格式转换↓↓↓↓↓↓↓↓↓↓↓
export function img2x(imgEl) {return tf.tidy(() => {const input = tf.browser.fromPixels(imgEl).toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);return input;});
}export function file2img(f) {return new Promise(resolve => {const reader = new FileReader();reader.readAsDataURL(f);reader.onload = (e) => {const img = document.createElement('img');img.src = e.target.result;img.width = 224;img.height = 224;img.onload = () => resolve(img);};});
}
//图片格式转换↑↑↑↑↑↑↑↑↑↑↑
执行结果
tensorflow.js基本使用 图标识别(八)相关推荐
- TensorFLow.js实现手写体数字识别
先看最终效果: 一.加载MNIST数据集 使用预先准备好的脚本加载MNIST数据集,脚本可在文章末尾的源码里面获取. 为了避免从国外直接下载数据集花费太多时间,所以脚本文件里面已经将地址改成本地的,因 ...
- 在浏览器中进行深度学习:TensorFlow.js (四)用基本模型对MNIST数据进行识别
2019独角兽企业重金招聘Python工程师标准>>> 在了解了TensorflowJS的一些基本模型的后,大家会问,这究竟有什么用呢?我们就用深度学习中被广泛使用的MINST数据集 ...
- 绒毛动物探测器:通过TensorFlow.js中的迁移学习识别浏览器中的自定义对象
目录 起点 MobileNet v1体系结构上的迁移学习 修改模型 训练新模式 运行物体识别 终点线 下一步是什么?我们可以检测到脸部吗? 下载TensorFlowJS-Examples-master ...
- 用tensorflow.js实现浏览器内的手写数字识别
原文 简介 Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预 ...
- Tensorflow.js||使用 CNN 识别手写数字
Tensorflow官方的tesorflow.js实操课程 链接为:link 使用 CNN 识别手写数字 文章目录 使用 CNN 识别手写数字 1. 简介 2. 设置操作 3. 加载数据 4. 定义模 ...
- 利用tensorflow.js在线实现图像要素识别提取
什么是Tensorflow.js? TensorFlow.js是一个开源的基于硬件加速的JavaScript库,用于训练和部署机器学习模型.谷歌推出的第一个基于TensorFlow的前端深度学习框架T ...
- 小白玩机器学习(6)--- 基于Tensorflow.js的在线手写数字识别
一.题目要求 1.三个js文件,分别完成:网络训练以及模型保存.模型加载及准确率测试.在线手写数字识别: 2.模型测试准确率要高于99.3%(尽量): 3.在线手写数字识别需要能够通过鼠标在画布中写入 ...
- TensorFlow.js实现商标识别
在VsCode中利用TensorFlow.js结合迁移学习实现商标识别. 一.加载商标数据并可视化 数据保存在data文件夹下面,需要先在data文件夹下创建一个静态服务器,用于加载图片. http- ...
- 在浏览器中进行深度学习:TensorFlow.js (十二)异常检测算法
2019独角兽企业重金招聘Python工程师标准>>> 异常检测是机器学习领域常见的应用场景,例如金融领域里的信用卡欺诈,企业安全领域里的非法入侵,IT运维里预测设备的维护时间点等. ...
- tensorflow.js基本使用 截断模型、引入外部模型(七)
图标识别 import $ from 'jquery'; import * as tf from '@tensorflow/tfjs'; import { img2x, file2img } from ...
最新文章
- Flutter框架分析(五)-- 动画
- c primer plus第六版电子版_【财经】京东超市PLUS有机联盟:有机品牌提升一站式营销解决方案...
- lua 初接触 --- The first time use Lua for programing
- zabbix 配置mysql_zabbix 配置mysql监控
- Web Hacking 101 中文版 十四、XML 外部实体注入(一)
- 【Flink】Flink 实时去重方案 四种方案 MapState 、SQL方式、HyperLogLog、Bitmap
- win11资源管理器卡顿怎么办 Windows11解决资源管理器卡顿的步骤方法
- C++ msdn 离线版下载地址
- IEEE Access模板caption无法换行,换行后标题不居中解决办法
- 全国短信息中心号码一览
- UVaOJ 12304 2D Geometry 110 in 1!
- 阿里云免费SSL证书申请
- 遍历Lua全局环境变量
- 阿里云携手卫宁健康发布WinCloud智慧医疗云联合解决方案,打造新一代智慧医疗系统
- 行列式基础知识,重要定理和公式
- Leetcode刷题33. 搜索旋转排序数组
- 多种方式实现动态替换Android默认桌面Launcher
- iOS Safari阅读模式研究
- python爬取360手机助手APP信息
- 期末设计(计划进度表)