一、实现过程

残差网络(Residual Network)的特点是容易优化,并且能够通过增加相当的深度来提高准确率。其内部的残差块使用了跳跃连接,缓解了在深度神经网络中增加深度带来的梯度消失问题。
本文实现如图1所示的两层残差模块用于识别MNIST数据集,其中每一层均是卷积层。

图1 残差构建模块

残差构建模块封装成类,代码如下:

class ResidualBlock(torch.nn.Module):def __init__(self,channels):super(ResidualBlock,self).__init__()self.channels = channelsself.conv1 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)self.conv2 = torch.nn.Conv2d(channels,channels,kernel_size=3,padding=1)def forward(self, x):y = F.relu(self.conv1(x))y = self.conv2(y)return F.relu(x+y)

嵌入残差模块的网络模型代码如下:

class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 = torch.nn.Conv2d(1,16,kernel_size=5)self.conv2 = torch.nn.Conv2d(16,32,kernel_size=5)self.mp = torch.nn.MaxPool2d(2)self.rblock1 = ResidualBlock(16)self.rblock2 = ResidualBlock(32)self.fc = torch.nn.Linear(512,10)def forward(self,x):# Flatten data from (n,1,28,28) to (n,784)in_size = x.size(0)x = self.mp(F.relu(self.conv1(x)))x = self.rblock1(x)x = self.mp(F.relu(self.conv2(x)))x = self.rblock2(x)x = x.view(in_size,-1)  # flatten
#         print(x.size(1))return self.fc(x)
model = Net()

运行结果如下:

[1,300] loss: 0.486
[1,600] loss: 0.143
[1,900] loss: 0.103
Accuracy on test set: 97.34 % [9734/10000]
[2,300] loss: 0.082
[2,600] loss: 0.074
[2,900] loss: 0.066
Accuracy on test set: 98.37 % [9837/10000]
[3,300] loss: 0.058
[3,600] loss: 0.052
[3,900] loss: 0.051
Accuracy on test set: 98.68 % [9868/10000]
[4,300] loss: 0.044
[4,600] loss: 0.047
[4,900] loss: 0.038
Accuracy on test set: 98.81 % [9881/10000]
[5,300] loss: 0.037
[5,600] loss: 0.035
[5,900] loss: 0.038
Accuracy on test set: 98.8 % [9880/10000]
[6,300] loss: 0.030
[6,600] loss: 0.034
[6,900] loss: 0.032
Accuracy on test set: 98.89 % [9889/10000]
[7,300] loss: 0.029
[7,600] loss: 0.030
[7,900] loss: 0.026
Accuracy on test set: 98.83 % [9883/10000]
[8,300] loss: 0.026
[8,600] loss: 0.028
[8,900] loss: 0.021
Accuracy on test set: 99.04 % [9904/10000]
[9,300] loss: 0.021
[9,600] loss: 0.023
[9,900] loss: 0.022
Accuracy on test set: 99.05 % [9905/10000]
[10,300] loss: 0.019
[10,600] loss: 0.019
[10,900] loss: 0.022
Accuracy on test set: 99.05 % [9905/10000]


可以看出:带残差的深度网络比普通的深度网络的学习效果更好。

二、参考文献

[1] K. He, X. Zhang, S. Ren and J. Sun. Deep Residual Learning for Image Recognition[C]. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778.
[2] https://www.bilibili.com/video/BV1Y7411d7Ys?p=11

