摘要:

承接上一篇LeNet网络模型的图像分类实践,本次我们再来认识一个新的网络模型:ResNet-50。不同网络模型之间的主要区别是神经网络层的深度和层与层之间的连接方式,正文内容我们就分析下使用ResNet-50进行图像分类有什么神奇之处,以下操作使用MindSpore框架实现。

1.网络:ResNet-50

对于类似LeNet网络模型深度较小并且参数也较少,训练起来会相对简单,也很难会出现梯度消失或爆炸的情况。但ResNet-50的深度较大,训练起来就会比较困难,所以在加深网络深度的同时提出残差学习的结构来减轻深层网络训练的难度。重新构建了网络以便学习包含推理的残差函数,而不是学习未经过推理的函数。实验结果显示,残差网络更容易优化,并且加深网络层数有助于提高正确率。

  • 深度模型的限制

深度卷积网络在图像分类任务上有非常优秀的表现。深度网络依赖于多层端到端的方式,集成了低中高三个层次的特征和分类器,并且这些特征的数量还可以通过堆叠层数来增加。这也展示出了网络深度非常重要。

但是随着网络层数的增加,训练时就会遇到梯度消失或爆炸的情况,这会在一开始就影响收敛。收敛的问题可以通过正则化来得到部分的解决,但也不是通用的方法。并且在深层网络能够收敛的前提下,随着网络深度的增加,正确率开始饱和甚至下降,称之为网络的退化。

图1:56层和20层网络效果图

通过上图1可以发现在不改变网络结构的情况下,仅加深网络深度的56层网络相较于20层在误差上表现都更大。

  • ResNet-50的残差结构

对于网络退化现象并不是过拟合造成的。在给定的网络上增加层数就会增大训练误差。这说明不是所有的系统都很容易优化。我们可以先分析一个浅层的网络架构和在它基础上构建的深层网络,如果增加的所有层都是前一层的直接复制(即y=x),这种情况下深层网络的训练误差应该和浅层网络相等。因此,网络退化的根本原因还是优化问题。为了解决优化的难题,大佬们提出了残差网络,在ResNet-50中残差网络结构可分为Identity Block和Conv Block,下面分别介绍下。

Identity Block:在残差网络中,不是让网络直接拟合原先的映射,而是拟合残差映射。意味着后面的特征层的内容会有一部分由前面的某一层线性贡献。假设原始的映射为 H(x),残差网络拟合的映射为:F(x):=H(x)。输入和输出的维度(通道数和Size)是一样的,所以可以串联,它的主要用处是加深网络的深度。

图2:Identity Block结构

如图2中所示,identity mapping会直接跳过中间一些网络层,建立了一些快捷链接,直接恒等映射过去。这样的快捷链接不会增加模型的复杂度和参数。

Conv Block:在Identity Block的残差结构基础上,又增加了Conv的过程。输入和输出的维度(通道数和Size)是不一样的,所以不能进行连续的串联,它的作用是改变网络的维度,所以残差边上新增了卷积。

图3:Conv Block结构

如图3中所示,Conv Block将在残差的通道上经过一轮卷积处理。再将卷积处理后的结果给到后面的网络层中。

Conv Block的具体设置需要看Block的输入和输出,对照通道数和Size的变化,设定符合需求的Conv。

  • ResNet-50的整体结构

上面了解完了残差结构和用途,现在我们再带入到ResNet-50中看下整体的结构

图4:ResNet结构图

从左到右依次的分析,图4最左边是ResNet-50的步骤图,后面是将每个步骤再拆解Input stem是正常的输入和处理。Stage1->Stage4就是包含了加深网络深度的Identity Block和Conc Block的模块,同时避免了计算训练困难和网络的退化的问题。

  • ResNet-50的调用

MindSpore已上线支持该模型,我们可以直接调用该模型的接口,所以我们在使用过程中传入定义好的超参数和数据即可。

network = resnet50(class_num=10)

如果想要了解下更底层的参数设置,可以查看https://gitee.com/mindspore/models/blob/master/official/cv/resnet/config/resnet50_cifar10_config.yaml。

论文链接:https://arxiv.org/pdf/1512.03385.pdf

2.数据集:CIFAR-10

数据集CIFAR-10由10个类的60000个32x32彩**像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。

  • 数据集结构

CIFAR-10数据集的原文连接中包含三种类型的数据集,这里可以根据自己的需求进行下载。这里我们使用python版本数据集。

Version                                               Size

CIFAR-10 python version                                     163 MB

CIFAR-10 Matlab version                                     175 MB

CIFAR-10 binary version (suitable for C programs)                   162 MB

