0 引言

很少基于pytorch框架的故障诊断模型搭建,也对自己前期的实验做个总结,欢迎讨论。

框架:pytorch

硬件:Intel(R) Core(TM) i5-10500H CPU @ 2.50GHz   2.50 GHz

1650

1 数据集

数据集:凯斯西储大学(CWRU)滚动轴承数据。

2 实验对象

2.1 数据集说明

故障轴承以激光蚀刻加工制作,人为在轴承上蚀刻单点缺陷模拟轴承的故障,缺陷点的大小设置(7,14,21,28)mils 四种类型。轴承振动数据使用加速度传感器进行信号采集,采样频率为 12000Point/S(12kHz)及 48000point/S(48kHz)。

具体说明可到官网查看。

本实验采用驱动端的轴承振动数据,采样频率选择 12kHz。

2.2 故障标签

2.3  数据对比

标签太多不好对比 ,感觉用matlab画这个要更直观和方便些,数据提取嫌麻烦就没搞了。

一维查看:

t-sne训练集可视化:

data_sne=data_valid_last           #[3385x512] list_target对应的标签
tsne = TSNE(n_components=2,init='pca',learning_rate=200)
X_tsne=tsne.fit_transform(data_sne)
x_min,x_max=X_tsne.min(0),X_tsne.max(0)
X_norm = (X_tsne - x_min) / (x_max - x_min)  # 归一化
plt.figure(figsize=(12, 12))
print(X_norm.shape)
for i in range(X_norm.shape[0]):plt.text(X_norm[i, 0], X_norm[i, 1], str(list_target[i]), color=plt.cm.Set1(list_target[i]),fontdict={'weight': 'bold', 'size': 6})
plt.xticks([])
plt.yticks([])
time_end=time.time()
time_c=time_end-time_start
print("time cost:{}min{}s".format(int(time_c/60),int(time_c%60)))
plt.show()

3 数据处理

3.1 数据预处理

def data_load(path, data_name, cut_num, label):""":param path:        数据地址:param data_name:   数据名称:param cut_num:     每份样本数量,cut_length:param label:       数据标签:return:            data_cut, label_cut"""name_str=str(data_name)data=loadmat(path+name_str+'.mat')                    #数据加载为字典格式,data为字典if data_name<100:                                     #如X097_DE_time,数据提取格式data_name='0'+str(data_name)else:data_name=str(data_name)#原始数据提取org_DE=data['X'+data_name+'_DE_time']org_FE=data['X'+data_name+'_FE_time']# 数据归一化# 归一化DEscaler = MinMaxScaler()list_DE_n = scaler.fit_transform(org_DE)list_FE_n =scaler.fit_transform(org_FE)                #风扇端数据list_DE = []for de in list_DE_n:list_DE.append(de[0])list_FE = []for fe in list_FE_n:list_FE.append(fe[0])list_r = []# 分割数据data_cut = []label_cut = []for i in range(0, int(len(list_DE_n) / cut_num)):   #分成i个data_cut.append(list_DE[i * cut_num: (i + 1) * cut_num])label_cut.append(label)return data_cut, label_cut
data_97,label_97=data_load('data/Normal Baseline Data/',97,cut_length,0)
data_98,label_98=data_load('data/Normal Baseline Data/',98,cut_length,0)
data_99,label_99=data_load('data/Normal Baseline Data/',99,cut_length,0)
data_100,label_100=data_load('data/Normal Baseline Data/',100,cut_length,0)
# print(np.asarray(data_97,dtype = 'float').shape)
data_normal = data_97 + data_98 + data_99 + data_100
label_normal = label_97 + label_98 + label_99 +label_100
print("处理后正常样本shape:",np.asarray(data_normal,dtype = 'float').shape)
print("label数:",len(label_normal))

notes:其中正常样本是其他单个样本的三倍。

3.2 数据加载、划分及处理