PyTorch实现简单的残差网络相关推荐

  1. pytorch实现简单的Resnet网络

    笔者也是最近刚学不久的深度学习,也有很多地方不懂,下面给大家使用pytorch实现一个简单的Resnet网络(残差网络),并且训练MNIST数据集.话不多说,直接上代码.   笔者认为最主要的地方就是 ...

  2. 【PyTorch】Resnet/深度残差网络

    1 模型描述 深度残差网络(Resnet)是由来自Microsoft Research的4位学者(Kaiming He等人)提出的卷积神经网络,在2015年的ImageNet大规模视觉识别竞赛(Ima ...

  3. 简单复现 残差网络、Googlenet、mobilenet、SqueezeNet、ShuffleNet、Densenet

    1.残差网络 1)网络结构 当对x求偏导的时候,F(x)对x求偏导的值很小的时候,对整体求x的偏导会接近于1 这样解决了梯度消失问题,我们可以对离输入很近的层进行很好的更新. 要注意的是F(x)与x的 ...

  4. 【深度学习】ResNet残差网络 ResidualBlock残差块实现(pytorch) | 跟着李沐学AI笔记 | ResNet18进行猫狗分类

    文章目录 前言 一.卷积的相关计算公式(复习) 二.残差块ResidualBlock复现(pytorch) 三.残差网络ResNet18复现(pytorch) 四.直接调用方法 五.具体实践(ResN ...

  5. 【深度学习】深入探讨:残差网络解决了什么,为什么有效?

    转载自 | 极市平台 作者丨LinT@知乎 来源丨https://zhuanlan.zhihu.com/p/80226180 0 『引言』 残差网络是深度学习中的一个重要概念.这篇文章将简单介绍残差网 ...

  6. 深度学习《残差网络简单学习》

    一:残差网络 VGG网络将网络达到了19层的深度,GoogleNet的深度是22层,一般而言,深度越深,月面临如下问题: 1:计算量增大 2:过拟合 3:梯度消失和梯度爆炸 4:网络退化 第一个问题呢 ...

  7. ResNet残差网络Pytorch实现——对花的种类进行训练

    ResNet残差网络Pytorch实现--对花的种类进行训练 上一篇:[结合各个残差块] ✌✌✌✌ [目录] ✌✌✌✌ 下一篇:[对花的种类进行单数据预测] 大学生一枚,最近在学习神经网络,写这篇文章 ...

  8. 通过深度残差网络ResNet进行图像分类(pytorch网络多网络集成配置)

    通过深度残差网络进行图像分类(pytorch网络多网络集成配置) 简介 本项目通过配置文件修改,实现pytorch的ResNet18, ResNet34, ResNet50, ResNet101, R ...

  9. pytorch实现 残差网络 ResNet18 CIFAR-10 分类 计算top1-ACC,top3-ACC

    采用残差网络 ResNet18 或 ResNet34 深度神经网络对CIFAR-10图像数据集实现分类,计算模型预测性能(top1-ACC,top3-ACC),并以友好的方式图示化结果. 目录 1.定 ...

最新文章

  1. python自己做电子词典_python实现电子词典
  2. 身体器官工作表一览,别熬夜
  3. ps自定义形状工具_PS教程——用PS绘制虚线的三种方法
  4. 更改mysql数据库存放位置_更改mysql数据库存放位置
  5. 中英文对照 —— 医学与医院
  6. BZOJ4530:[BJOI2014]大融合
  7. 复制字符串 _strdup _wcsdup _mbsdup
  8. 刚开始学习.NET 怎么样能使自己学习的更快点啊?
  9. php与ununtu通信,Ubuntu 20.04 LTS 已引入 PHP 7.4
  10. transition animation
  11. 题目54:小明的存钱计划
  12. 策略路由 本地策略+接口策略
  13. 53所高校研究生补贴一览表
  14. USRP工作流程及各部分功能
  15. Wireshark 实用篇2:Wireshark 抓包常用过滤命令
  16. mysql dba 工资,好大夫高级mysql dba工资待遇怎么样 - 好大夫在线 - 职友集
  17. H3C 交换机DRNI特性使用介绍
  18. 全能型Mac解压缩软件 MacZip2.0.1(41)中文版 原ezip
  19. SpringBoot 之 数据访问
  20. 学雷锋做好事留名是发挥正能量的勇气!

热门文章

  1. teengamb数据集进行回归分析
  2. 《应用回归分析》何晓群 最新版数据下载
  3. archlinux - W3af
  4. 【python】公考数学
  5. 【SemiDrive源码分析】【Yocto源码分析】07 - core-image-base-x9h_ref_serdes.rootfs.ext4 文件系统是如何生成的
  6. 大家看看这个vmp壳如何下手脱壳?
  7. 组合辛普森公式(数值积分)
  8. 【洛谷P1430】序列取数【dp】
  9. python之自动发送微信消息
  10. 使用SVM分类器做颜色分类走过的坑