【tensorflow.js学习笔记(1)】tf.js环境搭建及曲线拟合例子
月初TensorFlow开发者大会上,谷歌正式发布了TensorFlow的JS版本tensorflow.js,并演示了几个很有意思的demo,展现了浏览器环境下也能进行深度学习任务的能力。tensorflowjs利用WebGl加速,在浏览器环境下训练、部署机器学习模型。下面我尝试引入tensorflow.js并运行一个曲线拟合的例子。
1、文件形式引入
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.10.0"></script>
<script>const model = tf.sequential();model.add(tf.layers.dense({units: 1, inputShape: [1]}));model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});const x = tf.tensor2d([1, 2, 3, 4], [4, 1]);const y = tf.tensor2d([1, 3, 5, 7], [4, 1]);model.fit(x, y).then(() => {model.predict(tf.tensor2d([5], [1, 1])).print();});
</script>
首先调用tf.sequential()构建模型,损失函数为均方差,优化器为sgd(梯度下降)。待拟合的点序列为(1,1),(2,3),(3,5),(4,7),训练模型,输入x=5。
打开浏览器,输出为:
Tensor[[8.1529675],]
2、使用webpack
npm install @tensorflow/tfjs
首先利用npm安装tensorflow.js(也可用yarn),新建index.js文件,内容如下。
import * as tf from '@tensorflow/tfjs';const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);model.fit(xs, ys).then(() => {model.predict(tf.tensor2d([5], [1, 1])).print();
});
利用webpack可使用import语法引入tf.js。配置webpack.config.js文件如下。
const path = require('path');module.exports={//入口文件的配置项entry:{entry: './index.js'},//出口文件的配置项output:{path: path.resolve(__dirname, 'dist'),filename: 'bundle.js'}
}
运行webpack命令,将在目录下生成dist文件夹。cd进入该文件夹,用node运行bundle.js文件,输出结果。
3、曲线拟合
参考Fitting a Curve to Synthetic Data,这是TensorFlow官方关于曲线拟合的例子,其中使用Vega进行可视化展示。下面我将用Echarts替换Vega进行可视化展示并重写部分程序,代码结构如下所示。
其中index.js为入口文件,dist文件夹下为分发文件,webpack.config.js的内容如下。
const path = require('path');module.exports={mode: 'development',//入口文件的配置项entry:{entry: './src/index.js'},//出口文件的配置项output:{path: path.resolve(__dirname, 'dist'),filename: 'bundle.js'},//控制台报错信息devtool: 'inline-source-map'
}
index.html内容如下。
<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><title>Document</title><style>#chart {width: 800px;height: 800px;}</style>
</head>
<body><div id="chart"></div><script src="bundle.js"></script>
</body>
</html>
入口文件index.js内容如下。
import * as tf from '@tensorflow/tfjs';
var echarts = require('echarts');const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const d = tf.variable(tf.scalar(Math.random()));function predict(x) {return tf.tidy(() => {return a.mul(x.pow(tf.scalar(3, 'int32'))) .add(b.mul(x.square())).add(c.mul(x)).add(d);});
}function loss(prediction, labels) {const error = prediction.sub(labels).square().mean();return error;
}const numIterations = 75;
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);async function train(xs, ys, numIterations) {for (let iter = 0; iter < numIterations; iter++) {optimizer.minimize(() => {const pred = predict(xs);return loss(pred, ys);});await tf.nextFrame();}
}function generateData(numPoints, coeff, sigma = 0.04) {return tf.tidy(() => {const [a, b, c, d] = [tf.scalar(coeff.a),tf.scalar(coeff.b),tf.scalar(coeff.c),tf.scalar(coeff.d)];const xs = tf.randomUniform([numPoints], -1, 1);const ys = a.mul(xs.pow(tf.scalar(3, 'int32'))).add(b.mul(xs.square())).add(c.mul(xs)).add(d).add(tf.randomNormal([numPoints], 0, sigma));const ymin = ys.min();const ymax = ys.max();const yrange = ymax.sub(ymin);const ysNormalized = ys.sub(ymin).div(yrange);return {xs,ys: ysNormalized};})
}async function plotData(xs, ys, preds) {const xvals = await xs.data();const yvals = await ys.data();const predVals = await preds.data();const valuesBefore = Array.from(xvals).map((x, i) => {return [xvals[i], yvals[i]];});const valuesAfter= Array.from(xvals).map((x, i) => {return [xvals[i], predVals[i]];});// 二维数组排序valuesAfter.sort(function(x, y) {return x[0] - y[0];});curveChart.setOption({xAxis: {min: -1,max: 1},yAxis: {min: 0,max: 1},series: [{symbolSize: 12,data: valuesBefore,type: 'scatter'},{data: valuesAfter,encode: {x: 0,y: 1},type: 'line'}]});
}async function learnCoefficients() {const trueCoefficients = {a: -0.8, b: -0.2, c: 0.9, d: 0.5};// 生成有误差的训练数据const trainingData = generateData(100, trueCoefficients);// 训练模型await train(trainingData.xs, trainingData.ys, numIterations);// 预测数据const predictionsAfter = predict(trainingData.xs);// 绘制散点图及拟合曲线await plotData(trainingData.xs, trainingData.ys, predictionsAfter);predictionsAfter.dispose();
}const curveChart = echarts.init(document.getElementById('chart'));
learnCoefficients();
首先引入tensorflow.js及echarts,之后定义4个参数a、b、c、d,分别是待拟合曲线y=a*x^3+b*x^2+c*x+d的四个参数,初始设为随机值。
定义函数predict,传入x,返回拟合后的估计值y。函数loss为损失函数,这里定义loss为均方差。
定义优化器optimizer,其中学习率为0.5,学习率过小会导致训练速度慢,学习率过高会造成拟合参数在最优解附近“左右摇摆”。
定义async函数(Generator 函数的语法糖)train,train函数内根据迭代步数及学习率调用优化器并计算损失函数loss。
函数generateData随机生成[-1, 1]范围内的点,并根据传入的a、b、c、d加上一定的随机扰动生成数据点xs,ys,其中ys进行归一化处理。
函数plotData将随机生成的样本点映射为散点图,将根据训练后的参数拟合出的点映射为曲线。
函数renderCoefficients将a、b、c、d的值输出到document内。
函数learnCoefficients是index.js的main函数,函数内先设定预定义的a、b、c、d,再生成有误差的训练数据,利用训练数据训练a、b、c、d参数并参数输出到文档,之后利用训练好的参数拟合x数据,将结果绘制为散点图及曲线,最后通知GC清理。
此时拟合出的曲线图会有bug,如下所示。
原因分析:
传入echarts的点对是按生成顺序排序的,是无序数组,但绘制曲线时是按传入数组的顺序连接各点,因此在传入前需对二维数据进行排序。在curveChart.setOption前加入如下代码。
// 二维数组排序
valuesAfter.sort(function(x, y) {return x[0] - y[0];
});
结果如下。
完整程序见我的github,具体步骤为:
step1 新建文件夹,cmd输入git clone git@github.com:orangecsy/tfjs-exercise.git,cd 1进入文件夹1;
step2 cmd输入webpack,打包;
step3 cd dist进入dist文件夹,cmd中输入http-server(需先npm install http-server)或使用webpack配置开发服务器;
step4 浏览器中输入http://127.0.0.1:8080/,即为结果。
【tensorflow.js学习笔记(1)】tf.js环境搭建及曲线拟合例子相关推荐
- JS学习笔记六:js中的DOM操作
1. JS学习笔记六:js中的DOM操作 文章目录 1. JS学习笔记六:js中的DOM操作 1.1. 获取Dom节点 1.2. 元素属性的操作方式 1.3. DOM节点的创建.插入和删除 1.4. ...
- nginx学习笔记-01nginx入门,环境搭建,常见命令
nginx学习笔记-01nginx入门,环境搭建,常见命令 文章目录 nginx学习笔记-01nginx入门,环境搭建,常见命令 1.nginx的基本概念 2.nginx的安装,常用命令和配置文件 3 ...
- 华芯微特SWM181学习笔记--GPIO应用与环境搭建
华芯微特SWM181 系列 32 位 MCU(以下简称 SWM181)内嵌 ARM® CortexTM-M0 内核, SWM181 支持片上包含精度为 1%以内的 24MHz.48MHz 时钟,并提供 ...
- 迪文屏幕T5L平台学习笔记一:开发环境搭建注意事项
前面一直用T5UID3平台的屏幕开发,但是吐槽下<DWIN C Compiler 1>编译器bug太多,项目能不能做好,全靠运气:售后说T5L平台支持keil开发,我感觉挺好,于是从新学习 ...
- Go+Wails学习笔记(一)环境搭建与配置
前言 Go,又称Golang,是谷歌在21世纪开发的一种新的编程语言,它静态强类型.从语言层面支持并发(Goroutine).支持垃圾回收GC. Go语言有一些笔者很喜欢的特点,譬如跨平台.交叉编译( ...
- ReactNative学习笔记(一)环境搭建
前言 本文开发环境为Windows,目标平台为Android,react-native版本为0.35.0. 环境搭建 注意,本文不是按照官网的教程来的,官网说必须安装什么Chocolatey,我懒得鸟 ...
- Python学习笔记:Day1-2 开发环境搭建
前言 最近在学习深度学习,已经跑出了几个模型,但Pyhton的基础不够扎实,因此,开始补习Python了,大家都推荐廖雪峰的课程,因此,开始了学习,但光学有没有用,还要和大家讨论一下,因此,写下这些帖 ...
- 视觉SLAM十四讲学习笔记-第二讲-开发环境搭建
专栏系列文章如下: 视觉SLAM十四讲学习笔记-第一讲_goldqiu的博客-CSDN博客 视觉SLAM十四讲学习笔记-第二讲-初识SLAM_goldqiu的博客-CSDN博客 lin ...
- STM32上手-STWingSKIT_BC28学习笔记(一)环境搭建和LED灯点亮
嵌入式STM32上手学习笔记(一)LED灯点亮 STM32开发环境的搭建 1. 安装keil5 IDE 2. 下载STM32F1的支持包 3. 在Pack installer中找到F1支持包下载 4. ...
最新文章
- 对接kafka_flume对接kafka多路径同时收集日志,配置怎么写?
- 对勾选的下拉选择进行同步选择
- 半监督学习下的高维图构建
- CTFshow php特性 web96
- python 装饰器有哪些_python之装饰器
- LeetCode 829. 连续整数求和(数学)
- android按钮点击无响应时间,AndroidStudio下的点击事件不响应
- 现代ups电源及电路图集_2020山特UPS电源自动开机200KVA实力
- hadoop+hbase安装
- 【单片机】简单的时钟代码
- 一台变两台,电脑也分身
- 深度linux系统老版本,Deepin Linux15.7下载
- CreateProcess并隐藏窗口
- golang 爆破破解 rar5 压缩文件密码
- ZeroC Ice Hello World
- 【2022牛客多校5 A题 Don‘t Starve】DP
- windows服务器详细安全设置
- android装windows bios,普通安卓平板刷win10图文教程
- 华为智慧屏 鸿蒙,精挑细选的高品质大屏,新一代华为智慧屏V系列不要错过
- From C++ to Objective-C
热门文章
- oracle中nvarchar,SQL中的Nvarchar在oracle中用作varchar2
- 阿里云服务器绑定域名 搭建环境 到 部署项目
- 精彩回顾 | 一张图读懂OPPO应用与数据安全防护
- mysql -%3e卡在_华为nova 3e手机卡怎么办?五个技巧帮你缓解手机卡顿烦恼!
- 易周金融分析 | Q2手机银行活跃用户环比增长2.17%
- 实现获取阿里云STS上传token
- 骨传导蓝牙耳机哪个牌子好?五款骨传导耳机推荐
- 家用路由器网段互通的问题
- imac mysql导入sql_MAMP
- 年薪50w的3D建模师告诉你,互联网行业什么职业最赚钱?