数据集中所包含的类别

图5:CIFAR-10类别图

  • 数据加载和处理

数据加载:下载完成后将数据集放在一个文件目录下,将目录传入到数据的加载过程中。

cifar_ds = ds.Cifar10Dataset(data_home)

数据增强:是对数据进行归一化和丰富数据样本数量。常见的数据增强方式包括裁剪、翻转、色彩变化等等。MindSpore通过调用map方法在图片上执行增强操作。

resize_height = 224resize_width = 224rescale = 1.0 / 255.0shift = 0.0# define map operationsrandom_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANTrandom_horizontal_op = C.RandomHorizontalFlip()resize_op = C.Resize((resize_height, resize_width)) # interpolation default BILINEARrescale_op = C.Rescale(rescale, shift)normalize_op = C.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))changeswap_op = C.HWC2CHW()type_cast_op = C2.TypeCast(mstype.int32)c_trans = []if training:c_trans = [random_crop_op, random_horizontal_op]c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]# apply map operations on imagescifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label")cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image")

最后通过数据混洗(shuffle)随机打乱数据的顺序,并按batch读取数据,进行模型训练。

# apply shuffle operationscifar_ds = cifar_ds.shuffle(buffer_size=10)# apply batch operationscifar_ds = cifar_ds.batch(batch_size=args_opt.batch_size, drop_remainder=True)# apply repeat operationscifar_ds = cifar_ds.repeat(repeat_num)

3.损失函数:SoftmaxCrossEntropyWithLogits

本次训练调用的损失函数是:SoftmaxCrossEntropyWithLogits。那为什么是SoftmaxCrossEntropyWithLogits损失函数呢?

  • 损失函数的选择

我们上面提到,为什么是使用SoftmaxCrossEntropyWithLogits损失函数呢,这要从我们本次的实验目的分析。

本次项目的:实现CIFAR-10图像数据集的分类。既然是分类,那么分类中的损失函数是怎么计算的,它是计算logits和标签之间的softmax交叉熵。使用交叉熵损失测量输入概率(使用softmax函数计算)与类别互斥(只有一个类别为正)的目标之间的分布误差,具体公式可以表示成

