基于 PyTorch 实现残差神经网络 ResNet

文章目录

  • 基于 PyTorch 实现残差神经网络 ResNet
    • 0. 概述
    • 1. 数据集介绍
      • 1.1 数据集准备
      • 1.2 分析分类难度:CIFAR-10 vs MNIST
    • 2. 残差神经网络
      • 2.1 残差神经网络基础
      • 2.2 构建两种 Residual Blocks
      • 2.3 构建完整的残差神经网络
      • 2.4 训练与测试

0. 概述

在本节实验中,我们将基于 PyTorch 实现残差神经网络 ResNet,并在一个难度稍大的图片数据集(CIFAR-10)上进行训练和测试。

具体包括如下几个部分:

(1) 熟悉新数据集 CIFAR-10,并和 MNIST 对比分类难度;

(2) 学习残差神经网络,特别是 Block 的概念;

(3) 构建残差神经网络,并基于此实现 CIFAR-10 的训练与测试。

Ref: https://arxiv.org/pdf/1512.03385.pdf

https://zhuanlan.zhihu.com/p/106764370

1. 数据集介绍

CIFAR-10 数据集样例和 10 个类别如下所示:

官方说明及下载地址:http://www.cs.toronto.edu/~kriz/cifar.html

1.1 数据集准备

我们首先来准备数据集,方法与 MNIST 类似。CIFAR-10 数据集。

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimprint(torch.manual_seed(1))
<torch._C.Generator object at 0x0000015FF0CD6B50>
batch_size = 250  # 设置训练集和测试集的 batch size,即每批次将参与运算的样本数# 训练集
train_set = torchvision.datasets.CIFAR10('./dataset_cifar10', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))])
)# 测试集
test_set = torchvision.datasets.CIFAR10('./dataset_cifar10', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))]))train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True)

1.2 分析分类难度:CIFAR-10 vs MNIST

下面我们用实验二中定义过的卷积神经网络在 CIFAR-10 数据集上训练并测试。请注意:由于 CIFAR-10 的图片格式与 MNIST 稍有不同,所以网络结构中 conv1 的输入通道数和 fc1 的输入向量长度都进行了调整。调整后的神经网络比原先拥有更多的参数,理论上有助于增加网络的学习能力。

原先的卷积神经网络在 MNIST 数据集上取得的测试准确率在 98.9% 左右。通过如下对比,我们可以看出 CIFAR-10 的分类难度相对于 MNIST 来说有了显著增加。

class CNN5(nn.Module):def __init__(self):super(CNN5, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)  # in_channels 由 1 改变为 3self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)self.fc1 = nn.Linear(in_features=12*5*5, out_features=120)  # in_features 由 12*4*4 改变为 12*5*5self.fc2 = nn.Linear(in_features=120, out_features=60)self.out = nn.Linear(in_features=60, out_features=10)def forward(self, t):# conv1t = self.conv1(t)t = F.relu(t)  t = F.max_pool2d(t, kernel_size=2, stride=2)  # conv2t = self.conv2(t)t = F.relu(t)t = F.max_pool2d(t, kernel_size=2, stride=2)      t = t.reshape(batch_size, 12*5*5)  # dim1 由 12*4*4 改变为 12*5*5# fc1t = self.fc1(t)t = F.relu(t)# fc2t = self.fc2(t)t = F.relu(t)# output layert = self.out(t)return t
network = CNN5()
network.cuda()loss_func = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.SGD(network.parameters(), lr=0.1)  # 优化器def get_num_correct(preds, labels):  # get the number of correct timesreturn preds.argmax(dim=1).eq(labels).sum().item()

开始训练

