timm 视觉库中的 create_model 函数详解

最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm。各位炼丹师应该已经想必已经对其无比熟悉了,本文将介绍其中最关键的函数之一:create_model 函数。

timm简介

PyTorchImageModels,简称timm,是一个巨大的PyTorch代码集合,包括了一系列:

  • image models
  • layers
  • utilities
  • optimizers
  • schedulers
  • data-loaders / augmentations
  • training / validation scripts

旨在将各种 SOTA 模型、图像实用工具、常用的优化器、训练策略等视觉相关常用函数的整合在一起,并具有复现ImageNet训练结果的能力。

源码:https://github.com/rwightman/pytorch-image-models

文档:https://fastai.github.io/timmdocs/

create_model 函数的使用及常用参数

本小节先介绍 create_model 函数,及常用的参数 **kwargs

顾名思义,create_model 函数是用来创建一个网络模型(如 ResNet、ViT 等),timm 库本身可供直接调用的模型已有接近400个,用户也可以自己实现一些模型并注册进 timm (这一部分内容将在下一小节着重介绍),供自己调用。

model_name

我们首先来看最简单地用法:直接传入模型名称 model_name

import timm
# 创建 resnet-34
model = timm.create_model('resnet34')
# 创建 efficientnet-b0
model = timm.create_model('efficientnet_b0')

我们可以通过 list_models 函数来查看已经可以直接创建、有预训练参数的模型列表:

all_pretrained_models_available = timm.list_models(pretrained=True)
print(all_pretrained_models_available)
print(len(all_pretrained_models_available))

输出:

[..., 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_small_patch16_224', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', 'xception41', 'xception65', 'xception71']
452

如果没有设置 pretrained=True 的话有将会输出612,即有预训练权重参数的模型有452个,没有预训练参数,只有模型结构的共有612个。

pretrained

如果我们传入 pretrained=True,那么 timm 会从对应的 URL 下载模型权重参数并载入模型,只有当第一次(即本地还没有对应模型参数时)会去下载,之后会直接从本地加载模型权重参数。

model = timm.create_model('resnet34', pretrained=True)

输出:

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/song/.cache/torch/hub/checkpoints/resnet34-43635321.pth

features_only、out_indices

create_mode 函数还支持 features_only=True 参数,此时函数将返回部分网络,该网络提取每一步最深一层的特征图。还可以使用 out_indices=[…] 参数指定层的索引,以提取中间层特征。

# 创建一个 (1, 3, 224, 224) 形状的张量
x = torch.randn(1, 3, 224, 224)
model = timm.create_model('resnet34')
preds = model(x)
print('preds shape: {}'.format(preds.shape))all_feature_extractor = timm.create_model('resnet34', features_only=True)
all_features = all_feature_extractor(x)
print('All {} Features: '.format(len(all_features)))
for i in range(len(all_features)):print('feature {} shape: {}'.format(i, all_features[i].shape))out_indices = [2, 3, 4]
selected_feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=out_indices)
selected_features = selected_feature_extractor(x)
print('Selected Features: ')
for i in range(len(out_indices)):print('feature {} shape: {}'.format(out_indices[i], selected_features[i].shape))

我们以一个 (1, 3, 224, 224) 形状的张量为输入,在视觉任务中,图像输入张量总是类似的形状。上面例程展示了,创建完整模型 model,创建完整特征提取器 all_feature_extractor,和创建某几层特征提取器 selected_feature_extractor 的具体输出。

可以结合下面 ResNet34 的结构图来理解(图中不同的颜色表示不同的 layer),根据下图分析各层的卷积操作,计算各层最后一个卷积的输入,并与上面例程的输出(附在图后)验证是否一致。

输出:

preds shape: torch.Size([1, 1000])
All 5 Features:
feature 0 shape: torch.Size([1, 64, 112, 112])
feature 1 shape: torch.Size([1, 64, 56, 56])
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])
Selected Features:
feature 2 shape: torch.Size([1, 128, 28, 28])
feature 3 shape: torch.Size([1, 256, 14, 14])
feature 4 shape: torch.Size([1, 512, 7, 7])