图6:SoftmaxCrossEntropyWithLogits表达式

  • 损失函数参数分析
  • logits (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32.
  • labels (Tensor) - Tensor of shape (N, ). If sparse is True, The type of labels is int32 or int64. Otherwise, the type of labels is the same as the type of logits.

第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes;第二个参数labels:实际的标签,大小同上。

  • 损失函数的使用
#在主函数处调用如下if __name__ == '__main__':ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")model = Model(loss_fn=ls)

更详细的使用请参考SoftmaxCrossEntropyWithLogits API链接:https://mindspore.cn/docs/api/zh-CN/master/api_python/nn/mindspore.nn.SoftmaxCrossEntropyWithLogits.html#mindspore.nn.SoftmaxCrossEntropyWithLogits

4.优化器:Momentum

本次训练中我们使用的是Momentum,也叫动量优化器。为什么是它?下面我们了解下它的计算原理。

  • 优化器的计算

图7:Momentum表达式

上面表达式中的grad、lr、p、v 和 u 分别表示梯度、learning_rate、参数、矩和动量。其中的梯度是通过损失函数求导得出的,在训练过程中得到的Loss是一个连续值,那么它就有梯度可求,并反向传播给每个参数。Momentum优化器的主要思想就是利用了类似移动指数加权平均的方法来对网络的参数进行平滑处理的,让梯度的摆动幅度变得更小。

  • 优化器的使用
#在主函数处调用如下if __name__ == '__main__':opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)model = Model(optimizer=opt)

更详细的使用请参考Momentum API链接:https://mindspore.cn/docs/api/zh-CN/master/api_python/nn/mindspore.nn.Momentum.html#mindspore.nn.Momentum

5.评价指标:Accuracy

损失函数的值虽然可以反应网络的性能,但对于图片分类的任务,使用精度可以更加准确的表示最终的分类结果。

  • 精度指标的选择

基于分类任务的考虑,我们使用简单的`分类正确数量/总数量`来表示,也就是Accuracy。精度表达式比较简单,也好理解。

图8:Accuracy达式

  • 精度的使用
#在主函数处调用即可if __name__ == '__main__':model = Model(metrics={'acc'})

更详细的使用请参考Accuracy API链接:https://mindspore.cn/docs/api/zh-CN/master/api_python/nn/mindspore.nn.Accuracy.html#mindspore.nn.Accuracy

总结:

本次内容是以图像分类任务为例,首先要了解下我们本次使用的模型结构以及要完成的目标,本次内容和LeNet网络图像分类的区别是使用网络和数据集的不同,所以可重点对照下两种网络结构。然后是选择设置失函数、优化器和精度这几部分,构成完整的训练。谢谢赏阅。

使用ResNet-50实现图像分类任务相关推荐

  1. resnet模型的图像分类结构图_ResNet - 2015年 ILSVRC 的赢家(图像分类,定位及检测)...

    本文为 AI 研习社编译的技术博客,原标题 : Review: ResNet - Winner of ILSVRC 2015 (Image Classification, Localization, ...

  2. Resnet实现CIFAR-10图像分类

    1.概述 计算机视觉是当前深度学习研究最广泛.落地最成熟的技术领域,在手机拍照.智能安防.自动驾驶等场景有广泛应用.从2012年AlexNet在ImageNet比赛夺冠以来,深度学习深刻推动了计算机视 ...

  3. 使用resNet网络 进行图像分类(jupyter notebook)

    这学期做了三次的CV把他贴出来, resNet网络的结构 import torch.nn as nn import torchclass BasicBlock(nn.Module):expansion ...

  4. Resnet 50 残差网络

    1.简述 resnet50是何凯明提出,能有效解决深度网络退化问题的一种结构,将输入的多重非线性变化拟合变成了拟合输入与输出的残差,变为恒等映射,50即50层 膜拜巨神:https://github. ...

  5. Resnet 50 完整可跑代码 pytorch

    import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision ...

  6. resnet 50 网络分析

  7. MXNet预训练模型下载 ResNet 50 101

    imagenet11k resnet-50-symbol.json resnet-50-0000.params resnet-101-symbol.json resnet-101-0000.param ...

  8. 图像分类篇:pytorch实现ResNet

    一.ResNet详解 ResNet网络是在2015年由微软实验室提出的,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名,获得COCO数据集中目标检测第一名,图像分类第一名. 在ResNe ...

  9. pytorch图像分类篇:6. ResNet网络结构详解与迁移学习简介

    前言 最近在b站发现了一个非常好的 计算机视觉 + pytorch 的教程,相见恨晚,能让初学者少走很多弯路. 因此决定按着up给的教程路线:图像分类→目标检测→-一步步学习用pytorch实现深度学 ...

  10. 【完结】16篇图像分类干货文章总结,从理论到实践全流程大盘点!

    专栏<图像分类>正式完结啦!我们从数据集展开讲解,由最基本的多类别图像分类一步步深入到细粒度图像分类.多标签图像分类,再到更加有难度的无监督图像分类,随后我们又对图像分类中面临的各种问题展 ...

最新文章

  1. OpenCV 【四】————Watershed Algorithm(图像分割)——分水岭算法的原理及实现
  2. 科学养猪的真正奥义,培养拯救人类的医学英雄
  3. 利用Servlet生成动态验证码
  4. linux+synaptics+驱动程序,Synaptics
  5. 处理Akka应用程序中的每个事件
  6. GeoServer自动发布地图服务
  7. 如何将dataset中的值赋值给datatable_金融行业实战项目:如何理解业务?
  8. 基础组件完善的今天,如何通过业务组件提效?
  9. 把关与服务的关系_泉州代做投标书-电子标书值得信赖 - 泉州广告服务
  10. 阿里云宗志刚:云网一体,新一代洛神云网络平台
  11. django优化--ORM优缺点
  12. AD14一般使用流程
  13. prolog与python_python中prolog事实词法分析器
  14. 《烈烈先秦》7、大秦的克星——侠将公子信陵君
  15. Welcome to MySQL Workbench:MySQL 复制表
  16. Python中统一快速更换变量的名称
  17. ccs: memory range overlaps existing memory range
  18. 信息学奥赛一本通 1947:【09NOIP普及组】细胞分裂 | 洛谷 P1069 [NOIP2009 普及组] 细胞分裂
  19. 已知直线方程,计算直线对应的向量
  20. vsftpd参数cmds_allowed

热门文章

  1. 【从零开始的大数据学习】Flink官方教程学习笔记(一)
  2. 玩转英伟达jetson系列(一)刷系统
  3. 15张超详细的Python学习路线图,纯良心分享,零基础学习宝典
  4. 弘辽科技:白象方便面也被野性消费了吗?
  5. AIX磁盘管理基础知识
  6. 2022最新微信小程序游戏:一起来找茬
  7. Spring_AOP(execution表达式)
  8. Adobe系列软件大全
  9. MGCtoken与IMtoken哪个好?安全吗?
  10. ApplePay对接java后台详细代码