月初TensorFlow开发者大会上,谷歌正式发布了TensorFlow的JS版本tensorflow.js,并演示了几个很有意思的demo,展现了浏览器环境下也能进行深度学习任务的能力。tensorflowjs利用WebGl加速,在浏览器环境下训练、部署机器学习模型。下面我尝试引入tensorflow.js并运行一个曲线拟合的例子。

1、文件形式引入

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.10.0"></script>
<script>const model = tf.sequential();model.add(tf.layers.dense({units: 1, inputShape: [1]}));model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});const x = tf.tensor2d([1, 2, 3, 4], [4, 1]);const y = tf.tensor2d([1, 3, 5, 7], [4, 1]);model.fit(x, y).then(() => {model.predict(tf.tensor2d([5], [1, 1])).print();});
</script>

首先调用tf.sequential()构建模型,损失函数为均方差,优化器为sgd(梯度下降)。待拟合的点序列为(1,1),(2,3),(3,5),(4,7),训练模型,输入x=5。

打开浏览器,输出为:

Tensor[[8.1529675],]

2、使用webpack

npm install @tensorflow/tfjs

首先利用npm安装tensorflow.js(也可用yarn),新建index.js文件,内容如下。

import * as tf from '@tensorflow/tfjs';const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);model.fit(xs, ys).then(() => {model.predict(tf.tensor2d([5], [1, 1])).print();
});

利用webpack可使用import语法引入tf.js。配置webpack.config.js文件如下。

const path = require('path');module.exports={//入口文件的配置项entry:{entry: './index.js'},//出口文件的配置项output:{path: path.resolve(__dirname, 'dist'),filename: 'bundle.js'}
}

运行webpack命令,将在目录下生成dist文件夹。cd进入该文件夹,用node运行bundle.js文件,输出结果。

3、曲线拟合

参考Fitting a Curve to Synthetic Data,这是TensorFlow官方关于曲线拟合的例子,其中使用Vega进行可视化展示。下面我将用Echarts替换Vega进行可视化展示并重写部分程序,代码结构如下所示。

其中index.js为入口文件,dist文件夹下为分发文件,webpack.config.js的内容如下。

const path = require('path');module.exports={mode: 'development',//入口文件的配置项entry:{entry: './src/index.js'},//出口文件的配置项output:{path: path.resolve(__dirname, 'dist'),filename: 'bundle.js'},//控制台报错信息devtool: 'inline-source-map'
}

index.html内容如下。

<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><title>Document</title><style>#chart {width: 800px;height: 800px;}</style>
</head>
<body><div id="chart"></div><script src="bundle.js"></script>
</body>
</html>

入口文件index.js内容如下。

import * as tf from '@tensorflow/tfjs';
var echarts = require('echarts');const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const d = tf.variable(tf.scalar(Math.random()));function predict(x) {return tf.tidy(() => {return a.mul(x.pow(tf.scalar(3, 'int32'))) .add(b.mul(x.square())).add(c.mul(x)).add(d);});
}function loss(prediction, labels) {const error = prediction.sub(labels).square().mean();return error;
}const numIterations = 75;
const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);async function train(xs, ys, numIterations) {for (let iter = 0; iter < numIterations; iter++) {optimizer.minimize(() => {const pred = predict(xs);return loss(pred, ys);});await tf.nextFrame();}
}function generateData(numPoints, coeff, sigma = 0.04) {return tf.tidy(() => {const [a, b, c, d] = [tf.scalar(coeff.a),tf.scalar(coeff.b),tf.scalar(coeff.c),tf.scalar(coeff.d)];const xs = tf.randomUniform([numPoints], -1, 1);const ys = a.mul(xs.pow(tf.scalar(3, 'int32'))).add(b.mul(xs.square())).add(c.mul(xs)).add(d).add(tf.randomNormal([numPoints], 0, sigma));const ymin = ys.min();const ymax = ys.max();const yrange = ymax.sub(ymin);const ysNormalized = ys.sub(ymin).div(yrange);return {xs,ys: ysNormalized};})
}async function plotData(xs, ys, preds) {const xvals = await xs.data();const yvals = await ys.data();const predVals = await preds.data();const valuesBefore = Array.from(xvals).map((x, i) => {return [xvals[i], yvals[i]];});const valuesAfter= Array.from(xvals).map((x, i) => {return [xvals[i], predVals[i]];});// 二维数组排序valuesAfter.sort(function(x, y) {return x[0] - y[0];});curveChart.setOption({xAxis: {min: -1,max: 1},yAxis: {min: 0,max: 1},series: [{symbolSize: 12,data: valuesBefore,type: 'scatter'},{data: valuesAfter,encode: {x: 0,y: 1},type: 'line'}]});
}async function learnCoefficients() {const trueCoefficients = {a: -0.8, b: -0.2, c: 0.9, d: 0.5};// 生成有误差的训练数据const trainingData = generateData(100, trueCoefficients);// 训练模型await train(trainingData.xs, trainingData.ys, numIterations);// 预测数据const predictionsAfter = predict(trainingData.xs);// 绘制散点图及拟合曲线await plotData(trainingData.xs, trainingData.ys, predictionsAfter);predictionsAfter.dispose();
}const curveChart = echarts.init(document.getElementById('chart'));
learnCoefficients();

