pytorch:对小鼠的脑电数据进行睡眠状态三分类
pytorch:对小鼠的脑电数据进行睡眠状态三分类
- 1.特征提取
- 2.用pytorch分类
- 首先分出训练集和测试集
- 构建数据迭代器
- 构建网络
- 准确率评估
- 模型训练
- 3. 分类结果
本文采用PyTorch框架,用实验室采集到老鼠大脑数据来学习一波多分类。
1.特征提取
我的数据是26只老鼠的mat形式的脑电数据,每个老鼠有15个通道,每个老鼠在不同的睡眠状态截取30段数据,每一段数据长为10s。 这里采集的睡眠状态有三种,分别是awake,rem,和sws
这里为了更好的获取时域和频域的特征,分别用相关矩阵和转移熵矩阵作为我们分类的特征,然后用pca方法对相关矩阵和转移熵矩阵各提取前30个特征,总共得到60个特征,最终得到的数据为(2340,60)的矩阵,接下来就可以开始处理了。
2.用pytorch分类
首先分出训练集和测试集
data_size = np.shape(all_sample)[0]
feature_size = np.shape(all_sample)[1]
n = int(data_size / 3) # 每个睡眠状态的样本数
X = all_sample
Y = np.zeros(data_size) # 为每个样本打label
Y[n:2 * n] = 1
Y[2 * n:] = 2
# 用sklearn里面的MinMaxScaler进行归一化
scaler_minmax = MinMaxScaler(feature_range=(-1, 1)) # 设置变换范围
X_minmax = scaler_minmax.fit_transform(X)
# 随机选取30%的数据作为测试集
X_train, X_test, Y_train, Y_test = train_test_split(X_minmax, Y, test_size=0.3)
# 样本数据转tensor float32
# label 转tensor int64
X_train = torch.from_numpy(X_train).to(torch.float).to(device)
X_test = torch.from_numpy(X_test).to(torch.float).to(device)
Y_train = torch.from_numpy(Y_train).to(torch.int64).to(device)
Y_test = torch.from_numpy(Y_test).to(torch.int64).to(device)
构建数据迭代器
pytorch 的迭代器可以很方便分批导入batch块,我们这里设置batch_size为30.
batch_size = 30
train_dataset = torch.utils.data.TensorDataset(X_train, Y_train)
test_dataset = torch.utils.data.TensorDataset(X_test, Y_test)
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
构建网络
用nn.Swquential()函数来构建网络。因为我们用到的神经网络中每一层之前都是全连接,所以可以用线性层加激活函数层来进行正向传播。隐藏层神经元个数设置为20,激活函数直接调用了nn里面的sigmoid函数(ReLU也可以,并且训练速度会快很多)。pytorch中的初始化可以用torch.nn中的init.normal_()函数直接对我们构建的net进行初始化。
# 模型构建
num_inputs, num_outputs, num_hiddens = feature_size, 3, 20
net = nn.Sequential(nn.Linear(num_inputs, num_hiddens),nn.Sigmoid(),# nn.Linear(num_hiddens, num_hiddens),# nn.ReLU(),nn.Linear(num_hiddens, num_outputs)# torch交叉熵的损失函数自带softmax运算,这里就没有再加一层激活函数
)# 参数初始化
for param in net.parameters():init.normal_(param, mean=0, std=0.01)
准确率评估
训练完成后需要对测试集进行评估,我们先构建一个评估函数.这里模型的输出net(x)的形状是(batch_size, 3),所以判断输出结果时,用.argmax(dim=1)返回第二维的最大值的index。然后用.item()提取结算结果。
def acc(test_iter, net, device):acc_sum, n = 0.0, 0for x, y in test_iter:acc_sum += (net(x.to(device)).argmax(dim=1) == y.to(device)).sum().item()n += y.shape[0]return acc_sum / n
用该评估器来评估还未训练的模型对测试集的识别率,输出结果应该分布在1/3左右:
acc(test_iter, net, device)
Out[3]: 0.31196581196581197
模型训练
损失函数使用torch中的交叉熵损失函数
# 交叉熵损失函数
loss = torch.nn.CrossEntropyLoss()
优化器采用torch.optim中的SGD方法(Adam使用方法也类似),这里设置学习率为0.01,正则化参数为0.0001
# 优化器
optimizer = torch.optim.SGD(params=net.parameters(), lr=0.01, weight_decay=0.0001)
#optimizer = torch.optim.Adam(params=net.parameters(), lr=0.01, weight_decay=0.0001)
接下来开始训练,在pytorch中梯度的计算十分方便,我们可以直接在创建tensor时用requires_grad=True来设置是否跟踪梯度,在每个batch中只需用.backward()函数就可获得梯度,然后用optimizer.step()自动更新net中的参数。具体代码如下
num_epochs = 500
for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for x, y in train_iter:y_hat = net(x)l = loss(y_hat, y).sum()optimizer.zero_grad() # 优化器的梯度清零l.backward() # 反向传播梯度计算optimizer.step() # 优化器迭代train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = acc(test_iter, net, device)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, batch_size * train_l_sum / n, train_acc_sum / n, test_acc))
3. 分类结果
最终可以得到97%左右的正确率
pytorch:对小鼠的脑电数据进行睡眠状态三分类相关推荐
- Python协方差矩阵处理脑电数据
在本教程中,我们将介绍传感器协方差计算的基础知识,并构建一个噪声协方差矩阵,该矩阵可用于计算最小范数逆解. 诸如MNE的源估计方法需要从记录中进行噪声估计. 在本教程中,我们介绍了噪声协方差的基础知识 ...
- 脑电数据预处理-ICA去除伪影
ICA/BSS的理论与模型 独立成分分析(ICA)是一种盲信号分离(Blind Signal Separation,BSS)方法.ICA可线性建模如下图所示. 假设X为" ...
- 脑电分析系列[MNE-Python-18]| 生成模拟原始脑电数据
在实验中有时需要原始脑电数据来进行模拟实验,但又限于实验条件的不足,需要构造模拟的原始脑电数据. 本示例通过多次重复所需的源激活来生成原始数据. 案例介绍 # 导入工具包 import numpy a ...
- Python-生成模拟原始脑电数据
在实验中有时需要原始脑电数据来进行模拟实验,但又限于实验条件的不足,需要构造模拟的原始脑电数据. 本示例通过多次重复所需的源激活来生成原始数据. 案例介绍 # 导入工具包 import numpy a ...
- Python脑电数据的Epoching处理
点击上面"脑机接口社区"关注我们 更多技术干货第一时间送达 import os.path as op import numpy as np import mne import ma ...
- 【思维导图】利用LSTM(长短期记忆网络)来处理脑电数据
文章来源| 脑机接口社区群友 认知计算_茂森的授权分享 在此非常感谢 认知计算_茂森! 本篇文章主要通过思维导图来介绍利用LSTM(长短期记忆网络)来处理脑电数据. 文章的内容来源于社区分享的文章&l ...
- Python读取.edf格式脑电数据文件
MNE-python读取.edf文件 EDF,全称是 European Data Format,是一种标准文件格式,用于交换和存储医疗时间序列. 该格式文件能够存储多通道的数据,允许每个信号拥有不同的 ...
- 使用时空-频率模式分析从脑电数据的一些试验中提取N400成分
今天介绍的内容是清华大学高小榕教授团队的研究成果,从脑电数据中提取N400成分. 关于高小榕教授的介绍,可以查看本社区之前分享的<第1期 | 国内脑机接口领域专家教授汇总> 高小榕教授 单 ...
- Python读取保存在hdf5文件中的脑电数据
当脑电数据保存在hdf5文件中如何读取呢? 1.首先需要查看hdf5文件的结构: 2.通过结构来获取数据. import h5py import numpy as np fname='test.hdf ...
最新文章
- 人人都能看懂的EM算法推导
- Android4.2.2中对安全性的改进
- leetcode 687. Longest Univalue Path | 687. 最长同值路径(树形dp)
- 左神算法:猫狗队列(通过给不同实例盖时间戳的方法实现)
- python爬取天气预报源代码_python抓取天气并分析 实例源码
- python实现isodd函数、参数为整数、如果整数为奇数_python 程序练习题
- Maven:maven-shade-plugin, 打包失败, MojoExecutionException: Error creating shaded jar: null
- ios创建自定义控件必须具备的三个方法
- el-upload进度条无效,on-progress无效问题解决方案
- if __name__ == '__main__' 如何正确理解?
- Xcode Developer Tools
- 系统集成项目管理工程师教程重点、笔记和试题大全
- 安装SQL Server 2012过程中出现“启用windows功能NetFx3时出错”
- 软件产品需求分析思维导图
- 【组合逻辑电路】——通用译码器
- C语言100题练习计划 47——查询水果价格
- 给网页加一个全屏转场动画 HTML JS
- 视频教程-项目经理俱乐部-项目实战.职场求生.敏捷.企业管理-敏捷开发
- MySQL错误:Column ‘pno‘ in field list is ambiguous是什么问题呢?
- PHP-FFMpeg 操作视频/音频文件