pytorch:对小鼠的脑电数据进行睡眠状态三分类

  • 1.特征提取
  • 2.用pytorch分类
    • 首先分出训练集和测试集
    • 构建数据迭代器
    • 构建网络
    • 准确率评估
    • 模型训练
  • 3. 分类结果

本文采用PyTorch框架,用实验室采集到老鼠大脑数据来学习一波多分类。

1.特征提取

我的数据是26只老鼠的mat形式的脑电数据,每个老鼠有15个通道,每个老鼠在不同的睡眠状态截取30段数据,每一段数据长为10s。 这里采集的睡眠状态有三种,分别是awake,rem,和sws

这里为了更好的获取时域和频域的特征,分别用相关矩阵和转移熵矩阵作为我们分类的特征,然后用pca方法对相关矩阵和转移熵矩阵各提取前30个特征,总共得到60个特征,最终得到的数据为(2340,60)的矩阵,接下来就可以开始处理了。

2.用pytorch分类

首先分出训练集和测试集

data_size = np.shape(all_sample)[0]
feature_size = np.shape(all_sample)[1]
n = int(data_size / 3) # 每个睡眠状态的样本数
X = all_sample
Y = np.zeros(data_size) # 为每个样本打label
Y[n:2 * n] = 1
Y[2 * n:] = 2
# 用sklearn里面的MinMaxScaler进行归一化
scaler_minmax = MinMaxScaler(feature_range=(-1, 1))  # 设置变换范围
X_minmax = scaler_minmax.fit_transform(X)
# 随机选取30%的数据作为测试集
X_train, X_test, Y_train, Y_test = train_test_split(X_minmax, Y, test_size=0.3)
# 样本数据转tensor float32
# label 转tensor int64
X_train = torch.from_numpy(X_train).to(torch.float).to(device)
X_test = torch.from_numpy(X_test).to(torch.float).to(device)
Y_train = torch.from_numpy(Y_train).to(torch.int64).to(device)
Y_test = torch.from_numpy(Y_test).to(torch.int64).to(device)

构建数据迭代器

pytorch 的迭代器可以很方便分批导入batch块,我们这里设置batch_size为30.

batch_size = 30
train_dataset = torch.utils.data.TensorDataset(X_train, Y_train)
test_dataset = torch.utils.data.TensorDataset(X_test, Y_test)
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

构建网络

用nn.Swquential()函数来构建网络。因为我们用到的神经网络中每一层之前都是全连接,所以可以用线性层加激活函数层来进行正向传播。隐藏层神经元个数设置为20,激活函数直接调用了nn里面的sigmoid函数(ReLU也可以,并且训练速度会快很多)。pytorch中的初始化可以用torch.nn中的init.normal_()函数直接对我们构建的net进行初始化。

# 模型构建
num_inputs, num_outputs, num_hiddens = feature_size, 3, 20
net = nn.Sequential(nn.Linear(num_inputs, num_hiddens),nn.Sigmoid(),# nn.Linear(num_hiddens, num_hiddens),# nn.ReLU(),nn.Linear(num_hiddens, num_outputs)# torch交叉熵的损失函数自带softmax运算,这里就没有再加一层激活函数
)# 参数初始化
for param in net.parameters():init.normal_(param, mean=0, std=0.01)

准确率评估

训练完成后需要对测试集进行评估,我们先构建一个评估函数.这里模型的输出net(x)的形状是(batch_size, 3),所以判断输出结果时,用.argmax(dim=1)返回第二维的最大值的index。然后用.item()提取结算结果。

def acc(test_iter, net, device):acc_sum, n = 0.0, 0for x, y in test_iter:acc_sum += (net(x.to(device)).argmax(dim=1) == y.to(device)).sum().item()n += y.shape[0]return acc_sum / n

用该评估器来评估还未训练的模型对测试集的识别率,输出结果应该分布在1/3左右:

acc(test_iter, net, device)
Out[3]: 0.31196581196581197

模型训练

损失函数使用torch中的交叉熵损失函数

# 交叉熵损失函数
loss = torch.nn.CrossEntropyLoss()

优化器采用torch.optim中的SGD方法(Adam使用方法也类似),这里设置学习率为0.01,正则化参数为0.0001

# 优化器
optimizer = torch.optim.SGD(params=net.parameters(), lr=0.01, weight_decay=0.0001)
#optimizer = torch.optim.Adam(params=net.parameters(), lr=0.01, weight_decay=0.0001)

接下来开始训练,在pytorch中梯度的计算十分方便,我们可以直接在创建tensor时用requires_grad=True来设置是否跟踪梯度,在每个batch中只需用.backward()函数就可获得梯度,然后用optimizer.step()自动更新net中的参数。具体代码如下

num_epochs = 500
for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for x, y in train_iter:y_hat = net(x)l = loss(y_hat, y).sum()optimizer.zero_grad()  # 优化器的梯度清零l.backward()  # 反向传播梯度计算optimizer.step()  # 优化器迭代train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = acc(test_iter, net, device)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, batch_size * train_l_sum / n, train_acc_sum / n, test_acc))