首先引入tensorflow.js及echarts,之后定义4个参数a、b、c、d,分别是待拟合曲线y=a*x^3+b*x^2+c*x+d的四个参数,初始设为随机值。

定义函数predict,传入x,返回拟合后的估计值y。函数loss为损失函数,这里定义loss为均方差。

定义优化器optimizer,其中学习率为0.5,学习率过小会导致训练速度慢,学习率过高会造成拟合参数在最优解附近“左右摇摆”。

定义async函数(Generator 函数的语法糖)train,train函数内根据迭代步数及学习率调用优化器并计算损失函数loss。

函数generateData随机生成[-1, 1]范围内的点,并根据传入的a、b、c、d加上一定的随机扰动生成数据点xs,ys,其中ys进行归一化处理。

函数plotData将随机生成的样本点映射为散点图,将根据训练后的参数拟合出的点映射为曲线。

函数renderCoefficients将a、b、c、d的值输出到document内。

函数learnCoefficients是index.js的main函数,函数内先设定预定义的a、b、c、d,再生成有误差的训练数据,利用训练数据训练a、b、c、d参数并参数输出到文档,之后利用训练好的参数拟合x数据,将结果绘制为散点图及曲线,最后通知GC清理。

此时拟合出的曲线图会有bug,如下所示。

原因分析:

传入echarts的点对是按生成顺序排序的,是无序数组,但绘制曲线时是按传入数组的顺序连接各点,因此在传入前需对二维数据进行排序。在curveChart.setOption前加入如下代码。

// 二维数组排序
valuesAfter.sort(function(x, y) {return x[0] - y[0];
});

结果如下。

完整程序见我的github,具体步骤为:

step1 新建文件夹,cmd输入git clone git@github.com:orangecsy/tfjs-exercise.git,cd 1进入文件夹1;

step2 cmd输入webpack,打包;

step3 cd dist进入dist文件夹,cmd中输入http-server(需先npm install http-server)或使用webpack配置开发服务器;

step4 浏览器中输入http://127.0.0.1:8080/,即为结果。

