示例

import $ from 'jquery';
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs, img2x, file2img } from './utils.js';$(async () => {const { inputs, labels } = await getInputs();console.log(inputs, labels);const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 255 } });const NUM_CLASSES = 3;inputs.forEach(imgEl => {surface.drawArea.appendChild(imgEl);});const MOBILENET_MODEL_PATH = "http://127.0.0.1:8080/mobilenet/web_model/model.json";//加载外部模型const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);mobilenet.summary();const layer = mobilenet.getLayer('conv_pw_13_relu');//截断模型const truncatedMobilenet = tf.model({inputs: mobilenet.inputs,outputs: layer.output});//将处理结果返回给第二个模型const model = tf.sequential();model.add(tf.layers.flatten({//去掉首位图片数目,现在没数据,值为nullinputShape: layer.outputShape.slice(1)}));model.add(tf.layers.dense({units: 10,activation: 'relu'}));model.add(tf.layers.dense({units: NUM_CLASSES,activation: 'softmax'}));//设置损失函数,优化器model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });//把输入数据输入给截断模型const { xs, ys } = tf.tidy(() => {const xs = tf.concat(inputs.map((item) => truncatedMobilenet.predict(img2x(item))));const ys = tf.tensor(labels);return { xs, ys };});model.fit(xs, ys, {epochs: 20,callbacks: tfvis.show.fitCallbacks({ name: '训练效果' },['loss'],{ callbacks: ['onEpochEnd'] })});const BRAND_CLASSES = ['android', 'apple', 'windows'];window.predict = async (file) => {const img = await file2img(file);document.body.appendChild(img);const pred = tf.tidy(() => {const x = img2x(img);const input = truncatedMobilenet.predict(x);return model.predict(input);});const index = pred.argMax(1).dataSync()[0];setTimeout(() => {alert(`预测结果:${BRAND_CLASSES[index]}`);}, 0);};window.download = async () => {await model.save('downloads://model');}
});

html 部分

<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>Document</title>
</head>
<body><div>图标识别</div><input type="file" onchange="predict(this.files[0])"><button onclick="download()">保存模型</button>
</body>
<script src="./t7.js"></script>
</html>

util.js

import * as tf from '@tensorflow/tfjs';//载入测试图片的方法↓↓↓↓↓↓↓↓↓↓↓
const IMAGE_SIZE = 224;const loadImg = (src) => {return new Promise(resolve => {const img = new Image();img.crossOrigin = "anonymous";img.src = src;img.width = IMAGE_SIZE;img.height = IMAGE_SIZE;img.onload = () => resolve(img);});
};export const getInputs = async () => {const loadImgs = [];const labels = [];for (let i = 0; i < 30; i += 1) {['android', 'apple', 'windows'].forEach(label => {const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;const img = loadImg(src);loadImgs.push(img);labels.push([label === 'android' ? 1 : 0,label === 'apple' ? 1 : 0,label === 'windows' ? 1 : 0,]);});}const inputs = await Promise.all(loadImgs);return {inputs,labels,};
}
//载入测试图片的方法↑↑↑↑↑↑↑↑↑↑↑//图片格式转换↓↓↓↓↓↓↓↓↓↓↓
export function img2x(imgEl) {return tf.tidy(() => {const input = tf.browser.fromPixels(imgEl).toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]);return input;});
}export function file2img(f) {return new Promise(resolve => {const reader = new FileReader();reader.readAsDataURL(f);reader.onload = (e) => {const img = document.createElement('img');img.src = e.target.result;img.width = 224;img.height = 224;img.onload = () => resolve(img);};});
}
//图片格式转换↑↑↑↑↑↑↑↑↑↑↑

执行结果

