Spatio-Temporal Backpropagation for Training High-performance Spiking Neural Networks笔记
《Spatio-Temporal Backpropagation for Training High-performance Spiking Neural Networks》笔记
ABSTRACT
STBP:Spatio-Temporal Backpropagation
1.探索SNN的一个重要原因:spikes的相关编码可以包含很多的时空信息。
2.目前很多(这篇文章发之前)研究都只关注于神经网络的空间域信息,造成了相关研究的瓶颈。也有部分只研究独立时域信息。
3.spike activity是无法微分的,这给SNN的训练造成了很大的困难。
4.本文建立了iterative LIF模型,易于梯度下降算法的训练。训练同时考虑层与层之间的空间域关系和独立的时域关系。
5.提出了spike activity导数的近似处理函数。
6.本文涉及到的训练没有采用任何复杂的训练trick。
7.应用的数据集为静态MNIST、自己建立的一个目标检测数据集、动态N-MNIST。
INTRODUCTION
1.SNN的两点优势:时间动态性&硬件友好性(时间和能量消耗低,神经形态芯片)
2.三种SNN的训练方法:
无监督学习(利用生物突触可塑性)
例如spiking timing dependent plasticity(STDP),但是其只考虑当前神经元随时间的活动所以表现不大好。
间接监督学习(ANN2SNN)
先训练好一个ANN,然后将其转换到SNN的版本。SNN的spiking rate充当ANN神经元的模拟活动。这个方法生物可信性比较差。
直接监督学习(如本文采用的STBP方法)
这种方法在本论文之前很多都只考虑了spatial domain,而没考虑temporal domain。并且需要很多复杂的训练技巧来提高准确率。但是STBP就不需要这些复杂的训练技巧。
3.利用SNN的动态特性建立了一个iterative LIF模型,这个模型对梯度下降算法很友好。介绍了近似导数来解决无法微分的问题。用三种数据集检验。
4.分析了时间动态性以及不同导数近似方法对最后结果的影响。
正文部分
LIF模型如下:
τdu(t)dt=−u(t)+I(t)\tau\frac{du(t)}{dt} = -u(t) + I(t)τdtdu(t)=−u(t)+I(t)
u(t)u(t)u(t)是神经元膜电位,τ\tauτ是时间常数,I(t)I(t)I(t)是前突触传来的输入,由前神经元活动以及外界干扰和突触权重决定。膜电位超过阈值VthV_{th}Vth神经元便会激发一个脉冲,然后马上恢复到urestu_{rest}urest。
相比于DNN的传播,SNN的每个神经元有自反馈注入,从而在时间域产生非易失性的电位积累。
对LIF模型适当简化,并且遵循事件驱动迭代的更新规则,有下式:
u(t)=u(ti−1)eti−1−tτ+I(t)u(t) = u(t_{i-1})e^\frac{t_{i-1}-t}{\tau} + I(t)u(t)=u(ti−1)eτti−1−t+I(t)
膜电位指数级衰减直到神经元接收到前突触的输入,在新的一个更新周期该神经元才有可能激发脉冲。一个神经元的状态由上式右边的两个因素决定,分别代表时域和空间域。
SNN(STBP)的整体(SD&TD)结构图如下:
在SD和TD同时进行迭代,如下式:
xit+1,n=∑j=1l(n−1)wijnojt+1,n−1x_i^{t+1,n} = \sum_{j=1}^{l(n-1)}w_{ij}^no_j^{t+1,n-1}xit+1,n=∑j=1l(n−1)wijnojt+1,n−1
uit+1,n=uit,nf(oit,n)+xit+1,n+binu_i^{t+1,n} = u_i^{t,n}f(o_i^{t,n})+x_i^{t+1,n} + b_i^nuit+1,n=uit,nf(oit,n)+xit+1,n+bin
Oit+1,n=g(uit+1,n)O_i^{t+1,n} = g(u_i^{t+1,n})Oit+1,n=g(uit+1,n)
其中f(x)=τe−xτf(x) = \tau e^{-\frac{x}{\tau}}f(x)=τe−τx,g(x)={1x≥Vth0x<Vthg(x)=\left\{ \begin{aligned} 1 & & x \geq V_{th} \\ 0 & & x < V_{th} \end{aligned} \right.g(x)={10x≥Vthx<Vth
t表示时刻,n表示神经元所处的层数,l(n)l(n)l(n)表示第n层神经元的总数,wijw_{ij}wij是从第j个前突触层的神经元到第i个后突触层得神经元的权重。oj=1o_j = 1oj=1表示产生了脉冲,等于0表示什么也没发生,xix_ixi是第i个神经元的前突触输入和的简要表示,类似于LIF中的I,uiu_iui是第i个神经元的膜电位,bib_ibi是一个与与阈值VthV_{th}Vth相关的参数值。
uit+1,n=uit,nf(oit,n)+xit+1,n+binu_i^{t+1,n} = u_i^{t,n}f(o_i^{t,n})+x_i^{t+1,n} + b_i^nuit+1,n=uit,nf(oit,n)+xit+1,n+bin和Oit+1,n=g(uit+1,n)O_i^{t+1,n} = g(u_i^{t+1,n})Oit+1,n=g(uit+1,n)是由LSTM模型的灵感得到。
τ\tauτ 是一个较小的正时间常数,则有:
f(oit,n)={τoit,n=00oit,n=1f(o_i^{t,n})=\left\{ \begin{aligned} \tau & & o_i^{t,n} = 0 \\ 0 & & o_i^{t,n} = 1 \end{aligned} \right.f(oit,n)={τ0oit,n=0oit,n=1
在时间窗口T上用均方误差制定损失函数如下:
L=12S∑s=1S∣∣ys−1T∑t=1Tost,N∣∣22L = \frac{1}{2S} \sum_{s=1}^{S}||y_s - \frac{1}{T} \sum_{t =1}^{T} o_s^{t,N}||^2_2L=2S1∑s=1S∣∣ys−T1∑t=1Tost,N∣∣22
其中ysy_sys和oso_sos表示标签向量和第N层的输出向量。可以看出L是W和b的函数。假定已经得到每层每个时刻的导数∂L∂oi\frac{\partial L}{\partial o_i}∂oi∂L和∂L∂ui\frac{\partial L}{\partial u_i}∂ui∂L,这对求取∂L∂W\frac{\partial L}{\partial W}∂W∂L和∂L∂b\frac{\partial L}{\partial b}∂b∂L,很关键。根据导数的链式法则推导四种情况下的梯度下降过程暂时忽略(主要是推导∂L∂oit,n\frac{\partial L}{\partial o_i^{t,n}}∂oit,n∂L和∂L∂uit,n\frac{\partial L}{\partial u_i^{t,n}}∂uit,n∂L)。
下面是在STD上误差传播的示意图:
a图表示垂直方向是SD的误差传播,水平方向是TD的误差传播。b图扩展成了网络级别的示意图。
关于∂L∂W\frac{\partial L}{\partial W}∂W∂L和∂L∂b\frac{\partial L}{\partial b}∂b∂L,有下面的式子:
其中和∂L∂ut,n\frac{\partial L}{\partial u^{t,n}}∂ut,n∂L可以根据之前推导的过程得出。
前面推导了基于STBP的梯度信息,下面解决在脉冲激发点无法微分的问题。这个主要是针对g(u)g(u)g(u),它的导数在0处取无限值,其余地方都取0,会在误差传播过程中造成梯度消失或爆炸的问题。目前的一个办法是将这个不连续的点(脉冲时刻)视作噪声,然后宣称它对模型的鲁棒性有益,但是不从根本解决问题。
下面是本文提出的四种曲线来近似g(u)g(u)g(u)的导数:
h1(u)=1a1sign(∣u−Vth∣<a12)h_1(u) = \frac{1}{a_1}sign(|u-V_{th}| < \frac{a_1}{2})h1(u)=a11sign(∣u−Vth∣<2a1)
h2(u)=(a22−a24∣u−Vth∣)sign(2a2−∣u−Vth∣)h_2(u) = (\frac{\sqrt{a_2}}{2}-\frac{a_2}{4}|u-V_{th}|)sign(\frac{2}{\sqrt{a_2}}-|u-V_{th}|)h2(u)=(2a2−4a2∣u−Vth∣)sign(a22−∣u−Vth∣)
h3(u)=1a3eVth−ua3(1+eVth−ua3)2h_3(u) = \frac{1}{a_3}\frac{e^{\frac{V_th-u}{a_3}}}{(1+e^{\frac{V_th-u}{a_3}})^2}h3(u)=a31(1+ea3Vth−u)2ea3Vth−u
h4(u)=12πa4e(u−Vth)2a4h_4(u) = \frac{1}{\sqrt{2\pi a_4}}e^{\frac{(u-V_th)^2}{a_4}}h4(u)=2πa41ea4(u−Vth)2
参数aia_iai决定曲线的尖锐程度。事实上,这四种曲线分别是矩形函数、多项式函数、sigmoid函数和高斯积累分布函数的衍生。调整aia_iai的值,使每个曲线的积分都为1,可以证明下面的式子:
limai=0+hi(u)=dgdu,i=1,2,3,4\lim_{a_i = 0^+}h_i(u) = \frac{dg}{du},i = 1,2,3,4limai=0+hi(u)=dudg,i=1,2,3,4
进一步得到下面的近似:
dgdu≈dhidu\frac{dg}{du} \approx \frac{dh_i}{du}dudg≈dudhi
下面是g(u)g(u)g(u)和它导数以及各个曲线的示意图:
结果
参数初始化
本文固定阈值为一常数,调整weight来控制spiking活动的平衡。
首先将所有的权重参数标准均匀分布在-1到1之间:W∼U[−1,1]W \sim U[-1,1]W∼U[−1,1]。
然后normalize所有参数: wi,jn=wi,jn∑j=1l(n−1)wi,jn2w_{i,j}^n = \frac{ w_{i,j}^n}{\sqrt{\sum_{j=1}^{l(n-1)}{w_{i,j}^n}^2}}wi,jn=∑j=1l(n−1)wi,jn2wi,jn
其他参数见下图:
各个数据集的实验
首先要将静态的数据集转换为spike events。本文采用了伯努利采样的方式从原始像素密度转换为了spike rate。
全连接网络
一、静态数据集的相关实验(主要)
测试该网络的数据集有MNIST和自制的一个目标检测数据集,均为静态数据集。MNIST数据集中用60000个有标签的手写数字进行训练,10000个用来测试,每个数字图片是28×2828\times2828×28的灰度图片;目标检测数据集是一个二类图像数据集,以有无行人分类,包含1509个训练样本和631个测试样本,也是28×2828\times2828×28的灰度图像。
具体网络进行的工作与一些相关操作:
1.限定一个时间窗口,例如20ms,程序中即依次仿真0-19时刻。
2.对于每个时刻,做一次下面的处理:先把输入的(1,28,28)图像转换成784维的列向量并进行标准化操作把向量中的每个元素处理成0~1范围内的数,然后用与随机向量做比较的方法,对784维向量进行0-1赋值,作为输入的spike。
3.网络框架:784-800-10。
4.最后看窗口时间内,最后10个神经元(标签是0-9)的spiking rate的大小决定该图片表示的数字。
注意点:由于目标检测数据集的训练输入脉冲的激发率高于MNIST数据集,所以相应的阈值会有变化。
下面两幅图是实验结果(和其他方法对比):
二、动态数据集的相关实验
动态数据集相较于静态数据集包含更丰富的时域特征。N-MNIST数据集是通过动态视觉传感器(DVS)把静态MNIST数据集转换成它的动态版本即脉冲训练输入。
具体转换的操作:控制DVS按照等腰三角形的三边依次移动,在每个像素点强度变化时触发并生成spike train。示意图如下:
由于对于每个像素点的强度变化有变亮和变暗两种,DVS可以捕捉到两种spike events,记作on-event和off-event。由于N-MNIST允许在扫视过程中图像的相对偏移,它产生 34×3434\times3434×34 像素范围。
从上图中可以明显知道on-event和off-event区别很大,所以用两个通道区分,网络框架为:34×34×2−400−400−1034\times34\times2-400-400-1034×34×2−400−400−10。
下图是用N-MNIST数据集进行各种方法的结果比较(两层400合成了一层800):
ANN相关的方法通常用框架固定的模型,但是由于动态数据集的图像较为模糊,这种方法表现不好,且放弃了硬件友好的事件驱动范式。相比起来,SNN易于处理事件流图片。
卷积网络
SNN的全连接网络扩展到卷积网络可以让网络更深,并且包含更多的时域信息。卷积网络和全连接的主要区别在于对输入图像的处理上。在卷积层上,每个卷积神经元接收卷积的输入,然后通过LIF模型更新网络参数;在池化层,由于SNN的二值编码对max pooling不适用,所以采用average pooling。
同样在MNIST和目标检测数据集进行测试。
在MNIST上,卷积网络包含一个卷积层,其中卷积核的大小为5;包含相互交替的两个平均池化层,然后再接一个隐藏层。类似于传统CNN,用**弹性形变(elastic distortion)**的方法预处理数据。
下面是卷积网络的结果比较(MNIST&objection detection dataset):
一些分析
1.导数近似曲线的选取及相关参数对于结果的影响:不同曲线最终结果的差异较小,参数aia_iai太大或太小会导致结果很差。用曲线取近似脉冲活动的导数关键在于捕捉非线性特征,曲线的形状不是关键。
2.本文相对于SDBP,主要多考虑了时间上的因素,使得结果准确率提升。对比如下图:
3.SNN的训练严重依赖于参数的初始化。但是本文中的方法不大依赖,例如侧抑制、正则化和标准化的一些tricks。想要SNN网络有更好的稳定性和鲁棒性就要多考虑动态时间的影响。
4.在未来的研究中,有两个问题比较关键:一是把本文的框架应用到更多的含有时间特性的问题上,像动态数据处理、视频流鉴定和语音识别等领域上的问题;二是如何加速SNN在GPU/CPU或是神经形态芯片上的训练。
code reference
Spatio-Temporal Backpropagation for Training High-performance Spiking Neural Networks笔记相关推荐
- 论文略读1《Direct training for spiking neural networks:faster,larger,better》
发布时间2019-07-17,AAAI 文章链接Direct Training for Spiking Neural Networks: Faster, Larger, Better | Procee ...
- Mapping Spiking Neural Networks的论文汇总以及思考
首先感谢CSDN平台,发现不是我一个人在SNN Mapping方面纠结着.去年看了Mapping方面的内容后感觉想创新还是有点难度的,毕竟优化就是生物进化算法类似的套路,可是你会发现自己实现的结果就是 ...
- Spiking neural networks 2017 进展
The Brain as an Efficient and Robust Adaptive Learner Training Spiking Neural Networks for Cognit ...
- 论文精翻《Progressive Tandem Learning for Pattern Recognition With Deep Spiking Neural Networks》
目录 0 摘要/Abstract 1 简介/Introduction 2 相关工作/Related Work 3 重新思考ANN-to-SNN的转换/Rethinking ANN-to-SNN Con ...
- Direct Training for Spiking Neural Networks: Faster, Larger, Better
摘要 我们提出一种神经元正则化技术去调整神经元分立,而且发展了一种直接的训练算法对于深层SNN. 通过缩小速率编码窗口和转换LIF模型到精确的迭代版本,我们提出了基于pytorch版本的手段去训练深度 ...
- SpykeTorch: Efficient Simulation of Convolutional Spiking Neural Networks With at Most One Spike per
引言 下面我们介绍一下Spyketorch,这个是基于pytorch的框架的.在相同的学习规则下,可以使用STDP和R-STDP,其他的法则也能够应用.该框架中的计算是基于张量的,完全由PyTorch ...
- Going Deeper in Spiking Neural Networks: VGG and Residual Architectures
摘要 在过去几年中,SNN已经成为最受欢迎的低能耗方式.我们提出了分析稀疏事件驱动计算去展现硬件的冗余,当在脉冲领域. 关键词: 脉冲神经网络,事件驱动,稀疏性,神经元计算,视觉辨认. 主要贡献 1. ...
- A remark on the error-backpropagation learning algorithm for spiking neural networks
关于尖峰神经网络的误差反向传播学习算法的一点评论✩ 摘要 在用于脉冲神经网络的误差反向传播学习算法中,必须将触发时间tαt^\alphatα区分为状态函数x(t)x(t)x(t)的函数.但是这种区分是 ...
- 【阅读】A Comprehensive Survey on Distributed Training of Graph Neural Networks——翻译
转载请注明出处:小锋学长生活大爆炸[xfxuezhang.cn] (本文中,涉及到公式部分的翻译不准确,请看对应原文.) 另一篇:[阅读]Distributed Graph Neural Networ ...
最新文章
- SOA标准发展混乱 国内业务缺少经验
- SNF软件开发机器人-子系统-功能-【列表】自由排序-如何配置?
- 畅销书《深入浅出Vue.js》作者,在阿里淘系1年的收获成长
- 2016年:勒索病毒造成损失预估超过10亿美元
- 全局修改elementui message 右边弹出_ElementUI 只允许 $message 提示一次
- 【CCF】201709-2公共钥匙盒
- libx264进行视频编码的流程
- 代码管理_阿里巴巴如何管理代码分支?
- Flutter基础—绘画效果之不透明度
- 考试倒计时,计算机二级重难点汇总【39套历年考题】
- access to同义替换_access to 用法
- arcmap10.7打开tif文件一片空白 | 解决方法
- linux中屏蔽定时任务,linux中的定时任务
- 公司女同事深夜11点让我去她住处修电脑,原来是C盘爆红,看我一招搞定女同事....的电脑
- 二级路由器设置,二级路由器无法上网
- python f 格式字符串输出
- ANSYS - 表格加载方法
- Chrome谷歌浏览器无法调用摄像头原因及解决办法
- Java之IO流技术详解
- 【舰船数据集格式转换】HRSID数据集VOC转COCO
热门文章
- 什么是iBeacon
- 中文分词是一个伪命题
- Fedora Linux添加Canon打印机驱动
- prism IRegionMemberLifetime(区域成员生命周期)
- matlab y e x,如何使用matlab绘制函数y=xloge(x^2-1)的函数图像,以e为底。
- 【Android Gradle 插件】DexOptions 配置 ⑤ ( additionalParameters 属性配置 | --minimal-main-dex 参数最小化主 dex 字节码 )
- js-跟着鼠标移动的图片
- CANoe C-V2X Demo(V2I+V2V)演示视频
- 我是如何从屌丝程序员逆袭成为大厂总监的?
- Python网络爬虫之Xpath详解