鱼羊 假装发自 凹非寺
量子位 报道 | 公众号 QbitAI

只要网络足够宽,深度学习动态就能大大简化,并且更易于理解。

最近的许多研究结果表明,无限宽度的DNN会收敛成一类更为简单的模型,称为高斯过程(Gaussian processes)。

于是,复杂的现象可以被归结为简单的线性代数方程,以了解AI到底是怎样工作的。

所谓的无限宽度(infinite width),指的是完全连接层中的隐藏单元数,或卷积层中的通道数量有无穷多。

但是,问题来了:推导有限网络的无限宽度限制需要大量的数学知识,并且必须针对不同研究的体系结构分别进行计算。对工程技术水平的要求也很高。

谷歌最新开源的 Neural Tangents,旨在解决这个问题,让研究人员能够轻松建立、训练无限宽神经网络。

甚至只需要5行代码,就能够打造一个无限宽神经网络模型。

这一研究成果已经中了ICLR 2020。戳进文末Colab链接,即可在线试玩。

开箱即用,5行代码打造无限宽神经网络模型

Neural Tangents 是一个高级神经网络 API,可用于指定复杂、分层的神经网络,在 CPU/GPU/TPU 上开箱即用。

该库用 JAX编写,既可以构建有限宽度神经网络,亦可轻松创建和训练无限宽度神经网络。

有什么用呢?举个例子,你需要训练一个完全连接神经网络。通常,神经网络是随机初始化的,然后采用梯度下降进行训练。

研究人员通过对一组神经网络中不同成员的预测取均值,来提升模型的性能。另外,每个成员预测中的方差可以用来估计不确定性。

如此一来,就需要大量的计算预算。

但当神经网络变得无限宽时,网络集合就可以用高斯过程来描述,其均值和方差可以在整个训练过程中进行计算。

而使用 Neural Tangents ,仅需5行代码,就能完成对无限宽网络集合的构造和训练。

from neural_tangents import predict, staxinit_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(),stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(),stax.Dense(1, W_std=1.5, b_std=0.05))y_mean, y_var = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ‘ntk’, diag_reg=1e-4, compute_cov=True)

上图中,左图为训练过程中输出(f)随输入数据(x)的变化;右图为训练过程中的不确定性训练、测试损失。

将有限神经网络的集合训练和相同体系结构的无限宽度神经网络集合进行比较,研究人员发现,使用无限宽模型的精确推理,与使用梯度下降训练整体模型的结果之间,具有良好的一致性。

这说明了无限宽神经网络捕捉训练动态的能力。

不仅如此,常规神经网络可以解决的问题,Neural Tangents 构建的网络亦不在话下。

研究人员在 CIFAR-10 数据集的图像识别任务上比较了 3 种不同架构的无限宽神经网络。

可以看到,无限宽网络模拟有限神经网络,遵循相似的性能层次结构,其全连接网络的性能比卷积网络差,而卷积网络的性能又比宽残余网络差。

但是,与常规训练不同,这些模型的学习动力在封闭形式下是易于控制的,也就是说,可以用前所未有的视角去观察其行为。

对于深入理解机器学习机制来说,该研究也提供了一种新思路。谷歌表示,这将有助于“打开机器学习的黑匣子”。

传送门

论文地址:https://arxiv.org/abs/1912.02803

谷歌博客:https://ai.googleblog.com/2020/03/fast-and-easy-infinitely-wide-networks.html

GitHub地址:https://github.com/google/neural-tangents

Colab地址:https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb

—完—

@量子位 · 追踪AI技术和产品新动态

深有感触的朋友,欢迎赞同、关注、分享三连վ'ᴗ' ի ❤

