实现残差网络(ResNet)

  • 我们一般认为,增加神经网络模型的层数,充分训练后的模型理论上能更有效地降低训练误差。
  • 理论上,原模型解的空间只是新模型解的空间的子空间。也就是说,如果我们能将新添加的层训练成恒等映射f(x)=xf(x) = xf(x)=x,新模型和原模型将同样有效。
  • 由于新模型可能得出更优的解来拟合训练数据集,因此添加层似乎更容易降低训练误差。
  • 然而在实践中,添加过多的层后训练误差往往不降反升。即使利用批量归一化带来的数值稳定性使训练深层模型更加容易,该问题仍然存在。针对这一问题,何恺明等人提出了残差网络(ResNet)。它在2015年的ImageNet图像识别挑战赛夺魁,并深刻影响了后来的深度神经网络的设计。

残差块

残差块的结构在之前的blog中详细解释了,感兴趣的可以去看。

ResNet沿用了VGG全3×33\times 33×3卷积层的设计。

  • 残差块里首先有2个有相同输出通道数的3×33\times 33×3卷积层。
  • 每个卷积层后接一个批量归一化层和ReLU激活函数
  • 然后输入跳过这两个卷积运算后直接加在最后的ReLU激活函数前。
  • 这样的设计要求两个卷积层的输出与输入形状一样,从而可以相加
  • 如果想改变通道数,就需要引入一个额外的1×11\times 11×1卷积层来将输入变换成需要的形状后再做相加运算。

残差块的实现如下。它可以设定输出通道数、是否使用额外的1×11\times 11×1卷积层来修改通道数以及卷积层的步幅。

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class Residual(nn.Module):  def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):super(Residual, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)return F.relu(Y + X)

查看输入和输出形状一致的情况。

blk = Residual(3, 3)
X = torch.rand((4, 3, 6, 6))
blk(X).shape # torch.Size([4, 3, 6, 6])

我们也可以在增加输出通道数的同时减半输出的高和宽。

blk = Residual(3, 6, use_1x1conv=True, stride=2)
blk(X).shape # torch.Size([4, 6, 3, 3])

ResNet模型

ResNet的前两层跟GoogLeNet中的一样:

  • 在输出通道数为64、步幅为2的7×77\times 77×7卷积层后接步幅为2的3×33\times 33×3的最大池化层。
  • 每个卷积层后增加批量归一化层。
net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
  • ResNet使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。
  • 第一个模块的通道数同输入通道数一致无须减小高和宽(之前已经使用了步幅为2的最大池化层)。
  • 之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半
def resnet_block(in_channels, out_channels, num_residuals, first_block=False):if first_block:assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))else:blk.append(Residual(out_channels, out_channels))return nn.Sequential(*blk)

接着我们为ResNet加入所有残差块。这里每个模块使用两个残差块

net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))

加入全局平均池化层后接上全连接层输出。

class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)
net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(512, 10)))
  • 这里每个模块里有4个卷积层(不计算1×11\times 11×1卷积层),加上最开始的卷积层和最后的全连接层,共计18层。这个模型通常也被称为ResNet-18。
  • 通过配置不同的通道数和模块里的残差块数可以得到不同的ResNet模型,例如更深的含152层的ResNet-152。虽然ResNet的主体架构跟GoogLeNet的类似,但ResNet结构更简单,修改也更方便。这些因素都导致了ResNet迅速被广泛使用。
X = torch.rand((1, 1, 224, 224))
for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)
0  output shape:  torch.Size([1, 64, 112, 112])
1  output shape:     torch.Size([1, 64, 112, 112])
2  output shape:     torch.Size([1, 64, 112, 112])
3  output shape:     torch.Size([1, 64, 56, 56])
resnet_block1  output shape:     torch.Size([1, 64, 56, 56])
resnet_block2  output shape:     torch.Size([1, 128, 28, 28])
resnet_block3  output shape:     torch.Size([1, 256, 14, 14])
resnet_block4  output shape:     torch.Size([1, 512, 7, 7])
global_avg_pool  output shape:   torch.Size([1, 512, 1, 1])
fc  output shape:    torch.Size([1, 10])

获取数据

def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter
batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=96)

训练模型

def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

