1. 官方代码

FUSE_MODULES
TORCH.AO.QUANTIZATION.FUSE_MODULES的源代码

2. fuse_modules源码解读

仅融合以下序列:

  • conv, bn
  • conv, bn, relu
  • conv, relu
  • linear, relu
  • bn, relu
    网络中所有其他序列保持不变,对于上述序列,用融合的模块替换列表中的第一项,用identity替换其余模块。

fuse_modules

def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  • model:要进行操作的模型名称
  • modules_to_fuse:要融合的模块名称的列表。如果只有一个要融合的模块列表,可以是一个字符串列表,如:[‘conv1’, ‘bn1’, ‘relu’]
  • inplace:bool类型参数,默认为false。融合发生在模型上,默认会返回一个新模型
  • fuser_func:接收模块列表并输出相同长度的融合模块列表的函数。例如,fuser_func([convModule, BNModule]) 返回 [ConvBNModule, nn.Identity()] 。 默认为 fuse_known_modules
  • fuse_custom_config_dict :自定义配置,默认为none

fuse_known_modules

将给定的模块列表mod_list中的一些常见模块进行融合,返回融合后的模块列表。融合后的模块可以有效地减少模型计算量和内存占用,从而提高模型的计算效率。

参数

  • mod_list:一个包含了一系列PyTorch模块对象的列表,这些模块可以是常见的卷积、线性、批归一化等模块。
  • is_qat:指定模型是否使用量化感知训练(true使用,false不使用)
  • additional_fuser_method_mapping:一个可选的字典,用于指定额外的融合方法。字典的key是要融合的模块类型value是一个融合函数,它将被用于融合指定类型的模块。默认为None。
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):r"""Returns a list of modules that fuses the operations specifiedin the input module list.Fuses only the following sequence of modules:conv, bnconv, bn, reluconv, relulinear, bnlinear, reluFor these sequences, the first element in the output module list performsthe fused operation. The rest of the elements are set to nn.Identity()"""types = tuple(type_before_parametrizations(m) for m in mod_list)fuser_method = get_fuser_method(types, additional_fuser_method_mapping)if fuser_method is None:raise NotImplementedError("Cannot fuse modules: {}".format(types))new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)fused = fuser_method(is_qat, *mod_list)# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion# Move pre forward hooks of the base module to resulting fused modulefor handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():fused.register_forward_pre_hook(pre_hook_fn)del mod_list[0]._forward_pre_hooks[handle_id]# Move post forward hooks of the last module to resulting fused modulefor handle_id, hook_fn in mod_list[-1]._forward_hooks.items():fused.register_forward_hook(hook_fn)del mod_list[-1]._forward_hooks[handle_id]new_mod[0] = fusedfor i in range(1, len(mod_list)):identity = nn.Identity()identity.training = mod_list[0].trainingnew_mod[i] = identityreturn new_mod
  • 融合前,首先获取mod_list中每个模块的类型,并将它们作为一个元组存储在types变量中。这个元组中的类型用于选择要使用的模块融合方法。在默认情况下,该函数支持一些特定的模块序列进行融合。如果输入模块序列不符合这些支持的模式,则函数会尝试使用 additional_fuser_method_mapping 中定义的自定义融合函数fuser_method
  • 融合方法fuser_method :使用get_fuser_method() 函数根据types选择一个合适的融合函数
    – 在 get_fuser_method函数中调用了字典DEFAULT_OP_LIST_TO_FUSER_METHOD(定义了元组融合函数之间的映射关系)。下面仅展示部分2d模块融合
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,(nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,(nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),(nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
}
  • 如果在特定模块序列的additional_fuser_method_mapping中提供了自定义fuser函数,则将使用该函数来代替默认的fuser函数。如果找不到合适的fuser函数,该函数将引发NotImplementedError
  • 定义new_mod :使用 [None] * len(mod_list)创建一个长度为len(mod_list)的列表,这个列表中,每个元素都是一个nn.Module类型的可选对象,初始值为None。
  • 融合后的新模块fused:使用fuser_method调用对应的融合函数,如 fuse_conv_bn(is_qat, conv, bn)得到一个模块融合后的新的模块(ConvBn2d)。该模块包含了卷积层和BN层的参数,并将其组合成一个新的运算,该融合模块的名称默认为ConvBn2d、ConvBn1d或ConvBn3d。fuse_conv_bn函数在后面进行介绍。
  • 融合后,第一个for循环遍历 mod_list列表中第一个模块(mod_list[0])的handle_id(前向预处理钩子函数的ID)和hook_fn(前向预处理钩子函数,在模块前向传递时会被自动调用,用于执行某些操作,如记录中间结果、打印日志等。)。
    – 然后,将这些钩子函数注册fused模块中,使其能够在后续计算中被调用。
    – 接着,从mod_list[0]._forward_pre_hooks字典中删除这些钩子函数,避免这些钩子函数被重复调用。
  • 第一个for循环的作用是将mod_list列表中第一个模块前向预处理钩子函数从原始模块对象中转移到融合模块对象中,以确保在使用融合模块进行前向传递时,所有需要的操作都能够被执行。
  • 第二个for循环将mod_list列表中最后一个模块前向钩子函数注册到fused模块中,并从原始模块对象的钩子字典中删除这些钩子函数。
  • 前向预处理钩子函数不同,前向钩子函数是在模块的前向传递过程中执行的,通常用于在模块输出计算完成后执行某些操作,如统计模型输出分布、进行可视化等。
  • 最后,将融合好的fused模块赋给前面定义的new_mod 列表的第一个元素,最后使用for循环补充identity()到new_mod列表,使其长度和原始模块长度一致。

fuse_conv_bn

将给定的conv和bn模块融合并返回融合后的模块。

在此函数中构建了一个fused_module_class_map字典,用于指定模块类型与对应的融合模块类型之间的映射关系。

如果其类型在fused_module_class_map字典中有对应的融合模块类型,则将这些模块融合为一个新的模块(ConvBn2d),如果没有对应的融合模块类型,则不对其进行融合处理。

def fuse_conv_bn(is_qat, conv, bn):assert(conv.training == bn.training),\"Conv and BN both must be in the same mode (train or eval)."fused_module_class_map = {nn.Conv1d: nni.ConvBn1d,nn.Conv2d: nni.ConvBn2d,nn.Conv3d: nni.ConvBn3d,}if is_qat:assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'fused_module_class = fused_module_class_map.get((type(conv)), None)if fused_module_class is not None:return fused_module_class(conv, bn)else:raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))else:return nn.utils.fuse_conv_bn_eval(conv, bn)

返回调用的 fuse_conv_bn_eval(conv, bn) 函数如下

返回一个新的融合模块,该模块包含了卷积层和BN层的参数,并将其组合成一个新的运算。

def fuse_conv_bn_eval(conv, bn, transpose=False):assert(not (conv.training or bn.training)), "Fusion only for eval!"fused_conv = copy.deepcopy(conv)fused_conv.weight, fused_conv.bias = \fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)return fused_conv

