本文是由CSDN用户[Memory逆光]授权分享。主要介绍了使用 1D 卷积和 LSTM 混合模型做 EEG 信号识别。感谢Memory逆光!

内容包括:1. 数据集(1.1 数据集下载、1.2 数据集解释);2. 读取数据;3. 搭建模型;4. 训练模型;5. 展示结果;6. 完整代码。

1. 数据集

1.1 数据集下载:

https://archive.ics.uci.edu/ml/datasets/Epileptic+Seizure+Recognition

打开看一下

1.2 数据集解释:

表头为 X* 的是电信号数据,共有 11500 行,每行有 178 个数据,表示 1s 时间内截取的 178 个电信号;表头为 Y 的一列是该时间段数据的标签,包括 5 个分类:

5-记录大脑的EEG信号时病人睁开了眼睛;

4-记录大脑的EEG信号时患者闭上了眼睛;

3-健康大脑区域的脑电图活动;

2-肿瘤所在区域的脑电图活动;

1-癫痫发作活动;

2. 读取数据

import pandas as pddata = "data.csv"df = pd.read_csv(data, header=0, index_col=0)
df1 = df.drop(["y"], axis=1)
lbls = df["y"].values - 1

这里使用 pandas 库读取 data.csv,df1 保存电位数据,lbls 保存标签;

import numpy as npwave = np.zeros((11500, 178))z = 0
for index, row in df1.iterrows():wave[z, :] = rowz+=1mean = wave.mean(axis=0)
wave -= mean
std = wave.std(axis=0)
wave /= stddef one_hot(y):lbl = np.zeros(5)lbl[y] = 1return lbltarget = []
for value in lbls:target.append(one_hot(value))
target = np.array(target)
wave = np.expand_dims(wave, axis=-1)

我们将数据保存在数组 wave 和 target 中,将点位数据标准化(减去均值后除以方差),并将标签转换成 one hot 的形式;

3. 搭建模型

我们使用 keras 搭建一个模型,包括 1D 卷积层和几个堆叠的 LSTM 层:

from keras.models import Sequential
from keras import layersmodel = Sequential()
model.add(layers.Conv1D(64, 15, strides=2,input_shape=(178, 1), use_bias=False))
model.add(layers.ReLU())
model.add(layers.Conv1D(64, 3))
model.add(layers.Conv1D(64, 3, strides=2))
model.add(layers.ReLU())
model.add(layers.Conv1D(64, 3))
model.add(layers.Conv1D(64, 3, strides=2))  # [None, 54, 64]
model.add(layers.BatchNormalization())
model.add(layers.LSTM(64, dropout=0.5, return_sequences=True))
model.add(layers.LSTM(64, dropout=0.5, return_sequences=True))
model.add(layers.LSTM(32))
model.add(layers.Dense(5, activation="softmax"))
model.summary()

网络结构如图:

即该模型使用 1D 卷积进行特征提取,使用 LSTM 进行时域建模,最后通过一个全连接层预测类别;

4. 训练模型

我们使用 Adam 优化器,并设置学习率衰减来进行训练:

import matplotlib.pyplot as plt
import pandas as pd
from keras.models import Sequential
from keras import layers
from keras import regularizers
import os
import kerasimport keras.backend as Ksave_path = './keras_model.h5'if os.path.isfile(save_path):model.load_weights(save_path)print('reloaded.')adam = keras.optimizers.adam()model.compile(optimizer=adam,loss="categorical_crossentropy", metrics=["acc"])
# 计算学习率
def lr_scheduler(epoch):# 每隔100个epoch,学习率减小为原来的0.5if epoch % 100 == 0 and epoch != 0:lr = K.get_value(model.optimizer.lr)K.set_value(model.optimizer.lr, lr * 0.5)print("lr changed to {}".format(lr * 0.5))return K.get_value(model.optimizer.lr)lrate = LearningRateScheduler(lr_scheduler)history = model.fit(wave, target, epochs=400,batch_size=128, validation_split=0.2,verbose=1, callbacks=[lrate])model.save_weights(save_path)

这样就可以开始训练啦:

训练的模型参数保存在 sace_path 中;

5. 展示结果

print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

