模型结构:

4层bp模型如下

代码:

使用方法;
python bp_train_use_matrix.py 0.16

# coding: utf-8

运行结果如下:

推导

一共4层:x,h,m,y
模型公式为:(@表示矩阵乘法)
h=torch.tanh(x@Wx)
m=torch.tanh(h@Wh)
y=torch.tanh(m@Wm)

x=[x0,x1,x2]
h=[h0,h1,h2,h3]
m=[m0,m1,m2]
y=[y0,y1]

loss = 0.5(y^-y*)
y^:模型预测
y*:标注(真实值)

Loss对Wm求导:DWm

Wm = [ [wm00,wm01],
[wm10,wm11],
[wm20,wm21]]

DL/Dwm00 = DL/Dy0*Dy0/Dwm01 = (y0^-y0*) * (1-y0**2)*m0
DL/Dwm01 = DL/Dy1*Dy1/Dwm02 = (y1^-y1*) * (1-y1**2)*m0
DL/Dwm10 = DL/Dy0*Dy0/Dwm11 = (y0^-y0*) * (1-y0**2)*m1
DL/Dwm11 = DL/Dy1*Dy1/Dwm12 = (y1^-y1*) * (1-y1**2)*m1
DL/Dwm20 = DL/Dy0*Dy0/Dwm21 = (y0^-y0*) * (1-y0**2)*m2
DL/Dwm21 = DL/Dy1*Dy1/Dwm22 = (y1^-y1*) * (1-y1**2)*m2

若:EY = [y0^-y0*,y1^-y1*],DY=[1-y0**2,1-y1**2]

则:EY.*DY = [(y0^-y0*)(1-y0**2),(y1^-y1*)*(1-y1**2)]=[e0,e1]

则:

DL/Dwm00 = DL/Dy0*Dy0/Dwm00 = (y0^-y0*) * (1-y0**2)*m0 = e0*m0
DL/Dwm01 = DL/Dy1*Dy1/Dwm00 = (y1^-y1*) * (1-y1**2)*m0 =e1*m0
DL/Dwm10 = DL/Dy0*Dy0/Dwm10 = (y0^-y0*) * (1-y0**2)*m1 = e0*m1
DL/Dwm11 = DL/Dy1*Dy1/Dwm11 = (y1^-y1*) * (1-y1**2)*m1 = e1*m1
DL/Dwm20 = DL/Dy0*Dy0/Dwm20 = (y0^-y0*) * (1-y0**2)*m2 = e0*m2
DL/Dwm21 = DL/Dy1*Dy1/Dwm21 = (y1^-y1*) * (1-y1**2)*m2 = e1*m2

即:

Dwm =[[e0*m0,e1**m0]
[e0*m1,e1**m1]
[e0*m2,e1**m2] ]

=

[e0,e1].*[m0,m0]
[e0,e1].*[m1,m1]
[e0,e1].*[m2,m2]

=

[EY.*DY].*[m0,m0]
[EY.*DY].*[m1,m1]
[EY.*DY].*[m2,m2]

Loss对Wh求导:DWh

Wh = [ [wh00,wh01,wh02],
[wh10,wh11,wh12],
[wh20,wh21,wh22],
[wh30,wh31,wh32]]

则Dwh中各项:

DL/Dwh00= DL/Dm0*Dm0/Dwh00
DL/Dwh01= DL/Dm1*Dm1/Dwh01
DL/Dwh02= DL/Dm2*Dm2/Dwh02
DL/Dwh10= DL/Dm0*Dm0/Dwh10
DL/Dwh11= DL/Dm1*Dm1/Dwh11
DL/Dwh12= DL/Dm2*Dm2/Dwh12
DL/Dwh20= DL/Dm0*Dm0/Dwh20
DL/Dwh21= DL/Dm1*Dm1/Dwh21
DL/Dwh22= DL/Dm2*Dm2/Dwh22

DL/Dwh30= DL/Dm0*Dm0/Dwh30
DL/Dwh31= DL/Dm1*Dm1/Dwh31
DL/Dwh32= DL/Dm2*Dm2/Dwh32

上式子中的DL/Dm需要先求出:

DL/Dm0 = DL/Dy0*Dy0/Dm0 + DL/Dy1*Dy1/Dm0
= (y0^-y0*)*(1-y0^^2)*wm01 + (y1^-y1*)*(1-y1^^2)*wm02
= [EY.*DY] * [wm01,wm02]
= [EY.*DY] * Wm[0,:].t()

DL/Dm1 = DL/Dy0*Dy0/Dm1 + DL/Dy1*Dy1/Dm1
= EY.*DY * Wm[1,:].t()

DL/Dm2 = DL/Dy0*Dy0/Dm2 + DL/Dy1*Dy1/Dm2
= EY.*DY * Wm[2,:].t()

化简为矩阵的形式:

[ DL/Dm0, DL/Dm1, DL/Dm2] =EY.*DY * Wm.t() = [em0,em1,em2]=EM

DM = [1-m0**2,1-m1**2,1-m2**2]

