《Spatio-Temporal Backpropagation for Training High-performance Spiking Neural Networks》笔记

ABSTRACT

STBP:Spatio-Temporal Backpropagation

1.探索SNN的一个重要原因:spikes的相关编码可以包含很多的时空信息。

2.目前很多(这篇文章发之前)研究都只关注于神经网络的空间域信息,造成了相关研究的瓶颈。也有部分只研究独立时域信息。

3.spike activity是无法微分的,这给SNN的训练造成了很大的困难。

4.本文建立了iterative LIF模型,易于梯度下降算法的训练。训练同时考虑层与层之间的空间域关系和独立的时域关系。

5.提出了spike activity导数的近似处理函数。

6.本文涉及到的训练没有采用任何复杂的训练trick。

7.应用的数据集为静态MNIST、自己建立的一个目标检测数据集、动态N-MNIST。

INTRODUCTION

1.SNN的两点优势:时间动态性&硬件友好性(时间和能量消耗低,神经形态芯片)

2.三种SNN的训练方法:

无监督学习(利用生物突触可塑性)

例如spiking timing dependent plasticity(STDP),但是其只考虑当前神经元随时间的活动所以表现不大好。

间接监督学习(ANN2SNN)

先训练好一个ANN,然后将其转换到SNN的版本。SNN的spiking rate充当ANN神经元的模拟活动。这个方法生物可信性比较差。

直接监督学习(如本文采用的STBP方法)

这种方法在本论文之前很多都只考虑了spatial domain,而没考虑temporal domain。并且需要很多复杂的训练技巧来提高准确率。但是STBP就不需要这些复杂的训练技巧。

3.利用SNN的动态特性建立了一个iterative LIF模型,这个模型对梯度下降算法很友好。介绍了近似导数来解决无法微分的问题。用三种数据集检验。

4.分析了时间动态性以及不同导数近似方法对最后结果的影响。

正文部分

LIF模型如下:

τdu(t)dt=−u(t)+I(t)\tau\frac{du(t)}{dt} = -u(t) + I(t)τdtdu(t)​=−u(t)+I(t)

u(t)u(t)u(t)是神经元膜电位,τ\tauτ是时间常数,I(t)I(t)I(t)是前突触传来的输入,由前神经元活动以及外界干扰和突触权重决定。膜电位超过阈值VthV_{th}Vth​神经元便会激发一个脉冲,然后马上恢复到urestu_{rest}urest​。

相比于DNN的传播,SNN的每个神经元有自反馈注入,从而在时间域产生非易失性的电位积累。

对LIF模型适当简化,并且遵循事件驱动迭代的更新规则,有下式:

u(t)=u(ti−1)eti−1−tτ+I(t)u(t) = u(t_{i-1})e^\frac{t_{i-1}-t}{\tau} + I(t)u(t)=u(ti−1​)eτti−1​−t​+I(t)

膜电位指数级衰减直到神经元接收到前突触的输入,在新的一个更新周期该神经元才有可能激发脉冲。一个神经元的状态由上式右边的两个因素决定,分别代表时域和空间域。

SNN(STBP)的整体(SD&TD)结构图如下:

在SD和TD同时进行迭代,如下式:

xit+1,n=∑j=1l(n−1)wijnojt+1,n−1x_i^{t+1,n} = \sum_{j=1}^{l(n-1)}w_{ij}^no_j^{t+1,n-1}xit+1,n​=∑j=1l(n−1)​wijn​ojt+1,n−1​

uit+1,n=uit,nf(oit,n)+xit+1,n+binu_i^{t+1,n} = u_i^{t,n}f(o_i^{t,n})+x_i^{t+1,n} + b_i^nuit+1,n​=uit,n​f(oit,n​)+xit+1,n​+bin​

Oit+1,n=g(uit+1,n)O_i^{t+1,n} = g(u_i^{t+1,n})Oit+1,n​=g(uit+1,n​)