【tensorflow.js学习笔记(1)】tf.js环境搭建及曲线拟合例子相关推荐

  1. JS学习笔记六:js中的DOM操作

    1. JS学习笔记六:js中的DOM操作 文章目录 1. JS学习笔记六:js中的DOM操作 1.1. 获取Dom节点 1.2. 元素属性的操作方式 1.3. DOM节点的创建.插入和删除 1.4. ...

  2. nginx学习笔记-01nginx入门,环境搭建,常见命令

    nginx学习笔记-01nginx入门,环境搭建,常见命令 文章目录 nginx学习笔记-01nginx入门,环境搭建,常见命令 1.nginx的基本概念 2.nginx的安装,常用命令和配置文件 3 ...

  3. 华芯微特SWM181学习笔记--GPIO应用与环境搭建

    华芯微特SWM181 系列 32 位 MCU(以下简称 SWM181)内嵌 ARM® CortexTM-M0 内核, SWM181 支持片上包含精度为 1%以内的 24MHz.48MHz 时钟,并提供 ...

  4. 迪文屏幕T5L平台学习笔记一:开发环境搭建注意事项

    前面一直用T5UID3平台的屏幕开发,但是吐槽下<DWIN C Compiler 1>编译器bug太多,项目能不能做好,全靠运气:售后说T5L平台支持keil开发,我感觉挺好,于是从新学习 ...

  5. Go+Wails学习笔记(一)环境搭建与配置

    前言 Go,又称Golang,是谷歌在21世纪开发的一种新的编程语言,它静态强类型.从语言层面支持并发(Goroutine).支持垃圾回收GC. Go语言有一些笔者很喜欢的特点,譬如跨平台.交叉编译( ...

  6. ReactNative学习笔记(一)环境搭建

    前言 本文开发环境为Windows,目标平台为Android,react-native版本为0.35.0. 环境搭建 注意,本文不是按照官网的教程来的,官网说必须安装什么Chocolatey,我懒得鸟 ...

  7. Python学习笔记:Day1-2 开发环境搭建

    前言 最近在学习深度学习,已经跑出了几个模型,但Pyhton的基础不够扎实,因此,开始补习Python了,大家都推荐廖雪峰的课程,因此,开始了学习,但光学有没有用,还要和大家讨论一下,因此,写下这些帖 ...

  8. 视觉SLAM十四讲学习笔记-第二讲-开发环境搭建

    专栏系列文章如下: 视觉SLAM十四讲学习笔记-第一讲_goldqiu的博客-CSDN博客 视觉SLAM十四讲学习笔记-第二讲-初识SLAM_goldqiu的博客-CSDN博客 ​​​​​​​ lin ...

  9. STM32上手-STWingSKIT_BC28学习笔记(一)环境搭建和LED灯点亮

    嵌入式STM32上手学习笔记(一)LED灯点亮 STM32开发环境的搭建 1. 安装keil5 IDE 2. 下载STM32F1的支持包 3. 在Pack installer中找到F1支持包下载 4. ...

最新文章

  1. 对接kafka_flume对接kafka多路径同时收集日志,配置怎么写?
  2. 对勾选的下拉选择进行同步选择
  3. 半监督学习下的高维图构建
  4. CTFshow php特性 web96
  5. python 装饰器有哪些_python之装饰器
  6. LeetCode 829. 连续整数求和(数学)
  7. android按钮点击无响应时间,AndroidStudio下的点击事件不响应
  8. 现代ups电源及电路图集_2020山特UPS电源自动开机200KVA实力
  9. hadoop+hbase安装
  10. 【单片机】简单的时钟代码
  11. 一台变两台,电脑也分身
  12. 深度linux系统老版本,Deepin Linux15.7下载
  13. CreateProcess并隐藏窗口
  14. golang 爆破破解 rar5 压缩文件密码
  15. ZeroC Ice Hello World
  16. 【2022牛客多校5 A题 Don‘t Starve】DP
  17. windows服务器详细安全设置
  18. android装windows bios,普通安卓平板刷win10图文教程
  19. 华为智慧屏 鸿蒙,精挑细选的高品质大屏,新一代华为智慧屏V系列不要错过
  20. From C++ to Objective-C

热门文章

  1. oracle中nvarchar,SQL中的Nvarchar在oracle中用作varchar2
  2. 阿里云服务器绑定域名 搭建环境 到 部署项目
  3. 精彩回顾 | 一张图读懂OPPO应用与数据安全防护
  4. mysql -%3e卡在_华为nova 3e手机卡怎么办?五个技巧帮你缓解手机卡顿烦恼!
  5. 易周金融分析 | Q2手机银行活跃用户环比增长2.17%
  6. 实现获取阿里云STS上传token
  7. 骨传导蓝牙耳机哪个牌子好?五款骨传导耳机推荐
  8. 家用路由器网段互通的问题
  9. imac mysql导入sql_MAMP
  10. 年薪50w的3D建模师告诉你,互联网行业什么职业最赚钱?