结合深度置信网络(DBN)在提取特征和处理高维、非线性数据等方面的优势,提出一种基于深度置信网络的分类方法。该方法通过深度学习利用原始时域信号的傅里叶频谱(FFT)训练深度置信网络,其优势在于该方法对信号进行FFT时无需设置参数,且直接采用所有频谱分量进行建模,因此无需复杂的特征选择方法,具有较强的通用性和适应性。最后,为了进一步增强DBN的分类精度,采用麻雀搜索算法(SSA)对DBN各权重参数进行优化,实验结果表明:本文提出的方法能够有效地提高分类识别精度。

这个流程十分常见,之所以写这个博客,是因为麻雀搜索算法(SSA)是20年出来的优化算法,目前知网上还没有相关文章,因此对于有需要的人来说,可以水水文章。

1,数据处理

分类数据采用凯斯西楚的轴承数据,OHP/48KHz,数据处理程序如下:

%% 数据预处理(训练集 验证集 测试集划分)
clc;clear;close all%% 加载原始数据
load 0HP/48k_Drive_End_B007_0_122;    a1=X122_DE_time'; %1
load 0HP/48k_Drive_End_B014_0_189;    a2=X189_DE_time'; %2
load 0HP/48k_Drive_End_B021_0_226;    a3=X226_DE_time'; %3
load 0HP/48k_Drive_End_IR007_0_109;   a4=X109_DE_time'; %4
load 0HP/48k_Drive_End_IR014_0_174 ;  a5=X173_DE_time';%5
load 0HP/48k_Drive_End_IR021_0_213 ;  a6=X213_DE_time';%6
load 0HP/48k_Drive_End_OR007@6_0_135 ;a7=X135_DE_time';%7
load 0HP/48k_Drive_End_OR014@6_0_201 ;a8=X201_DE_time';%8
load 0HP/48k_Drive_End_OR021@6_0_238 ;a9=X238_DE_time';%9
load 0HP/normal_0_97                 ;a10=X097_DE_time';%10%%
N=100;
L=864;% 每种状态取N个样本  每个样本长度为L
data=[];label=[];
for i=1:10if i==1;ori_data=a1;endif i==2;ori_data=a2;endif i==3;ori_data=a3;endif i==4;ori_data=a4;endif i==5;ori_data=a5;endif i==6;ori_data=a6;endif i==7;ori_data=a7;endif i==8;ori_data=a8;endif i==9;ori_data=a9;endif i==10;ori_data=a10;endfor j=1:Nstart_point=randi(length(ori_data)-L);%随机取一个起点end_point=start_point+L+1;data=[data ;ori_data(start_point:end_point)];label=[label;i];end
end
%% 标签转换 onehot编码
output=zeros(10*N,10);
for i = 1:10*Noutput(i,label(i))=1;
end
%% 划分训练集 验证集与测试集 7:2:1比例
n=randperm(10*N);
m1=round(0.7*10*N);
m2=round(0.9*10*N);
train_X=data(n(1:m1),:);
train_Y=output(n(1:m1),:);valid_X=data(n(m1+1:m2),:);
valid_Y=output(n(m1+1:m2),:);test_X=data(n(m2+1:end),:);
test_Y=output(n(m2+1:end),:);save data_process train_X train_Y valid_X valid_Y test_X test_Y

2,FFT特征提取

