文章目录

  • 1. RNN 介绍
  • 2. 搭建RNN模型进行训练

1. RNN 介绍

循环神经网络RNN的提出主要针对于时间序列数据。

类似于股票、心律失常 ECG 和 电力数据 等数据都是属于时间序列数据。

RNN模型具有记忆功能。时间序列数据前一时刻的数据可能会影响后一时刻的数据;因此,循环神经网络在时间序列数据上有着较好的性能。

简单地说,循环神经网络目的在于探索序列之间的关系!!!它是根据"人的认知是基于过往的经验和记忆"这一观点提出的。

需要注意的是,该类模型主要处理一维数据。鉴于其记忆功能,能够将前后数据进行关联,因此循环神经网络模型在一维数据的处理上获得了较好的表现。

目前,主流的循环神经网络模型有RNN、LSTM、GRU以及基于它们的改进模型。

本篇文章主要以介绍 RNN 处理心律失常ECG数据为主。

参数解释:

  • xt−1,xt,xt+1x_{t-1}, x_{t}, x_{t+1}xt−1​,xt​,xt+1​ 分别表示 t-1, t 和 t+1 时刻的输入
  • ot−1,ot,ot+1o_{t-1}, o_{t}, o_{t+1}ot−1​,ot​,ot+1​ 分别表示 t-1, t 和 t+1 时刻的结果
  • st−1,st,st+1s_{t-1}, s_{t}, s_{t+1}st−1​,st​,st+1​ 分别表示 t-1, t 和 t+1 时刻的记忆或叫隐藏层
  • WWW 表示前一时刻输入的权重, UUU 表示此刻输入的样本的权重, VVV 表示输出的样本权重.
  • 在整个网络中,WWW,UUU 和 VVV 是共享的。需要注意的是: VVV 是需要看情况使用的;若搭建网络时需要每一个隐藏层的输出,这种情况下是需要用到 VVV 的;若不需要每一个隐藏层的输出,可以不使用 VVV 。(本篇文章只取最后一个隐藏层的输出,因此用不到 VVV。 )

在实现代码的过程中,可能有部分同学有疑问:整个训练过程只有一个参数矩阵A啊,不是应该有2个参数 WWW,UUU?

其实很简单,在实现过程中RNN会将上一个状态 st−1s_{t-1}st−1​和当前状态的输入 xtx_txt​ 进行 connect 操作,如上图;那么 W 和 U便是上图左边矩阵A中的参数;W是A中紫色部分,U是A中蓝色部分。也就是说,在实际操作过程中,将这两个参数合并在一起了。

在本实验中,我们举得例子中 x 大小(每个心拍)为300,隐藏层h的大小是 50;那么根据矩阵的乘法运算,我们可以轻易地得到A的大小是 50 * 350

2. 搭建RNN模型进行训练

关于心律失常ECG数据的相关处理之前都已经介绍过了,这里就不展开介绍了。

实验所用数据集:MIT-BIH Arrhythmia Database

主要有几点区别:

  1. 本次代码是用 Pytorch 实现的。
  2. 鉴于目前大多数心律失常分类都是在对N、S、V、F 和 Q 五类进行分类,之前我们是分类N、A、V、L和R五类。本次实验主要是针对 N、S、V、F 和 Q 五类。

直接看代码,相关代码均已在代码中注释。

