直观理解Neural Tangent Kernel
直观理解Neural Tangent Kernel
本文是文章Some Intuition on the Neural Tangent Kernel的翻译整理.
一句话总结:NTK衡量的是,在使用SGD优化参数下,其对应的随机到样本x′\displaystyle x'x′,在参数更新非常一小步η\displaystyle \etaη后,f(x)\displaystyle f( x)f(x)的变化。也就是:
k(x,x′)=limη→0f(x,θ+ηdfθ(x′)dθ)−f(x,θ)ηk(x,x')=\lim _{\eta \rightarrow 0}\frac{f\left( x,\theta +\eta \frac{df_{\theta } (x')}{d\theta }\right) -f(x,\theta )}{\eta } k(x,x′)=η→0limηf(x,θ+ηdθdfθ(x′))−f(x,θ)
热身
考虑最简单的函数f(i)\displaystyle f( i)f(i),在每个点i都有一个不同的取值,这些取值可以用一个参数θi=f(i)\displaystyle \theta _{i} =f( i)θi=f(i)来表示,如果我们初始化θi=3i+2\displaystyle \theta _{i} =3i+2θi=3i+2,那么这个函数大概长这样:
假设现在有一个样本是(x,y)=(10,50)\displaystyle ( x,y) =( 10,50)(x,y)=(10,50),那么根据这个样本,我们需要对这个函数进行梯度更新,显然,这个样本只会影响f(10)\displaystyle f( 10)f(10)这个点的参数θ10\displaystyle \theta _{10}θ10,所以其他参数并不会发现变化,只会在10这个点变化,而这个变化如图上红色箭头所示。
显然,假设我们使用squared error loss,L=(f(10;θ)−50)2\displaystyle L=( f( 10;\theta ) -50)^{2}L=(f(10;θ)−50)2,并且设更新步长η=0.1\displaystyle \eta =0.1η=0.1,那么∂L∂θ10=∂∂θ10(θ10−50)2=2(32−50)=−36\displaystyle \frac{\partial L}{\partial \theta _{10}} =\frac{\partial }{\partial \theta _{10}}( \theta _{10} -50)^{2} =2( 32-50) =-36∂θ10∂L=∂θ10∂(θ10−50)2=2(32−50)=−36,显然为了让loss减少,于是θ10=θ10−η∂L∂θ10=32+0.1∗36=35.6\displaystyle \theta _{10} =\theta _{10} -\eta \frac{\partial L}{\partial \theta _{10}} =32+0.1*36=35.6θ10=θ10−η∂θ10∂L=32+0.1∗36=35.6,我们发现f(10)\displaystyle f( 10)f(10)这个点增加3.6
线性函数
刚才的例子只有一个参数发生变化,过于特殊,现在给一个线性函数的例子,设f(x,θ)=θ1x+θ2\displaystyle f( x,\theta ) =\theta _{1} x+\theta _{2}f(x,θ)=θ1x+θ2. 我们初始化参数为θ1=3,θ1=1\displaystyle \theta _{1} =3,\theta _{1} =1θ1=3,θ1=1,这样,跟上面的例子差不多,不过这是一条直线,同样考虑样本点(x,y)=(10,50)\displaystyle ( x,y) =( 10,50)(x,y)=(10,50),在该样本下,作一次梯度下降更新参数,
我们发现,所有x的取值都会发现变化,而我们关注的点f(x)\displaystyle f( x)f(x)也会离目标值更近了点。
Nerual tangent kernel
考虑某个点x,我们关心该函数在该点下的取值为fθ(x)\displaystyle f_{\theta }( x)fθ(x),在SGD算法中,往往随机抽一个样本x′\displaystyle x'x′,我们想要知道,在这个新样本下,更新一次参数θ\displaystyle \thetaθ,f(x)\displaystyle f( x)f(x)会发生什么变化,而nerual tangent kernelk(x,x′)\displaystyle k( x,x')k(x,x′)正是衡量这种变化的函数:
ηk~θ(x,x′)=f(x,θ+ηfθ(x′)dθ)−f(x,θ)\eta \tilde{k}_{\theta } (x,x')=f\left( x,\theta +\eta \frac{f_{\theta } (x')}{d\theta }\right) -f(x,\theta ) ηk~θ(x,x′)=f(x,θ+ηdθfθ(x′))−f(x,θ)
换句话说,
k(x,x′)=limη→0f(x,θ+ηdfθ(x′)dθ)−f(x,θ)ηk(x,x')=\lim _{\eta \rightarrow 0}\frac{f\left( x,\theta +\eta \frac{df_{\theta } (x')}{d\theta }\right) -f(x,\theta )}{\eta } k(x,x′)=η→0limηf(x,θ+ηdθdfθ(x′))−f(x,θ)
我们对f\displaystyle ff进行泰勒近似,根据泰勒公式f(x+Δx)≈f(x)+f′(x)Δx\displaystyle f( x+\Delta x) \approx f( x) +f'( x) \Delta xf(x+Δx)≈f(x)+f′(x)Δx
f(x,θ+ηdfθ(x′)dθ)≈f(x,θ)+f′(x,θ)ηdfθ(x′)dθf\left( x,\theta +\eta \frac{df_{\theta } (x')}{d\theta }\right) \approx f( x,\theta ) +f'( x,\theta ) \eta \frac{df_{\theta } (x')}{d\theta } f(x,θ+ηdθdfθ(x′))≈f(x,θ)+f′(x,θ)ηdθdfθ(x′)
于是
f(x,θ+ηdfθ(x′)dθ)−f(x,θ)η≈f′(x,θ)dfθ(x′)dθ=<dfθ(x)dθ,dfθ(x′)dθ>\frac{f\left( x,\theta +\eta \frac{df_{\theta } (x')}{d\theta }\right) -f(x,\theta )}{\eta } \approx f'( x,\theta )\frac{df_{\theta } (x')}{d\theta } =\left< \frac{df_{\theta } (x)}{d\theta } ,\frac{df_{\theta } (x')}{d\theta }\right> ηf(x,θ+ηdθdfθ(x′))−f(x,θ)≈f′(x,θ)dθdfθ(x′)=⟨dθdfθ(x),dθdfθ(x′)⟩
我们发现,NTK给予了我们一种“预测”f(x)\displaystyle f( x)f(x)在SGD下变化的能力。那它有一些什么性质呢?
NTK对参数的取值敏感
显然,神经网络可以改变参数,但是保持输出的值不变,那么参数的变化对NTK会有影响吗?答案是有,比如说,上面线性函数的例子将函数改为
fθ(x)=θ1x+10θ2f_{\theta }( x) =\theta _{1} x+10\theta _{2} fθ(x)=θ1x+10θ2
但是设θ1=3,θ2=0.1\displaystyle \theta _{1} =3,\theta _{2} =0.1θ1=3,θ2=0.1,你会发现这个函数跟上面是一致的,只是截距项从1变成10*0.1,然而,使用同样的样本(x,y)=(10,50)\displaystyle ( x,y) =( 10,50)(x,y)=(10,50)更新这么一个函数,你会发现它的函数变化是不同的:
也就是说,NTK对参数是敏感的。
tiny radial basis function network
最后,再来一个小型的神经网络举个例子,考虑函数
fθ(x)=θ1exp(−(x−θ2)230)+θ3exp(−(x−θ4)230)+θ5,f_{\theta } (x)=\theta _{1}\exp\left( -\frac{(x-\theta _{2} )^{2}}{30}\right) +\theta _{3}\exp\left( -\frac{(x-\theta _{4} )^{2}}{30}\right) +\theta _{5} , fθ(x)=θ1exp(−30(x−θ2)2)+θ3exp(−30(x−θ4)2)+θ5,
初始化为(θ1,θ2,θ3,θ4,θ5)=(4.0,−10.0,25.0,10.0,50.0)\displaystyle (\theta _{1} ,\theta _{2} ,\theta _{3} ,\theta _{4} ,\theta _{5} )=(4.0,-10.0,25.0,10.0,50.0)(θ1,θ2,θ3,θ4,θ5)=(4.0,−10.0,25.0,10.0,50.0),同样的,在样本点(x,y)=(10,50)\displaystyle ( x,y) =( 10,50)(x,y)=(10,50)更新这么一个函数,我们得到函数的变化为:
显然,我们发现,在靠近0附近它的变化是很小的,而在10附近它的变化是很大的,之前说过,NTK就是刻画这种变化的,因此,我们可以把NTK画出来:
这里除以了在10,10处标准化了一下(只是个除了个常数可以无视),可以发现,确实在0附近的值很小,而在10附近的值很大,符合我们的观察。值得一提的是,虽然样本是x=10\displaystyle x=10x=10的点,但是变化最大的地方其实是在x=7\displaystyle x=7x=7的地方。
那如果我们不停的更新参数会怎样?以下是更新15次的图
显然,随着参数的变化,kernel大小也在变化,而且越来越平滑,这意味着函数在每个取值下的变化越来越一致。
NTK有什么用?
NTK在无限宽神经网络下有几个非常重要,有用的性质:
- 在无限宽的网络中,如果参数θ0\displaystyle \theta _{0}θ0在以某种合适的分布下初始化,那么在该初始值下的NTKkθ0\displaystyle k_{\theta _{0}}kθ0是一个确定的函数,这意味着,不管我的初始值是多少,最终总会收敛到一个确定的核函数上,它与初始化无关!
- 而且在无限宽网络中,kθt\displaystyle k_{\theta _{t}}kθt并不会随着训练的变化而变化,也就是说,在训练中参数的改变并不会改变该核函数。
以上两个事实告诉我们,在无限宽网络中,训练可以理解成一个简单的交kernel gradient descent的算法,而且kernel还是固定的,只取决于网络的结构还有激活函数之类的。这些性质,加上,Neal,(1994)的结论,使得我们可以将这个用梯度下降收敛的极值的概率分布看做是一个随机过程。
最后要注意的就是,这里的NTK,其实是针对梯度下降法提出来的,以往的无限宽网络与高斯过程的联系其实只是在初始化阶段的时候收敛到高斯过程,它并没有说训练过程也是一个高斯过程。它是没有考虑随机梯度下降这一过程的。
而在NTK这里,我们发现,训练的时候与kernel无关,而且初始化决定了它的取值,也就是说,在训练过程中,我们仍然可以认为它还是一个高斯过程而不仅仅是初始化的时候。
参考资料
Some Intuition on the Neural Tangent Kernel
colab notebook
直观理解Neural Tangent Kernel相关推荐
- Neural Tangent Kernel 理解(一)原论文解读
欢迎关注WX公众号,每周发布论文解析:PaperShare, 点我关注 NTK的理解系列 暂定会从(一)论文解读,(二)kernel method基础知识,(三)神经网络表达能力,(四)GNN表达能力 ...
- kernel方法的直观理解与详述
1.直观理解 通常意义上的kernel method 主要由两种解释:一是相似性的度量:二是特征映射. 在详细引入kernel的工作之前,以支持向量机 SVM 为例说明kernel方法,因为它是SVM ...
- 3.7 注意力模型直观理解-深度学习第五课《序列模型》-Stanford吴恩达教授
注意力模型直观理解 (Attention Model Intuition) 在本周大部分时间中,你都在使用这个编码解码的构架(a Encoder-Decoder architecture)来完成机器翻 ...
- SVM支持向量机【直观理解】
转载文章:https://baijiahao.baidu.com/s?id=1607469282626953830&wfr=spider&for=pc 如果你曾经使用机器学习解决分类问 ...
- 反向传播的直观理解 (以及为什么反向传播是一种快速的算法)
解释:公式 46 其实就是微分的定义公式,"W + εej "代表对于由 j 个 weight 构成的单位向量 W,在其他权重都不变的情况下,使得 Wj 微小的改变 ε(W为单位 ...
- 梯度的直观理解_关于梯度、旋度和散度的直观理解
关于梯度.旋度和散度的直观理解 散度为零,说明是无源场:散度不为零时,则说明是有源场(有正源或负源) 若你的场是一个流速场,则该场的散度是该流体在某一点单位时间流出单位体积的净流量. 如果在某点,某场 ...
- BP反向传播算法的思考和直观理解 -卷积小白的随机世界
https://www.toutiao.com/a6690831921246634504/ 2019-05-14 18:47:24 本篇文章,本来计划再进一步完善对CNN卷积神经网络的理解,但在对卷积 ...
- RNN循环神经网络的直观理解:基于TensorFlow的简单RNN例子
RNN 直观理解 一个非常棒的RNN入门Anyone Can learn To Code LSTM-RNN in Python(Part 1: RNN) 基于此文章,本文给出我自己的一些愚见 基于此文 ...
- 3.10 直观理解反向传播-深度学习-Stanford吴恩达教授
←上一篇 ↓↑ 下一篇→ 3.9 神经网络的梯度下降法 回到目录 3.11 随机初始化 直观理解反向传播 (Backpropagation Intuition (Optional)) 这个视频主要是推 ...
最新文章
- Ubuntu 设置Android adb 环境变量
- VS2012及VS系列怎样屏蔽CMD窗口~
- java中四种常用的引用类型_java中四种引用类型
- 29、Power Query-分支语句的进阶
- oracle上机题库_Oracle数据库考试试题库
- 如何正确的在一个循环中删除ArrayList中的元素。
- HDU 5439 Aggregated Counting
- 深度学习《patchGAN》
- jenkins配置sonar并扫描C#代码
- mysql获取当天,昨天,本周,本月,上周,上月的起始时间
- 教你如何不登陆复制CSDN代码
- easyUI的iconCls
- html 防网页假死,HTML 5 Web开发:防止浏览器假死的方法
- word无法选定图片随文字移动
- Quasi-Newton拟牛顿法(共轭方向法)
- iSCSI target initiator
- 基于springboot的生鲜门店配送管理系统(idea+springboot+html+thymeleaf)
- 稳定匹配 5分钟看懂GS算法 附有常考常见例题及解析
- 苹果和安卓正确的卸载软件方法分享
- 上传图片和得到图片长宽大小的方法
热门文章
- <Android开发> Android vold - 第七篇 vold 的runCommand()方法解析
- [转载]菲尔兹奖历届得主
- 第21节迁移学习原理及实例
- 苏州新导室内定位方案之WIFI RTLS室内定位解决方案
- 项目实训工作总结(2)
- 100代码搞定C语言游戏开发,编程原来如此简单
- 如何从github上下载文件并运行
- 小白都能看懂的实战教程 手把手教你Python Web全栈开发(DAY 3)
- html f12键的作用,电脑键盘中F1-F12每个功能键的作用您都知道吗?
- 新手如何看k线(图) .