%% 提取FFT谱作为特征向量
clc;close all;clear
%%
load data_process
%% 2、加载数据
x1=train_X(1,:);
% 采样点
L=length(x1);
%采样频率
fs=48000;
%采样间隔
Ts=1/fs;
%采样点数
t=Ts:Ts:L*Ts;
%轴承信号
% 3、对原始信号作图
figure
plot(t,x1)
title('原始信号')
xlabel('采样点/n')
ylabel('幅值')%  4、fft频谱
Y = fft(x1);
P2 = abs(Y/L);
P1 = P2(1:L/2+1);
P1(2:end-1) = 2*P1(2:end-1);
f = fs*(0:(L/2))/L;figure
bar(f,P1)
title('FFT频谱')
xlabel('频率/Hz')
ylabel('幅值')
%% 训练集
TZ=[];
for IIII=1:size(train_X,1) %依次对每个样本进行处理x1=train_X(IIII,:);%轴承信号Y = fft(x1);L=length(x1);P2 = abs(Y/L);P1 = P2(1:L/2+1);P1(2:end-1) = 2*P1(2:end-1);TZ(IIII,:)=P1/max(P1);
end
disp('训练集提取完毕')
train_X=TZ;
%% 验证集
TZ=[];
for IIII=1:size(valid_X,1) %依次对每个样本进行处理x1=valid_X(IIII,:);%轴承信号Y = fft(x1);L=length(x1);P2 = abs(Y/L);P1 = P2(1:L/2+1);P1(2:end-1) = 2*P1(2:end-1);TZ(IIII,:)=P1/max(P1);
end
disp('验证集提取完毕')
valid_X=TZ;%% 测试集
TZ=[];
for IIII=1:size(test_X,1) %依次对每个样本进行处理x1=test_X(IIII,:);%轴承信号Y = fft(x1);L=length(x1);P2 = abs(Y/L);P1 = P2(1:L/2+1);P1(2:end-1) = 2*P1(2:end-1);TZ(IIII,:)=P1/max(P1);
end
disp('测试集提取完毕')
test_X=TZ;%% 保存结果
save data_feature train_X valid_X test_X train_Y valid_Y test_Y

3,DBN分类

