使用迁移学习和TensorFlow.js在浏览器中进行AI情感检测
目录
KNN分类器
迁移学习
我们的技术栈
配置
使用KNN分类器
将代码放在一起
测试结果
下一步是什么?
- 下载源-10.6 MB
在上一篇文章中,我们已经看到了加载预训练模型有多么容易。在本文中,我们将使用迁移学习(Transfer Learning)扩展预训练模型。我们将使用自己的训练集在模型上建立模型,并使用K最近邻(KNN)模块将面部表情的图像分类为脾气暴躁或中性。
在深入研究任何代码之前,让我们快速讨论一下KNN和迁移学习的工作原理。
KNN分类器
KNN算法是一种简单、易于实现的有监督的机器学习算法,可用于解决分类以及回归预测问题。
该算法假定相似的事物彼此靠近存在。对于一般理解,红色的阴影比黄色或黑色之类的任何其他颜色更相似。KNN使用相同的相似性思想,并通过使用距离函数(即余弦,汉明)将其与预先分类的案例进行比较,从而对新案例进行分类。然后,它为K个最接近的案例中最常见的新案例或所谓的“最近邻居”选择类别。
TensorFlow.js的KNN分类器提供了使用相同算法创建分类器的实用程序。这里要注意的一件事是,它不提供模型,而是提供了一种用于构造KNN模型并使用来自另一个模型或张量的激活的实用程序。您可以在此处了解更多信息。
迁移学习
迁移学习是一种机器学习技术,可让您重用针对特定任务开发的模型作为其他任务模型的起点或基础。
迁移学习在深度学习中特别流行,在深度学习中,您可以使用预训练的模型作为计算机视觉任务的起点。由于开发用于这些平台的神经网络需要大量的计算资源和时间,因此迁移学习非常有用,可以显着提高整个系统的性能。
我们的技术栈
对于此示例,我们将使用以下技术堆栈:
- TensorFlow.js ——一种机器学习框架,使在网络上的客户端进行机器学习成为可能。
- MobileNet模型——用于图像分类的经过预先训练的TensorFlow.js模型。
- KNN分类器——基本的TensorFlow.js分类器,可用于自定义图像分类。
您可以根据需要使用其他技术堆栈,例如React或Angular。也可以随意扩展示例。
配置
让我们从导入所需的模型开始:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
我们需要做的下一件事是定义一个具有特定宽度和高度的canvas元素:
<canvas width="224" height="224"></canvas>
这是因为已经在相同特定尺寸的图像上训练了分类器。我们使用相同的大小来匹配数据格式,因此在将图像输入分类器之前不必调整图像大小。
由于我们正在构建一个分类器,将人脸的图像分类为具有脾气暴躁或中性的表情,因此我们创建了“脾气暴躁”和“中性”按钮以手动对图像进行分类并将其添加到我们的训练数据中,并创建“预测”按钮以预测图像的分类:
<button class="grumpy">Grumpy</button>
<button class="neutral">Neutral</button>
<button class="predict">Predict</button>
现在,我们将事件侦听器附加到按钮:
const grumpy = document.querySelector('.grumpy');
const neutral = document.querySelector('.neutral');grumpy.addEventListener('click', () => addExamples('grumpy'));
neutral.addEventListener('click', () => addExamples('neutral'));document.querySelector('.predict').addEventListener('click', predict);
为了使其简单易用,我们将使画布通过拖放来接受图像:
const canvas = document.querySelector("canvas");
const context = canvas.getContext("2d");
canvas.addEventListener('dragover', e => e.preventDefault(), false);
canvas.addEventListener('drop', onImageDrop, false);
我们需要的最后一件事是处理丢弃文件的功能:
const onImageDrop = e => {e.preventDefault();const imageFile = e.dataTransfer.files[0];const imageReader = new FileReader();imageReader.onload = imageFile => {const image = new Image();image.onload = () => {context.drawImage(image, 0, 0, 224, 224);};image.src = imageFile.target.result;};imageReader.readAsDataURL(imageFile);};
一切就绪后,这就是我们的HTML文档的外观:
<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8" /><title>Image classification with Tensorflow.js</title><script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script><script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script><script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script></head><body><h1>Custom Image Classifier using Tensorflow.js</h1><canvas style=" border: 2px dashed #34495e; margin: auto;" width="224" height="224"></canvas><h3>Train classifier with examples</h3><button class="grumpy">Grumpy</button><button class="neutral">Neutral</button><button class="predict">Predict</button><script src="knnClassifier.js"></script><script>const canvas = document.querySelector("canvas");const context = canvas.getContext("2d");const grumpy = document.querySelector('.grumpy');const neutral = document.querySelector('.neutral');const onImageDrop = e => {e.preventDefault();const imageFile = e.dataTransfer.files[0];const imageReader = new FileReader();imageReader.onload = imageFile => {const image = new Image();image.onload = () => {context.drawImage(image, 0, 0, 224, 224);};image.src = imageFile.target.result;};imageReader.readAsDataURL(imageFile);};canvas.addEventListener('dragover', e => e.preventDefault(), false);canvas.addEventListener('drop', onImageDrop, false);grumpy.addEventListener('click', () => addExamples('grumpy'));neutral.addEventListener('click', () => addExamples('neutral'));document.querySelector('.predict').addEventListener('click', predict);</script></body>
</html>
您可能已经注意到我们也在使用knnClassifier.js文件。该文件将包含创建分类器,加载模型和处理预测的功能。让我们首先创建KNN分类器并加载MobileNet模型。
const loadKnnClassifier = async () => {knn = knnClassifier.create();console.log("Model is Loading...")model = await mobilenet.load();console.log("Model Loaded successfully!")
};
使用KNN分类器
如前所述,我们需要在自定义图像上训练分类器。KNN分类器的addExample方法带有两个参数:
- example ——通常是从另一个模型激活以将示例添加到数据集。
- label ——示例的类名。
这是我们添加到训练数据中的功能:
const addExamples = label => {const img = tf.browser.fromPixels(canvas);const attribute = model.infer(img, 'conv_preds');knn.addExample(attribute, label);context.clearRect(0, 0, canvas.width, canvas.height);if(label === 'grumpy'){grumpy.innerText = `Grumpy (${++trainingDataSets[0]})`}else {neutral.innerText = `Neutral (${++trainingDataSets[1]})`}console.log(`Trained classifier with ${label}`)img.dispose();
};
最后但并非最不重要的是我们的预测功能:
const predict = async () => {if (knn.getNumClasses() > 0) {const img = tf.browser.fromPixels(canvas);const attribute = model.infer(img, 'conv_preds');const prediction = await knn.predictClass(attribute);context.clearRect(0, 0, canvas.width, canvas.height);console.log(`Prediction: ${prediction.label}`)img.dispose();}
};
将代码放在一起
我们的代码的最终外观如下:
let knn;
let model;let trainingDataSets = [0, 0];const loadKnnClassifier = async () => {knn = knnClassifier.create();console.log("Model is Loading...")model = await mobilenet.load();console.log("Model Loaded successfully!")
};const addExamples = label => {const img = tf.browser.fromPixels(canvas);const attribute = model.infer(img, 'conv_preds');knn.addExample(attribute, label);context.clearRect(0, 0, canvas.width, canvas.height);if(label === 'grumpy'){grumpy.innerText = `Grumpy (${++trainingDataSets[0]})`}else {neutral.innerText = `Neutral (${++trainingDataSets[1]})`}console.log(`Trained classifier with ${label}`)img.dispose();
};const predict = async () => {if (knn.getNumClasses() > 0) {const img = tf.browser.fromPixels(canvas);const attribute = model.infer(img, 'conv_preds');const prediction = await knn.predictClass(attribute);context.clearRect(0, 0, canvas.width, canvas.height);console.log(`Prediction: ${prediction.label}`)img.dispose();}
};loadKnnClassifier();
测试结果
在浏览器中打开HTML文档,然后将图像文件拖放到画布上,然后单击“脾气暴躁”或“中性”按钮将其分类。
用几幅图像训练分类器后,请拖动另一幅图像,然后单击“预测”按钮以获取预测。
最终的控制台输出应类似于以下内容:
下一步是什么?
在本文中,我们借助使用迁移学习的KNN分类器扩展了预训练的MobileNet模型。我们训练了一个自定义分类器,将图像文件中的人类表情分类为脾气暴躁或中性。我们在浏览器中完成了所有操作,但是我们使用静态图像来训练我们的模型。如果我们对实时自定义分类感兴趣怎么办?
请继续阅读本系列的下一篇文章,我们将扩展模型以使用网络摄像头实时进行自定义分类。
使用迁移学习和TensorFlow.js在浏览器中进行AI情感检测相关推荐
- 使用face-api和Tensorflow.js在浏览器中进行AI年龄估计
目录 性别和年龄检测 下一步是什么? 下载源-10.6 MB 在上一篇文章中,我们学习了如何使用face-api.js和Tensorflow.js在浏览器中对人的情绪进行分类. 如果您尚未阅读该文章, ...
- 使用TensorFlow.js在浏览器中进行深度学习入门
目录 设置TensorFlow.js 创建训练数据 检查点 定义神经网络模型 训练AI 测试结果 终点线 内存使用注意事项 下一步是什么?狗和披萨? 下载TensorFlowJS示例-6.1 MB T ...
- 图像迁移风格保存模型_用TensorFlow.js在浏览器中部署可进行任意图像风格迁移的模型...
风格迁移一直是很多读者感兴趣的内容之一,近日,网友ReiichiroNakano公开了自己的一个实现:用TensorFlow.js在浏览器中部署可进行任意图像风格迁移的模型.让我们一起去看看吧! Gi ...
- 狗和披萨:使用TensorFlow.js在浏览器中实现计算机视觉
目录 起点 托管说明 MobileNet v1 运行物体识别 终点线 下一步是什么?绒毛动物? 下载TensorFlowJS示例-6.1 MB TensorFlow + JavaScript.现在,最 ...
- 使用 Colab 在 tf.keras 中训练模型,并使用 TensorFlow.js 在浏览器中运行
文 / Zaid Alyafeai 我们将创建一个简单的工具来识别图纸并输出当前图纸的名称. 此应用程序将直接在浏览器上运行,无需任何安装.我们会使用 Google Colab 来训练模型,并使用 T ...
- 用 TensorFlow.js 在浏览器中训练一个计算机视觉模型(手写数字分类器)
文章目录 Building a CNN in JavaScript Using Callbacks for Visualization Training with the MNIST Dataset ...
- 有了TensorFlow.js,浏览器中也可以实时人体姿势估计
翻译文章,内容有删减.原文地址:https://medium.com/tensorflow/real-time-human-pose-estimation-in-the-browser-with-te ...
- 用TensorFlow.js在浏览器中进行实时语义分割 | MixLab算法系列
语义分割是监测和描绘图像中每个感兴趣对象的问题 当前,有几种方法可以解决此问题并输出结果 如下图示: 语义分割示例 这种分割是对图像中的每个像素进行预测,也称为密集预测. 十分重要且要注意的是,同一类 ...
- 使用face-api和Tensorflow.js进行预训练的AI情绪检测
目录 设置服务器 设置HTML 获取实时视频源 使用face-api.js进行预测 将代码放在一起 测试结果 下一步是什么? 下载源-10.6 MB 面部表情识别是图像识别中关注的关键领域之一,一直都 ...
最新文章
- [LeetCode 120] - 三角形(Triangle)
- PHP实现文件下载断点续传详解
- 状态压缩DP AcWing算法提高课 (详解)
- 总的来讲safari上面的research gate和canvas会出现奇奇怪怪的bug,但是chrome没问题
- 变量声明和函数声明的意义详解
- 产品文档如何说清楚产品业务?关注这几点就够了
- 一个程序员转产品经理的经验分享
- android+动画+锯齿,Android_rotate--animation 动画旋转两图片,消除动画锯齿现象 android 开发:动画旋转两图片 - 下载 - 搜珍网...
- 怎样免费将Word导出为PDF格式?
- mysql网站倒计时代码_最简单的一个网页倒计时代码 时间到期后会显示出提醒内容 收藏版...
- 如何一键制作DTS Audio DVD、AC3 Audio DVD、WAV Audio DVD纯音乐碟片
- 白胡子不杀黑胡子的真正原因
- 企鹅号转正后是2级账号还有用吗,企鹅号不被系统推荐怎么办
- 支持Genero BDL 4gl语言的编辑器
- 有多少“垃圾”App藏在你的手机里?
- Virtual Box与win10系统不兼容问题
- Python中汉字繁简体互转
- 华为云大数据BI 为中小型企业智慧运营保驾护航
- 王阳明:越是艰难时,越要知行合一[附疫情生存哲学]
- 免费计算机网络同传系统,ghost网络同传系统.doc
热门文章
- 唯有自己变得强大_只有自己变得强大,才够让你的人生一帆风顺
- 防qq页面多边形html5,高仿QQ Xplan的H5页面
- 谭浩强c语言入门_计算机学生为什么学不会C语言?看到这4点原因,学生表示太真实...
- oracle一列有多个约束,在oracle中创建unique唯一约束(单列和多列)
- jquery 毫秒转换成日期_jquery js 秒 毫秒转时分秒
- 让用户感到体贴登录页设计灵感
- 设计灵感|网页建议页面(联系页面)版式案例
- 渐变海报背景素材|潮流2021还将延续
- 设计导航网站|解决寻找合适的字体麻烦
- nuxt webpack配置css,vuecli或nuxt用Webpack的优雅ProgressBar(webpackBar)