其中f(x)=τe−xτf(x) = \tau e^{-\frac{x}{\tau}}f(x)=τe−τx​,g(x)={1x≥Vth0x<Vthg(x)=\left\{ \begin{aligned} 1 & & x \geq V_{th} \\ 0 & & x < V_{th} \end{aligned} \right.g(x)={10​​x≥Vth​x<Vth​​

t表示时刻,n表示神经元所处的层数,l(n)l(n)l(n)表示第n层神经元的总数,wijw_{ij}wij​是从第j个前突触层的神经元到第i个后突触层得神经元的权重。oj=1o_j = 1oj​=1表示产生了脉冲,等于0表示什么也没发生,xix_ixi​是第i个神经元的前突触输入和的简要表示,类似于LIF中的I,uiu_iui​是第i个神经元的膜电位,bib_ibi​是一个与与阈值VthV_{th}Vth​相关的参数值。

uit+1,n=uit,nf(oit,n)+xit+1,n+binu_i^{t+1,n} = u_i^{t,n}f(o_i^{t,n})+x_i^{t+1,n} + b_i^nuit+1,n​=uit,n​f(oit,n​)+xit+1,n​+bin​和Oit+1,n=g(uit+1,n)O_i^{t+1,n} = g(u_i^{t+1,n})Oit+1,n​=g(uit+1,n​)是由LSTM模型的灵感得到。

τ\tauτ 是一个较小的正时间常数,则有:

f(oit,n)={τoit,n=00oit,n=1f(o_i^{t,n})=\left\{ \begin{aligned} \tau & & o_i^{t,n} = 0 \\ 0 & & o_i^{t,n} = 1 \end{aligned} \right.f(oit,n​)={τ0​​oit,n​=0oit,n​=1​

在时间窗口T上用均方误差制定损失函数如下:

L=12S∑s=1S∣∣ys−1T∑t=1Tost,N∣∣22L = \frac{1}{2S} \sum_{s=1}^{S}||y_s - \frac{1}{T} \sum_{t =1}^{T} o_s^{t,N}||^2_2L=2S1​∑s=1S​∣∣ys​−T1​∑t=1T​ost,N​∣∣22​

其中ysy_sys​和oso_sos​表示标签向量和第N层的输出向量。可以看出L是W和b的函数。假定已经得到每层每个时刻的导数∂L∂oi\frac{\partial L}{\partial o_i}∂oi​∂L​和∂L∂ui\frac{\partial L}{\partial u_i}∂ui​∂L​,这对求取∂L∂W\frac{\partial L}{\partial W}∂W∂L​和∂L∂b\frac{\partial L}{\partial b}∂b∂L​,很关键。根据导数的链式法则推导四种情况下的梯度下降过程暂时忽略(主要是推导∂L∂oit,n\frac{\partial L}{\partial o_i^{t,n}}∂oit,n​∂L​和∂L∂uit,n\frac{\partial L}{\partial u_i^{t,n}}∂uit,n​∂L​)。

下面是在STD上误差传播的示意图:

a图表示垂直方向是SD的误差传播,水平方向是TD的误差传播。b图扩展成了网络级别的示意图。

关于∂L∂W\frac{\partial L}{\partial W}∂W∂L​和∂L∂b\frac{\partial L}{\partial b}∂b∂L​,有下面的式子:

其中和∂L∂ut,n\frac{\partial L}{\partial u^{t,n}}∂ut,n∂L​可以根据之前推导的过程得出。

前面推导了基于STBP的梯度信息,下面解决在脉冲激发点无法微分的问题。这个主要是针对g(u)g(u)g(u),它的导数在0处取无限值,其余地方都取0,会在误差传播过程中造成梯度消失或爆炸的问题。目前的一个办法是将这个不连续的点(脉冲时刻)视作噪声,然后宣称它对模型的鲁棒性有益,但是不从根本解决问题。

下面是本文提出的四种曲线来近似g(u)g(u)g(u)的导数:

h1(u)=1a1sign(∣u−Vth∣<a12)h_1(u) = \frac{1}{a_1}sign(|u-V_{th}| < \frac{a_1}{2})h1​(u)=a1​1​sign(∣u−Vth​∣<2a1​​)

h2(u)=(a22−a24∣u−Vth∣)sign(2a2−∣u−Vth∣)h_2(u) = (\frac{\sqrt{a_2}}{2}-\frac{a_2}{4}|u-V_{th}|)sign(\frac{2}{\sqrt{a_2}}-|u-V_{th}|)h2​(u)=(2a2​​​−4a2​​∣u−Vth​∣)sign(a2​​2​−∣u−Vth​∣)

h3(u)=1a3eVth−ua3(1+eVth−ua3)2h_3(u) = \frac{1}{a_3}\frac{e^{\frac{V_th-u}{a_3}}}{(1+e^{\frac{V_th-u}{a_3}})^2}h3​(u)=a3​1​(1+ea3​Vt​h−u​)2ea3​Vt​h−u​​

h4(u)=12πa4e(u−Vth)2a4h_4(u) = \frac{1}{\sqrt{2\pi a_4}}e^{\frac{(u-V_th)^2}{a_4}}h4​(u)=2πa4​​1​ea4​(u−Vt​h)2​

参数aia_iai​决定曲线的尖锐程度。事实上,这四种曲线分别是矩形函数、多项式函数、sigmoid函数和高斯积累分布函数的衍生。调整aia_iai​的值,使每个曲线的积分都为1,可以证明下面的式子:

lim⁡ai=0+hi(u)=dgdu,i=1,2,3,4\lim_{a_i = 0^+}h_i(u) = \frac{dg}{du},i = 1,2,3,4limai​=0+​hi​(u)=dudg​,i=1,2,3,4

进一步得到下面的近似:

dgdu≈dhidu\frac{dg}{du} \approx \frac{dh_i}{du}dudg​≈dudhi​​

下面是g(u)g(u)g(u)和它导数以及各个曲线的示意图:

结果

参数初始化

本文固定阈值为一常数,调整weight来控制spiking活动的平衡。

首先将所有的权重参数标准均匀分布在-1到1之间:W∼U[−1,1]W \sim U[-1,1]W∼U[−1,1]。

然后normalize所有参数: wi,jn=wi,jn∑j=1l(n−1)wi,jn2w_{i,j}^n = \frac{ w_{i,j}^n}{\sqrt{\sum_{j=1}^{l(n-1)}{w_{i,j}^n}^2}}wi,jn​=∑j=1l(n−1)​wi,jn​2​wi,jn​​

其他参数见下图:

各个数据集的实验

首先要将静态的数据集转换为spike events。本文采用了伯努利采样的方式从原始像素密度转换为了spike rate。

全连接网络

一、静态数据集的相关实验(主要)

测试该网络的数据集有MNIST和自制的一个目标检测数据集,均为静态数据集。MNIST数据集中用60000个有标签的手写数字进行训练,10000个用来测试,每个数字图片是28×2828\times2828×28的灰度图片;目标检测数据集是一个二类图像数据集,以有无行人分类,包含1509个训练样本和631个测试样本,也是28×2828\times2828×28的灰度图像。

具体网络进行的工作与一些相关操作:

1.限定一个时间窗口,例如20ms,程序中即依次仿真0-19时刻。

2.对于每个时刻,做一次下面的处理:先把输入的(1,28,28)图像转换成784维的列向量并进行标准化操作把向量中的每个元素处理成0~1范围内的数,然后用与随机向量做比较的方法,对784维向量进行0-1赋值,作为输入的spike。

3.网络框架:784-800-10。

4.最后看窗口时间内,最后10个神经元(标签是0-9)的spiking rate的大小决定该图片表示的数字。

注意点:由于目标检测数据集的训练输入脉冲的激发率高于MNIST数据集,所以相应的阈值会有变化。

下面两幅图是实验结果(和其他方法对比):

二、动态数据集的相关实验

动态数据集相较于静态数据集包含更丰富的时域特征。N-MNIST数据集是通过动态视觉传感器(DVS)把静态MNIST数据集转换成它的动态版本即脉冲训练输入。

具体转换的操作:控制DVS按照等腰三角形的三边依次移动,在每个像素点强度变化时触发并生成spike train。示意图如下:

由于对于每个像素点的强度变化有变亮和变暗两种,DVS可以捕捉到两种spike events,记作on-event和off-event。由于N-MNIST允许在扫视过程中图像的相对偏移,它产生 34×3434\times3434×34 像素范围。

从上图中可以明显知道on-event和off-event区别很大,所以用两个通道区分,网络框架为:34×34×2−400−400−1034\times34\times2-400-400-1034×34×2−400−400−10。

下图是用N-MNIST数据集进行各种方法的结果比较(两层400合成了一层800):

ANN相关的方法通常用框架固定的模型,但是由于动态数据集的图像较为模糊,这种方法表现不好,且放弃了硬件友好的事件驱动范式。相比起来,SNN易于处理事件流图片。

卷积网络

SNN的全连接网络扩展到卷积网络可以让网络更深,并且包含更多的时域信息。卷积网络和全连接的主要区别在于对输入图像的处理上。在卷积层上,每个卷积神经元接收卷积的输入,然后通过LIF模型更新网络参数;在池化层,由于SNN的二值编码对max pooling不适用,所以采用average pooling。

同样在MNIST和目标检测数据集进行测试。

在MNIST上,卷积网络包含一个卷积层,其中卷积核的大小为5;包含相互交替的两个平均池化层,然后再接一个隐藏层。类似于传统CNN,用**弹性形变(elastic distortion)**的方法预处理数据。

下面是卷积网络的结果比较(MNIST&objection detection dataset):

一些分析

1.导数近似曲线的选取及相关参数对于结果的影响:不同曲线最终结果的差异较小,参数aia_iai​太大或太小会导致结果很差。用曲线取近似脉冲活动的导数关键在于捕捉非线性特征,曲线的形状不是关键。

2.本文相对于SDBP,主要多考虑了时间上的因素,使得结果准确率提升。对比如下图:

3.SNN的训练严重依赖于参数的初始化。但是本文中的方法不大依赖,例如侧抑制、正则化和标准化的一些tricks。想要SNN网络有更好的稳定性和鲁棒性就要多考虑动态时间的影响。

4.在未来的研究中,有两个问题比较关键:一是把本文的框架应用到更多的含有时间特性的问题上,像动态数据处理、视频流鉴定和语音识别等领域上的问题;二是如何加速SNN在GPU/CPU或是神经形态芯片上的训练。

code reference

Spatio-Temporal Backpropagation for Training High-performance Spiking Neural Networks笔记相关推荐

  1. 论文略读1《Direct training for spiking neural networks:faster,larger,better》

    发布时间2019-07-17,AAAI 文章链接Direct Training for Spiking Neural Networks: Faster, Larger, Better | Procee ...

  2. Mapping Spiking Neural Networks的论文汇总以及思考

    首先感谢CSDN平台,发现不是我一个人在SNN Mapping方面纠结着.去年看了Mapping方面的内容后感觉想创新还是有点难度的,毕竟优化就是生物进化算法类似的套路,可是你会发现自己实现的结果就是 ...

  3. Spiking neural networks 2017 进展

     The Brain as an Efficient and Robust Adaptive Learner Training Spiking Neural Networks for Cognit ...

  4. 论文精翻《Progressive Tandem Learning for Pattern Recognition With Deep Spiking Neural Networks》

    目录 0 摘要/Abstract 1 简介/Introduction 2 相关工作/Related Work 3 重新思考ANN-to-SNN的转换/Rethinking ANN-to-SNN Con ...

  5. Direct Training for Spiking Neural Networks: Faster, Larger, Better

    摘要 我们提出一种神经元正则化技术去调整神经元分立,而且发展了一种直接的训练算法对于深层SNN. 通过缩小速率编码窗口和转换LIF模型到精确的迭代版本,我们提出了基于pytorch版本的手段去训练深度 ...

  6. SpykeTorch: Efficient Simulation of Convolutional Spiking Neural Networks With at Most One Spike per

    引言 下面我们介绍一下Spyketorch,这个是基于pytorch的框架的.在相同的学习规则下,可以使用STDP和R-STDP,其他的法则也能够应用.该框架中的计算是基于张量的,完全由PyTorch ...

  7. Going Deeper in Spiking Neural Networks: VGG and Residual Architectures

    摘要 在过去几年中,SNN已经成为最受欢迎的低能耗方式.我们提出了分析稀疏事件驱动计算去展现硬件的冗余,当在脉冲领域. 关键词: 脉冲神经网络,事件驱动,稀疏性,神经元计算,视觉辨认. 主要贡献 1. ...

  8. A remark on the error-backpropagation learning algorithm for spiking neural networks

    关于尖峰神经网络的误差反向传播学习算法的一点评论✩ 摘要 在用于脉冲神经网络的误差反向传播学习算法中,必须将触发时间tαt^\alphatα区分为状态函数x(t)x(t)x(t)的函数.但是这种区分是 ...

  9. 【阅读】A Comprehensive Survey on Distributed Training of Graph Neural Networks——翻译

    转载请注明出处:小锋学长生活大爆炸[xfxuezhang.cn] (本文中,涉及到公式部分的翻译不准确,请看对应原文.) 另一篇:[阅读]Distributed Graph Neural Networ ...

最新文章

  1. SOA标准发展混乱 国内业务缺少经验
  2. SNF软件开发机器人-子系统-功能-【列表】自由排序-如何配置?
  3. 畅销书《深入浅出Vue.js》作者,在阿里淘系1年的收获成长
  4. 2016年:勒索病毒造成损失预估超过10亿美元
  5. 全局修改elementui message 右边弹出_ElementUI 只允许 $message 提示一次
  6. 【CCF】201709-2公共钥匙盒
  7. libx264进行视频编码的流程
  8. 代码管理_阿里巴巴如何管理代码分支?
  9. Flutter基础—绘画效果之不透明度
  10. 考试倒计时,计算机二级重难点汇总【39套历年考题】
  11. access to同义替换_access to 用法
  12. arcmap10.7打开tif文件一片空白 | 解决方法
  13. linux中屏蔽定时任务,linux中的定时任务
  14. 公司女同事深夜11点让我去她住处修电脑,原来是C盘爆红,看我一招搞定女同事....的电脑
  15. 二级路由器设置,二级路由器无法上网
  16. python f 格式字符串输出
  17. ANSYS - 表格加载方法
  18. Chrome谷歌浏览器无法调用摄像头原因及解决办法
  19. Java之IO流技术详解
  20. 【舰船数据集格式转换】HRSID数据集VOC转COCO

热门文章

  1. 什么是iBeacon
  2. 中文分词是一个伪命题
  3. Fedora Linux添加Canon打印机驱动
  4. prism IRegionMemberLifetime(区域成员生命周期)
  5. matlab y e x,如何使用matlab绘制函数y=xloge(x^2-1)的函数图像,以e为底。
  6. 【Android Gradle 插件】DexOptions 配置 ⑤ ( additionalParameters 属性配置 | --minimal-main-dex 参数最小化主 dex 字节码 )
  7. js-跟着鼠标移动的图片
  8. CANoe C-V2X Demo(V2I+V2V)演示视频
  9. 我是如何从屌丝程序员逆袭成为大厂总监的?
  10. Python网络爬虫之Xpath详解