EM.*DM = [em0*(1-m0**2),em1*(1-m1**2),em2*(1-m2**2)]=[dm0,dm1,dm2]

则Dwh中各项:

DL/Dwh00= DL/Dm0*Dm0/Dwh00 = dm0*h0
DL/Dwh01= DL/Dm1*Dm1/Dwh01 =dm1*h0
DL/Dwh02= DL/Dm2**Dm2/Dwh02 =dm2*h0
DL/Dwh10= DL/Dm0*Dm0/Dwh10 =dm0*h1
DL/Dwh11= DL/Dm1*Dm1/Dwh11 =dm1*h1
DL/Dwh12= DL/Dm2*Dm2/Dwh12 =dm2*h1
DL/Dwh20= DL/Dm0*Dm0/Dwh20 =dm0*h2
DL/Dwh21= DL/Dm1*Dm1/Dwh21 =dm1*h2
DL/Dwh22= DL/Dm2*Dm2/Dwh22. =dm2*h2

DL/Dwh30= DL/Dm0*Dm0/Dwh30 =dm1*h3
DL/Dwh31= DL/Dm1*Dm1/Dwh31 =dm2*h3
DL/Dwh32= DL/Dm2*Dm2/Dwh32. =dm3*h3

即:

Dwh =[[dm0*h0,dm1*h0,dm2*h0],
[dm0*h1,dm1*h1,dm2*h1],
[dm0*h2,dm1*h2,dm2*h3],
[dm0*h3,dm1*h3,dm2*h3]]

=

[dm0,dm1,dm2].*[h0,h0,h0]
[dm0,dm1,dm2].*[h1,h1,h1]
[dm0,dm1,dm2].*[h2,h2,h2]
[dm0,dm1,dm2].*[h3,h3,h3]
=

[EM.*DM].*[h0,h0,h0]
[EM.*DM].*[h1,h1,h1]
[EM.*DM].*[h2,h2,h2]
[EM.*DM].*[h3,h3,h3]

最后Loss对Wx求导:DWx

由EM=EY.*DY*Wm.t(),类比可得:
EH=EM.*DM*Wh.t()

m层的规律可适用于h层:DH = [1-h0**2,1-h1**2,1-h2**2,1-h3**2]DWx =
[EH.*DH].*[x0,x0,x0,x0]
[EH.*DH].*[x1,x1,x1,x1]
[EH.*DH].*[x2,x2,x2,x2]

致此,已求得梯度值:

DWm = D(loss)/D(Wm)DWh = D(loss)/D(Wh)DWx = D(loss)/D(Wx)

可用于更新参数:

Wm = Wm - DWm*learn_rate Wh = Wh - DWh*learn_rate Wx = Wx - DWx*learn_rate

总结:

拿到Y^ 与Y*后,
Y^ = [y0^,y1^]
Y* = [y0*,y0*]

算m层:

先算:Y^-Y*=[y0^-y0*,y1^-y1*]=EY

然后:1-Y^^2 = [1-y0^^2,1-y1^^2]=DY

EY.*DY = [Y^-Y*].*[1-Y^^2]

= [(y0^-y0*)*(1-y0^^2),(y1^-y1*)*(1-y1^^2)]

=[dy0,dy1]

DWm=
[EY.*DY].*[m0,m0]
[EY.*DY].*[m1,m1]
[EY.*DY].*[m2,m2]

若有bias的话:
DBm = EY.*DY =[dbm0,dbm1,dbm2]

算h层:

EM = EY.*DY*Wm.t
DM =[1-m0^^2,1-m1^^2,1-m2^^2]

DWh=

[EM.*DM].*[h0,h0,h0]
[EM.*DM].*[h1,h1,h1]
[EM.*DM].*[h2,h2,h2]
[EM.*DM].*[h3,h3,h3]

若有bias的话:
DBh = EM.*DM

算x层:

EH = EM.*DM*Wh.t
DH = [1-h0^^2,1-h1^^2,1-h2^^2,1-h3^^2]

DWx =
[EH.*DH].*[x0,x0,x0,x0]
[EH.*DH].*[x1,x1,x1,x1]
[EH.*DH].*[x2,x2,x2,x2]

若有bias的话:
DBx = EH.*DH

附录:

符号说明:

DL:D(Loss)
Dwm01:D(wm01)

DL/Dy1 = 2*0.5(y1^-y1*)=(y1^-y1*)
DL/Dy2 = 2*0.5(y2^-y2*)=(y2^-y2*)

y1 = tanh(wm01*m0+wm11*m1+m21*m2)

tanh的导数为:1-tanh**2

Dy1/Dwm01 = (1-y1**2)*m0

Dy2/Dwm02 = (1-y2**2)*m0