这样,我们就可以通过 timm_model 函数及其 features_onlyout_indices 参数将预训练模型方便地转换为自己想要的特征提取器。

接下来我们来看一下这些特征提取器究竟是什么类型:

import timm
feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[3])print('type:', type(feature_extractor))
print('len: ', len(feature_extractor))
for item in feature_extractor:print(item)

输出:

type: <class 'timm.models.features.FeatureListNet'>
len:  7
conv1
bn1
act1
maxpool
layer1
layer2
layer3

可以看到,feature_extractor 其实也是一个神经网络,在 timm 中称为 FeatureListNet,而我们通过 out_indices 参数来指定截取到哪一层特征。

需要注意的是,ViT 模型并不支持 features_only 选项(0.4.12版本)。

extractor = timm.create_model('vit_base_patch16_224', features_only=True)

输出:

RuntimeError: features_only not implemented for Vision Transformer models.

create_model 函数究竟做了什么

registry

在了解了 create_model 函数的基本使用之后,我们来深入探索一下 create_model 函数的源码,看一下究竟是怎样实现从模型到特征提取器的转换的。

create_model 主体只有 50 行左右的代码,因此所有这些神奇的事情是在其他地方完成的。我们知道 timm.list_models() 函数中的每一个模型名字(str)实际上都是一个函数。以下代码可以测试这一点:

import timm
import random
from timm.models import registrym = timm.list_models()[-1]
print(m)
registry.is_model(m)

输出:

xception71
True

实际上,在 timm 内部,有一个字典称为 _model_entrypoints 包含了所有的模型名称和他们各自的函数。比如说,我们可以通过 model_entrypoint 函数从 _model_entrypoints 内部得到 xception71 模型的构造函数。

constuctor_fn = registry.model_entrypoint(m)
print(constuctor_fn)

输出:

<function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>

也有可能输出:

<function xception71 at 0x7fc0cba0eca0>

一样的。

如我们所见,在 timm.models.xception_aligned 模块中有一个函数称为 xception71 。类似的,timm 中的每一个模型都有着一个这样的构造函数。事实上,内部的 _model_entrypoints 字典大概长这个样子:

_model_entrypoints
> >
{'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,}

所以说,在 timm 对应的模块中,每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet 模块中。因此,实际上我们有两种方式来创建一个 resnet34 模型:

import timm
from timm.models.resnet import resnet34# 使用 create_model
m = timm.create_model('resnet34')# 直接调用构造函数
m = resnet34()

但使用上,我们无须调用构造函数。所用模型都可以通过 create_model 函数来将创建。

Register model

resnet34 构造函数的源码如下:

@register_model
def resnet34(pretrained=False, **kwargs):"""Constructs a ResNet-34 model."""model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)return _create_resnet('resnet34', pretrained, **model_args)

我们会发现 timm 中的每个模型都有一个 register_model 装饰器。最开始, _model_entrypoints 是一个空字典。我们是通过 register_model 装饰器来不断地像其中添加模型名称和它对应的构造函数。该装饰器的定义如下:

def register_model(fn):# lookup containing modulemod = sys.modules[fn.__module__]module_name_split = fn.__module__.split('.')module_name = module_name_split[-1] if len(module_name_split) else ''# add model to __all__ in modulemodel_name = fn.__name__if hasattr(mod, '__all__'):mod.__all__.append(model_name)else:mod.__all__ = [model_name]# add entries to registry dict/sets_model_entrypoints[model_name] = fn_model_to_module[model_name] = module_name_module_to_models[module_name].add(model_name)has_pretrained = False  # check if model has a pretrained url to allow filtering on thisif hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:# this will catch all models that have entrypoint matching cfg key, but miss any aliasing# entrypoints or non-matching comboshas_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']if has_pretrained:_model_has_pretrained.add(model_name)return fn