3. 分类结果

最终可以得到97%左右的正确率

pytorch:对小鼠的脑电数据进行睡眠状态三分类相关推荐

  1. Python协方差矩阵处理脑电数据

    在本教程中,我们将介绍传感器协方差计算的基础知识,并构建一个噪声协方差矩阵,该矩阵可用于计算最小范数逆解. 诸如MNE的源估计方法需要从记录中进行噪声估计. 在本教程中,我们介绍了噪声协方差的基础知识 ...

  2. 脑电数据预处理-ICA去除伪影

    ‍‍‍‍‍‍‍‍‍‍ ICA/BSS的理论与模型 独立成分分析(ICA)是一种盲信号分离(Blind Signal Separation,BSS)方法.ICA可线性建模如下图所示. 假设X为" ...

  3. 脑电分析系列[MNE-Python-18]| 生成模拟原始脑电数据

    在实验中有时需要原始脑电数据来进行模拟实验,但又限于实验条件的不足,需要构造模拟的原始脑电数据. 本示例通过多次重复所需的源激活来生成原始数据. 案例介绍 # 导入工具包 import numpy a ...

  4. Python-生成模拟原始脑电数据

    在实验中有时需要原始脑电数据来进行模拟实验,但又限于实验条件的不足,需要构造模拟的原始脑电数据. 本示例通过多次重复所需的源激活来生成原始数据. 案例介绍 # 导入工具包 import numpy a ...

  5. Python脑电数据的Epoching处理

    点击上面"脑机接口社区"关注我们 更多技术干货第一时间送达 import os.path as op import numpy as np import mne import ma ...

  6. 【思维导图】利用LSTM(长短期记忆网络)来处理脑电数据

    文章来源| 脑机接口社区群友 认知计算_茂森的授权分享 在此非常感谢 认知计算_茂森! 本篇文章主要通过思维导图来介绍利用LSTM(长短期记忆网络)来处理脑电数据. 文章的内容来源于社区分享的文章&l ...

  7. Python读取.edf格式脑电数据文件

    MNE-python读取.edf文件 EDF,全称是 European Data Format,是一种标准文件格式,用于交换和存储医疗时间序列. 该格式文件能够存储多通道的数据,允许每个信号拥有不同的 ...

  8. 使用时空-频率模式分析从脑电数据的一些试验中提取N400成分

    今天介绍的内容是清华大学高小榕教授团队的研究成果,从脑电数据中提取N400成分. 关于高小榕教授的介绍,可以查看本社区之前分享的<第1期 | 国内脑机接口领域专家教授汇总> 高小榕教授 单 ...

  9. Python读取保存在hdf5文件中的脑电数据

    当脑电数据保存在hdf5文件中如何读取呢? 1.首先需要查看hdf5文件的结构: 2.通过结构来获取数据. import h5py import numpy as np fname='test.hdf ...

最新文章

  1. 人人都能看懂的EM算法推导
  2. Android4.2.2中对安全性的改进
  3. leetcode 687. Longest Univalue Path | 687. 最长同值路径(树形dp)
  4. 左神算法:猫狗队列(通过给不同实例盖时间戳的方法实现)
  5. python爬取天气预报源代码_python抓取天气并分析 实例源码
  6. python实现isodd函数、参数为整数、如果整数为奇数_python 程序练习题
  7. Maven:maven-shade-plugin, 打包失败, MojoExecutionException: Error creating shaded jar: null
  8. ios创建自定义控件必须具备的三个方法
  9. el-upload进度条无效,on-progress无效问题解决方案
  10. if __name__ == '__main__' 如何正确理解?
  11. Xcode Developer Tools
  12. 系统集成项目管理工程师教程重点、笔记和试题大全
  13. 安装SQL Server 2012过程中出现“启用windows功能NetFx3时出错”
  14. 软件产品需求分析思维导图
  15. 【组合逻辑电路】——通用译码器
  16. C语言100题练习计划 47——查询水果价格
  17. 给网页加一个全屏转场动画 HTML JS
  18. 视频教程-项目经理俱乐部-项目实战.职场求生.敏捷.企业管理-敏捷开发
  19. MySQL错误:Column ‘pno‘ in field list is ambiguous是什么问题呢?
  20. PHP-FFMpeg 操作视频/音频文件

热门文章

  1. 树莓派安装pytorch环境记录
  2. linux 重新加载dev,/dev/shm修改大小并重新挂载
  3. 【亚马逊运营】应该如何去优化关键词的自然排名?
  4. js将秒转换成几天几小时几分几秒,每秒刷新
  5. 新媒体运营黎想: UGC社区运营技巧!
  6. SQL Server 2008 安装
  7. EasyDSS如何实现定期检测和取读加密狗授权?
  8. delphi7在AdvStringGrid中添加ComboBox方法,记录下来
  9. T---Win10监控软件的GDI数量
  10. CSDN客服联系方式汇总