(pytorch-深度学习)实现残差网络(ResNet)相关推荐

  1. 深度学习之残差网络的原理

    目录 一. 什么是残差(residual) 二.残差网络的背景 三.残差块(residual block) 四.深度残差学习 五.DenseNet网络和Resnets网络对比 转载:https://b ...

  2. 深度学习之残差网络原理深度刨析

    为什么要加深网络? 深度卷积网络自然的整合了低中高不同层次的特征,特征的层次可以靠加深网络的层次来丰富. 从而,在构建卷积网络时,网络的深度越高,可抽取的特征层次就越丰富. 所以一般我们会倾向于使用更 ...

  3. 深度学习《残差网络简单学习》

    一:残差网络 VGG网络将网络达到了19层的深度,GoogleNet的深度是22层,一般而言,深度越深,月面临如下问题: 1:计算量增大 2:过拟合 3:梯度消失和梯度爆炸 4:网络退化 第一个问题呢 ...

  4. (刘二大人)PyTorch深度学习实践-卷积网络(Advance)

    1. 1x1的卷积核的作用 在width和height不变的基础上改变通道的数量 减少计算量 2. GoogLeNet中Inception Module的实现 2.1 Inception块的代码实现 ...

  5. 深度学习目标检测 RCNN F-RCNN SPP yolo-v1 v2 v3 残差网络ResNet MobileNet SqueezeNet ShuffleNet

    深度学习目标检测--结构变化顺序是RCNN->SPP->Fast RCNN->Faster RCNN->YOLO->SSD->YOLO2->Mask RCNN ...

  6. 【深度学习】深度残差网络ResNet

    文章目录 1 残差网络ResNet 1.1要解决的问题 1.2 残差网络结构 1.3 捷径连接 1.4 总结 1 残差网络ResNet 1.1要解决的问题   在传统CNN架构中,如果我们简单堆叠CN ...

  7. TF2.0深度学习实战(七):手撕深度残差网络ResNet

    写在前面:大家好!我是[AI 菌],一枚爱弹吉他的程序员.我热爱AI.热爱分享.热爱开源! 这博客是我对学习的一点总结与记录.如果您也对 深度学习.机器视觉.算法.Python.C++ 感兴趣,可以关 ...

  8. 何恺明编年史之深度残差网络ResNet

    文章目录 前言 一.提出ResNet原因 二.深度残差模块 1.数学理论基础 2.深度网络结构 三.Pytorch代码实现 四.总结 前言 图像分类是计算机视觉任务的基石,在目标监测.图像分割等任务中 ...

  9. dlibdotnet 人脸相似度源代码_使用dlib中的深度残差网络(ResNet)实现实时人脸识别 - supersayajin - 博客园...

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  10. 深度残差网络RESNET

    一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别的效果有着很大的影响,所以正常想法就是能把网络设计的越深越好, 但是事实上却不是这样,常规的网络的堆叠(plain netw ...

最新文章

  1. 关于C++对象模型的一点理解(2)
  2. 你不懂js系列学习笔记-类型与文法- 04
  3. 机器学习(MACHINE LEARNING) 【周志华版-”西瓜书“-笔记】 DAY15-规则学习
  4. php接收ajax转数组
  5. 评价算法的性能从利用计算机资源角度,计算机专业数据结构课后练习题汇编
  6. css点击a标签显示下划线_好程序员HTML5培训教程-html和css基础知识
  7. 不用L约束又不会梯度消失的GAN,了解一下?
  8. ns-3文件编译出错总结
  9. [转载]Asp.net MVC中Controller返回值类型
  10. android 捕获Home键和ACTION_TIME_TICK广播
  11. [POJ 2503] Babelfish【二分查找】
  12. DataSet和实体类的相互转换
  13. 同时读取两个文件进行while循环
  14. java web中的相对路径和绝对路径
  15. 知名互联网公司网站架构图
  16. java kml_从Java中的KML文件中提取坐标
  17. Xshell 官方免费版下载流程
  18. 基于单片机的温度监测系统proteus仿真
  19. php如何批量发送短信,如何在php中运行批量短信api [关闭]
  20. QCC514x-QCC304x(headset)系列(实战篇)之5.1 tone详解

热门文章

  1. linux用u盘上传文件,linux如何挂载U盘和文件系统(或需要用到).doc
  2. mysql1440秒未活动_phpMyAdmin登陆超时1440秒未活动请重新登录
  3. 如何打开屏幕坏的手机_每天打开手机屏幕20次?打开10次以上的朋友进~
  4. zincrby redis python_【Redis数据结构 序】使用redis-py操作Redis数据库
  5. Android studio的Activity详解
  6. Java程序员越来越多工资反而越高?
  7. 自学python编程基础科学计算_Python基础与科学计算常用方法
  8. java解析excel文件_1.3.1 python解析excel格式文件
  9. php 执行mysql查询_php中执行mysql的常用操作
  10. 免费mysql空间_php+mysql免费空间