背景介绍:MNIST数据集识别黑白的手写数字图片,不适合彩色模型的RGB三通道图片。用深度残差网络学习多通道图片。

简单介绍一下深度残差网络:普通的深度网络随着网络深度的加深,拟合效果可能会越来越好,也可能会变差,换句话说在不停地学习,但是有可能学歪了。本次介绍的深度残差网络最后输出H(x)=x+f(x)。其中x是本层网络的输入,f(x)是本层网络的输出,H(x)是最终得到的结果。由以上公式可以表明,最终结果包含输入x,也就是说不论怎么学习,起码效果不会变差,不会学歪。x和f(x)之间的变换的网络层就被成为残差模块。

有不懂的地方可以看代码下面的解释与讲解

目录

1.残差模块类的构建:

(1)残差模块:

(2)在残差模块中前向传播:

2.残差网络类:

(1)通用残差网络实现:

(2)调用残差模块建立方法:

(3)在整个网络中向前传播:

4.形成特殊网络:


1.残差模块类的构建:

(1)残差模块:

 # 残差模块def __init__(self, filter_num, stride=1):super(BasicBlock, self).__init__()#输入x经过两个卷积层得到f(x),f(x)+x=H(x),对应元素相加得到残差模块H(x)# 第一个卷积单元 卷积核大小3*3是超参数,需要学习,自己制定self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')self.bn1 = layers.BatchNormalization()self.relu = layers.Activation('relu')# 第二个卷积单元self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')self.bn2 = layers.BatchNormalization()#当x与f(x)形状不同的时候,无法进行相加,新建identity(x)卷积层,完成x的形状转换if stride != 1:# 步长不为1,需要通过1x1卷积完成shape匹配self.downsample = Sequential()self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))else:# shape匹配,直接短接self.downsample = lambda x:x

再重复一遍残差模块是x和f(x)之间的网络变换,包括两个卷积单元。

第一个卷积单元:卷积核的数量由传入参数给定,使用3*3卷积核,步长设定为1,经过卷积变换后形状不变;经过BN层主要对参数进行标准化,对网络有益;最后经过激活层。第二个卷积单元同上,不过不需要激活函数了。

我们上面说过经过残差模块输出f(x)需要与x相加,因此需要保证二者形状相同。如果shape不相同:用1*1的卷积核对矩阵通道数进行变换。在此细说一下x与f(x)的形状:由于padding都是same因此矩阵形状是保持不变的,但是由于卷积层有多个卷积核,则导致最终的矩阵通道维数和卷积核数量一样。因此f(x)和x仅仅在通道维度上不同,则使用1*1卷积核变换。如果shape相同:直接拼接就行。

(2)在残差模块中前向传播:

    def call(self, inputs, training=None):#向前传播# [b, h, w, c],通过第一个卷积单元out = self.conv1(inputs)out = self.bn1(out)out = self.relu(out)# 通过第二个卷积单元out = self.conv2(out)out = self.bn2(out)# 通过identity模块,进行identity转换identity = self.downsample(inputs)# 2条路径输出直接相加;out-f(x),identity-x,实现f(x)+xoutput = layers.add([out, identity])output = tf.nn.relu(output) # 激活函数return output

原始输入x经过两个卷积单元一层一层输出。注意identity模块,需要进行通道数调整,因此输入不是上一个输出,而是原始输入x,要将x的shape修改为f(x)的shape,进行相加。另外注意第二个卷积单元的relu函数从残差模块中提取出来了,放在了H(x)后面,当然这个是不固定的,也可以放在第二个卷积单元内部。

2.残差网络类:

(1)通用残差网络实现:

def __init__(self, layer_dims, num_classes=10): # [2, 2, 2, 2]super(ResNet, self).__init__()# 根网络,预处理    在这个容器中经过卷积层,标准化层,激活函数,池化层减半self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),layers.BatchNormalization(),layers.Activation('relu'),layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')])# 堆叠4个Block,每个block包含了多个BasicBlock,设置步长不一样self.layer1 = self.build_resblock(64,  layer_dims[0])self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)# 通过Pooling层将高宽降低为1x1self.avgpool = layers.GlobalAveragePooling2D()# 最后连接一个全连接层分类self.fc = layers.Dense(num_classes)

先经过一个容器,不是残差网络的根网络:包括卷积层-BN层-激活层-池化层。

下面是四个残差层,每个残差层都利用build_resblock函数进行残差网络层的建立,并传入参数。

通过池化层完成高宽的转化,最后连接一个全连接层,转化为十个属性的输出,判断结果到底是什么。

(2)调用残差模块建立方法:

#通过该函数一次完成多个残差模块的建立def build_resblock(self, filter_num, blocks, stride=1):# 辅助函数,堆叠filter_num个BasicBlockres_blocks = Sequential()# 只有第一个BasicBlock的步长可能不为1,实现下采样res_blocks.add(BasicBlock(filter_num, stride))for _ in range(1, blocks):#其他BasicBlock步长都为1res_blocks.add(BasicBlock(filter_num, stride=1))return res_blocks

调用残差模块类,传入卷积核数量,步长。blocks表示要建立的残差模块的数量。这些参数都由调用该方法的代码传入。

该方法的调用完成表明残差神经网络构建完成。

(3)在整个网络中向前传播:

    def call(self, inputs, training=None):# 通过根网络x = self.stem(inputs)# 一次通过4个模块x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# 通过池化层x = self.avgpool(x)# 通过全连接层x = self.fc(x)return x

这一部分注意区别于残差模块中的向前传播,先通过根网络,再逐步通过每一个残差模块,最后经过池化层和全连接层得到输出。

4.形成特殊网络:

