参加 2018 AI开发者大会,请点击 ↑↑↑

作者 | Vincent Mühler

译者 | 刘旭坤

整理 | Jane

出品 | AI科技大本营

【导读】TensorFlow.js 的发布可以说是 JS 社区开发者的福音!但是在浏览器中训练一些模型还是会存在一些问题与不同,如何可以让训练效果更好?本文的作者,是一位前端工程师,经过自己不断的经验积累,为大家总结了 18 个 Tips,希望可以帮助大家训练出更好的模型。

TensorFlow.js 发布之后我就把之前训练的目标/人脸检测和人脸识别的模型往 TensorFlow.js 里导,我发现有些模型在浏览器里运行的效果还相当不错。感觉 TensorFlow.js 让我们搞前端的也潮了一把。

虽说浏览器也能跑深度学习模型了,这些模型终归不是为在浏览器里运行设计的,所以很多限制和挑战也就随之而来了。就拿目标检测来说,不说实时检测,就是维持一定的帧率恐怕都很困难。更别提动辄上百兆的模型给用户浏览器和带宽(手机端的话)带来的压力了。

不过只要我们遵循一定的原则,用卷积神经网络 CNN 和 TensorFlow.js 在浏览器里训练个像样的深度学习模型并非痴人说梦。从下面图里可以看到,我训练的这几个模型大小都控制在了 2 MB 以下,最小的才 3 KB。

大家可能心中会有个疑问:你脑残吗?要用浏览器训练模型?对,用自己电脑、服务器、集群或者云来训练深度学习模型肯定是一条正道,但并非人人都有钱用 NVIDIA GTX 1080 Ti 或者Titan X(尤其是显卡集体大涨价之后)。这时,在浏览器中训练深度学习模型的优势就体现出来了,有了 WebGL 和 TensorFLow.js 我用电脑上的 AMD GPU 也能很方便地训练深度学习模型。

对目标识别问题,为了稳妥起见通常都会建议大家用一些现成的架构比如YOLO、SSD、残差网络 ResNet 或 MobileNet ,但我个人认为如果完全照搬的话,在浏览器上训练效果肯定是不好的。在浏览器上训练就要求模型要小、要快、要越容易训练越好。下面我们就从模型架构、训练和调试等几个方面来看看如何才能做到这三点。

模型架构

1. 控制模型大小

控制模型的规模很重要。如果模型架构太大太复杂,训练和运行的速度都会降低,从浏览器载入模型度速度也会变慢。控制模型的规模说起来简单,难的是取得准确率和模型规模之间的平衡。如果准确率达不到要求,模型再小也是废物。

2. 使用深度可分离卷积操作

与标准卷积操作不同,深度可分离卷积先对每个通道进行卷积操作,之后再进行1X1跨通道卷积。这样做的好处是可以大大减小参数个数,所以模型运行速度会有很大提升,资源的消耗和训练速度也会有所提升。深度可分离卷积操作的过程如下图所示:

MobileNet 和 Xception 都使用了深度可分离卷积,TensorFlow.js 版本的 MobileNet 和 PoseNet 中你也能见到深度可分离卷积的身影。虽然深度可分离卷积对模型准确率的影响还有争议,但从我个人的经验来看在浏览器里训练模型用它肯定没错。

第一层我推荐用标准的 conv2d 操作来保持提取完特征的通道之间的关系。因为第一层一般参数不多,所以对性能的影响不大。

其他卷积层就可以都用深度可分离卷积了。比如这里我们就使用了两个过滤器。

这里 tf.separableConv2d 使用的卷积核结构分别是[3,3,32,1]和[1,1,32,64]。

3.运用跳跃连接和密集块

随着网络层数的增加,梯度消失问题出现的可能性也会增大。梯度消失会造成损失函数下降太慢训练时间超长或者干脆失败。ResNet 和 DenseNet 中采用的跳跃连接则能避免这一问题。简单说来跳跃连接就是把某些层的输出跳过激活函数直接传给网络深处的隐藏层作为输入,如下图所示:

这样就避免了因为激活函数和链式求导造成的梯度消失问题,我们也能根据需求增加网络的层数了。

显然跳跃连接隐含的一个要求就是连接的两层输出和输入的格式必须能对应得上。我们要用残差网络的话,那最好保证两层的过滤器数目和填充都一致而且步幅为1(不过肯定有其它做法来保证格式对应)。

一开始我模仿残差网络的思路隔一层加一个跳跃连接(如下图)。不过我发现密集块效果更好,模型收敛的速度比加跳跃连接快得多。

下面我们就来看看具体的代码,这里的密集块有四个深度可分离卷积层,其中第一层我把步幅设为 2 来改变输入的大小。

4.激活函数选ReLU

在浏览器里训练深度网络的话激活函数不用看直接选 ReLU 就行了,主要原因还是梯度消失。不过大家可以试试 ReLU 的不同变种,比如

和 MobileNet 用的 ReLU-6 (y = min(max(x, 0), 6)):

训练过程

5.优化器选Adam

