本文主要讲解卷积神经网络(CNN)反向传播过程的matlab代码实现。

01

简介

CNN主要由三种层堆叠而成,即卷积层、池化层和全连接层,在《卷积神经网络(三):反向传播过程》中又推导了这三种层的误差反向传播公式。因此,CNN反向传播的代码主要由这三种层的反向传播代码构成。

02

代码实现

在CNN反向传播时,输出层的误差(目标函数)会依次从后往前经过全连接层、池化层和卷积层。在编写全连接层、池化层和卷积层的代码时,均是先求出误差关于各层输入的导数,然后再计算参数的导数。假设本文研究的是分类问题,输出层采用softmax函数,目标函数定义为交叉熵损失+参数正则化。(一)定义并计算目标函数

梯度下降要优化的目标函数,主要分为两部分:一部分是由于分类器输出结果和真实结果的差异引起的误差,另一部分是对权重w的正则约束。

logp = log(probs);index = sub2ind(size(logp),mb_labels',1:size(probs,2));ceCost = -sum(logp(index));wCost = lambda/2 * (sum(Wd(:).^2)+sum(Wc(:).^2));cost = ceCost/numImages + wCost;

(二)softmax层

交叉熵损失函数关于softmax层输入的导数为:

即直接用预测结果减去真实结果。如果采用是平方差损失函数,则平方差损失函数关于softmax层输入的导数为(需要分情况讨论):

本文采用分类问题常用的交叉熵损失函数。

output = zeros(size(probs));output(index) = 1;DeltaSoftmax = probs - output;

注:笔者做过利用Levenberg–Marquardt算法优化网络结构,此时即需要平方差损失函数。

(三)全连接层

全连接层的误差反向传播是将DeltaSoftmax乘以各层的权重以及点乘激活函数的导数。

Wd_grad = (1./numImages) .* DeltaSoftmax*activationsPooled'+lambda*Wd;bd_grad = (1./numImages) .* sum(DeltaSoftmax,2);

(四)池化层

在求出误差关于第一个全连接层的导数后,需要将该结果还原成最后一个池化层输出的形状。如果采用的是平均池化,则误差在池化区域内的所有元素上均分;如果采用的是最大池化,则误差只由最大元素负责。

DeltaPool = reshape(Wd' * DeltaSoftmax,outputDim,outputDim,numFilters,numImages);DeltaUnpool = zeros(convDim,convDim,numFilters,numImages);for imNum = 1:numImages    for FilterNum = 1:numFilters        unpool = DeltaPool(:,:,FilterNum,imNum);        DeltaUnpool(:,:,FilterNum,imNum) = kron(unpool,ones(poolDim))./(poolDim ^ 2);    endend

(五)卷积层

卷积层的反向传播较为复杂,但是具体的推导细节已经在《卷积神经网络(三):反向传播过程》中解释清楚。

% 在求出误差关于池化层输入的导数后,再点乘激活函数的导数。DeltaConv = DeltaUnpool .* activations .* (1 - activations);% 卷积层偏置的代码bc_grad = zeros(size(bc));for filterNum = 1:numFilters    error = DeltaConv(:,:,filterNum,:);    bc_grad(filterNum) = (1./numImages) .* sum(error(:));end% 卷积层权重的代码Wc_grad = zeros(filterDim,filterDim,numFilters);% 旋转所有DealtaConv:下面的conv2在函数内部会自动旋转180度,% 所以在这里旋转是为了抵消conv2旋转的影响。for filterNum = 1:numFilters    for imNum = 1:numImages        error = DeltaConv(:,:,filterNum,imNum);        DeltaConv(:,:,filterNum,imNum) = rot90(error,2);    endendfor filterNum = 1:numFilters    for imNum = 1:numImages        Wc_grad(:,:,filterNum) = Wc_grad(:,:,filterNum) + conv2(images(:,:,imNum),DeltaConv(:,:,filterNum,imNum),'valid');    endendWc_grad = (1./numImages) .* Wc_grad + lambda*Wc;

matlab卷积神经网络代码_卷积神经网络(四):反向传播过程的代码实现相关推荐

  1. 卷积神经网络前向及反向传播过程数学解析

    卷积神经网络前向及反向传播过程数学解析 文章目录 <center>卷积神经网络前向及反向传播过程数学解析 1.卷积神经网络初印象 2.卷积神经网络性质 3.前向传播 3.1.卷积层层级间传 ...

  2. 机器学习笔记 - 使用python代码实现易于理解的反向传播

    一.反向传播概述 反向传播可以说是神经网络历史上最重要的算法--如果没有有效的反向传播,就不可能将深度学习网络训练到我们今天看到的深度.反向传播可以被认为是现代神经网络和深度学习的基石. 反向传播的最 ...

  3. 卷积神经网络结构_卷积神经网络

    卷积神经网络结构 CNN's are a special type of ANN which accepts images as inputs. Below is the representation ...

  4. 卷积神经网络的反向传播,卷积反向传播过程

    如何对CNN网络的卷积层进行反向传播 在多分类中,CNN的输出层一般都是Softmax.RBF在我的接触中如果没有特殊情况的话应该是"径向基函数"(RadialBasisFunct ...

  5. python深度神经网络量化_深度神经网络数据集大小

    问题描述 我的数据集才一千多个,是不是用深度神经网络的模型,不够,容易欠拟合 问题出现的环境背景及自己尝试过哪些方法 我之前的训练参照了两层的CIFAR卷积层测试了 用1000次迭代 每次10batc ...

  6. 人工神经网络 神经网络区别_人工神经网络概述

    人工神经网络 神经网络区别 Artificial neural networks (ANN) in machine learning (artificial intelligence) are com ...

  7. 深度神经网络回归_深度神经网络

    深度神经网络回归 深度神经网络 (Deep Neural Networks) A deep neural network (DNN) is an ANN with multiple hidden la ...

  8. 机器学习入门(14)— 神经网络学习整体流程、误差反向传播代码实现、误差反向传播梯度确认、误差反向传播使用示例

    1. 神经网络学习整体流程 神经网络学习的步骤如下所示. 前提 神经网络中有合适的权重和偏置,调整权重和偏置以便拟合训练数据的过程称为学习.神经网络的学习分为下面 4 个步骤. 步骤1(mini-ba ...

  9. 膨胀卷积的缺点_卷积、反卷积与膨胀卷积

    卷积(多---->1 的映射) 本质:在对输入做9--->1的映射关系时,保持了输出相对于input中的位置性关系 对核矩阵做以下变形:卷积核的滑动步骤变成了卷积核矩阵的扩增 卷积的矩阵乘 ...

  10. 循环神经网络(RNN)模型与前向反向传播算法

    在前面我们讲到了DNN,以及DNN的特例CNN的模型和前向反向传播算法,这些算法都是前向反馈的,模型的输出和模型本身没有关联关系.今天我们就讨论另一类输出和模型间有反馈的神经网络:循环神经网络(Rec ...

最新文章

  1. removeAllViews()和removeAllViewsInLayout()之间的区别?
  2. asp.net 中ListBox 显示 2 列
  3. C/C++之内存对齐
  4. 如何学好初中计算机,初中生怎么学习方法好 十大方法告诉你
  5. 视觉平衡与物理平衡_设计中的平衡理论为什么这么重要?
  6. 怎样把Image数据放入数据库
  7. k8s安装读取内核modules_kubespray国内云平台一键部署k8s
  8. 教你 7 步快速构建 GitLab 持续集成环境
  9. 《C和C++程序员面试秘笈》——1.4 i++与++i哪个效率更高
  10. 临湘东经子午线经度_地区经度查询_实用查询工具大全 - Powered by Senlon!
  11. 小程序转uni-app——引入组件显示问题
  12. chrome浏览器替换code.jquery.com CDN的加速URL
  13. 最新超唯美情侣网站开源+带后台/亲测可用
  14. teambition、Tower、worktile 、trello 等任务管理工具哪个好?
  15. Eclipse+Pydev详细配置
  16. 当年表白流行写情书,现在流行的是……
  17. Linux7浏览器打不开网页,centos7浏览器打不开网页
  18. OSPF为何需要loopback接口
  19. 2019年记录:java小白级程序员工作一年以来的经历,遇到的坎坷以及当时的心态
  20. Visual Studio2018无法加载pdb文件怎么办

热门文章

  1. iOS-----用LLDB调试,让移动开发更简单(二)
  2. 反射的基础(二):构造器类的使用
  3. 客户端无刷新调用服务器程序
  4. Spring Cloud(Greenwich版)-03-编写高可用Eureka Server(集群)
  5. Python基础--03
  6. 简要说明python的缩进规则_关于python的缩进规则的知识点详解
  7. oracle如何创建基表,创建本地基表的物化视图
  8. gogs仓库代码拉取不需要用户账号验证问题
  9. ThinkPhp报错:thinkphp\library\think\Template.php Line(1243) template not exists:...test\...\index.html
  10. Layui数据表格动态cols(字段)动态变化(2)