脉冲神经网络大致流程
脉冲神经网络整体流程
1、导入库
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
# 若环境没有相应的包,则通过pip/conda install *** 进行安装
2、定义超参数
#####如下参数为常用参数,可自行添加或删除#####
parser = argparse.ArgumentParser(description='Classify MNIST Use LIF')
parser.add_argument('-device', default='cpu', help='运行的设备,例如“cpu”或“cuda:0”')
parser.add_argument('--dataset-dir', default='./', help='MNIST数据集的位置')
parser.add_argument('-b', '--batch-size', default=64, type=int, help='Batch 大小')
parser.add_argument('-T', '--timesteps', default=100, type=int, dest='T', help='时间窗口')
parser.add_argument('-lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='学习率')
parser.add_argument('-tau', default=2.0, type=float, help='LIF神经元的时间常数tau')
parser.add_argument('-epochs', default=64, type=int, metavar='N',help='训练轮次')parser.add_argument('-j', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)')
parser.add_argument('-channels', default=128, type=int, help='channels of Conv2d in SNN')
# parser.add_argument('--log-dir', default='./', help='保存日志文件的位置')
# parser.add_argument('--model-output-dir', default='./', help='模型保存路径')
args = parser.parse_args()
print(args)
3、定义网络结构 Net
需要先定义一个 Class,继承自 nn.Module
类,这个 Class 里主要写两个函数,一个是初始化的 __init__
函数,另一个是 forward 函数。
① Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks 论文中的网络模型
class VotingLayer(nn.Module):def __init__(self, voter_num: int):super().__init__()self.voting = nn.AvgPool1d(voter_num, voter_num)def forward(self, x: torch.Tensor):# x.shape = [N, voter_num * C]# ret.shape = [N, C]return self.voting(x.unsqueeze(1)).squeeze(1)# 参考:https://blog.csdn.net/m0_55519533/article/details/119103011class Net(nn.Module):def __init__(self, channels: int):super().__init__()conv = []conv.extend(PythonNet.conv3x3(2, channels))conv.append(nn.MaxPool2d(2, 2))for i in range(4):conv.extend(PythonNet.conv3x3(channels, channels))conv.append(nn.MaxPool2d(2, 2))self.conv = nn.Sequential(*conv)# (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)# (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)# (2): LIFNode(# v_threshold=1.0, v_reset=0.0, detach_reset=True, tau=2.0# (surrogate_function): ATan(alpha=2.0, spiking=True)# )# (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)self.fc = nn.Sequential(nn.Flatten(),layer.Dropout(0.5),nn.Linear(channels * 4 * 4, channels * 2 * 2, bias=False),neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True),layer.Dropout(0.5),nn.Linear(channels * 2 * 2, 110, bias=False),neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True))# (0): Flatten(start_dim=1, end_dim=-1)# (1): Dropout(p=0.5)# (2): Linear(in_features=2048, out_features=512, bias=False)# (3): LIFNode(# v_threshold=1.0, v_reset=0.0, detach_reset=True, tau=2.0# (surrogate_function): ATan(alpha=2.0, spiking=True)# )self.vote = VotingLayer(10) # 平均池化,将输出层的tensor(16,110)采用投票机制,转为tensor(16,11)# (voting): AvgPool1d(kernel_size=(10,), stride=(10,), padding=(0,))def forward(self, x: torch.Tensor):x = x.permute(1, 0, 2, 3, 4) # [N, T, 2, H, W] -> [T, N, 2, H, W]out_spikes = self.vote(self.fc(self.conv(x[0])))for t in range(1, x.shape[0]):out_spikes += self.vote(self.fc(self.conv(x[t])))return out_spikes / x.shape[0]@staticmethoddef conv3x3(in_channels: int, out_channels):return [nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1, bias=False), # c128k3s1nn.BatchNorm2d(out_channels), # BNneuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True) # MPk2s2]
② 一个简单识别 MNIST 数据集的网络结构
net = nn.Sequential(nn.Flatten(), # 将28*28->784nn.Linear(28 * 28, 10, bias=False),neuron.LIFNode(tau=tau)
)
将网络加载到运行设备上
net = Net()
net = net.to(device)
4、选择优化器 Optimizer
# 添加超参数
parser.add_argument('-opt', default='SGD', type=str, help='use which optimizer. SDG or Adam')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')# 选择优化器:'SGD、Adam or others'optimizer = Noneif args.opt == 'SGD':optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)elif args.opt == 'Adam':optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)else:raise NotImplementedError(args.opt)# 输出net.parameters()参数,为各层网络权重值for _, param in enumerate(net.parameters()):print(param.shape)
5、(选用)选择学习率衰减策略
# 添加超参数
parser.add_argument('-lr_scheduler', default='CosALR', type=str, help='学习率衰减策略. StepLR or CosALR')
parser.add_argument('-step_size', default=32, type=float, help='step_size for StepLR')
parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR')
parser.add_argument('-T_max', default=32, type=int, help='T_max for CosineAnnealingLR')# 选用学习率衰减策略:'StepLR or CosALR or others'lr_scheduler = Noneif args.lr_scheduler == 'StepLR':lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)elif args.lr_scheduler == 'CosALR':lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.T_max)else:raise NotImplementedError(args.lr_scheduler)
6、(选用)自动混合精度训练
# 添加超参数
parser.add_argument('-amp', action='store_true', help='是否进行自动混合精度训练,可以大幅度提升速度,减少显存消耗')scaler = Noneif args.amp:scaler = amp.GradScaler()
7、加载数据
train_dataset = torchvision.datasets.MNIST(root=dataset_dir,train=True,transform=torchvision.transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root=dataset_dir,train=False,transform=torchvision.transforms.ToTensor(),download=True)train_data_loader = data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,drop_last=True)test_data_loader = data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False,drop_last=False)
8、编码数据为脉冲序列
在定义完之后,开始一次一次的循环:
① 先清空优化器里的梯度信息,optimizer.zero_grad();② 再将data传入,正向传播,output=net(data);③ 计算损失,loss=F.mse_loss(target,output) # 这里target就是识别目标,需要自己准备,和之前传入的input类型一一对应;④ 误差反向传播,loss.backward();⑤ 更新参数,optimizer.step();⑥ 重置网络状态,functional.reset_net(net);
脉冲神经网络大致流程相关推荐
- 强化学习中的脉冲神经网络
简 介: 脉冲强化学习是最近兴起的将脉冲神经网络应用到强化学习中的一个研究领域.固然脉冲神经网络的引入会给强化学习带来一些新的东西,但目前的研究仍然仅仅满足于如何让算法收敛,而没有发挥出脉冲神经网络独 ...
- 各种神经网络的应用领域,脉冲神经网络发展前景
脉冲神经网络的简介 脉冲神经网络 (SNN-Spiking Neuron Networks) 经常被誉为第三代人工神经网络.第一代神经网络是感知器,它是一个简单的神经元模型并且只能处理二进制数据. 第 ...
- 【NEST】脉冲神经网络仿真平台入门手册整理翻译记录
这是国庆前导师让了解的脉冲神经仿真平台NEST的部分介绍手册的翻译和整理,记录一下留个备份,主要内容可以通过查看文档中的链接索引到官网. 如需要手册代码合辑及例程ipynb文件,请查看NEST脉冲神经 ...
- 人工神经网络秒变脉冲神经网络,新技术有望开启边缘AI计算新时代
点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 来自:机器之心 能更好模仿生物神经系统运行机制的脉冲神经网络在发展速度和应用范围上都还 ...
- 边缘AI计算新时代,人工神经网络秒变脉冲神经网络
点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 能 ...
- SNN介绍-来自脉冲神经网络原理
神经科学的一些实验证据表明,视觉与听觉等许多生物神经系统都采用神经元发放的动作电位(即脉冲)的时间来编码信息.针对这些问题,更加符合生物神经系统实际情况的第三代人工神经网络模型--脉冲神经网络应运而生 ...
- 脉冲神经网络在目标检测的首次尝试,性能堪比CNN | AAAI 2020
译者 | VincentLee 来源 | 晓飞的算法工程笔记 脉冲神经网络(Spiking neural network, SNN)将脉冲神经元作为计算单元,能够模仿人类大脑的信息编码和处理过程.不 ...
- 第三十二课.脉冲神经网络SNN
目录 时间驱动与事件驱动 时间驱动 事件驱动 基于时间驱动的脉冲神经元 spikingjelly:LIF神经元 实验仿真 时间驱动与事件驱动 时间驱动 为了便于理解时间驱动,我们可以将SNN(spik ...
- 脉冲神经网络(SNN)概述
https://www.toutiao.com/a6701844289518830091/ 主要讨论脉冲神经网络的拓扑结构.信息的脉冲序列编码方法.脉冲神经网络的学习算法和进化方法等. 一.脉冲神经网 ...
最新文章
- 【Windows Phone】Metro设计语言
- filezilla 定时上传_FileZilla Server安装教程 - FtpCopy数据自动备份软件(FTP定时备份)|FTP自动下载|FTP自动上传|FTP自动备份...
- 模板类 Template Classes 以及模板类编译时的处理
- 利用Asp.net MVC处理文件的上传下载
- 成员变量与局部变量 java 1613807617
- 1、eclipse 使用git提交项目至github进行项目托管
- Windows Server 2012正式版RDS系列⑻
- 通过分析词性进行人名、地名、组织的替换,生成新的狗屁不通文章
- 【Windows】VMware虚拟机安装Windows 10 教程
- 清华上交等发表Nature子刊!分片线性神经网络最新综述!
- 201871010114-李岩松《面向对象程序设计(java)》第四周学习总结
- “esxcli software vib” commands to patch an ESXi 5.x/6.x host (2008939)
- java解析shp文件以及坐标转换(工具类)
- ecshop ecmall shopex
- 计算机基础1杨石答案第五章,计算机基础课程教学改革与实践
- 十大python培训机构
- 直播弹幕系统(五)- 整合Stomp替换原生WebSocket方案探究
- JavaScript 双击禁止选中文字
- 安卓手机可以用python编程软件-有哪些可以在手机上敲Python代码的App
- 最新170个站长在线工具箱网站源码/野兔在线工具系统V2.4.1中文版
热门文章
- 在html插入数学公式,给WordPress的文章插入数学公式
- Python爬虫-Selenium(1)
- bmi计算 python_python tkinter bmi计算
- Java实现 蓝桥杯VIP 算法提高 3000米排名预测
- “我的代码正在被千百万人使用”,MySQL 之父等六大国际数据库掌门人谈如何做数据库!...
- 【C语言】求s = a + aa + aaa + aaaa + aa...a的值,其中a是一个数字
- tarjan算法求解强连通分量问题
- php输出setcookie,PHP函数:setcookie()
- 【0基础强力推荐】R语言快速入门
- 华为怎么设置计算机快捷,使用命令快速设置华为路由器