3. fuse_modules实际测试

3.1 modules_to_fuse参数的使用方法

1. 此参数的列表可以包含多个需要融合的组合,子模块列表也可以,使用方法一

方法一:

modules_to_fuse = [ [‘conv1’, ‘bn1’, ‘relu1’], [‘submodule.conv’, ‘submodule.relu’]]

融合ResNet18中layer1的conv和bn层如下:

print('\n Before fusion \n\n', r18_o.layer1)r18_o.eval()r18 = torch.quantization.fuse_modules(r18_o,[['conv1', 'bn1', 'relu'],['layer1.0.conv1', 'layer1.0.bn1'], # , 'layer1.0.relu'],['layer1.0.conv2', 'layer1.0.bn2'],['layer1.1.conv1', 'layer1.1.bn1'], #, 'layer1.1.relu'],['layer1.1.conv2', 'layer1.1.bn2']]
)
print('\n After fusion\n\n', r18.layer1)

结果:

  • ResNet18融合前:(仅显示ResNet18中layer1的网络结构)

  • ResNet18融合后

    此融合只将Conv2d和BN层进行融合,从上面对比可以看到融合后的 (bn) 变成了 identity(),(conv) 中的Conv2d是原本Conv2d和BN融合的。

2. 如果要融合的module被Sequential封装了,可使用方法二

方法二:

torch.quantization.fuse_modules(m, [‘0’, ‘1’, ‘2’], inplace=True)

1. 使用方法二对ResNet18中模块进行融合操作,融合代码如下:

def fuse_model(self):for m in self.modules():if type(m) == BasicBlock:torch.quantization.fuse_modules(m, [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], inplace=True)

此处代码是仿pytorch官方写MobileNetV2模块融合,这部分代码写在 class ResNet(nn.Module) 中,后面融合直接使用model.fuse_model(),得到的方法二融合ResNet18结果如下:

此处是分别对(conv2d、bn、relu)和(conv2d、bn)进行融合融合

2. 使用方法二对MobileNetv2中模块进行融合操作

def fuse_model(self):for m in self.modules():if type(m) == ConvBNReLU:torch.quantization.fuse_modacules(m, ['0', '1', '2'], inplace=True)if type(m) == InvertedResidual:for idx in range(len(m.conv)):if type(m.conv[idx]) == nn.Conv2d:torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

结果

  • MobileNetv2融合前(下面结果展示的是第一个残差模块,因此没有第一个1x1的卷积)

  • MobileNetv2融合后

    从此对比可以看到,融合前的conv2d、bn、relu融合成了ConvRelu2d(Conv2d,ReLU),这里面的Conv2d是融合前的Conv2d和BN融合的。

pytorch中fuse_modules相关推荐

  1. pytorch中调整学习率的lr_scheduler机制

    pytorch中调整学习率的lr_scheduler机制 </h1><div class="clear"></div><div class ...

  2. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  3. PyTorch中的MIT ADE20K数据集的语义分割

    PyTorch中的MIT ADE20K数据集的语义分割 代码地址:https://github.com/CSAILVision/semantic-segmentation-pytorch Semant ...

  4. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  5. 利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型

    作者 | Comet 译者 | 天道酬勤,责编 | Carol 出品 | AI 科技大本营(ID:rgznai100) 这篇文章是由AssemblyAI的机器学习研究工程师Michael Nguyen ...

  6. 实践指南 | 用PyTea检测 PyTorch 中的张量形状错误

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨陈萍.泽南 来源丨机器之心 编辑丨极市平台 导读 韩国首尔大学 ...

  7. 实践教程 | 浅谈 PyTorch 中的 tensor 及使用

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...

  8. 详解PyTorch中的ModuleList和Sequential

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨小占同学@知乎(已授权) 来源丨https://zhuanla ...

  9. 在PyTorch中进行双线性采样:原理和代码详解

    ↑ 点击蓝字 关注视学算法 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/257958558 编辑丨极市平台 在pytorch中的双线性采样(Bilinear Sa ...

最新文章

  1. 路由器学习之静态路由实验
  2. WCF系列之.net(3.0/3.5)Rest使用示例
  3. 看过这么多爆文,依旧走不好异步编程这条路?​
  4. 学生档案c语言编程,学生档案管理问题
  5. 字符串匹配算法(BF RK)
  6. hdu 4381(背包变形)
  7. iis8.5限速没有效果怎么回事_电梯为何会发生坠梯?没有安全措施吗?能在井道底安装大弹簧吗?...
  8. 【Spring第七篇】Java配置类:JavaConfig
  9. 关于进程与线程的讲解 最最最生动的理解
  10. iOS 通讯录编程【总结】
  11. 电商金额计算的 4 个坑,千万注意了!
  12. msvcp71.dll、msvcr71.dll丢失解决方法
  13. 为什么应该学好软件工程?
  14. 台式机安装centos7系统
  15. 微信小程序自定义组件制作图表动画
  16. C语言:strcpy()---字符串复制
  17. sql语句,sql文件加注释
  18. Matlab:二维傅里叶变换
  19. BZOJ1064 NOI2008假面舞会
  20. 【王道训练营 C/C++方向基础 60 题(1-10)】

热门文章

  1. UI组件库Form表单_数字类型验证之坑实现数字框
  2. 时序分析基础(1)----寄存器时序分析模型
  3. c++ set使用(增删查遍历)
  4. 如何解决error: failed to push some refs to ‘git@github.com:......git pull冲突问题
  5. Thonny连接PiPico出现Device is busy or does not respond.解决方法
  6. DDR SDRAM原理介绍
  7. C语言直接驱动硬件实现PC机的串口操作
  8. javascript 方法 一直提示 对象不支持此属性或方法
  9. 可微和可导的关系,全微分、偏微分、偏导数
  10. linux下ffmpeg库 ARM交叉编译