径向基神经网络_谷歌开源Neural Tangents:5行代码打造无限宽神经网络模型,帮助“打开ML黑匣子”...相关推荐

  1. 谷歌重磅开源新技术:5行代码打造无限宽神经网络模型,帮助“打开ML黑匣子”...

    鱼羊 假装发自 凹非寺 量子位 报道 | 公众号 QbitAI 只要网络足够宽,深度学习动态就能大大简化,并且更易于理解. 最近的许多研究结果表明,无限宽度的DNN会收敛成一类更为简单的模型,称为高斯 ...

  2. 径向基神经网络(rbfn)进行函数插值,代码实现

    1.例题:(第一个式子里的cos2.4π掉了一个π) 求解问题:使用精确插值方法,并确定 RBFN 的权重.假设 RBF 是标准差为 0.1 的高斯函数.使用测试集评估得到的 RBFN 的近似性能 2 ...

  3. 径向基神经网络(RBFNN)的实现(Python,附源码及数据集)

    文章目录 一.理论基础 1.径向基神经网络结构 2.前向传播过程 3.反向传播过程 4.建模步骤 二.径向基神经网络的实现 1.训练过程(RBFNN.py) 2.测试过程(test.py) 3.测试结 ...

  4. 径向基神经网络及MATLAB实现

    应用背景:我们知道,在使用BP神经网络时,由于其采用负梯度下降法对权值进行调节而具有收敛速度慢和容易陷入局部最小值等缺点,为了克服这些缺点,人们提出了径向基神经网络(Radial  Basis  Fu ...

  5. 【零散知识】径向基函数,径向基神经网络和其与BP神经网络的区别

    前言: { 最近在重新看傅立叶变换,感觉这简直是打开新世界的大门.都怪我之前没学好,现在看起来比较费劲,花了不少时间,所以这次还是零散知识. 这次的主要内容都是围绕径向基神经网络展开的. } 正文: ...

  6. 主流神经网络(3)——径向基神经网络

    (2)径向基神经网络 径向基神经网络*(radial basis function networks, RBF)本质上就是FFNN,结构没有任何改变,只不过使用"径向基函数(radial b ...

  7. 径向基神经网络(实例故障分类)

    径向神经网络的创建: 调用格式: net=newrbe(p,t,spread) -------------------p  t分别为输入和输出样本,spread  为径向神经网络的散布常数 或者更高效 ...

  8. 【水位预测】基于matlab径向基神经网络地下水位预测【含Matlab源码 1939期】

    一.获取代码方式 获取代码方式1: 完整代码已上传我的资源:[水位预测]基于matlab径向基神经网络地下水位预测[含Matlab源码 1939期] 点击上面蓝色字体,直接付费下载,即可. 获取代码方 ...

  9. 数据拟合 | MATLAB实现RBF径向基神经网络多输入数据拟合

    数据拟合 | MATLAB实现RBF径向基神经网络多输入数据拟合 目录 数据拟合 | MATLAB实现RBF径向基神经网络多输入数据拟合 基本介绍 程序设计 模型差异 参考资料 基本介绍 RBF神将网 ...

最新文章

  1. 认知学习法-学习笔记
  2. ubuntu20输入法qiehuan_ubuntu20.04中文输入法安装步骤
  3. Tomcat的安装和环境变量配置
  4. js中执行到一个if就停止的代码_Node.JS实战64:ES6新特性:Let和Const。
  5. Android细节问题总结(二)
  6. (todo)数组名 有存储空间吗?
  7. 信息奥赛一本通(1231:最小新整数)
  8. iOS6,7,8,9新特性汇总
  9. TCP流量控制和滑动窗口
  10. 【转载】MySQL5.6.27 Release Note解读(innodb及复制模块)
  11. PHP has encountered an Access Violation at
  12. 【微信页面】移动端微信页面禁止字体放大
  13. ubuntu下配置安装PyQt4
  14. oracle 排序性能优化,Oracle优化之: 利用索引的有序性减少排序
  15. 程序设计比赛WBS图
  16. 思科路由器如何强行中断命令
  17. MFC中Combo控件的使用
  18. 虚拟机如何安装优麒麟19.10
  19. 03 【前端笔试】- 2020 搜狗校招笔试题
  20. Java入门-机票购买、座舱等级、淡旺季计算价格

热门文章

  1. 学习Spring Boot:(十二)Mybatis 中自定义枚举转换器
  2. Java面试题整理(附参考答案)
  3. MyBatis JdbcType介绍
  4. python类的定义和创建_Python类对象的创建和使用
  5. hive mysql类型,(二)Hive数据类型、数据定义、数据操作和查询
  6. python绘制饼图程序_python使用Matplotlib绘制饼图
  7. Git初学札记(九)————EGit检出远程分支
  8. 工业互联网智能智造-工业企业大数据汇聚通道-产品设计
  9. HTML+CSS+JS实现 ❤️touchSlider图片滚动图片轮播❤️
  10. 语言中要输出表格_C语言 | 表格输出若干人的信息