clc;clear;close all;
tic
%% 加载数据
% load('data_process.mat');
load('data_feature.mat');trainX=double(train_X);
trainYn=double(train_Y);
testX=double(test_X);
testYn=double(test_Y);
clear train_X train_Y test_X test_Y valid_X valid_Y%% DBN参数设置
rng(0)
% 网络各层节点
input_num=size(trainX,2);%输入层
hidden_num=[50 20];%隐含层,两个数就是两个隐含层 3个数就是3个隐含层
class=size(trainYn,2);%输出层
nodes = [input_num hidden_num class]; %节点数
% 初始化网络权值
dbn = randDBN(nodes);%调用randDBN
nrbm=numel(dbn.rbm);
opts.MaxIter =100;                % 迭代次数
% opts.BatchSize = round(length(trainYn)/4);  % batch规模为四分之一的训练集trainY的长度进行四舍五入取整
opts.BatchSize = 32;  % batch规模
opts.Verbose = 0;               % 是否展示中间过程
opts.StepRatio = 0.1;             % 学习速率
% opts.InitialMomentum = 0.9;%opts.InitialMomentum为0.7
% opts.FinalMomentum = 0.1;%opts.FinalMomentum为0.8
% opts.WeightCost = 0.005;%opts.WeightCost为0
%opts.InitialMomentumIter = 10;%% RBM逐层预训练
dbn = pretrainDBN(dbn, trainX, opts);%进行dbn的预训练
%% 线性映射-将训练好的各RBM 堆栈初始化DBN网络
dbn= SetLinearMapping(dbn, trainX, trainYn);%调用SetLinearMapping函数%% 训练DBN-微调整个DBN
opts.MaxIter =100;                % 迭代次数
% opts.BatchSize = round(length(trainYn)/4);  % batch规模为四分之一的训练集trainY的长度进行四舍五入取整
opts.BatchSize =32;
opts.Verbose = 0;               % 是否展示中间过程
opts.StepRatio = 0.1;             % 学习速率
opts.Object = 'CrossEntropy';            % 目标函数: Square CrossEntropy
%opts.Layer = 1;
dbn = trainDBN(dbn, trainX, trainYn, opts);%dbn调用trainDBN函数%% 测试
% 对训练集进行预测
trainYn_out = v2h( dbn, trainX );%trainYn_out调用v2h函数
[~,trainY] = max(compet(trainYn'));
[~,trainY_out] = max(compet(trainYn_out'));
%compet是神经网络的竞争传递函数,用于指出矩阵中每列的最大值。对应最大值的行的值为1,其他行的值都为0。
%分类
% 计算准确率
accTrain = sum(trainY==trainY_out)/length(trainY);%accTrain为TrainY==trainY_out'的总和除以length(trainY)% 画训练集预测结果
figure%图形
plot(trainY,'r o')%画一个名为trainY,红色的圆圈
hold on%hold on 是当前轴及图形保持而不被刷新,准备接受此后将绘制
plot(trainY_out,'g +')%画一个名为trainY_out,绿色的加号
legend('真实值','预测值')%legend(图例1,图例2,)
grid on%画网格
xlabel('样本','fontsize',13)%xlabel(x轴说明)
ylabel('类别','fontsize',13)%ylabel(y轴说明)
title(['原始数据迭代100次训练集准确率:' num2str(accTrain*100) '%'],'fontsize',13)%title(图形名称)% 对测试集进行预测
testYn_out = v2h( dbn, testX );%testYn_out为调用v2h函数
[~,testY] = max(compet(testYn'));
[~,testY_out] = max(compet(testYn_out'));%compet是神经网络的竞争传递函数,用于指出矩阵中每列的最大值。对应最大值的行的值为1,其他行的值都为0。% 计算准确率
accTest = sum(testY==testY_out)/length(testY);%accTest为testY==testY_out'的总和除以length(testY)% 画测试集预测结果
figure%图形
plot(testY,'r o')%画一个名为testY,红色的圆圈
hold on%hold on 是当前轴及图形保持而不被刷新,准备接受此后将绘制
plot(testY_out,'g *')%画一个名为testY_out,绿色的加号
legend('真实值','预测值')%legend(图例1,图例2,)
grid on%画网格
xlabel('样本','fontsize',13)%xlabel(x轴说明)
ylabel('类别','fontsize',13)%ylabel(y轴说明)
title(['原始数据迭代100次测试集准确率:' num2str(accTest*100) '%'],'fontsize',13)%title(图形名称)
toc

4.SSA-DBN分类

clc;clear;close all;format compact
tic
%% 加载数据
% load('data_process.mat');
load('data_feature.mat');
trainX=double(train_X);
trainYn=double(train_Y);
validX=double(valid_X);
validYn=double(valid_Y);
testX=double(test_X);
testYn=double(test_Y);
clear train_X train_Y test_X test_Y valid_X valid_Y%% 麻雀算法优化DBN
% 网络各层节点设置
input_num=size(trainX,2);%输入层
hidden_num=[50 30];%隐含层,两个数就是两个隐含层 3个数就是3个隐含层
class=size(trainYn,2);%输出层
nodes = [input_num hidden_num class]; %节点数% [x,trace]=ssafordbn(trainX,validX,trainYn,validYn,nodes); %ssa优化隐含层节点数
% save best_para x trace
%%
load best_para
%将优化结果放到h中
figure
plot(trace)
title('适应度进化曲线')
xlabel('迭代次数')
ylabel('目标函数值/错误率')

采用训练集进行DBN训练,优化时,SSA以验证集的错误率为适应度函数,进行优化,目标函数如图,SSA的目的就是找到一组权重参数,使得训练出来的DBN在验证集上拥有最低的错误率。

采用优化得到的这组超参数,重新训练模型,并对测试集分类,结果如图:

基于麻雀搜索算法优化深度置信网络的分类方法(SSA-DBN)相关推荐

  1. 【DBN分类】基于麻雀算法优化深度置信网络SSA-DBN实现数据分类附matlab代码

    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信.

  2. 【DBN分类】基于matlab麻雀算法优化深度置信网络SSA-DBN数据分类【含Matlab源码 2318期】

    ⛄一.DBN DBN由数个RBM堆叠构成,通常会在顶层加入一个BPNN来实现有监督的分类,DBN中下一层的隐藏层就是上一层的可见层.图1所示的DBN即由两个RBM和顶层一个BPNN构成. 图1 深度置 ...

  3. 基于麻雀搜索算法优化的lssvm回归预测

    基于麻雀搜索算法优化的lssvm回归预测 - 附代码 文章目录 基于麻雀搜索算法优化的lssvm回归预测 - 附代码 1.数据集 2.lssvm模型 3.基于麻雀算法优化的LSSVM 4.测试结果 5 ...

  4. 基于麻雀搜索算法优化的支持向量机回归预测-附代码

    基于麻雀搜索算法优化的支持向量机预测及其MATLAB代码实现 文章目录 基于麻雀搜索算法优化的支持向量机预测及其MATLAB代码实现 1. 基于麻雀搜索算法优化的支持向量机预测简介 1.1 支持向量机 ...

  5. 基于Teager-Kaiser能量算子和深度置信网络的往复式压缩机阀门故障诊断方法

    原文:An approach to fault diagnosis of reciprocating compressor valves using Teager–Kaiser energy oper ...

  6. 单目标应用:基于麻雀搜索算法优化灰色神经网络(grey neural network)的数据预测(提供MATLAB代码)

    一.麻雀搜索算法 麻雀搜索算法(sparrow search algorithm,SSA)由Jiankai Xue等人于2020年提出,该算法是根据麻雀觅食并逃避捕食者的行为而提出的群智能优化算法.S ...

  7. 【DBN分类】基于哈里斯鹰算法优化深度置信网络HHO-DBN实现数据分类附matlab代码

    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信.

  8. 【DBN分类】基于粒子群算法优化深度置信网络PSO-DBN实现数据分类附matlab代码

    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信.

  9. 【预测模型-ELM预测】基于麻雀算法优化极限学习机预测附matlab代码

    1 内容介绍 一种基于麻雀搜索算法优化极限学习机的风电功率预测方法,具体包括如下步骤:步骤1,确定影响风电功率的主导影响因子:步骤2,构建麻雀搜索算法优化核极限学习机预测模型,通过该模型对风电功率进行 ...

  10. 基于深度学习的安卓恶意应用检测----------android manfest.xml + run time opcode, use 深度置信网络(DBN)...

    基于深度学习的安卓恶意应用检测 from:http://www.xml-data.org/JSJYY/2017-6-1650.htm 苏志达, 祝跃飞, 刘龙     摘要: 针对传统安卓恶意程序检测 ...

最新文章

  1. Android开发——回调(Callback)
  2. 技术解析系列 阿里 PouchContainer 资源管理探秘
  3. 语音特征提取: MFCC的理解
  4. JFinal问题整理
  5. 必读:Java Java
  6. android power 按键,Android Framework层Power键关机流程(一,Power长按键操作处理)
  7. (二叉树DFS)下落的树叶
  8. FrameBuffer编程二(简单的程序上)
  9. Linux 系统使用之 VMware Tools安装
  10. FFMPEG类库打开流媒体的方法(需要传参数的时候)
  11. Python爬取网易云热歌榜所有音乐及其热评
  12. 面对SDN,我们该怎么办?
  13. 信息管理系统项目前端界面设计
  14. linux vi 应用
  15. c语言关于数组排序法和插入一个数的详细讲解
  16. 论坛刷访客神器-Header自定义工具
  17. 知乎App加密流量分析初探
  18. Python 用户输入和循环的学习
  19. 100uF,10uF,100nF,10nF不同的容值,这些参数是如何确定的?
  20. Git的基本使用方法教程(入门级)

热门文章

  1. 你的Android HTTPS真的安全吗?(转载)
  2. VMware虚拟机体验koolshare论坛LEDE固件
  3. ASA 5520 ASDM 配置
  4. 李永乐2021线代讲义练习题答案
  5. 原生JS实现canvas移动端电子签名板/画板
  6. J2EE框架技术(持续更新)
  7. 用Java开发手机Andriod系统Apk软件
  8. AForge.net获取摄像头
  9. 照片格式怎么快速转JPG或JPEG格式
  10. three.js入门——写个小车