我们可以看到, register_model 函数完成了一些比较基础的步骤,但这里需要指出的是这一句:

_model_entrypoints[model_name] = fn

它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.__name__。所以说 resnet34 函数上的装饰器 @register_model_model_entrypoints 中创建一个新的条目,像这样:

{’resnet34’: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}

我们同样可以看到在 resnet34 构造函数的源码中,在设置完一些 model_args 之后,它会随后调用 _create_resnet 函数。让我们再来看一下该函数的源码:

def _create_resnet(variant, pretrained=False, **kwargs):return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)

所以在 _create_resnet 函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 ResNet 、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。

default config

timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。

resnet34 的默认配置如下:

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc'}

此默认配置与其他参数(如构造函数类和一些模型参数)一起传递给 build_model_with_cfg 函数。

build model with config

这个 build_model_with_cfg 函数负责:

  1. 真正地实例化一个模型类来创建一个模型
  2. pruned=True,对模型进行剪枝
  3. pretrained=True,加载预训练模型参数
  4. features_only=True,将模型转换为特征提取器

看一下该函数的源码:

def build_model_with_cfg(model_cls: Callable,variant: str,pretrained: bool,default_cfg: dict,model_cfg: dict = None,feature_cfg: dict = None,pretrained_strict: bool = True,pretrained_filter_fn: Callable = None,pretrained_custom_load: bool = False,**kwargs):pruned = kwargs.pop('pruned', False)features = Falsefeature_cfg = feature_cfg or {}if kwargs.pop('features_only', False):features = Truefeature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))if 'out_indices' in kwargs:feature_cfg['out_indices'] = kwargs.pop('out_indices')model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)model.default_cfg = deepcopy(default_cfg)if pruned:model = adapt_model_from_file(model, variant)# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for featsnum_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))if pretrained:if pretrained_custom_load:load_custom_pretrained(model)else:load_pretrained(model,num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),filter_fn=pretrained_filter_fn, strict=pretrained_strict)if features:feature_cls = FeatureListNetif 'feature_cls' in feature_cfg:feature_cls = feature_cfg.pop('feature_cls')if isinstance(feature_cls, str):feature_cls = feature_cls.lower()if 'hook' in feature_cls:feature_cls = FeatureHookNetelse:assert False, f'Unknown feature class {feature_cls}'model = feature_cls(model, **feature_cfg)model.default_cfg = default_cfg_for_features(default_cfg)  # add back default_cfgreturn model

我们可以看到,模型在这一步被创建出来:model = model_cls(**kwargs)。本文将不再深入到 prunedadapt_model_from_file 内部查看。

总结

通过本文,我们已经完全了解了 create_model 函数,我们了解到:

  • 每个模型有不同的构造函数,可以传入不同的参数, _model_entrypoints 字典包括了所有的模型名称及其对应的构造函数
  • build_with_model_cfg 函数接收模型构造器类和其中的一些具体参数,真正地实例化一个模型
  • load_pretrained 会加载预训练参数
  • FeatureListNet 类可以将模型转换为特征提取器

Ref:

https://github.com/rwightman/pytorch-image-models

https://fastai.github.io/timmdocs/

https://fastai.github.io/timmdocs/create_model#Turn-any-model-into-a-feature-extractor

https://fastai.github.io/timmdocs/tutorial_feature_extractor

https://zhuanlan.zhihu.com/p/404107277