这也是我个人的经验只谈。之前用 SGD 经常会卡在局部极小值或者出现梯度爆炸。我推荐大家一开始把学习速率设为 0.001 然后其他参数都用默认:

6.动态调整学习速率

一般来说当损失函数不再下降的时候我们就该停止训练了,因为再训练就过拟合了。不过如果我们发现损失函数出现上下震荡的情况,则可能通过减小学习速率让损失函数变得更小。

下面这个例子中我们可以看到学习速率一开始设的是 0.01,然后从 32 期开始出现震荡(黄线)。这里通过将学习速率改为 0.001(蓝线)使损失函数又减小了大概 0.3。

7.权重初始化原则

我个人喜欢把偏置量设为 0,权重则用传统的正态分布。我一般用的是 Glorot 正态分布初始化法:

8.把数据集顺序打乱

老生常谈了。TensorFlow.js 中我们可以用 tf.utils.shuffle 来实现。

9. 保存模型

js 可以通过 FileSaver.js 来实现模型的存储(或者叫下载)。比如下面的代码就可以把模型所有的权重保存起来:

保存成什么格式是自己定的,但 FileSaver.js 只管存,所以这里要用JSON.strinfify 把 Blob 转成字符串:

调试

10.保证预处理和后处理的正确性

虽然是句废话但“垃圾数据垃圾结果”实在是至理名言。标记要标对,每层的输入输出也要前后一致。尤其是对图片做过一些预处理和后处理的话更要仔细,有时候这些小问题还比较难发现。所以虽然费些功夫但磨刀不误砍柴工。

11.自定义损失函数

TensorFlow.js 提供了很多现成的损失函数给大家用,而且一般说来也够用了,所以我不太建议大家自己写。如果实在要自己写的话,请一定注意先测试测试。

12.在数据子集试试过拟合

我建议大家模型定义好之后先挑个十几二十张图试试看损失函数有没有收敛。最好能把结果可视化一下,这样就能很明显地看出这个模型有没有成功的潜质。

这样做我们也能早早地发现模型和预处理时的一些低级错误。这其实也就是 11 条里说的测试测试损失函数。

性能

13.内存泄漏

不知道大家知不知道 TensorFlow.js 不会自动帮你进行垃圾回收。张量所占的内存必须自己手动调用 tensor.dispose() 来释放。如果忘记回收的话内存泄漏是早晚的事。

判断有没有内存泄漏很容易。大家把 tf.memory() 每次迭代都输出来看看张量的个数。如果没有一直增加那说明没泄漏。

14.调整画布大小,而不是张量大小

在调用 TF . from pixels 之前,要将画布转换成张量,请调整画布的大小,否则你会很快耗尽 GPU 内存。

如果你的训练图像大小都一样,这将不会是一个问题,但是如果你必须明确地调整它们的大小,你可以参考下面的代码。(注意,以下语句仅在 tfjs - core 的当前状态下有效,我当前正在使用 tfjs - core 版本 0.12.14)

15.慎选批大小

每一批的样本数选多少,也就是批大小显然取决于我们用的什么 GPU 和网络结构,所以大家最好试试不同的批大小看看怎么最快。我一般从 1 开始试,而且有时候我发现增加批大小对训练的效率也没啥帮助。

16.善用IndexedDB

我们训练的数据集因为都是图片所以有时候还是挺大的。如果每次都下载的话肯定效率低,最好是用 IndexedDB 来存储。IndexedDB 其实就是浏览器里嵌入的一个本地数据库,任何数据都能以键值对的形式进行存储。读取和保存数据也只要几行代码就能搞定。

17.异步返回损失函数值

要实时监测损失函数值的话可以用下面的代码这来自己算然后异步返回:

需要注意的是如果每期训练完要把损失函数值存到文件里的话这样的代码就有点问题了。因为现在损失函数的值是异步返回了所以我们得等最后一个 promise 返回才能存。不过我一般都暴力地在一期结束之后直接等个 10 秒再存:

18.权重的量化

为了实现又小又快的目标,在模型训练完成之后我们应该对权重进行量化来压缩模型。权重量化不光能减小模型的体积,对提高模型的速度也很有帮助,而且几乎全是好处没坏处。这一步就让模型又能小又能快,非常适合我们在浏览器里训练深度学习模型。

在浏览器里训练深度学习模型的十八招(实际十七招)就总结到这里,希望大家读了这篇文章能够有所收获。

如果有问题也欢迎在后台给我们留言,大家一起讨论!

原文链接:

https://itnext.io/18-tips-for-training-your-own-tensorflow-js-models-in-the-browser-3e40141c9091

【完】

2018 AI开发者大会

只讲技术,拒绝空谈

2018 AI开发者大会是一场由中美人工智能技术高手联袂打造的AI技术与产业的年度盛会!是一场以技术落地为导向的干货会议!大会设置了10场技术专题论坛,力邀15+硅谷实力讲师团和80+AI领军企业技术核心人物,多位一线经验大咖带你将AI从云端落地。

大会日程以及嘉宾议题请查看下方海报