这时我们可以查看训练结果(因为时间有限,我只训练了 100 个 epoch:

6. 完整代码


import matplotlib.pyplot as plt
import pandas as pd
from keras.models import Sequential
from keras import layers
from keras import regularizers
import os
import kerasimport keras.backend as Kimport numpy as npfrom keras.callbacks import LearningRateSchedulerdata = "data.csv"df = pd.read_csv(data, header=0, index_col=0)
df1 = df.drop(["y"], axis=1)
lbls = df["y"].values - 1wave = np.zeros((11500, 178))z = 0
for index, row in df1.iterrows():wave[z, :] = rowz+=1mean = wave.mean(axis=0)
wave -= mean
std = wave.std(axis=0)
wave /= stddef one_hot(y):lbl = np.zeros(5)lbl[y] = 1return lbltarget = []
for value in lbls:target.append(one_hot(value))
target = np.array(target)
wave = np.expand_dims(wave, axis=-1)model = Sequential()
model.add(layers.Conv1D(64, 15, strides=2,input_shape=(178, 1), use_bias=False))
model.add(layers.ReLU())
model.add(layers.Conv1D(64, 3))
model.add(layers.Conv1D(64, 3, strides=2))
model.add(layers.BatchNormalization())
model.add(layers.Dropout(0.5))
model.add(layers.Conv1D(64, 3))
model.add(layers.Conv1D(64, 3, strides=2))
model.add(layers.BatchNormalization())
model.add(layers.LSTM(64, dropout=0.5, return_sequences=True))
model.add(layers.LSTM(64, dropout=0.5, return_sequences=True))
model.add(layers.LSTM(32))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(5, activation="softmax"))
model.summary()save_path = './keras_model3.h5'if os.path.isfile(save_path):model.load_weights(save_path)print('reloaded.')adam = keras.optimizers.adam()model.compile(optimizer=adam,loss="categorical_crossentropy", metrics=["acc"])
# 计算学习率
def lr_scheduler(epoch):# 每隔100个epoch,学习率减小为原来的0.5if epoch % 100 == 0 and epoch != 0:lr = K.get_value(model.optimizer.lr)K.set_value(model.optimizer.lr, lr * 0.5)print("lr changed to {}".format(lr * 0.5))return K.get_value(model.optimizer.lr)lrate = LearningRateScheduler(lr_scheduler)history = model.fit(wave, target, epochs=400,batch_size=128, validation_split=0.2,verbose=2, callbacks=[lrate])model.save_weights(save_path)print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

作者博客:

https://blog.csdn.net/weixin_44936889/article/details/105202661

文章来源于网络,仅用于学术交流,不用于商业行为

若有侵权及疑问,请后台留言,管理员即时删侵!

更多阅读

【脑电信号分类】脑电信号提取PSD功率谱密度特征

EEG伪影类型详解和过滤工具的汇总(一)

临床脑电图常用术语(二)-脑波及相关形态和分布

清华张钹院士专刊文章:迈向第三代人工智能(全文收录)

脑机接口拼写器是否真的安全?华中科技大学研究团队对此做了相关研究

脑机接口和卷积神经网络的初学指南(一)

脑电数据处理分析教程汇总(eeglab, mne-python)

P300脑机接口及数据集处理

快速入门脑机接口:BCI基础(一)

如何快速找到脑机接口社区的历史文章?

脑机接口BCI学习交流QQ群:515148456