bp神经网络matlab代码_4层bp神经网络详细推导以及代码(矩阵化运算)相关推荐

  1. 蚁群算法优化神经网络matlab源程序,粒子群优化神经网络的程序大集合

    粒子群程序集合 866867259psobp psobp.m pso(粒子群算法)优化神经网络 粒子群算法(PSO)应用于神经网络优化[matlab] PSOt A Particle Swarm Op ...

  2. 不用工具箱的神经网络matlab程序_MATLAB中的神经网络工具箱(2)函数命令及模型搭建...

    前面介绍了神经网络工具箱GUI的使用,它功能强大可以直接生成脚本.但是函数命令的灵活性是GUI所不及的.也应该有所了解. 神经网络函数命令 1.网络创建函数 函数名称 功能 fitnet 创建函数拟合 ...

  3. html网页div框架代码,div层仿网页框架布局特效代码

    脚本代码(For Alixixi.com)如下: div层仿网页框架布局特效代码 - by 阿里西西 js.alixixi.com * { margin:0; padding:0; list-styl ...

  4. 机器学习——从线性回归到逻辑回归【附详细推导和代码】

    本文始发于个人公众号:TechFlow,原创不易,求个关注 在之前的文章当中,我们推导了线性回归的公式,线性回归本质是线性函数,模型的原理不难,核心是求解模型参数的过程.通过对线性回归的推导和学习,我 ...

  5. 递归神经网络 matlab,机器学习系列:递归神经网络

    原标题:机器学习系列:递归神经网络 前言 BP 神经网络,训练的时候,给定一组输入和输出,不断的对权值进行训练,使得输出达到稳定.但 BP 神经网络并不是适合所有的场景,并不真正的体现出某些场景的真正 ...

  6. 单隐层神经网络 python_单隐层前馈神经网络,single hidden-layer feeclforward neural networks,在线英语词典,英文翻译,专业英语...

    补充资料:Hopfield神经网络模型 Hopfield神经网络模型 Hopfield neural network model 收敛于稳定状态或Han加Ing距离小于2的极限环. 上述结论保证了神经 ...

  7. 复数值神经网络matlab,【原创】复数神经网络的反向传播算法,及pytorch实现方法...

    复函数的可导性 复变函数按照是否可导,分为全纯函数holomothic和nonholomophic,判断条件为Cauchy-Riemann方程. 对于不可导的nonholomophic函数: Wirt ...

  8. 常微分方程数值解——差商、欧拉公式详细推导及代码实现

    引言 在自然科学的许多领域特别是科学与工程计算中,经常会遇到常微分方程的求解问题.然而只有非常少数且十分简单的微分方程可以用初值等方法求得它们的解,多数只能近似方法求解. 一.预备知识 (差商的推导) ...

  9. 神经网络家族为何 BP 网络一枝独秀?谁能谈谈神经网络家族的兴衰史?

    来源:https://www.zhihu.com/question/364755843 编辑:深度学习与计算机视觉 声明:仅做学术分享,侵删 从两个方面来问这个问题. 从网络结构上说,神经网络有五种典 ...

  10. 吴恩达机器学习笔记——含一个隐藏层的神经网络

    含一个隐藏层的神经网络 含一个隐藏层的神经网络构造如下图所示: 其中记号用a上标的方括号a[n]a^{[n]}a[n]代表是第n层的a,用下标表示是某一层下面的某一个神经元,如图中的a1[2]a^{[ ...

最新文章

  1. “神经网络”的逆袭:80年AI斗争史
  2. mysql数据库添加索引和去重
  3. ibatis学习笔记
  4. 的微波感知_上海交大彭志科教授团队研发:微波微动监测与智能感知技术
  5. C/C++中volatile关键字的作用
  6. 12306外包给阿里巴巴、IBM等大企业做是否可行?
  7. 在less中不能正常使用css3的calc属性的解决方法
  8. 10人勾结苹果外包公司员工窃个人信息 涉案900万
  9. JAVA_JSP考勤带请假的管理系统
  10. matlab 最舒适的背景配色
  11. 什么是网站跳出率?一招教你如何处理高跳出率?
  12. 一次系统宕机认识系统日志
  13. 交换机级联,堆叠,集群技术介绍
  14. 数据可视化Matplotlib使用5-改变坐标轴的默认显示方式
  15. jQuery+js+css实现键盘按键呼吸灯效果
  16. 使用Spring Boot + Resilience 4j实现断路器
  17. python3web开发教程_三、Python web开发入门
  18. 机器学习中分类与聚类的本质区别
  19. gammaray报Error: gdb: Yama security extension is blocking runtime attaching, see /proc/sys/kernel/yam
  20. vue中eslint报错的解决方案

热门文章

  1. windows server 2008 远程桌面(授权、普通用户登录)
  2. SQL server 数据导入导出BCP工具使用详解
  3. JAVA OOP(二)——方法的重载、构造方法以及this关键字
  4. Pytorch:Tensor(张量)的使用
  5. Luogu1169 [ZJOI2007]棋盘制作
  6. [2018.07.17 T2] Palindromes
  7. 第一部分 第五章 数组 1102-1149
  8. eslint+prettier+husky的配置说明
  9. Express框架学习笔记-get请求中参数的获取
  10. spython_spython