PyTorch之VGG16网络结构详解以及源码解读
论文:Very Deep Convolutional Networks for Large-Scale Image Recognition
简单介绍
意义: 证明了增加小卷积核的个数以及网络深度可以提高分类结果的正确率。
预处理: 各通道减去RGB在训练集上的均值。
特点:
1)使用小的卷积核(3 × 3)叠加的形式代替大的卷积核(5 × 5 or 7 × 7)
2)卷积层不会改变layer大小,而是通过max pooling减小layer大小
3)网络层数比较深
优点:
1)网络结构简洁:整个网络都使用了同样大小的卷积核(3x3)和最大池化尺寸(2x2)
2)使用小的卷积核叠加的形式代替大的卷积核,表达能力更强,网络性能更佳
缺点:
1)网络参数比较多,训练时间过长,调参难度大。
2)需要的存储容量大,不利于部署。例如存储VGG16权重值文件的大小为500多MB。
其中需要说明的是2个3 x 3的卷积核是可以代替一个5 x 5的卷积核的,其示意图如下:
网络结构
其网络结构图如上图所示,均还有5个block,其中VGG系列包含了vgg11、vgg13、vgg16以及vgg19,其中后面的数字代表的是对应的网络结构中卷积层和全连接层的数量,例如vgg16中含有13个卷积层和3个全连接层。其中vgg11中的LRN(Local Response Normalization)表示局部响应归一化。源码讲解
首先要在电脑中安装torchvision,其源码可在torchvision下的models文件夹中找到,名为vgg.py
torchvision是pytorch框架中一个非常重要且好用的包,该包主要由三个子包组成,分别是:
torchvision.datasets;
torchvision.models;
torchvision.transforms
1)导入相应的包
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
2)所有的网络名称及其预训练好的参数文件
__all__ = ['VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19',
]model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth','vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth','vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth','vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth','vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}
3)vgg类的定义,其中features表示对应的所有卷积以及池化层,avgpool表示平均池化(池化分为平均池化以及最大池化),classifier表示全连接层,共三层,_initialize_weights函数表示对网络参数进行初始化
class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=True):super(VGG, self).__init__()self.features = featuresself.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)
4)该函数表示添加相应的卷积层以及池化层,其中nn.Sequential表示一个有序的容器,神经网络模块将按照在传入nn.Sequential的顺序依次被添加到计算图中执行。
def make_layers(cfg, batch_norm=False):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)
5)其中A、B、D、E分别表示vgg11、vgg13、vgg16以及vgg19,其中数字表示一个卷积层对应的输出通道数, ’ M ’ 表示池化层
cfgs = {'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
6)下面表示不同的vgg网络接口
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):if pretrained:kwargs['init_weights'] = Falsemodel = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)if pretrained:state_dict = load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict)return modeldef vgg11(pretrained=False, progress=True, **kwargs):r"""VGG 11-layer model (configuration "A") from`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)def vgg11_bn(pretrained=False, progress=True, **kwargs):r"""VGG 11-layer model (configuration "A") with batch normalization`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)def vgg13(pretrained=False, progress=True, **kwargs):r"""VGG 13-layer model (configuration "B")`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)def vgg13_bn(pretrained=False, progress=True, **kwargs):r"""VGG 13-layer model (configuration "B") with batch normalization`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)def vgg16(pretrained=False, progress=True, **kwargs):r"""VGG 16-layer model (configuration "D")`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)def vgg16_bn(pretrained=False, progress=True, **kwargs):r"""VGG 16-layer model (configuration "D") with batch normalization`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)def vgg19(pretrained=False, progress=True, **kwargs):r"""VGG 19-layer model (configuration "E")`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)def vgg19_bn(pretrained=False, progress=True, **kwargs):r"""VGG 19-layer model (configuration 'E') with batch normalization`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
参考链接:https://zhuanlan.zhihu.com/p/41423739
PyTorch之VGG16网络结构详解以及源码解读相关推荐
- EasyExcel使用详解与源码解读
EasyExcel使用详解 1.EasyExcel简单介绍 64M内存20秒读取75M(46W行25列)的Excel(3.0.2+版本) 2.EasyExcel和POI数据处理能力对比 3.使用Eas ...
- hadoop作业初始化过程详解(源码分析第三篇)
(一)概述 我们在上一篇blog已经详细的分析了一个作业从用户输入提交命令到到达JobTracker之前的各个过程.在作业到达JobTracker之后初始化之前,JobTracker会通过submit ...
- SpringMVC异常处理机制详解[附带源码分析]
SpringMVC异常处理机制详解[附带源码分析] 参考文章: (1)SpringMVC异常处理机制详解[附带源码分析] (2)https://www.cnblogs.com/fangjian0423 ...
- 详解LAMP源码编译安装
实战:LAMP源码编译安装 家住海边喜欢浪:zhang789.blog.51cto.com 目录 详解LAMP源码编译安装 LAMP简介 一.准备工作 二.编译安装 Apache 三.编译安装 MyS ...
- 详解 Python 源码之对象机制
在Python中,对象就是在堆上申请的结构体,对象不能是被静态初始化的,并且也不能是在栈空间上生存的.唯一的例外就是类型对象(type object),Python中所有的类型对象都是被静态初始化的. ...
- spark RDD详解及源码分析
spark RDD详解及源码分析 @(SPARK)[spark] spark RDD详解及源码分析 一基础 一什么是RDD 二RDD的适用范围 三一些特性 四RDD的创建 1由一个已经存在的scala ...
- spark 调度模块详解及源码分析
spark 调度模块详解及源码分析 @(SPARK)[spark] spark 调度模块详解及源码分析 一概述 一三个主要的类 1class DAGScheduler 2trait TaskSched ...
- FPGA学习之路—接口(2)—I2C协议详解+Verilog源码分析
FPGA学习之路--I2C协议详解+Verilog源码分析 定义 I2C Bus(Inter-Integrated Circuit Bus) 最早是由Philips半导体(现被NXP收购)开发的两线时 ...
- linux设备驱动开发详解源码,linux设备驱动开发详解光盘源码.rar
压缩包 : linux设备驱动开发详解光盘源码.rar 列表 19/busybox源代码/busybox-1.2.1.tar.bz2 19/MTD工具/mtd-utils-1.0.0.tar.gz 1 ...
- FreeRTOS之Tracealyzer for FreeRTOS(FreeRTOS+Trace) 详解(源码解析+移植)
源:FreeRTOS之Tracealyzer for FreeRTOS(FreeRTOS+Trace) 详解(源码解析+移植)
最新文章
- Installshield在安装结束时刷新系统
- 启动FastDFS服务,使用python客户端对接fastdfs完成上传测试
- 同一主机的多个子进程使用同一个套接字_如何在Go语言中使用Websockets:最佳工具与行动指南...
- JUC之volatile
- SET-Priority_Queue
- 三目运算符?:结合性
- cloverconfig机型修改_CloverConfig新手设置教程.doc
- 蜂鸣器发声程序c语言,基于51单片机蜂鸣器发声的C语言程序
- 美国旧金山共享单车数据分析
- 一些真正免费的API接口
- 数据库原理与应用(五)专门的关系运算
- Windows 7怎么让电脑定时关机?Windows 7怎么取消自动关机?
- day 19 C# 窗体金额换算
- GitHub 热榜:天才黑客开源新项目,不到 1000 行代码,1400 Star!
- EXCEL图表:使用excel画坐标轴图
- 信息收集:CCF大学生计算机系统与程序设计竞赛(CCSP)
- B站在计算机内部,如何在Bilibili中弹幕式播放本地视频?
- html5课件动画制作,从此再也不担心课件/动画的开发了!
- v-model的底层原理
- 基于STM32的新西达电调和无刷电机