(点击查看大图)

点击「阅读原文」,查看大会更多详情。2018 AI开发者大会——摆脱焦虑,拥抱技术前沿。

前端工程师掌握这18招,就能在浏览器里玩转深度学习相关推荐

  1. 【专访英特尔高级首席工程师戴金权】普通数据工程师,如何玩转深度学习?

    记者 | 白羽 几乎每周,人工智能深度学习,总会在某个领域有新的技术突破,新的亮眼成果出来. 不过,这些最新的突破和成果,更多还是在深度学习的各大社区流动,更多是被顶尖教授.学者所掌握和应用,对于普通 ...

  2. 六招教你用Python构建好玩的深度学习应用

    摘要: 导读 深度学习是近来数据科学中研究和讨论最多的话题.得益于深度学习的发展,数据科学在近期得到了重大突破,深度学习也因此得到了很多关注.据预测,在不久的将来,更多的深度学习应用程序会影响人们的生 ...

  3. 一个初级的前端工程师需要知道些什么?

    一个初级的前端工程师需要知道些什么? 按照我的想法,我把前端工程师分为了入门.初级.中级.高级这四个级别入门级别指的是了解什么是前端(前端到底是什么其实很多人还是不清楚的,底什么是前端后端.后台),了 ...

  4. 一个前端工程师的基本修养

    有人说互联网是前端工程师的舞台,先不论这个说法是否有些夸大其词,但前端工程师绝对撑起了互联网应用开发的"半壁江山".随着传统网站.手机应用.桌面应用.微信小程序等次第出现,需要前端 ...

  5. 教你做好web前端工程师简历

    春节前在蓝色理想上发了个"雅虎口碑招聘前端工程师 "的启事,节后收到很多简历,加之HR通过专业招聘网站得到的简历和朋友同事推荐的简历,数量上是相当的多,把这些简历一一看完真是一个漫 ...

  6. 前端工程师是怎样一种职业

    作者:吕大豹 文章链接 : 前端工程师是怎样一种职业 本文略有删减,请尊看原文. 前端工程师的英文名为front-end engineer,简称FE,下文将用FE来代称.现在意义上的前端(并非只制作网 ...

  7. 【杂谈】什么是我心目中深度学习算法工程师的标准

    有三AI平台只专心做原创输出很少扯淡也不蹭热点,不过最近询问的朋友多了,不得不统一写篇文章来回答一下这个大家都很关心的问题,当然,这仅仅是个人观点. 作者&编辑 | 言有三 目前利用深度学习这 ...

  8. 深度学习工程师能力评估标准

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 深度学习工程师能力评估标准 1.范围 2术语和定义 2.1人工智能artificial int ...

  9. 「杂谈」什么是我心目中深度学习算法工程师的标准

    http://blog.sina.com.cn/s/blog_cfa68e330102zoco.html 有三AI平台只专心做原创输出很少扯淡也不蹭热点,不过最近询问的朋友多了,不得不统一写篇文章来回 ...

最新文章

  1. 案例:用JS实现放大镜特效
  2. python美国股票数据api_【美股量化00篇】Python获取新浪接口美股实时数据
  3. Andorid应用去google广告
  4. Codeforces 1336E Chiori and Doll Picking (子集和变换、线性基、阈值算法、状压 DP、组合计数)...
  5. FastJSON的依赖
  6. cargo maven_与Maven 3,Failsafe和Cargo插件的集成测试
  7. 需加装饰——装饰模式
  8. unity 常用函数
  9. 管理者必看!深度剖析BI与数据仓库,企业能否成功转型就看它
  10. 理论基础 —— 索引 —— 2-3 树
  11. [转]CTO谈豆瓣网和校内网技术架构变迁
  12. 【C语言】结构和指针
  13. html没有内容怎么爬,Url没有在网页中返回正确的html(对于我的Java爬虫)
  14. IIS添加对ashx文件的支持
  15. Scikit-learn_回归算法_支持向量机回归
  16. 软件测试员如何进行产品测试?
  17. 中望3d快捷键命令大全_CAD、3D快捷命令
  18. 【蓝桥杯省赛真题24】Scratch哪吒飞行 少儿编程scratch蓝桥杯省赛真题讲解
  19. 蓝牙的四种音频编码:Apt-X、SBC、AAC、LDAC
  20. Java学习---day07_继承及final、Object的介绍

热门文章

  1. 【敏捷3.2】评估价值的方法
  2. js同一页面两个表格table数据显示冲突
  3. 银行信用卡最低还款利息计算方法
  4. 网易视频云郭再荣:视频云服务的未来在于场景化
  5. 7-4 打印倒直角三角形图形(10 分)
  6. ACL 2022 | 字节AI Lab联合UCSB提出MOSST:基于单调切分的端到端同传
  7. matlab点云三维重构,无序点云三维重建方法技术
  8. 帮助理解Java中ThreadLocal的一篇文章
  9. linux要uefi启动mbr,uefi可以引导mbr吗
  10. Keil主题配色方案