递归神经网络LSTM原理——结合实例MATLAB实现
最近正在看递归神经网络,看了网上很多博文,算是鱼龙混杂,并且基本都是使用Python实现,要不就是使用Matlab中的函数库等。对于使用Matlab的同学,甚为不方便。所以我将结合实例,使用matlab语言,完成递归神经网络程序的编写(LSTM)。本人菜鸡一枚,如有错误还望各路大神,指正教导。文章的问题和数据和我之前写的递归神经网络BPTT文章中一致,方便大家比较两种方法的差异,文章链接递归神经网络BPTT的MATLAB实现。另外,关于理论推导算法步骤,等我过几天有时间更新。
一、问题描述
问题描述
二、相关数据
相关数据
三、程序代码
LSTM_mian.m
%%% LSTM网络结合实例仿真
%%% 作者:xd.wp
%%% 时间:2016.10.08 12:06
%% 程序说明
% 1、数据为7天,四个时间点的空调功耗,用前三个推测第四个训练,依次类推。第七天作为检验
% 2、LSTM网络输入结点为12,输出结点为4个,隐藏结点18个clear all;
clc;
%% 数据加载,并归一化处理
[train_data,test_data]=LSTM_data_process();
data_length=size(train_data,1);
data_num=size(train_data,2);
%% 网络参数初始化
% 结点数设置
input_num=12;
cell_num=18;
output_num=4;
% 网络中门的偏置
bias_input_gate=rand(1,cell_num);
bias_forget_gate=rand(1,cell_num);
bias_output_gate=rand(1,cell_num);
% ab=1.2;
% bias_input_gate=ones(1,cell_num)/ab;
% bias_forget_gate=ones(1,cell_num)/ab;
% bias_output_gate=ones(1,cell_num)/ab;
%网络权重初始化
ab=20;
weight_input_x=rand(input_num,cell_num)/ab;
weight_input_h=rand(output_num,cell_num)/ab;
weight_inputgate_x=rand(input_num,cell_num)/ab;
weight_inputgate_c=rand(cell_num,cell_num)/ab;
weight_forgetgate_x=rand(input_num,cell_num)/ab;
weight_forgetgate_c=rand(cell_num,cell_num)/ab;
weight_outputgate_x=rand(input_num,cell_num)/ab;
weight_outputgate_c=rand(cell_num,cell_num)/ab;%hidden_output权重
weight_preh_h=rand(cell_num,output_num);%网络状态初始化
cost_gate=1e-6;
h_state=rand(output_num,data_num);
cell_state=rand(cell_num,data_num);
%% 网络训练学习
for iter=1:3000yita=0.01; %每次迭代权重调整比例for m=1:data_num%前馈部分if(m==1)gate=tanh(train_data(:,m)'*weight_input_x);
input_gate_input=train_data(:,m)'*weight_inputgate_x+bias_input_gate;output_gate_input=train_data(:,m)'*weight_outputgate_x+bias_output_gate;
for n=1:cell_num
input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
end
forget_gate=zeros(1,cell_num);
forget_gate_input=zeros(1,cell_num);
cell_state(:,m)=(input_gate.*gate)';elsegate=tanh(train_data(:,m)'*weight_input_x+h_state(:,m-1)'*weight_input_h);input_gate_input=train_data(:,m)'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;forget_gate_input=train_data(:,m)'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;output_gate_input=train_data(:,m)'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;for n=1:cell_numinput_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));endcell_state(:,m)=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)'; endpre_h_state=tanh(cell_state(:,m)').*output_gate;
h_state(:,m)=(pre_h_state*weight_preh_h)';%误差计算Error=h_state(:,m)-test_data(:,m);Error_Cost(1,iter)=sum(Error.^2);if(Error_Cost(1,iter)<cost_gate)flag=1;break;else[ weight_input_x,...weight_input_h,...weight_inputgate_x,...weight_inputgate_c,...weight_forgetgate_x,...weight_forgetgate_c,...weight_outputgate_x,...weight_outputgate_c,...weight_preh_h ]=LSTM_updata_weight(m,yita,Error,...weight_input_x,...weight_input_h,...weight_inputgate_x,...weight_inputgate_c,...weight_forgetgate_x,...weight_forgetgate_c,...weight_outputgate_x,...weight_outputgate_c,...weight_preh_h,...cell_state,h_state,...input_gate,forget_gate,...output_gate,gate,...train_data,pre_h_state,...input_gate_input,...output_gate_input,...forget_gate_input);endendif(Error_Cost(1,iter)<cost_gate)break;end
end
%% 绘制Error-Cost曲线图
% for n=1:1:iter
% text(n,Error_Cost(1,n),'*');
% axis([0,iter,0,1]);
% title('Error-Cost曲线图');
% end
for n=1:1:itersemilogy(n,Error_Cost(1,n),'*');hold on;title('Error-Cost曲线图');
end
%% 使用第七天数据检验
%数据加载
test_final=[0.4557 0.4790 0.7019 0.8211 0.4601 0.4811 0.7101 0.8298 0.4612 0.4845 0.7188 0.8312]';
test_final=test_final/sqrt(sum(test_final.^2));
test_output=test_data(:,4);
%前馈
m=4;
gate=tanh(test_final'*weight_input_x+h_state(:,m-1)'*weight_input_h);
input_gate_input=test_final'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate;
forget_gate_input=test_final'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate;
output_gate_input=test_final'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate;
for n=1:cell_num
input_gate(1,n)=1/(1+exp(-input_gate_input(1,n)));
forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n)));
output_gate(1,n)=1/(1+exp(-output_gate_input(1,n)));
end
cell_state_test=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)';
pre_h_state=tanh(cell_state_test').*output_gate;
h_state_test=(pre_h_state*weight_preh_h)'
test_output
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
LSTM_data_process.m
function [train_data,test_data]=LSTM_data_process()
%% 数据加载并完成初始归一化
train_data_initial= [0.4413 0.4707 0.6953 0.8133 0.4379 0.4677 0.6981 0.8002 0.4517 0.4725 0.7006 0.8201;0.4379 0.4677 0.6981 0.8002 0.4517 0.4725 0.7006 0.8201 0.4557 0.4790 0.7019 0.8211;0.4517 0.4725 0.7006 0.8201 0.4557 0.4790 0.7019 0.8211 0.4601 0.4911 0.7101 0.8298]';
% train_data_initial=[ 0.4413 0.4707 0.6953 0.8133;
% 0.4379 0.4677 0.6981 0.8002;
% 0.4517 0.4725 0.7006 0.8201;
% 0.4557 0.4790 0.7019 0.8211;
% 0.4601 0.4811 0.7101 0.8298;
% 0.4612 0.4845 0.7188 0.8312]';
test_data_initial=[0.4557 0.4790 0.7019 0.8211;0.4612 0.4845 0.7188 0.8312;0.4601 0.4811 0.7101 0.8298;0.4615 0.4891 0.7201 0.8330]';data_length=size(train_data_initial,1); %每个样本的长度
data_num=size(train_data_initial,2); %样本数目 %%归一化过程
for n=1:data_numtrain_data(:,n)=train_data_initial(:,n)/sqrt(sum(train_data_initial(:,n).^2));
end
for m=1:size(test_data_initial,2)test_data(:,m)=test_data_initial(:,m)/sqrt(sum(test_data_initial(:,m).^2));
end
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
LSTM_updata_weight.m
function [ weight_input_x,weight_input_h,weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h ]=LSTM_updata_weight(n,yita,Error,...weight_input_x, weight_input_h, weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h,...cell_state,h_state,input_gate,forget_gate,output_gate,gate,train_data,pre_h_state,input_gate_input, output_gate_input,forget_gate_input)
%%% 权重更新函数
input_num=12;
cell_num=18;
output_num=4;
data_length=size(train_data,1);
data_num=size(train_data,2);
weight_preh_h_temp=weight_preh_h;%% 更新weight_preh_h权重
for m=1:output_numdelta_weight_preh_h_temp(:,m)=2*Error(m,1)*pre_h_state;
end
weight_preh_h_temp=weight_preh_h_temp-yita*delta_weight_preh_h_temp;%% 更新weight_outputgate_x
for num=1:output_numfor m=1:data_lengthdelta_weight_outputgate_x(m,:)=(2*weight_preh_h(:,num)*Error(num,1).*tanh(cell_state(:,n)))'.*exp(-output_gate_input).*(output_gate.^2)*train_data(m,n);endweight_outputgate_x=weight_outputgate_x-yita*delta_weight_outputgate_x;
end
%% 更新weight_inputgate_x
for num=1:output_num
for m=1:data_lengthdelta_weight_inputgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*train_data(m,n);
end
weight_inputgate_x=weight_inputgate_x-yita*delta_weight_inputgate_x;
endif(n~=1)%% 更新weight_input_xtemp=train_data(:,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;for num=1:output_numfor m=1:data_lengthdelta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);endweight_input_x=weight_input_x-yita*delta_weight_input_x;end%% 更新weight_forgetgate_xfor num=1:output_numfor m=1:data_lengthdelta_weight_forgetgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train_data(m,n);endweight_forgetgate_x=weight_forgetgate_x-yita*delta_weight_forgetgate_x;end%% 更新weight_inputgate_cfor num=1:output_numfor m=1:cell_numdelta_weight_inputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);endweight_inputgate_c=weight_inputgate_c-yita*delta_weight_inputgate_c;end%% 更新weight_forgetgate_cfor num=1:output_numfor m=1:cell_numdelta_weight_forgetgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);endweight_forgetgate_c=weight_forgetgate_c-yita*delta_weight_forgetgate_c;end%% 更新weight_outputgate_cfor num=1:output_numfor m=1:cell_numdelta_weight_outputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);endweight_outputgate_c=weight_outputgate_c-yita*delta_weight_outputgate_c;end%% 更新weight_input_htemp=train_data(:,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h;for num=1:output_numfor m=1:output_numdelta_weight_input_h(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);endweight_input_h=weight_input_h-yita*delta_weight_input_h;end
else%% 更新weight_input_xtemp=train_data(:,n)'*weight_input_x;for num=1:output_numfor m=1:data_lengthdelta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n);endweight_input_x=weight_input_x-yita*delta_weight_input_x;end
end
weight_preh_h=weight_preh_h_temp;end
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
四、程序结果图
Error_Cost图
第七天预测值与理论值,第一组为预测值,第二组为实际值
递归神经网络LSTM原理——结合实例MATLAB实现相关推荐
- 长短记忆型递归神经网络LSTM
原文链接http://www.csdn.NET/article/2015-11-25/2826323?ref=myread scrolling="no" src="htt ...
- 自组织神经网络SOM原理——结合例子MATLAB实现
本文主要内容为SOM神经网络原理的介绍,并结合实例给出相应的MATLAB代码实现,方便初学者接触学习,本人才疏学浅,如有纰漏,还望各路大神积极指点. 一.SOM神经网络介绍 自组织映射神经网络, 即S ...
- 通过展开序列ISTA(SISTA)算法创建的递归神经网络(RNN)(Matlab代码实现)
目录
- rnn 递归神经网络_递归神经网络rnn的简单解释
rnn 递归神经网络 Recurrent neural network is a type of neural network used to deal specifically with seque ...
- 系统学习深度学习(五) --递归神经网络原理,实现及应用
递归神经网络(RNN),是两种人工神经网络的总称,一种是时间递归神经网络(recurrent neural network),另一种是结构递归神经网络(recursive neur ...
- 神经网络算法实例说明,简单神经网络算法原理
神经网络算法实例说明有哪些? 在网络模型与算法研究的基础上,利用人工神经网络组成实际的应用系统,例如,完成某种信号处理或模式识别的功能.构作专家系统.制成机器人.复杂系统控制等等. 纵观当代新兴科学技 ...
- python 神经网络预测未来30天数据_使用LSTM循环神经网络的时间序列预测实例:预测未来的货币汇率...
Statsbot团队发表过一篇关于使用时间序列分析来进行异常检测的文章.文章地址:https://blog.statsbot.co/time-series-anomaly-detection-algo ...
- BP神经网络原理及在Matlab中的应用
一.人工神经网络 关于对神经网络的介绍和应用,请看如下文章 神经网络潜讲 如何简单形象又有趣地讲解神经网络是什么 二.人工神经网络分类 按照连接方式--前向神经网络.反馈(递归)神经网络 按照 ...
- 神经网络原理与实例精解,神经网络计算工作原理
神经网络算法原理 4.2.1概述人工神经网络的研究与计算机的研究几乎是同步发展的. 1943年心理学家McCulloch和数学家Pitts合作提出了形式神经元的数学模型,20世纪50年代末,Rosen ...
最新文章
- 扩增子分析流程QIIME. 1 使用Docker配置QIIME
- python-九九乘法表
- php直销二叉树,PHP二叉树递归算法
- vSphere 7 With K8s系列07:客户端工具使用
- 负载均衡算法 : 加权轮询
- 20200718每日一句
- php 怎么判断月份最后一天_PHP基础案例三:判断学生星座
- 如何在Macbook安装Linux双系统
- 一招鲜——交换机配置mstp+vrrp实验
- fftshift函数详解
- CCNA(七)思科路由器基本配置
- oracle中间人投毒漏洞,‘TNS Listener’远程数据投毒漏洞
- 某程序员上线原谅宝:抓取全球不可描述网站和社交平台10万渣女
- man手册,安装中文手册
- autojs代码大全(实战演练)
- 哔哩哔哩 声音太小,怎么调大?
- Docker部署mysql主从模式
- 轻量级Qt键盘-实现篇
- 【硬刚大数据】企业级大数据平台建设参考 | 淘宝滴滴美团360快手京东
- openstack--T版—nava计算服务