batch_size=128                #  批量大小
train_data_loader=DataLoader(data_train_last,batch_size,drop_last=True)
train_label_loader=DataLoader(data_train_label,batch_size,drop_last=True)
valid_label_loader=DataLoader(valid_label,batch_size,drop_last=True)
valid_data_loader=DataLoader(data_valid_last,batch_size,drop_last=True)

3.3 数据打包

4 模型构建

超参数设置:

epoch可以设置的大一些,我平常都是500,写这篇文章为了节约时间才设置的150,在后面正确率还可以浅浅的上升。

loss_n=nn.CrossEntropyLoss()optimizer=torch.optim.Adam(params=LsTm.parameters(),lr=0.001,betas=(0.9,0.999),eps=1e-08,weight_decay=1e-4)
epoch=150

4.1 LSTM

输入大小为:[128,512,1]

class mylstm(nn.Module):def __init__(self):super(mylstm, self).__init__()self.modle1 = nn.LSTM(input_size=1,hidden_size=16,num_layers=2, batch_first=True, dropout=0.2)self.modle2=nn.Sequential     (nn.Flatten(),nn.Linear(16*512,256),nn.ReLU(),nn.Dropout(0.2),nn.Linear(256,10))def forward(self, x):h_0 = torch.randn(2, 128,16)c_0 = torch.randn(2,128, 16)h_0=h_0.cuda()c_0=c_0.cuda()x = x.to(torch.float32)x,(h_0,c_0)=self.modle1(x,(h_0,c_0))#x=x[:,-1,:]x=self.modle2(x)return x

4.1.1 结果

train_last_10_epoch_avg_accuracy:0.9739
 valid_last_10_epoch_avg_accuracy:0.9356
 time cost:8min48s

4.2 LSTM-1DCNN

后面就接了个一维池化层。