timm 视觉库中的 create_model 函数详解相关推荐

  1. python getattr_Python中的getattr()函数详解:

    标签:Python中的getattr()函数详解: getattr(object, name[, default]) -> value Get a named attribute from an ...

  2. python input函数详解_对Python3中的input函数详解

    下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函 ...

  3. Python中的bbox_overlaps()函数详解

    Python中的bbox_overlaps()函数详解 想要编写自己的目标检测算法,就需要掌握bounding box(边界框)之间的关系.在这之中,bbox_overlaps()函数是一个非常实用的 ...

  4. java的匿名函数_JAVA语言中的匿名函数详解

    本文主要向大家介绍了JAVA语言中的匿名函数详解,通过具体的内容向大家展示,希望对大家学习JAVA语言有所帮助. 一.使用匿名内部类 匿名内部类由于没有名字,所以它的创建方式有点儿奇怪.创建格式如下: ...

  5. 内核中的kmalloc函数详解

    一.kmalloc函数详解 #include <linux/slab.h> void *kmalloc(size_t size, int flags); 给 kmalloc 的第一个参数是 ...

  6. 前端如何设置背景颜色的透明度 css中的 rgba() 函数详解 :background-color: rgba(255,192,203,0.3)

    目录 前言 rgba() 函数 详解 再分享一个小技巧哈哈哈 前言 今天在开发移动端的时候感觉没背景颜色有点丑,再加上自己做的是蛋糕app,觉得暖色应该会很好看,于是就用了这行代码 backgroun ...

  7. Opencv中的imshow函数详解

    前言 使用opencv对图像进行处理之后,通常调用imshow函数来显示处理结果.但是,我们经常会发现显示结果和我们预期的结果有些差别.这是由于opencv经常会涉及到对多种图像数据类型的处理,如果我 ...

  8. linux内核中的hook函数详解,linux内核中的hook函数详解

    在编写linux内核中的网络模块时,用到了钩子函数也就是hook函数.现在来看看linux是如何实现hook函数的. 先介绍一个结构体: struct nf_hook_ops,这个结构体是实现钩子函数 ...

  9. linux hook 任意内核函数,linux内核中的hook函数详解

    在编写linux内核中的网络模块时,用到了钩子函数也就是hook函数.现在来看看linux是如何实现hook函数的. 先介绍一个结构体: struct nf_hook_ops,这个结构体是实现钩子函数 ...

最新文章

  1. shell之字符串操作
  2. matlab 画函数图像
  3. redis -Spring与Jedis集群 Sentinel
  4. [bzoj2055]80人环游世界[网络流,上下界网络流]
  5. 直播 | LiveVideoStack Meet杭州:后直播时代技术
  6. mysql时间戳在某天内_mysql根据时间戳查询指定日期内数据
  7. java打印结果横向排列_Java8排列组合(6行代码实现)
  8. 关于CSS兼容IE与Firefox要点分析
  9. 另一个进程已被死锁在资源上且该事务已被选作死锁牺牲品
  10. exchange 2003 event id 1221
  11. android中getSystemService详解
  12. 微信公众号自定义菜单
  13. SQL 数据库 学习 004 预备知识
  14. PCB线宽过流能力估算
  15. 二、通用、布局、导航组件
  16. 软件暴力破解的原理和破解经验
  17. GBASE 8s 用户标示与鉴别
  18. 两台设备(手动)设置相同的局域网IP地址会怎么样?
  19. 个人号微信二次开发,微信ipad协议
  20. GB28181系列笔记-语音对讲功能

热门文章

  1. 检测到目标FTP服务可匿名访问
  2. VBA GetOpenFilename 方法
  3. python open函数参数newline_Python open() 函数
  4. linux性能分析top iostat vmstat free,linux 性能篇 -- top用法(示例代码)
  5. C语言 函数指针 - C语言零基础入门教程
  6. mysql 如何添加索引_MySQL如何创建一个好索引?创建索引的5条建议【宇哥带你玩转MySQL 索引篇(三)】...
  7. 将字符转换成数字(atoi),将数字转换成字符(itoa)
  8. java applog_java - 通过Logback登录到App Engine request_log - SO中文参考 - www.soinside.com
  9. c语言getch() 头文件,用getch()需要头文件吗?
  10. win10打印机终结点映射器_用了就回不去?微软官方免费“外挂”,让win10好用到飞起...