total_epochs = 10for epoch in range(total_epochs):total_loss = 0total_train_correct = 0for batch in train_loader:         images, labels = batchimages = images.cuda()labels = labels.cuda()preds = network(images)loss = loss_func(preds, labels)optimizer.zero_grad()loss.backward()optimizer.step()       total_loss += loss.item()total_train_correct += get_num_correct(preds, labels)print("epoch:", epoch, "correct times:", total_train_correct,f"training accuracy:", "%.3f" %(total_train_correct/len(train_set)*100), "%", "total_loss:", "%.3f" %total_loss)
epoch: 0 correct times: 12042 training accuracy: 24.084 % total_loss: 412.031
epoch: 1 correct times: 19445 training accuracy: 38.890 % total_loss: 337.925
epoch: 2 correct times: 22616 training accuracy: 45.232 % total_loss: 304.583
epoch: 3 correct times: 24415 training accuracy: 48.830 % total_loss: 286.545
epoch: 4 correct times: 25799 training accuracy: 51.598 % total_loss: 271.240
epoch: 5 correct times: 26835 training accuracy: 53.670 % total_loss: 261.250
epoch: 6 correct times: 27906 training accuracy: 55.812 % total_loss: 249.286
epoch: 7 correct times: 28499 training accuracy: 56.998 % total_loss: 242.985
epoch: 8 correct times: 29324 training accuracy: 58.648 % total_loss: 235.174
epoch: 9 correct times: 29982 training accuracy: 59.964 % total_loss: 227.551

测试结果(测试准确率约 56% 左右)

total_test_correct = 0
total_loss = 0for batch in test_loader:images, labels = batchimages = images.cuda()labels = labels.cuda()preds = network(images)loss = loss_func(preds, labels)total_loss += losstotal_test_correct += get_num_correct(preds, labels)print("correct times:", total_test_correct, f"test accuracy:", "%.3f" %(total_test_correct/len(test_set)*100), "%","total_loss:", "%.3f" %total_loss)
correct times: 5671 test accuracy: 56.710 % total_loss: 49.122

2. 残差神经网络

2.1 残差神经网络基础

从以上结果可以看出,CIFAR-10 数据集的分类难度远高于 MNIST 数据集。理论上,增加其准确率的一个有效方法即增加神经网络的深度(层数),例如从上面的 6 层神经网络增加至 20 层左右。网络的深度越深,可抽取的特征层次就越丰富越抽象。

然而,事实证明有时网络层数并不是越深越好。如下图所示,是两个普通的深层卷积神经网络 (plain CNN) 在 CIFAR-10 上的训练和测试结果。两个神经网络的深度分别是 20 层和 56 层。

图片来源:Kaiming He et al. “Deep Residual Learning for Image Recognition”, 2015.

我们选择加深神经网络的层数是希望深层网络的表现能比浅层好,或者是希望它的表现至少和浅层网络持平,可实际的结果却不是这样的。从结果中可以看到,56 层的神经网络在训练集和测试集上的表现均明显差于 20 层的神经网络。这一现象被称为退化问题(degradation problem)。退化问题出现的原因是随着网络变深,网络优化变得更加困难。

深度残差网络 (Deep residual network, ResNet) 正是为了解决这个问题而提出的,它的提出是计算机视觉领域的一件里程碑式的事件。残差网络解决退化问题的关键即引入恒等映射 (identity mapping)。什么是恒等映射呢?我们来看一个简单的例子:

上图中,右边的神经网络可以理解为左边的浅层网络增加了三层框起来的部分。假如我们希望右边的深层网络与左边的浅层网络相比准确率可以持平,那么额外加上的三个神经层应当输入等于输出。我们假设这三层的输入为

