EEGNet: 神经网络应用于脑电信号
脑机接口(BCI)使用神经活动作为控制信号,实现与计算机的直接通信。这种神经信号通常是从各种研究透彻的脑电图(EEG)信号中挑选出来的。卷积神经网络(CNN)主要用来自动特征提取和分类,其在计算机视觉和语音识别领域中的使用已经很广泛。CNN已成功应用于基于EEG的BCI;但是,CNN主要应用于单个BCI范式,在其他范式中的使用比较少,论文作者提出是否可以设计一个CNN架构来准确分类来自不同BCI范式的EEG信号,同时尽可能地紧凑(定义为模型中的参数数量)。
该论文介绍了EEGNet,这是一种用于基于EEG的BCI的紧凑型卷积神经网络。论文介绍了使用深度和可分离卷积来构建特定于EEG的模型,该模型封装了脑机接口中常见的EEG特征提取概念。论文通过四种BCI范式(P300视觉诱发电位、错误相关负性反应(ERN)、运动相关皮层电位(MRCP)和感觉运动节律(SMR)),将EEGNet在主体内和跨主体分类方面与目前最先进的方法进行了比较。结果显示,在训练数据有限的情况下,EEGNet比参考算法具有更强的泛化能力和更高的性能。同时论文也证明了EEGNet可以有效地推广到ERP和基于振荡的BCI。
实验结果如下图,P300数据集的所有CNN模型之间的差异非常小,但是MRCP数据集却存在显著的差异,两个EEGNet模型的性能都优于所有其他模型。对于ERN数据集来说,两个EEGNet模型的性能都优于其他所有模型(p < 0.05)。
EEGNet网络原理
EEGNet网络结构图:
EEGNet原理架构如下:
EEGNet网络实现
import numpy as np
from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
定义网络模型:
class EEGNet(nn.Module):def __init__(self):super(EEGNet, self).__init__()self.T = 120# Layer 1self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)self.batchnorm1 = nn.BatchNorm2d(16, False)# Layer 2self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))self.conv2 = nn.Conv2d(1, 4, (2, 32))self.batchnorm2 = nn.BatchNorm2d(4, False)self.pooling2 = nn.MaxPool2d(2, 4)# Layer 3self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))self.conv3 = nn.Conv2d(4, 4, (8, 4))self.batchnorm3 = nn.BatchNorm2d(4, False)self.pooling3 = nn.MaxPool2d((2, 4))# 全连接层# 此维度将取决于数据中每个样本的时间戳数。# I have 120 timepoints. self.fc1 = nn.Linear(4*2*7, 1)def forward(self, x):# Layer 1x = F.elu(self.conv1(x))x = self.batchnorm1(x)x = F.dropout(x, 0.25)x = x.permute(0, 3, 1, 2)# Layer 2x = self.padding1(x)x = F.elu(self.conv2(x))x = self.batchnorm2(x)x = F.dropout(x, 0.25)x = self.pooling2(x)# Layer 3x = self.padding2(x)x = F.elu(self.conv3(x))x = self.batchnorm3(x)x = F.dropout(x, 0.25)x = self.pooling3(x)# 全连接层x = x.view(-1, 4*2*7)x = F.sigmoid(self.fc1(x))return x
定义评估指标:
acc:准确率
auc:AUC 即 ROC 曲线对应的面积
recall:召回率
precision:精确率
fmeasure:F值
def evaluate(model, X, Y, params = ["acc"]):results = []batch_size = 100predicted = []for i in range(len(X)//batch_size):s = i*batch_sizee = i*batch_size+batch_sizeinputs = Variable(torch.from_numpy(X[s:e]))pred = model(inputs)predicted.append(pred.data.cpu().numpy())inputs = Variable(torch.from_numpy(X))predicted = model(inputs)predicted = predicted.data.cpu().numpy()"""设置评估指标:acc:准确率auc:AUC 即 ROC 曲线对应的面积recall:召回率precision:精确率fmeasure:F值"""for param in params:if param == 'acc':results.append(accuracy_score(Y, np.round(predicted)))if param == "auc":results.append(roc_auc_score(Y, predicted))if param == "recall":results.append(recall_score(Y, np.round(predicted)))if param == "precision":results.append(precision_score(Y, np.round(predicted)))if param == "fmeasure":precision = precision_score(Y, np.round(predicted))recall = recall_score(Y, np.round(predicted))results.append(2*precision*recall/ (precision+recall))return results
构建网络EEGNet,并设置二分类交叉熵和Adam优化器
# 定义网络
net = EEGNet()
# 定义二分类交叉熵 (Binary Cross Entropy)
criterion = nn.BCELoss()
# 定义Adam优化器
optimizer = optim.Adam(net.parameters())
创建数据集
"""
生成训练数据集,数据集有100个样本
训练数据X_train:为[0,1)之间的随机数;
标签数据y_train:为0或1
"""
X_train = np.random.rand(100, 1, 120, 64).astype('float32')
y_train = np.round(np.random.rand(100).astype('float32'))
"""
生成验证数据集,数据集有100个样本
验证数据X_val:为[0,1)之间的随机数;
标签数据y_val:为0或1
"""
X_val = np.random.rand(100, 1, 120, 64).astype('float32')
y_val = np.round(np.random.rand(100).astype('float32'))
"""
生成测试数据集,数据集有100个样本
测试数据X_test:为[0,1)之间的随机数;
标签数据y_test:为0或1
"""
X_test = np.random.rand(100, 1, 120, 64).astype('float32')
y_test = np.round(np.random.rand(100).astype('float32'))
训练并验证
batch_size = 32
# 训练 循环
for epoch in range(10): print("\nEpoch ", epoch)running_loss = 0.0for i in range(len(X_train)//batch_size-1):s = i*batch_sizee = i*batch_size+batch_sizeinputs = torch.from_numpy(X_train[s:e])labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)# wrap them in Variableinputs, labels = Variable(inputs), Variable(labels)# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 验证params = ["acc", "auc", "fmeasure"]print(params)print("Training Loss ", running_loss)print("Train - ", evaluate(net, X_train, y_train, params))print("Validation - ", evaluate(net, X_val, y_val, params))print("Test - ", evaluate(net, X_test, y_test, params))
Epoch 0
['acc', 'auc', 'fmeasure']Training Loss 1.6107637286186218
Train - [0.52, 0.5280448717948718, 0.6470588235294118]
Validation - [0.55, 0.450328407224959, 0.693877551020408]
Test - [0.54, 0.578926282051282, 0.6617647058823529]Epoch 1
['acc', 'auc', 'fmeasure']
Training Loss 1.5536684393882751
Train - [0.45, 0.41145833333333337, 0.5454545454545454]
Validation - [0.55, 0.4823481116584565, 0.6564885496183207]
Test - [0.65, 0.6530448717948717, 0.7107438016528926]Epoch 2
['acc', 'auc', 'fmeasure']
Training Loss 1.5197088718414307
Train - [0.49, 0.5524839743589743, 0.5565217391304348]
Validation - [0.53, 0.5870279146141215, 0.5436893203883495]
Test - [0.57, 0.5428685897435898, 0.5567010309278351]Epoch 3
['acc', 'auc', 'fmeasure']
Training Loss 1.4534167051315308
Train - [0.53, 0.5228365384615385, 0.4597701149425287]
Validation - [0.5, 0.48152709359605916, 0.46808510638297873]
Test - [0.61, 0.6502403846153847, 0.5517241379310345]Epoch 4
['acc', 'auc', 'fmeasure']
Training Loss 1.3821702003479004
Train - [0.46, 0.4651442307692308, 0.3076923076923077]
Validation - [0.47, 0.5977011494252874, 0.29333333333333333]
Test - [0.52, 0.5268429487179488, 0.35135135135135137]Epoch 5
['acc', 'auc', 'fmeasure']
Training Loss 1.440490186214447
Train - [0.56, 0.516025641025641, 0.35294117647058826]
Validation - [0.36, 0.3801313628899836, 0.2]
Test - [0.53, 0.6113782051282052, 0.27692307692307694]Epoch 6
['acc', 'auc', 'fmeasure']
Training Loss 1.4722238183021545
Train - [0.47, 0.4194711538461539, 0.13114754098360656]
Validation - [0.46, 0.5648604269293925, 0.2285714285714286]
Test - [0.5, 0.5348557692307693, 0.10714285714285714]Epoch 7
['acc', 'auc', 'fmeasure']
Training Loss 1.3460421562194824
Train - [0.51, 0.44871794871794873, 0.1694915254237288]
Validation - [0.44, 0.4490968801313629, 0.2]
Test - [0.53, 0.4803685897435898, 0.14545454545454545]Epoch 8
['acc', 'auc', 'fmeasure']
Training Loss 1.3336675763130188
Train - [0.54, 0.4130608974358974, 0.20689655172413793]
Validation - [0.39, 0.40394088669950734, 0.14084507042253522]
Test - [0.51, 0.5400641025641025, 0.19672131147540983]Epoch 9
['acc', 'auc', 'fmeasure']
Training Loss 1.438510239124298
Train - [0.53, 0.5392628205128205, 0.22950819672131148]
Validation - [0.42, 0.4848111658456486, 0.09375]
Test - [0.56, 0.5420673076923076, 0.2413793103448276]
仅用于学术交流,不用于商业行为,若有侵权及疑问,请后台留言,管理员即时删侵!
更多阅读
脑电数据的Epoching处理
脑电公开数据集汇总
马斯克近日表示:Neuralink脑机接口有望明年用于人类
北师大吴倩课题组与合作者共同揭示人类下丘脑发育的时空动态特征
“读心”神经元能够预测他人的行为与决策
硬核玩家改造《上古卷轴V》,脑机接口控制魔法施放
一种新型脑机接口--集成光子芯片的脑机接口是否可行?
新型脑刺激疗法治疗重度抑郁症
脑电与情绪简介
点个在看祝你开心一整天!
EEGNet: 神经网络应用于脑电信号相关推荐
- 应用深度学习EEGNet来处理脑电信号
目录 EEGNet论文 EEGNet简介 EEGNet代码实现 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 .QQ交流群:903290195 EEGNet论文 EEGNet简介 脑机接口 ...
- 脑电信号特征提取算法c语言_应用深度学习EEGNet来处理脑电信号
文章来源于"脑机接口社区" 应用深度学习EEGNet来处理脑电信号mp.weixin.qq.com 本篇文章内容主要包括: EEGNet论文: EEGNet的实现. EEGNet ...
- 深度学习应用于脑电信号分析处理的相关论文,更新了......
分享脑机学习论文,欢迎一起讨论 1.(综述性质论文)Deep learning-based electroencephalography analysis: a systematic review 论 ...
- ACM MM:一种基于情感脑电信号时-频-空特征的3D密集连接网络
本文介绍一篇于计算机领域顶级会议ACM MM 2020发表的论文<SST-EmotionNet: Spatial-Spectral-Temporal based Attention 3D Den ...
- 脑电信号预处理--去趋势化(Detrended fluctuation analysis)
点击上面"脑机接口社区"关注我们 更多技术干货第一时间送达 本文由c851038595授权分享 感谢c851038595! 由于脑电信号的不稳定性和不规则性,因此对脑电信号的处理也 ...
- 深度神经网络对脑电信号运动想象动作的在线解码
目录 简介 网络模型 结果比较 结论 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 QQ交流群:941473018 简介 近年来,深度学习方法的快速发展使得无需任何特征工程的端到端学习成为 ...
- 常见采集脑电信号的四种技术
目录 脑电图(Electroencephalography,EEG) 皮层脑电图(Electrocorticography, ECoG) 深度电极(Depth electrode) 功能磁共振成像(F ...
- 事件相关去同步 (ERD) 和事件相关同步化 (ERS)在脑电信号研究中的应用
作者:周思捷,白红民,广州军区广州总医院神经外科 人脑具有电活动是Hans Berger于1924年首先检测到的,并将这种检测方法命名为脑电图(electroencephalo-graph,EEG). ...
- 《黑镜》黑科技成真 | 解码脑电信号,AI重构脑中的画面
原作 TIM COLLINS Root 编译自 Dailymail 量子位 出品 | 公众号 QbitAI 上周五,一贯借黑科技刻画人性阴暗面的英剧<黑镜>刚出第四季,其中第三集<鳄 ...
最新文章
- 程序员发长贴讲述真实某多多:薪水高,普调高,环境差,厕所少!强制去买菜,全行业竞业,穿拖鞋会被暴力执法!...
- 缓存之 ACache
- lodash源码分析之compact中的遍历
- linux 文件系统cache,终于找到一篇详解Linux文件系统Cache的文章
- php checkbox批量提交,jquery获取多个checkbox的值异步提交给php
- Oracle 如何实现第M行至第N行的有序读取,避免子查询order by出错
- Leetcode每日一题:164.maximum-gap(最大间距)
- CloudStack(二)基础网络模式安装部署
- ssm框架验证码图片加载不出_基于SSM框架的文件图片上传/下载功能实现
- 项目配置urule规则引擎教程详解(带视频资源)
- 区块链赋能供应链金融
- sklearn基于轮廓系数来选择n_clusters
- 我们总是喜欢拿顺其自然来敷衍人生道路上的荆棘坎坷,却很少承认,真正的顺其自然是竭尽所能之后的不强求,而非两手一摊的不作为。
- 页面相似度检测,对SEO起到什么作用?
- ESP32开发日志之AiThinkerIDE_V1.0使用过程中的一个问题
- 仿b站的动漫视频网站
- Altium Designer入门学习笔记和快捷键整理
- 多卡训练中的BN(BatchNorm)
- (原创文章)羊毛党何去何从
- 数学建模 ————统计问题之预测(一)
热门文章
- C# “Thread类Suspend()与Resume()已过时” 解决方法(利用ManualResetEvent类)
- MySQL 5.6 Warning - Using a password on the command line interface can be insecur 解决方案
- php捕获Fatal error错误与异常处理
- Visual Studio中没有为此解决方案配置选中要生成的项目
- 如何从“查找”中排除所有“拒绝权限”消息?
- 检测未定义的对象属性
- 如何在JavaScript中验证电子邮件地址
- AI算法透明不是必须,黑箱和可解释性可简化为优化问题
- BEGINNING SHAREPOINT#174; 2013 DEVELOPMENT 第12章节--SP 2013中远程Event Receivers 总结
- 移动终端app测试点总结