卷积神经网络处理一维信号(故障诊断)

(注:从小白角度出发,刚接触卷积神经网络的小伙伴,很多人和我一样就是想知道这段代码怎么把信号输入进去,怎么看输出结果,怎么把输出结果与测试集的数据进行对比,从而知道测试结果,这些我在下面有解释。本文最后会附有链接,包括我用的数据,源码。大家可以看一下数据格式,我当时就是不知道表格里的数据到底是什么格式,然后搞了好久!!!!如果有问题的小伙伴可以留言,我会尽力解答。。)
编辑器:Anaconda+jupyter
环境    :python :3.7.10
               tensorflow::2.3.0

代码如下

import keras
from scipy.io import loadmat
import matplotlib.pyplot as plt
import glob
import numpy as np
import pandas as pd
import math
import os
from keras.layers import *
from keras.models import *
from keras.optimizers import *//这里是我导入的训练集数据训练集,大家对应自己的信号数据就好,数据我下面会发,大家可以看一下数据的格式;
MANIFEST_DIR = r'C:\Users\Administrator\Desktop\test\frftdata\train\frfttrain1.0.csv'
Batch_size = 30
Long = 800
Lens = 200
def convert2oneHot(index, lens):hot = np.zeros((lens,))hot[int(index)] = 1return(hot)def xs_gen(path=MANIFEST_DIR, batch_size=Batch_size, train=True, Lens=Lens):img_list = pd.read_csv(path)if train:img_list = np.array(img_list)[:Lens]print("Found %s train items." % len(img_list))print("list 1 is", img_list[0, -1])steps = math.ceil(len(img_list) / batch_size)else:img_list = np.array(img_list)[Lens:]print("Found %s test items." % len(img_list))print("list 1 is", img_list[0, -1])steps = math.ceil(len(img_list) / batch_size)while True:for i in range(steps):batch_list = img_list[i * batch_size: i * batch_size + batch_size]np.random.shuffle(batch_list)batch_x = np.array([file for file in batch_list[:, 1:-1]])batch_y = np.array([convert2oneHot(label, 4) for label in batch_list[:, -1]])yield batch_x, batch_y
//这里是导入的我测试集的数据
TEST_MANIFEST_DIR = r'C:\Users\Administrator\Desktop\test\frftdata\test\frfttest1.0.csv'def ts_gen(path=TEST_MANIFEST_DIR, batch_size=Batch_size):img_list = pd.read_csv(path)img_list = np.array(img_list)[:Lens]print("Found %s test items." % len(img_list))print("list 1 is", img_list[0, -1])steps = math.ceil(len(img_list) / batch_size)while True:for i in range(steps):batch_list = img_list[i * batch_size:i * batch_size + batch_size]batch_x = np.array([file for file in batch_list[:, 1:]])yield batch_x
TIME_PERIODS = 5000def build_model(input_shape=(TIME_PERIODS,), num_classes=4):model = Sequential()model.add(Reshape((TIME_PERIODS, 1), input_shape=input_shape))model.add(Conv1D(16, 8, strides=2, activation='relu', input_shape=(TIME_PERIODS, 1)))model.add(Conv1D(16, 8, strides=2, activation='relu', padding="same"))model.add(MaxPooling1D(2))model.add(Conv1D(64, 4, strides=2, activation='relu', padding="same"))model.add(Conv1D(64, 4, strides=2, activation='relu', padding="same"))model.add(MaxPooling1D(2))model.add(Conv1D(256, 4, strides=2, activation='relu', padding="same"))model.add(Conv1D(256, 4, strides=2, activation='relu', padding="same"))model.add(MaxPooling1D(2))model.add(Conv1D(512, 2, strides=1, activation='relu', padding="same"))model.add(Conv1D(512, 2, strides=1, activation='relu', padding="same"))model.add(MaxPooling1D(2))""" model.add(Flatten())model.add(Dropout(0.3))model.add(Dense(256, activation='relu'))"""model.add(GlobalAveragePooling1D())model.add(Dropout(0.3))model.add(Dense(num_classes, activation='softmax'))return(model)
Train = Trueif __name__ == "__main__":if Train == True:train_iter = xs_gen()val_iter = xs_gen(train=False)ckpt = keras.callbacks.ModelCheckpoint(filepath='best_model.{epoch:02d}-{val_loss:.4f}.h5',monitor='val_loss', save_best_only=True, verbose=1)model = build_model()opt = Adam(0.0002)model.compile(loss='categorical_crossentropy',optimizer = opt, metrics=['accuracy'])print(model.summary())train_history = model.fit_generator(generator=train_iter,steps_per_epoch=Lens // Batch_size,epochs=25,initial_epoch=0,validation_data=val_iter,validation_steps=(Long - Lens) // Batch_size,callbacks=[ckpt],)model.save("finishModel.h5")else:test_iter = ts_gen()model = load_model("best_model.49-0.00.h5")pres = model.predict_generator(generator=test_iter, steps=math.ceil(528 / Batch_size), verbose=1)print(pres.shape)ohpres = np.argmax(pres, axis=1)print(ohpres.shape)df = pd.DataFrame()df["id"] = np.arange(1, len(ohpres) + 1)df["label"] = ohpresdf.to_csv("predicts.csv", index=None)test_iter = ts_gen()for x in test_iter:x1 = x[0]breakplt.plot(x1)plt.show()def show_train_history(train_history, train, validation):plt.plot(train_history.history[train])plt.plot(train_history.history[validation])plt.ylabel('Train History')plt.ylabel(train)plt.xlabel('Epoch')plt.legend(['train', 'validation'], loc='upper left')plt.show()show_train_history(train_history, 'accuracy', 'val_accuracy')show_train_history(train_history, 'loss', 'val_loss')plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(6, 4))
plt.plot(train_history.history['accuracy'], "g--", label="训练集准确率")
plt.plot(train_history.history['val_accuracy'], "g", label="验证集准确率")
plt.plot(train_history.history['loss'], "r--", label="训练集损失函数")
plt.plot(train_history.history['val_loss'], "r", label="验证集损失函数")
plt.title('模型的准确率和损失函数', fontsize=14)
plt.ylabel('准确率和损失函数', fontsize=12)
plt.xlabel('世代数', fontsize=12)
plt.ylim(0)
plt.legend()
plt.show()//这里是我导入的测试集的标签表格,用来对比神经网络的测试结果,并且后面生成混淆矩阵;
//这里的标签就是200个测试集的数据的故障标签
file = r"C:\Users\Administrator\Desktop\shiyong22.csv"
all_df = pd.read_csv(file)
ndarray = all_df.values
ndarray[:2]test_iter = ts_gen()
pres = model.predict_generator(generator=test_iter, steps=math.ceil(520 / Batch_size), verbose=1)
print(pres.shape)print(ndarray.shape)ohpres = np.argmax(pres, axis=1)
print(ohpres.shape)
ohpres=ohpres[:200]
ohpresimport matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as npdef cm_plot(original_label, predict_label, pic=None):cm = confusion_matrix(original_label, predict_label)plt.figure()plt.matshow(cm, cmap=plt.cm.GnBu)plt.colorbar()for x in range(len(cm)):for y in range(len(cm)):plt.annotate(cm[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')plt.ylabel('Predicted label')plt.xlabel('True label')plt.title('Confusion Matrix')if pic is not None:plt.savefig(str(pic) + '.jpg')plt.show()plt.rcParams['font.sans-serif'] = 'SimHei'
plt.rcParams['axes.unicode_minus'] = False
cm_plot(ndarray, ohpres)from sklearn.metrics import accuracy_score
accuracy_score(ndarray, ohpres)train_history.history['loss']train_history.history['val_loss']train_history.history['val_accuracy']train_history.history['accuracy']

数据下载链接
https://gitee.com/wjj_xiaoxiansheng/cnn_-frft_-data

数据介绍
类别:标签0、1、2、3分别为正常状态、内圈故障、外圈故障、滚动体故障;
信号:每个样本信号5000个数据点,共有1000个样本。从中随机抽取800个样本作为训练集,另外200个样本作为测试集。
:我是对分数阶傅里叶变换做训练和测试,阶次(从0到1,间隔0.05,阶次为0时就是原始时域信号,阶次为1是就是傅里叶变换的数据结果,从文件夹名称可以看出。大家可以只用0.0阶次就是原始时域信号试一下。

卷积神经网络处理一维信号(故障诊断)相关推荐

  1. 一维卷积神经网络结构图,一维卷积神经网络原理

    1.卷积神经网络算法是什么? 一维构筑.二维构筑.全卷积构筑. 卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Fe ...

  2. 基于一维卷积神经网络对机械振动信号进行分类并加以预测

    基于一维卷积神经网络对机械振动信号进行分类并加以预测 *使用一维卷积神经网络训练振动信号进行二分类 2020年7月16日,一学期没等到开学,然而又放假了. 总览CSDN中大多数卷积神经网络都是对二维图 ...

  3. 基于1DCNN(一维卷积神经网络)的机械振动故障诊断

    基于1DCNN(一维卷积神经网络)的机械振动故障诊断 机械振动故障诊断最为经典的还是凯斯西储实验室的轴承故障诊断,开学一周了,上次改编鸢尾花分类的代码可用,但是并不准确.开学一周重新改编了别人的一篇代 ...

  4. (8)卷积神经网络如何处理一维时间序列数据?

    (8)卷积神经网络如何处理一维时间序列数据? 概述 许多文章都关注于二维卷积神经网络(2D CNN)的使用,特别是图像识别.而一维卷积神经网络(1D CNNs)只在一定程度上有所涉及,比如在自然语言处 ...

  5. 卷积神经网络如何处理一维时间序列数据?

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自 | 人工智能与算法学习 概述 许多文章都关注于二维卷积神 ...

  6. 一维卷积神经网络原理,一维卷积神经网络应用

    1.CNN(卷积神经网络).RNN(循环神经网络).DNN(深度神经网络)的内部网络结构有什么区别? 如下: 1.DNN:存在着一个问题--无法对时间序列上的变化进行建模.然而,样本出现的时间顺序对于 ...

  7. 基于卷积神经网络的心音信号识别

    目录 一.引言 1.研究背景 2.研究方案 二.信号预处理 1.小波去噪 2.处理效果 三.特征提取 1.数据筛选 2.特征提取 四.模型搭建 一.引言 1.研究背景 心音信号作为生物医学信号中的重要 ...

  8. 卷积神经网络之一维卷积、二维卷积、三维卷积

    1. 二维卷积 图中的输入的数据维度为14×1414×14,过滤器大小为5×55×5,二者做卷积,输出的数据维度为10×1010×10(14−5+1=1014−5+1=10).如果你对卷积维度的计算不 ...

  9. 刘雪峰卷积神经网络,刘雪峰老师哪里人

    哪些神经网络可以用在图像特征提取上 BP神经网络.离散Hopfield网络.LVQ神经网络等等都可以. 1.BP(BackPropagation)神经网络是1986年由Rumelhart和McCell ...

最新文章

  1. shell基础04 结构化命令
  2. 刻意练习:Python基础 -- Task10. 类与对象
  3. 用VS Code直接浏览GitHub代码 | 12.1K星
  4. Helm 从入门到实践 | 从 0 开始制作一个 Helm Charts
  5. mongoTemplate使用总结
  6. 2.4-2.5、Hive整合(整合Spark、整合Hbase)、连接方式Cli、HiveServer和hivemetastore、Squirrel SQL Client等
  7. 机器学习-LR推导及与SVM的区别
  8. qt4 连接mysql_Qt4访问mysql 数据库的简单教程
  9. *【HDU - 5707】Combine String(dp)
  10. Windows 2003 上使用 Windows Live Writer
  11. 王道 —— 进程通信
  12. smb协议讲解_SMB/CIFS协议解析
  13. UID_PR_01_基础操作
  14. 联想员工亲历联想大裁员:公司不是我的家
  15. 主力吸筹猛攻指标源码_通达信主力吸筹副图指标公式,通达信主力追踪副图源码...
  16. ThinkPHP6集成腾讯云、短信宝短信发送的工具类
  17. 华为:编程实现联想输入法 输入联想功能是非常实用的一个功能,请编程实现类似功能
  18. 多目标跟踪(MOT,Multiple Object Tracking)评价指标
  19. 文本溢出显示省略号效果
  20. 被骂“没前途”,那个996的程序员做错了什么?

热门文章

  1. pip:ffi.h: No such file or directory“
  2. TeamViewer“试用期已到期”解决方法
  3. Social Network之缘分
  4. 关于java操作zebraZT230打印机
  5. 【数据结构Python描述】优先级队列描述“银行VIP客户插队办理业务”及“被插队客户愤而离去”的模型实现
  6. 移动运营商AIS泄漏了83亿条用户数据 容量约4.7 TB
  7. 第九次ScrumMeeting博客
  8. vue使用云信实时通信(聊天室)
  9. html常见盒子居中小结
  10. Visual Studio Professional 2015 简体中文专业版