手把手教你使用 1D 卷积和 LSTM 混合模型做 EEG 信号识别相关推荐

  1. 手把手教你使用pytorch实现双向LSTM机器翻译

    目录 前言 1. 数据集 1.1 下载数据集并处理 1.2 将数据集分为source和target 1.3 定义词汇类 1.4 获取训练集 2. 定义模型 2.1 导入相关工具包 2.2. 定义Enc ...

  2. 手把手教你用直方图、饼图和条形图做数据分析(Python代码)

    导读:对数据进行质量分析以后,接下来可通过绘制图表.计算某些特征量等手段进行数据的特征分析. 其中,分布分析能揭示数据的分布特征和分布类型.本文就手把手教你做分布分析. 作者:张良均 谭立云 刘名军 ...

  3. python股票直方图代码_手把手教你用直方图、饼图和条形图做数据分析(Python代码)...

    云栖号资讯:[点击查看更多行业资讯] 在这里您可以找到不同行业的第一手的上云资讯,还在等什么,快来! 导读:对数据进行质量分析以后,接下来可通过绘制图表.计算某些特征量等手段进行数据的特征分析. 其中 ...

  4. mac虚拟机服务器设置u盘启动不了,手把手教你解决win7系统苹果电脑运行虚拟机后无法识别显示U盘的图文方案...

    许多win7系统用户在工作中经常会遇到win7系统苹果电脑运行虚拟机后无法识别显示U盘的情况,比如近日有用户到本站反映说win7系统苹果电脑运行虚拟机后无法识别显示U盘的问题,但是却不知道要怎么解决w ...

  5. 走亲访友不慌!手把手教你怎样用Mask R-CNN和Python做一个抢车位神器

    现在大家都忙着过大年,按照传统习俗,各种走亲访友.这时候的商场.饭馆也都是"人声鼎沸",毕竟走亲戚串门必不可少要带点礼品.聚餐喝茶. 热闹归热闹,这个时候最难的问题可能就是怎样从小 ...

  6. python模块cv2人脸识别_手把手教你使用OpenCV,Python和深度学习进行人脸识别

    使用OpenCV,Python和深度学习进行人脸识别 在本教程中,你将学习如何使用OpenCV,Python和深度学习进行面部识别.首先,我们将简要讨论基于深度学习的面部识别,包括"深度度量 ...

  7. 手把手教你用AI画梵高的《星空》

    导读:有人说,AI会导致失业:也有人说,AI创造大量工作机会,各行各业对AI人才的需求都将日益增加. AI在模仿人类的学习方式,那么,人类又该怎样学习AI?本文就带你了解一本学习AI的神书. 来源:华 ...

  8. 手把手教你:基于LSTM的股票预测系统

    系列文章 第七章.手把手教你:基于深度残差网络(ResNet)的水果分类识别系统 第六章.手把手教你:人脸识别的视频打码 第五章.手把手教你:基于深度学习的滚动轴承故障诊断 目录 系列文章 一.项目简 ...

  9. 3d卷积和2d卷积1d卷积运算-CNN卷积核与通道讲解

    全网最全的卷积运算过程:https://blog.csdn.net/Lucinda6/article/details/115575534?spm=1001.2101.3001.6661.1&u ...

最新文章

  1. docusign文档打不开_怎样查看 docusign pdf 电子签名
  2. 前端学习(579):chrome devtools功能简介
  3. script 标签到底该放在哪里
  4. oracle resize什么意思,Oracle调整表空间大小resize
  5. jmeter 连接 sqlite 进行压力测试
  6. baidumap vue 判断范围_vue中百度地图API的调用
  7. C#.NET 消息机制
  8. 002.FTP配置项详解
  9. OneLedger蓄势待发,引爆跨链热点
  10. twitter注册不了_如何阻止Twitter重点阻止不相关的通知
  11. el-collapse用法
  12. 信息检索平台Terrier的使用
  13. flexbox布局详解
  14. 利用dsp电机测速及详解
  15. Selective Search 学习笔记
  16. ServerSocket和Socket连接
  17. 【已解决】vue安装项目的时候出现了 command failed: pnpm install --reporter silent --shamefully-hoist 很有趣的解密过程
  18. 三维重建 几何方法 深度学习_三维重建算法综述|传统+深度学习方式
  19. python制作会动的表情包_Python自动生成表情包
  20. 电力电子技术笔记(4)——典型全控型器件

热门文章

  1. CORS跨域实现思路及相关解决方案
  2. Visual Studio中没有为此解决方案配置选中要生成的项目
  3. Visual Stdio 无法直接启动带有“类库输出类型”的项目若要调试此项目,请在此解决方案中添加一个引用库项目的可执行项目。将这个可执行项目设置为启动项目!
  4. 怎样解决VMware虚拟机无法连接外网问题
  5. 检查型异常(Checked Exception)与非检查型异常(Unchecked Exception)
  6. 如何在Pandas的DataFrame中的行上进行迭代?
  7. 手把手 | Python代码和贝叶斯理论告诉你,谁是最好的棒球选手
  8. 使用Scrapy构建一个网络爬虫
  9. WindowsServer2012史记4-重复数据删除的魅力
  10. mysql主从及读写分离