【Pytorch(七)】基于 PyTorch 实现残差神经网络 ResNet相关推荐

  1. 深度学习——残差神经网络ResNet在分别在Keras和tensorflow框架下的应用案例

    原文链接:https://blog.csdn.net/loveliuzz/article/details/79117397 一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别 ...

  2. Pytorch实现残差神经网络(ResNet)

    1. 残差块 输入X,经过两次次卷积,一次ReLU,得到F(X),在将X与F(X)相加,在经过一个ReLU,即为最后的结果.残差神经网络就是基于残差块的一个深度神经网络. 2. 代码 这篇博客理论涉及 ...

  3. julia有 pytorch包吗_用 PyTorch 实现基于字符的循环神经网络 | Linux 中国

    导读:在过去的几周里,我花了很多时间用 PyTorch 实现了一个 char-rnn 的版本.我以前从未训练过神经网络,所以这可能是一个有趣的开始. 本文字数:7201,阅读时长大约: 9分钟 htt ...

  4. hot编码 字符one_用 PyTorch 实现基于字符的循环神经网络 | Linux 中国

    在过去的几周里,我花了很多时间用 PyTorch 实现了一个 char-rnn 的版本.我以前从未训练过神经网络,所以这可能是一个有趣的开始. 来源:https://linux.cn/article- ...

  5. pytorch dataloader_基于pytorch的DeepLearning入门流程

    基于pytorch的DeepLearning学习笔记 最近开始学深度学习框架pytorch,从最简单的卷积神经网络开始了解pytorch的框架.以下涉及到的代码完整版请查看https://github ...

  6. 残差神经网络Resnet(MNIST数据集tensorflow实现)

    简述: 残差神经网络(ResNet)主要是用于搭建深度的网络结构模型 (一)优势: 与传统的神经网络相比残差神经网络具有更好的深度网络构建能力,能避免因为网络层次过深而造成的梯度弥散和梯度爆炸. (二 ...

  7. 残差神经网络(ResNet)

    残差神经网络的主要贡献是发现了退化现象,并针对退化现象发明了快捷连接(shortcut connection),极大的消除了深度过大的神经网络训练困难问题. 1.神经网络越深准确率越高 假设一个层数较 ...

  8. 残差神经网络 ResNet

    上图为ResNet残差神经网络,目的是为了防止出现过优化的问题 比如上图中,已经达到了最优化的情况下,这时候已经最优状态了,在进行卷积会出现退化现象,所以这时候输出的H(x) = F(x) + x [ ...

  9. 深度学习笔记(三十五)残差神经网络ResNet

    训练深层神经网络时,如果深度特别大,其实是很难训练下去的,因为会遇到梯度消失和梯度爆炸的问题.残差网络可以帮助我们更好地训练深层神经网络. 一.残差块 在神经网络的两层中,会执行如下运算过程(主路径) ...

最新文章

  1. pytorch 卷积分组
  2. Python 根据地址获取经纬度及求距离
  3. JavaFX将会留下来!
  4. COMET彗星(三)构建自己的COMET核心
  5. 关于大数据和互联网的一点想法
  6. Github上如何找到自己想要的开源项目(小技巧:精确搜索)
  7. MySQL常用日期时间函数
  8. ikm java_ikm(IKM在线)
  9. vsphere报错: 连接到虚拟机控制台失败并显示错误:VMRC 控制台的连接已断开。正在尝试重新连接
  10. 贪心算法哈夫曼java_贪心算法_哈夫曼编码问题(Huffman Coding)
  11. 网站开发的需求分析报告
  12. 概率论 方差公式_考研冲刺篇|数学概率论
  13. 社区拼团赛道的突然火爆,究竟是受何因素影响?
  14. TM4C123G学习记录(3)--外部中断
  15. 数据库表设计-第三方登录用户表结构设计
  16. android 7双排设置菜单,联想拯救者电竞手机优化横屏UI 设置菜单呈左右双排显示...
  17. 人从哪里来又到哪里去
  18. 机器学习算法之二KD树
  19. unity shader 边缘光,内发光,外发光,轮廓边缘光,轮廓内边缘光,轮廓外边缘光
  20. Java 编写在线考试系统-049 窗体程序 完整源码

热门文章

  1. 阿里云提示安全组与 VPC 不匹配问题解决方案
  2. 往期精彩,爬取10亿票房的《西虹市首富》热评,一起来解读吧!
  3. 基于Python实现的业务数据分析系统
  4. Alien Skin Exposure X4官方版下载
  5. PowerBI开发 第九篇:修改查询
  6. 如何正确的使用Photoshop进行图像的二值化(详细步骤)刘博士
  7. R语言ggplot2 | PCA分析及其可视化
  8. PHP寻找文体多个关键字,grep同时抓取多个关键字或抓取多个关键字之一
  9. iMeta | 国际标准刊号ISSN在线版正式确认
  10. 《途客圈创业记:不疯魔,不成活》一一2.10 天使投资