def resnet18():# 通过调整模块内部BasicBlock的数量和配置实现不同的ResNetreturn ResNet([2, 2, 2, 2])def resnet34():# 通过调整模块内部BasicBlock的数量和配置实现不同的ResNetreturn ResNet([3, 4, 6, 3])

ResNet18直接调用产生网络的方法,传入参数[2,2,2,2],表明layer_dims是个一维矩阵,元素都是2。ResNet18是指17层卷积层,1层全连接层的网络。每个layer都调用两次建立残差模块的方法,每个残差模块有两个卷积单元也就是两个卷积层,如此一来四个layer,就有16个卷积层,再加上根网络的一个卷积层和最后的一个全连接层刚好是17+1。

机器学习-卷积神经网络之深度残差网络(三)相关推荐

  1. AI(1)认知 人工智能、机器学习、神经网络、深度学习。

    宽为限 紧用功 功夫到 滞塞通 开篇 AI领域是个水很深的新领域,对于非科学研究专业人士来说更是深不可测.选择自己喜欢的学科,兴趣是最好的老师,攻克下去总会有意想不到的收获.AI时代,我们要更加努力! ...

  2. 经典卷积神经网络(二):VGG-Nets、Network-In-Network和深度残差网络

    上一节我们介绍了LeNet-5和AlexNet网络,本节我们将介绍VGG-Nets.Network-In-Network和深度残差网络(residual network). VGG-Nets网络模型 ...

  3. 深度学习之卷积神经网络(12)深度残差网络

    深度学习之卷积神经网络(12)深度残差网络 ResNet原理 ResBlock实现 AlexNet.VGG.GoogleLeNet等网络模型的出现将神经网络的法阵带入了几十层的阶段,研究人员发现网络的 ...

  4. 04.卷积神经网络 W2.深度卷积网络:实例探究(作业:Keras教程+ResNets残差网络)

    文章目录 作业1:Keras教程 1. 快乐的房子 2. 用Keras建模 3. 用你的图片测试 4. 一些有用的Keras函数 作业2:残差网络 Residual Networks 1. 深层神经网 ...

  5. 深度学习 --- 卷积神经网络CNN(LeNet-5网络详解)

    卷积神经网络(Convolutional Neural Network,CNN)是一种前馈型的神经网络,其在大型图像处理方面有出色的表现,目前已经被大范围使用到图像分类.定位等领域中.相比于其他神经网 ...

  6. 深度残差网络_深度残差收缩网络:(三) 网络结构

    1. 回顾一下深度残差网络的结构 在下图中,(a)-(c)分别是三种残差模块,(d)是深度残差网络的整体示意图.BN指的是批标准化(Batch Normalization),ReLU指的是整流线性单元 ...

  7. 基于FPGA的一维卷积神经网络CNN的实现(三)训练网络搭建及参数导出(附代码)

    训练网络搭建 环境:Pytorch,Pycham,Matlab. 说明:该网络反向传播是通过软件方式生成,FPGA内部不进行反向传播计算. 该节通过Python获取训练数据集,并通过Pytorch框架 ...

  8. 深度残差网络的无人机多目标识别

    深度残差网络的无人机多目标识别 人工智能技术与咨询 来源:<图学学报>.作者翟进有等 摘要:传统目标识别算法中,经典的区域建议网络(RPN)在提取目标候选区域时计算量大,时间复杂度较高,因 ...

  9. 深度残差网络ResNet解析

    ResNet在2015年被提出,在ImageNet比赛classification任务上获得第一名,因为它"简单与实用"并存,之后很多方法都建立在ResNet50或者ResNet1 ...

最新文章

  1. 甲方乙方和站在外包中间的你 | 每日趣闻
  2. Qt导入CMakeLists.txt后无法调试
  3. [转载] 中国好声音120720
  4. rabbitmq多个消费者监听一个队列_RabbitMQ的六种工作模式
  5. 重构——解决过长参数列表(long parameter list)
  6. python执行效率有多低_python – Scapy的低性能
  7. 2021年山西副高考试成绩查询,中国卫生人才网2021年山西卫生资格考试成绩查询...
  8. Mysql学习总结(13)——使用JDBC处理MySQL大数据
  9. 游戏 TRAP(SNRS)AlphaBeta版本
  10. Android SDK4.0(api14)安装
  11. Kali渗透测试:使用Word宏病毒进行渗透攻击
  12. Dempster证据理论python复现
  13. 拿来就能用的前端酷炫登录注册模板
  14. KeyError: 0 与 KeyError: 1(附例子)
  15. 极路由2(HC5761)免云平台开启SSH
  16. 《第五项修炼,学习型组织的艺术与实践》读书笔记
  17. hazelcast 搭建_hazelcast教程 入门
  18. 二维码怎么制作?手把手教你制作生成
  19. 一代宗师威廉·欧奈尔的选股法则详解
  20. NANK南卡lite Pro无线蓝牙耳机上手体验

热门文章

  1. javascript两个数组去重合并
  2. c语言窗口皮肤,MFC 界面美化 Skinmagic
  3. 论山寨手机和Android 【8】 自己动手做XP手机,DIY实战指南
  4. 保险公司“大都会人寿”以新兴企业态度应对大数据挑战
  5. @Scheduled cron 定时任务表达式含义用法及* ?的区别
  6. 01、Vue简易版网易云——项目简介
  7. 关于yield关键字的一些理解
  8. python数据库编程dbf_如何使用Python dbf库读取和创建新的foxpro2.6数据库表
  9. java math.sin()_Java Math sin() 使用方法及示例
  10. go设置后端启动_今日头条内涵段子使用Go语言构建千亿级微服务架构实践