文章目录:

  • 一、VGG模型简单介绍
  • 二、PyTorch源码分析
  • 三、预训练模型的使用

本文是以VGG模型为例,深入介绍了完整的模型搭建过程,以及预训练模型使用过程,希望本篇博客可以解答一些困惑,同时欢迎大家改错提意见。

一、VGG模型简单介绍

简单来说,VGG(Very Deep Convolutional Networks for Large-Scale Image Recognition)这篇论文的工作就是:通过添加更多的卷积层来增加神经网络的深度(depth),取得了不错的效果。VGG模型获得了2014年在ImageNet分类任务上的第一名。
下面介绍VGG的基本框架,框架配置如下图,加粗部分为新增加的层(下图来源于以上论文):
如上图所示,作者一共构建了6个基本框架,命名为A、A-LRN、B、C、D、E,层数由11到19层,每一列代表一个模型。VGG模型由两大部分组成,图中上部分由多个3x3卷积层堆叠而成,下部分是三个全连接层和一个softmax层。
观察所有模型,发现其中有两个特殊的:模型A-LRN使用了局部响应归一化(Local Response Normalization);模型C使用了1x1卷积层。这两处不同之处本文只指出,不详细展开,读者可自行了解。
本文主要集中于其他几个模型上。还有一点需要指出的是,原论文中并没有使用Batch Normalization,但现在普遍会使用BN来处理。

二、PyTorch源码分析

虽然是VGG模型的源码分析,但更多的是学习代码结构,为以后编写自己的模型打下基础。
原论文中有6个模型,为简单起见,我们只以模型A(vgg11)为例展开。查看源代码发现主要有两个类实现VGG模型:

  • 第一个VGG(nn.Module)类,实现VGG模型的构建。
  • 第二个VGG11_Weights(WeightsEnum)类,说白了,就是原作者将模型预训练权重和其他配置等打包好,放在该类中提供给我们使用。所以,实现我们自己模型一般不用实现该类。

好了,我们的首要任务就是实现第一个类,先看代码。

class VGG(nn.Module):def __init__(self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5) -> None:super().__init__()self.features = featuresself.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=dropout),nn.Linear(4096, num_classes),)#############注意################# 权重初始化部分在后面讲解,暂时跳过。if init_weights:for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)# 前向传递def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

乍一看,可能就蒙了,每个变量后面的分号是什么,函数后面的箭头又是什么?其实这是python的一种注解方式,提高可读性,运行时并不会执行。分号(:)后是变量的类型,箭头(->)后是函数的返回值类型。
需要注意,这个类的初始化的第一个参数为feature,类型是nn.Module,所以我们需要传入一个nn.Module,也就是上图VGG模型中的所有卷积层。然后才是 pooling 层(代码使用的是averagepool,论文使用的是maxpool),输出大小为7x7。最后接多个全连接层(classifier)。
所以,至少还得构建一个函数来创建卷积层,如下创建make_layers函数来实现。

def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:layers: List[nn.Module] = []in_channels = 3for v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:v = cast(int, v)conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)

首先看传入参数:

  • cfg:是一个List类型,元素是str和int两种类型,通过传入 cfg 来控制所有的卷积层。下面字典 cfgs 中的 A\B\D\E控制着卷积层参数,字符串 “M” 代表maxpool层,其余数字代表每一个卷积层输出channel数。
    例如要使用模型A(vgg11),就传入参数 cfg = cfgs["A"],一共有8个数字代表共用8个卷积层,最后还有三个全连接层,共11层。
cfgs: Dict[str, List[Union[str, int]]] = {"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],"D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],"E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
}
  • bacth_norm:是否使用BN,如果为True就在每一个卷积层后面增加一个BN层。现在一般在训练模型时会使用BN层,在评估时关闭BN层,因为增加BN层可以加速收敛,如果在评估时BN层继续动态计算方差(std)和均值(mean),那么std\mean不能反映整个数据集,所以会导致准确率降低。所以,在评估时BN层的std、mean 和 β、γ 都不会改变了。

    make_layers这个函数使用循环构建每一个layer,然后添加到 list (函数中的变量是layers)中,最后提取出来返回。
    补充一下python知识:
layers = [1, 2, 3, 4, 5]
# 直接输出打印list
print(layers)  # 输出:[1,2,3,4,5]
# 提取出list并打印
print(*layers)  # 输出:1,2,3,4,5
# 后面还会有字典dict的提取,通过两个星号(**dict)

读到这里,我们已经能够构造VGG的基本骨架了,虽然还有些细节还没介绍。我们可以直接实现一下:

feature = make_layer(cfds["A"], batch_norm=True)
# num_classes代表有多少个类别
model = VGG(features, num_classes=1000, init_weights=True, dropout=0.5)
# 然后就可以选择数据进行训练了。。。本文不介绍训练部分