tensorflow.js基本使用 图标识别(八)相关推荐

  1. TensorFLow.js实现手写体数字识别

    先看最终效果: 一.加载MNIST数据集 使用预先准备好的脚本加载MNIST数据集,脚本可在文章末尾的源码里面获取. 为了避免从国外直接下载数据集花费太多时间,所以脚本文件里面已经将地址改成本地的,因 ...

  2. 在浏览器中进行深度学习:TensorFlow.js (四)用基本模型对MNIST数据进行识别

    2019独角兽企业重金招聘Python工程师标准>>> 在了解了TensorflowJS的一些基本模型的后,大家会问,这究竟有什么用呢?我们就用深度学习中被广泛使用的MINST数据集 ...

  3. 绒毛动物探测器:通过TensorFlow.js中的迁移学习识别浏览器中的自定义对象

    目录 起点 MobileNet v1体系结构上的迁移学习 修改模型 训练新模式 运行物体识别 终点线 下一步是什么?我们可以检测到脸部吗? 下载TensorFlowJS-Examples-master ...

  4. 用tensorflow.js实现浏览器内的手写数字识别

    原文 简介 Tensorflow.js是google推出的一个开源的基于JavaScript的机器学习库,相对与基于其他语言的tersorflow库,它的最特别之处就是允许我们直接把模型的训练和数据预 ...

  5. Tensorflow.js||使用 CNN 识别手写数字

    Tensorflow官方的tesorflow.js实操课程 链接为:link 使用 CNN 识别手写数字 文章目录 使用 CNN 识别手写数字 1. 简介 2. 设置操作 3. 加载数据 4. 定义模 ...

  6. 利用tensorflow.js在线实现图像要素识别提取

    什么是Tensorflow.js? TensorFlow.js是一个开源的基于硬件加速的JavaScript库,用于训练和部署机器学习模型.谷歌推出的第一个基于TensorFlow的前端深度学习框架T ...

  7. 小白玩机器学习(6)--- 基于Tensorflow.js的在线手写数字识别

    一.题目要求 1.三个js文件,分别完成:网络训练以及模型保存.模型加载及准确率测试.在线手写数字识别: 2.模型测试准确率要高于99.3%(尽量): 3.在线手写数字识别需要能够通过鼠标在画布中写入 ...

  8. TensorFlow.js实现商标识别

    在VsCode中利用TensorFlow.js结合迁移学习实现商标识别. 一.加载商标数据并可视化 数据保存在data文件夹下面,需要先在data文件夹下创建一个静态服务器,用于加载图片. http- ...

  9. 在浏览器中进行深度学习:TensorFlow.js (十二)异常检测算法

    2019独角兽企业重金招聘Python工程师标准>>> 异常检测是机器学习领域常见的应用场景,例如金融领域里的信用卡欺诈,企业安全领域里的非法入侵,IT运维里预测设备的维护时间点等. ...

  10. tensorflow.js基本使用 截断模型、引入外部模型(七)

    图标识别 import $ from 'jquery'; import * as tf from '@tensorflow/tfjs'; import { img2x, file2img } from ...

最新文章

  1. Flutter框架分析(五)-- 动画
  2. c primer plus第六版电子版_【财经】京东超市PLUS有机联盟:有机品牌提升一站式营销解决方案...
  3. lua 初接触 --- The first time use Lua for programing
  4. zabbix 配置mysql_zabbix 配置mysql监控
  5. Web Hacking 101 中文版 十四、XML 外部实体注入(一)
  6. 【Flink】Flink 实时去重方案 四种方案 MapState 、SQL方式、HyperLogLog、Bitmap
  7. win11资源管理器卡顿怎么办 Windows11解决资源管理器卡顿的步骤方法
  8. C++ msdn 离线版下载地址
  9. IEEE Access模板caption无法换行,换行后标题不居中解决办法
  10. 全国短信息中心号码一览
  11. UVaOJ 12304 2D Geometry 110 in 1!
  12. 阿里云免费SSL证书申请
  13. 遍历Lua全局环境变量
  14. 阿里云携手卫宁健康发布WinCloud智慧医疗云联合解决方案,打造新一代智慧医疗系统
  15. 行列式基础知识,重要定理和公式
  16. Leetcode刷题33. 搜索旋转排序数组
  17. 多种方式实现动态替换Android默认桌面Launcher
  18. iOS Safari阅读模式研究
  19. python爬取360手机助手APP信息
  20. 期末设计(计划进度表)

热门文章

  1. 操作系统测试题(第1,2单元)
  2. windows脚本编写及使用方法
  3. 软件安装(一):VS2017安装和使用
  4. 实验三linux进程并发程序设计,实验三Linux进程并发程序设计.doc
  5. w10 http基本原理 Nginx部署
  6. 海南师范大学本科毕业论文答辩PPT模板
  7. Xcode打包ipa的基本步骤
  8. 大一c语言试题及答案解析,大一c语言期末题及参考答案.doc
  9. https 抓包解密
  10. uniapp-微信小程序直播插件小记