class mylstm(nn.Module):def __init__(self):super(mylstm, self).__init__()self.modle1 = nn.LSTM(input_size=1,hidden_size=16,num_layers=2, batch_first=True, dropout=0.2)self.modle2=nn.Sequential     (nn.AvgPool1d(16),nn.Flatten(),nn.Linear(512,256),nn.ReLU(),# nn.Dropout(0.2),nn.Linear(256, 64),nn.ReLU(),nn.Linear(64,10))def forward(self, x):h_0 = torch.randn(2,batch_size,16)c_0 = torch.randn(2,batch_size, 16)h_0=h_0.cuda()c_0=c_0.cuda()x = x.to(torch.float32)x,(h_0,c_0)=self.modle1(x,(h_0,c_0))#print(x.shape)#x=x[:,-1,:]x=self.modle2(x)return x

4.2.1 结果

train_last_10_epoch_avg_accuracy:0.9937
valid_last_10_epoch_avg_accuracy:0.9779
time cost:8min36s

5 总结

过拟合和稳定性还有很大的提升。

dropout要慎重,极大可能丢弃掉一些重要的神经元,导致损失值突变。

加入注意力机制后几乎没啥改变。

LSTM后面可以接二维卷积或者一维卷积用大的卷积核,提高感受视野,效果有提升。

怎么花怎么来吧。

基于LSTM的故障诊断相关推荐

  1. 手把手教你:基于LSTM的股票预测系统

    系列文章 第七章.手把手教你:基于深度残差网络(ResNet)的水果分类识别系统 第六章.手把手教你:人脸识别的视频打码 第五章.手把手教你:基于深度学习的滚动轴承故障诊断 目录 系列文章 一.项目简 ...

  2. 【LSTM】基于LSTM网络的人脸识别算法的MATLAB仿真

    1.软件版本 matlab2021a 2.本算法理论知识 长短时记忆模型LSTM是由Hochreiter等人在1997年首次提出的,其主要原理是通过一种特殊的神经元结构用来长时间存储信息.LSTM网络 ...

  3. lstm 根据前文预测词_干货 | Pytorch实现基于LSTM的单词检测器

    Pytorch实现 基于LSTM的单词检测器 字幕组双语原文: Pytorch实现基于LSTM的单词检测器 英语原文: LSTM Based Word Detectors 翻译: 雷锋字幕组(Icar ...

  4. 基于LSTM的序列预测: 飞机月流量预测

    基于LSTM的序列预测: 飞机月流量预测 循环神经网络,如RNN,LSTM等模型,比较适合用于序列预测,下面以一个比较经典的飞机月流量数据集,介绍LSTM的使用方法和训练过程. 完整的项目代码下载:h ...

  5. bilstm+crf中文分词_基于LSTM的中文分词模型

    中文分词任务是一个预测序列的经典问题,已知的方法有基于HMM[1]的分词方法.基于CRF[2]的方法和基于LSTM的分词方法. 本文介绍Xinchi Chen等人[3]提出的基于LSTM的分词方法.根 ...

  6. 基于LSTM三分类的文本情感分析,采用LSTM模型,训练一个能够识别文本postive, neutral, negative三种

    基于LSTM三分类的文本情感分析,采用LSTM模型,训练一个能够识别文本postive, neutral, negative三种 ,含数据集可直接运行 完整代码下载地址:基于LSTM三分类的文本情感分 ...

  7. tensorflow2.0 基于LSTM模型的文本生成

    春水碧于天,画船听雨眠 基于LSTM模型的唐诗文本生成 实验基本要求 实验背景 实验数据下载 LSTM模型分析 实验过程 文本预处理 编解码模型 LSTM模型设置 实验代码 实验结果 总结 致谢 实验 ...

  8. 深度解析论文 基于 LSTM 的 POI 个性化推荐框架

    基于 LSTM 的 POI 个性化推荐框架 Abstract Question Method Model Introduction 作者为什么研究这个课题? 参考模型介绍 ① word2vec ②CB ...

  9. java lstm pb_在Tensorflow Serving上部署基于LSTM的文本分类模型

    一些重要的概念 Servables Servables 是客户端请求执行计算的基础对象,大小和粒度是灵活的. Servables 不会管理自己的运行周期. 典型的Servables包括: a Tens ...

最新文章

  1. Win7:“找不到该项目”错误解决大法
  2. 关于湖北工业大学图书馆联网配置的方法
  3. go web本地化资源
  4. java虚拟机 第二章Java内存区域与内存溢出异常
  5. C++ class实现Huffman树(完整代码)
  6. j计算机专业英语题库,计算机专业英语单词习题
  7. Python正则表达式,看完这篇文章就够了...#华为云·寻找黑马程序员#
  8. java对csv格式的读写操作
  9. HTML5新标签 w3c
  10. 分布科技荣登海南省实施区块链应用示范揭榜工程名单
  11. Openvswtich 学习笔记
  12. paip.输入法编程----二级汉字2350个常用汉字2350个
  13. 一步步的Abaqus2021版本安装教程+汉化操作
  14. e531网卡驱动linux,联想e531网卡驱动下载-联想e531笔记本无线网卡驱动v6.30.223.201 官方版 - 极光下载站...
  15. 【houdini vex】边界点提取与扩展
  16. 《高性能MySQL》读书笔记(1~6章)
  17. qt程序在win10正常运行win7电脑上崩溃
  18. 解决python -m spacy download en_core_web_sm连接不上服务器的方案
  19. mysql ucase,Node.js MySQL UCASE()用法及代码示例
  20. 数字图像处理学习笔记4第四章 图像变换 附实验

热门文章

  1. Codeforces CodeCraft-20 (Div. 2) C. Primitive Primes
  2. pycharm提示python version 2.7 does not support...
  3. matlab hist函数_超全Matlab绘图方法整理(建议收藏!)
  4. 【MySQL优化(六)】InnoDB索引优化与索引规约
  5. 尚硅谷-MySQL流程控制结构
  6. vb变量名的命名规则
  7. 跟着团子学SAP PS—PS与IM(投资管理)模块的集成——项目群预算管理(投资盘管理) IM01/IM22
  8. 铝型材海报架|立式开启式海报架厂家提供
  9. 记一次破解自己win10登录密码的经历
  10. python dic 字典排序