残差网络的核心思想是:每个附加层都应该更容易地包含原始函数作为其元素之一。如图是一个正常块(左)和一个残差块(右)的结构,可以看出在残差块中,输入可通过跨层数据线路更快地向前传播。

ResNet沿用了VGG完整的3×3卷积层设计,残差块里首先有2个有相同输出通道数的3×3卷积层。 每个卷积层后接一个批量规范化层和ReLU激活函数。然后通过跨层数据通路,跳过这2个卷积运算,将输入直接加在最后的ReLU激活函数前。 这样的设计要求2个卷积层的输出与输入形状一样,从而使它们可以相加。 如果想改变通道数,就需要引入一个额外的1×1卷积层来将输入变换成需要的形状后再做相加运算。如下图是包含以及不包含 1×1 卷积层的残差块结构:

代码实现如下:

!pip install git+https://github.com/d2l-ai/d2l-zh@release  # installing d2l
!pip install matplotlib_inline
!pip install matplotlib==3.0.0import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels,num_channels,kernel_size=3,stride=strides,padding=1)self.conv2 = nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self,X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)b1 = nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))def resnet_block(input_channels,num_channels,num_residuals,first_block=False):blk = []for i in range(num_residuals):if i==0 and not first_block:blk.append(Residual(input_channels,num_channels,use_1x1conv=True,strides=2))else:blk.append(Residual(num_channels,num_channels))return blkb2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))batch_size =  256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr, num_epochs = 0.05, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

运行结果:

5--残差网络(ResNet)相关推荐

  1. (pytorch-深度学习)实现残差网络(ResNet)

    实现残差网络(ResNet) 我们一般认为,增加神经网络模型的层数,充分训练后的模型理论上能更有效地降低训练误差. 理论上,原模型解的空间只是新模型解的空间的子空间.也就是说,如果我们能将新添加的层训 ...

  2. dlibdotnet 人脸相似度源代码_使用dlib中的深度残差网络(ResNet)实现实时人脸识别 - supersayajin - 博客园...

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  3. 残差网络ResNet

    文章目录 ResNet模型 两个注意点 关于x 关于残差单元 核心实验 原因分析 ResNet的效果 题外话 ResNet是由何凯明在论文Deep Residual Learning for Imag ...

  4. 对残差网络resnet shortcut的解释

    重读残差网络--resnet(对百度vd模型解读) 往事如yan 已于 2022-02-25 07:53:37 修改 652 收藏 4 分类专栏: AI基础 深度学习概念 文章标签: 网络 cnn p ...

  5. 深度残差网络RESNET

    一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别的效果有着很大的影响,所以正常想法就是能把网络设计的越深越好, 但是事实上却不是这样,常规的网络的堆叠(plain netw ...

  6. 深度学习目标检测 RCNN F-RCNN SPP yolo-v1 v2 v3 残差网络ResNet MobileNet SqueezeNet ShuffleNet

    深度学习目标检测--结构变化顺序是RCNN->SPP->Fast RCNN->Faster RCNN->YOLO->SSD->YOLO2->Mask RCNN ...

  7. 吴教授的CNN课堂:进阶 | 从LeNet到残差网络(ResNet)和Inception Net

    转载自:https://www.jianshu.com/p/841ac51c7961 第二周是关于卷积网络(CNN)进阶部分,学到挺多新东西.因为之前了解过CNN基础后,就大多在用RNN进行自然语言处 ...

  8. 何恺明编年史之深度残差网络ResNet

    文章目录 前言 一.提出ResNet原因 二.深度残差模块 1.数学理论基础 2.深度网络结构 三.Pytorch代码实现 四.总结 前言 图像分类是计算机视觉任务的基石,在目标监测.图像分割等任务中 ...

  9. 【深度学习】深度残差网络ResNet

    文章目录 1 残差网络ResNet 1.1要解决的问题 1.2 残差网络结构 1.3 捷径连接 1.4 总结 1 残差网络ResNet 1.1要解决的问题   在传统CNN架构中,如果我们简单堆叠CN ...

  10. TF2.0深度学习实战(七):手撕深度残差网络ResNet

    写在前面:大家好!我是[AI 菌],一枚爱弹吉他的程序员.我热爱AI.热爱分享.热爱开源! 这博客是我对学习的一点总结与记录.如果您也对 深度学习.机器视觉.算法.Python.C++ 感兴趣,可以关 ...

最新文章

  1. 2018年全国多校算法寒假训练营练习比赛(第三场)
  2. 腾讯云+FFmpeg打造一条完备高效的视频产品链
  3. Wpf Binding.Path设置
  4. Prepare for Mac App Store Submission--为提交到Mac 应用商店做准备
  5. python获取他人的ip_Python获取指定网段正在使用的IP
  6. 人工智能TensorFlow工作笔记007---认识张量
  7. 服务重构理解接口编程的妙处
  8. struts2在action中获取request、session、application,并传递数据
  9. 禅道的下载和安装教程(Linux版)
  10. C#正则表达式提取txt小说目录
  11. 最新一键修改手机MAC地址和路由器wifi物理地址
  12. 简单的交换机下设备连接,路由器互通
  13. 深圳房价链家数据分析
  14. 保健操对颈椎病有辅助治疗。
  15. 首创Domino前后端彻底分离,结合vue、react优美例子
  16. 文件新旧判断和字符串判断
  17. 实战▍利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)
  18. 最新发布:IT行业近5年平均年薪出炉!你在哪个梯队?
  19. CSS3 - footer 固定在底部(无论页面多高始终在底部)
  20. Python 基于微博舆情分析系统的设计与实现,GUI可视化界面(毕业设计,附源码,教程)

热门文章

  1. c语言:简单的客户管理系统
  2. SpringMVC工作原理与工作流程
  3. 医疗AI的dicom图像拉取模块设计
  4. 搞一下SOA | 11 SOA 系统建模
  5. 银联支付之在线网关支付
  6. 四象限原则+番茄时间管理法
  7. web新手之使用easyAR实现WebAR
  8. 有道 - 扇贝 - 海词词典发音链接
  9. 蓝牙技术谈之跳频技术(一)
  10. Netapp存储日常检查及信息收集