'''
导入相关包
'''
import wfdb
import pywt
import seaborn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
import torch.utils.data as Data
from torch import nn'''
加载数据集
'''# 测试集在数据集中所占的比例
RATIO = 0.2# 小波去噪预处理
def denoise(data):# 小波变换coeffs = pywt.wavedec(data=data, wavelet='db5', level=9)cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs# 阈值去噪threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1))))cD1.fill(0)cD2.fill(0)for i in range(1, len(coeffs) - 2):coeffs[i] = pywt.threshold(coeffs[i], threshold)# 小波反变换,获取去噪后的信号rdata = pywt.waverec(coeffs=coeffs, wavelet='db5')return rdata# 读取心电数据和对应标签,并对数据进行小波去噪
def getDataSet(number, X_data, Y_data):ecgClassSet = ['N', 'A', 'V', 'L', 'R']# 读取心电数据记录# print("正在读取 " + number + " 号心电数据...")# 读取MLII导联的数据record = wfdb.rdrecord('./data/MIT-BIH-360/' + number, channel_names=['MLII'])data = record.p_signal.flatten()rdata = denoise(data=data)# 获取心电数据记录中R波的位置和对应的标签annotation = wfdb.rdann('./data/MIT-BIH-360/' + number, 'atr')Rlocation = annotation.sampleRclass = annotation.symbol# 去掉前后的不稳定数据start = 10end = 5i = startj = len(annotation.symbol) - end# 因为只选择NAVLR五种心电类型,所以要选出该条记录中所需要的那些带有特定标签的数据,舍弃其余标签的点# X_data在R波前后截取长度为300的数据点# Y_data将NAVLR按顺序转换为01234while i < j:try:# Rclass[i] 是标签lable = ecgClassSet.index(Rclass[i])# 基于经验值,基于R峰向前取100个点,向后取200个点x_train = rdata[Rlocation[i] - 100:Rlocation[i] + 200]X_data.append(x_train)Y_data.append(lable)i += 1except ValueError:i += 1return# 加载数据集并进行预处理
def loadData():numberSet = ['100', '101', '103', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115','116', '117', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '208','210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230','231', '232', '233', '234']dataSet = []lableSet = []for n in numberSet:getDataSet(n, dataSet, lableSet)# 转numpy数组,打乱顺序dataSet = np.array(dataSet).reshape(-1, 300)lableSet = np.array(lableSet).reshape(-1, 1)train_ds = np.hstack((dataSet, lableSet))np.random.shuffle(train_ds)# 数据集及其标签集X = train_ds[:, :300].reshape(-1, 1,300)Y = train_ds[:, 300]# 测试集及其标签集shuffle_index = np.random.permutation(len(X))# 设定测试集的大小 RATIO是测试集在数据集中所占的比例test_length = int(RATIO * len(shuffle_index))# 测试集的长度test_index = shuffle_index[:test_length]# 训练集的长度train_index = shuffle_index[test_length:]X_test, Y_test = X[test_index], Y[test_index]X_train, Y_train = X[train_index], Y[train_index]return X_train, Y_train, X_test, Y_testX_train, Y_train, X_test, Y_test = loadData()'''
数据处理
'''
train_Data = Data.TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train)) # 返回结果为一个个元组,每一个元组存放数据和标签
train_loader = Data.DataLoader(dataset=train_Data, batch_size=128)
test_Data = Data.TensorDataset(torch.Tensor(X_test), torch.Tensor(Y_test)) # 返回结果为一个个元组,每一个元组存放数据和标签
test_loader = Data.DataLoader(dataset=test_Data, batch_size=128)'''
模型搭建
'''
class RnnModel(nn.Module):def __init__(self):super(RnnModel, self).__init__()'''参数解释:(输入维度,隐藏层维度,网络层数)输入维度:每个x的输入大小,也就是每个x的特征数隐藏层:隐藏层的层数,若层数为1,隐层只有1层网络层数:网络层的大小'''self.rnn = nn.RNN(300, 50, 3, nonlinearity='tanh')self.linear = nn.Linear(50, 5)def forward(self, x):r_out, h_state = self.rnn(x)output = self.linear(r_out[:,-1,:])return outputmodel = RnnModel()'''
设置损失函数和参数优化方法
'''
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)'''
模型训练
'''
EPOCHS = 5
for epoch in range(EPOCHS):running_loss = 0for i, data in enumerate(train_loader):inputs, label = datay_predict = model(inputs)loss = criterion(y_predict, label.long())optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 预测correct = 0total = 0with torch.no_grad():for data in test_loader:inputs, label = datay_pred = model(inputs)_, predicted = torch.max(y_pred.data, dim=1)total += label.size(0)correct += (predicted == label).sum().item()print(f'Epoch: {epoch + 1}, ACC on test: {correct / total}')

分类效果基本上可以达到 98% 左右。

Deep Learning × ECG (5) :利用循环神经网络RNN对心律失常ECG数据进行分类相关推荐

  1. Deep Learning × ECG (4) :利用卷积神经网络CNN对心律失常ECG数据进行分类

    本文主要就是介绍搭建模型和模型训练了!! 文章目录 1. AAMI 标准 2. 模型搭建和训练 3. 模型搭建环境 1. AAMI 标准 根据 AAMI (简称:美国心脏病协会) 提供的标准:将心拍分 ...

  2. dive into deep learning 循环神经网络 RNN 部分 学习

    dive into deep learning 循环神经网络 RNN 部分 学习 到目前为止,我们遇到过两种类型的数据:表格数据和图像数据. 对于图像数据,我们设计了专门的卷积神经网络架构来为这类特殊 ...

  3. 第六章_循环神经网络(RNN)

    文章目录 第六章 循环神经网络(RNN) CNN和RNN的对比 http://www.elecfans.com/d/775895.html 6.1 为什么需要RNN? 6.1 RNN种类? RNN t ...

  4. 循环神经网络(RNN, Recurrent Neural Networks)介绍

    循环神经网络(RNN, Recurrent Neural Networks)介绍   循环神经网络(Recurrent Neural Networks,RNNs)已经在众多自然语言处理(Natural ...

  5. 循环神经网络RNN 2—— attention注意力机制(附代码)

    attention方法是一种注意力机制,很明显,是为了模仿人的观察和思维方式,将注意力集中到关键信息上,虽然还没有像人一样,完全忽略到不重要的信息,但是其效果毋庸置疑,本篇我们来总结注意力机制的不同方 ...

  6. 通过keras例子理解LSTM 循环神经网络(RNN)

    博文的翻译和实践: Understanding Stateful LSTM Recurrent Neural Networks in Python with Keras 正文 一个强大而流行的循环神经 ...

  7. 【NLP】毕设学习笔记(八)“前馈 + 反馈” = 循环神经网络RNN

    前馈神经网络和循环神经网络分别适合处理什么样的任务? 如果分类任务仅仅是进行判断和识别,例如判断照片上的人的性别,识别图片上是否有小狗图案,那么对输入的数据仅仅需要做特征寻找的工作即可,找到满足该任务 ...

  8. Keras 中的循环神经网络 (RNN)

    简介 循环神经网络 (RNN) 是一类神经网络,它们在序列数据(如时间序列或自然语言)建模方面非常强大. 简单来说,RNN 层会使用 ​​for​​ 循环对序列的时间步骤进行迭代,同时维持一个内部状态 ...

  9. 花书+吴恩达深度学习(十五)序列模型之循环神经网络 RNN

    目录 0. 前言 1. RNN 计算图 2. RNN 前向传播 3. RNN 反向传播 4. 导师驱动过程(teacher forcing) 5. 不同序列长度的 RNN 如果这篇文章对你有一点小小的 ...

最新文章

  1. iservice list方法_MyBatis-Plus 通用IService使用详解
  2. 本地也能运行AWS?是的,AWS开始进军混合云领域了
  3. ZOJ 1610 Count the Colors
  4. eigrp配置实验_EIGRP负载均衡的实现
  5. (转)智能投顾的大赢家,仍然会是传统机构
  6. 【0304】密码分类
  7. JavaScript运算符优先级
  8. 野火ISO-V2学习
  9. html中如何出现三重阴影,探索 CSS3 中的 box-shadow 属性
  10. 计算机网络知识之1M宽带下载速度多少?
  11. android:layout_weight权重与warp_content配合使用
  12. 外贸订单支付失败有哪些原因导致?有哪些解决方案?
  13. 请告诉我IT行业缺少怎样的人
  14. 【从0到1搭建LoRa物联网】13、低成本单通道网关(一)
  15. avi文件格式详解(一)
  16. 嘉立创常用叠层结构阻抗计算
  17. 数据全生命周期加密,三未信安参展2018贵阳数博会
  18. php如何连接postgresql,php如何连接和操作PostgreSQL数据库
  19. java毕业设计基于JS的租房网站mybatis+源码+调试部署+系统+数据库+lw
  20. 金庸、古龙笔下48句经典语录

热门文章

  1. Libgdx游戏编程之Touchpad摇杆控制角色行走
  2. AEC行业那些开源的软件在这里
  3. ubuntu无限登录,或者登录进去,界面卡顿无法使用独显,nvidia-setting打开失败
  4. 用C语言写一个日期计算器
  5. 《人类简史》十四、开启未来(上)——智人的灭亡
  6. 编译设备树时报错“arch/arm/boot/dts/imx50.dtsi:14:42:致命错误:dt-bindings/clock/imx5-clock/h:没有那个文件或目录”
  7. oracle计算比例,某字段的百分比
  8. 酷开系统鸿蒙,华为鸿蒙系统首发设备曝光!不是手机
  9. ClassNotFoundException:org.exolab.castor.xml.XMLException
  10. 电源12V稳压5V MP2359从数据手册到布线 经验分享