在进行下一部分介绍前,先补充之前遗留下来的部分。上面代码VGG类中的权重初始化还没有介绍,下面简单介绍一下。
首先是一个for循环:for m in self.modules():,其中nn.modules()返回该神经网络中所有的模块(module),重复模块只会返回一次。
需要初始化的有三部分:卷积层(conv)、线性层(linear)和BN层。

  • 当为卷积层时,使用He初始化方法,该方法有CV大佬何恺明在论文Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification中描述的。这里也不展开该方法,在本文中只需要知道是一种初始化方法即可。
  • 当为线性层的时候,使用正态分布N(mean=0,std=0.01)给权重(weight)赋值,偏差(bias)为常数0。
  • 当为BN层时,weight 为1,bias 为0。其实就是 γ = 1,β = 0 ,大小和输入大小一样,这两个参数是需要模型去学习的。

自此,VGG模型的构造就基本完备了!注意:以下部分源码分析可以跳读,不需要完全理解。

下面进入第二个类(VGG11_Weights)的分析。 同理,先上代码:

_COMMON_META = {"min_size": (32, 32),"categories": _IMAGENET_CATEGORIES,"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg","_docs": """These weights were trained from scratch by using a simplified training recipe.""",
}class VGG11_Weights(WeightsEnum):IMAGENET1K_V1 = Weights(url="https://download.pytorch.org/models/vgg11-8a719046.pth",transforms=partial(ImageClassification, crop_size=224),meta={**_COMMON_META,"num_params": 132863336,"_metrics": {"ImageNet-1K": {"acc@1": 69.020,"acc@5": 88.628,}},},)DEFAULT = IMAGENET1K_V1

由于以上代码对我们实际使用帮助不大,说以笔者也未深入挖掘,只提供一个思路。这里,主要就是一个基类WeightsEnum,pytorch中的每一个模型的权重类都继承与该类,就像本文中VGG11_Weights类也继承于WeightsEnum类。代码中的Weights也是一个类,该类就像一个容器一样存储模型的信息,有三个参数:url、transformers、meta。

  • url:预训练权重下载地址(str)
  • transfomers:对模型的预处理方法(callable)。在别人使用你的模型进行预训练 前,需要对图片进行预处理(例如:resize with right resolution/interpolation, apply inference transforms, rescale the values etc)。因为不同的模型有不同的图片预处理方式,而且使用不恰当的预处理会导致准确率下降,所有我们需要给用户提供一个预处理方法。该方法可以通过Weight.transforms()属性获取。
from torchvision.models.vgg import VGG11_BN_Weights
# Initialize the Weight Transforms
weights = VGG11_Weights.DEFAULT
preprocess = weights.transforms()
print(preprocess)# Apply it to the input image
img_transformed = preprocess(img)

打印结果:

ImageClassification(
crop_size=[224]
resize_size=[256]
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
interpolation=InterpolationMode.BILINEAR
)

  • meta:存储模型相关权重和配置(dict[str,any])

处理将transformers提供给用户进行预处理,还要让用户能够获取预训练权重。下面介绍怎么导入预训练权重。


def _vgg(cfg, batch_norm, weights, progress, **kwargs):# 如果传入了预训练权重就不需要初始化if weights is not None:kwargs["init_weights"] = Falseif weights.meta["categories"] is not None:kwargs["num_classes"] = len(weights.meta["categories"])model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)if weights is not None:model.load_state_dict(weights.get_state_dict(progress=progress))return modeldef vgg11_bn(weights, progress=True, **kwargs):weights = VGG11_BN_Weights.verify(weights)return _vgg("A", True, weights, progress, **kwargs)

上述代码定义了两个函数,主要实现在函数_vgg()中,我们直接调用函数vgg11_bn()就创建完成VGG模型了。

# 初始化模型,progress表示是否显示下载进度条
model = vgg11_bn(weights, progress=True)
# 打印模型
print(model)

输出VGG模型:

VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): ReLU(inplace=True)
(11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(13): ReLU(inplace=True)
(14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(17): ReLU(inplace=True)
(18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(19): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(20): ReLU(inplace=True)
(21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(22): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(23): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(27): ReLU(inplace=True)
(28): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)

可以使用tensorboard显示网络模型:

 from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("log")# 随便构造一个图片img = torch.rand([3, 500, 500])img = preprocess(img)print("图片预处理后大小:", img.shape)# 增加 bacthsize 维度,bacthsize = 1img = torch.unsqueeze(img, dim=0)writer.add_graph(model, input_to_model=img, verbose=True, use_strict_trace=False)writer.close()

然后打开Terminal,运行tensorbord,如下图所示。注意:log 是代码中SummaryWriter存放的地址,可以改变。

点击网址打开就可以看到详细的网络结构图:

每一个节点可以点开,查看更详细的结构:

三、预训练模型的使用

  1. 初始化预训练模型(以resnet50为例)
from torchvision.models import resnet50, ResNet50_Weights# 使用预训练的权重:(以下几个都是等价的)
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights=ResNet50_Weights.DEFAULT)
resnet50(weights="IMAGENET1K_V1")resnet50(pretrained=True)  # 过时,不推荐
resnet50(True)  # 过时,不推荐# 不使用预训练权重:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False)  # 过时,不推荐
resnet50(False)  # 过时,不推荐
  1. 使用预训练模型
    同本文第二部分中介绍的一致,需要先对图片进行预处理。
# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()# Apply it to the input image
img_transformed = preprocess(img)

由于某些神经网络模块只有在训练的时候才启用,而在评估的时候不使用(例如 batch normalization、dropout等)。所以,我们需要在这两种模式(mode)中切换,需要训练的时候使用model.train(),评估的时候使用model.eval()

# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)# 当切换到evaluate mode,就会关闭dropout和BN层
model.eval()

下面完成一个完整的例子:

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weightsimg = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()  # 表示不进行训练# Step 2: Initialize the inference transforms
preprocess = weights.transforms()# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)  # 数据需要传入一个bacth# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)  # 对结果进行softmax处理从而分类
class_id = prediction.argmax().item()  # 找到概率最大值的索引,然后取出来
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

PyTorch模型搭建和源码详解相关推荐

  1. 第28课:彻底解密Spark Sort-Based Shuffle排序具体实现内幕和源码详解

    第28课:彻底解密Spark Sort-Based Shuffle排序具体实现内幕和源码详解 本文根据家林大神系列课程编写 http://weibo.com/ilovepains 为什么讲解Sorte ...

  2. Android OkHttp使用和源码详解

    介绍 OkHttp 是一套处理 HTTP 网络请求的依赖库,由 Square 公司设计研发并开源,目前可以在 Java 和 Kotlin 中使用.对于 Android App 来说,OkHttp 现在 ...

  3. C++ priority_queue 的使用和源码详解

    目录 简介 priority_queue 的使用 泛型算法make_heap().push_heap().pop_heap() make_heap() push_heap() pop_heap() 简 ...

  4. Lottie的使用和源码详解

    一.说在前面的话 在Android开发中,Coder要兼顾各个模块的建设维护,当然也少不了动画的制作,为了让界面使用更为友善,一般会由设计狮的一番设计后交由开发者在App重现出来.开发着在开发动画的同 ...

  5. Android setFocusableInTouchMode 方法使用和源码详解

    是什么 一般点击一个button,就会执行onclick 事件, 但是有些情况,我们想要点击button之后, 先获取焦点,然后再次点击一次,才执行onClick 事件.这时候,setFocusabl ...

  6. YOLOv5的模型构建源码详解|CSDN创作打卡

    深度学习入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删. 代码分析注释全家桶部分只是为了方便看循环,条件判断的那些缩进对应,与二.三.四讲的东西 ...

  7. 【多输入模型 Multiple-Dimension 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人 (6/10)】

    多输入模型 Multiple-Dimension 数学原理分析以及源码源码详解 深度学习 Pytorch笔记 B站刘二大人(6/10) 数学推导 在之前实现的模型普遍都是单输入单输出模型,显然,在现实 ...

  8. 【 数据集加载 DatasetDataLoader 模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 (7/10)】

    数据集加载 Dataset&DataLoader 模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 (7/10) 模块介绍 在本节中没有关于数学原理的相关介绍,使用的数据集和类型 ...

  9. 【分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人(8/10)】

    分类器 Softmax-Classifier softmax数学原理与源码详解 深度学习 Pytorch笔记 B站刘二大人 (8/10) 在进行本章的数学推导前,有必要先粗浅的介绍一下,笔者在广泛查找 ...

最新文章

  1. Analytic Marching:一种基于解析的三维物体网格生成方法
  2. 查看IIS上面的每个网站分别用了多少内存
  3. Hadoop Writable机制
  4. java 全局变量_Javascript中的局部变量、全局变量的详解与var、let的使用区别
  5. 【转】分布式事务的常见解决方案
  6. OpenSSL加密与证书
  7. html实现验证码效果,js实现验证码功能
  8. 【SpringMVC】返回视图中包含数据(ModelAndView)
  9. PyCharm LicenseServer 破解
  10. Oracle数据库异常--- oracle_10g_登录em后,提示java.lang.Exception_Exception_in_sending_Request__null或Connection
  11. mysql查询语句理解
  12. 2018 CodeM初赛B轮:D.神奇盘子
  13. 物流公司货运配送管理系统设计
  14. 虚拟机服务器断网,Vmware虚拟机断网不能上网的解决方法教程[多图]
  15. mysql忘记密码重新设置步骤详解
  16. 饿了么小程序容器首屏秒开优化实践
  17. lambda分组集合中list和set区别
  18. 用python对excel进行打印操作
  19. 首涂第八套苹果CMSv10自适应视频模板原创4种颜色风格一键切换
  20. verilog行为级建模(1)

热门文章

  1. 基于springboot的少儿识字系统
  2. java new thread()_(一)java多线程之Thread
  3. 用户登录如何给密码加密xxtea.js
  4. 全网最强下载神器IDM使用教程:如何利用IDM加速下载百度网盘大文件
  5. 踩坑记录 PIL与Opencv读取图像的差别
  6. 一汽大众android面试题,一汽大众面试题
  7. 237删除链表中的节点(单链表基本操作)
  8. gitlab下载安装使用,rpm包
  9. 通过bat批处理命令进行adb push批量拉取文件
  10. ROS节点无法读入launch参数问题