摘要:本文旨在分享Pytorch->Caffe->om模型转换流程。

标准网络

Baseline:PytorchToCaffe

主要功能代码在:

PytorchToCaffe
+-- Caffe
|   +-- caffe.proto
|   +-- layer_param.py
+-- example
|   +-- resnet_pytorch_2_caffe.py
+-- pytorch_to_caffe.py

直接使用可以参考resnet_pytorch_2_caffe.py,如果网络中的操作Baseline中都已经实现,则可以直接转换到Caffe模型。

添加自定义操作

如果遇到没有实现的操作,则要分为两种情况来考虑。

Caffe中有对应操作

以arg_max为例分享一下添加操作的方式。

首先要查看Caffe中对应层的参数:caffe.proto为对应版本caffe层与参数的定义,可以看到ArgMax定义了out_max_val、top_k、axis三个参数:

message ArgMaxParameter {// If true produce pairs (argmax, maxval)optional bool out_max_val = 1 [default = false];optional uint32 top_k = 2 [default = 1];// The axis along which to maximise -- may be negative to index from the// end (e.g., -1 for the last axis).// By default ArgMaxLayer maximizes over the flattened trailing dimensions// for each index of the first / num dimension.optional int32 axis = 3;
}

与Caffe算子边界中的参数是一致的。

layer_param.py构建了具体转换时参数类的实例,实现了操作参数从Pytorch到Caffe的传递:

def argmax_param(self, out_max_val=None, top_k=None, dim=1):argmax_param = pb.ArgMaxParameter()if out_max_val is not None:argmax_param.out_max_val = out_max_valif top_k is not None:argmax_param.top_k = top_kif dim is not None:argmax_param.axis = dimself.param.argmax_param.CopyFrom(argmax_param)

pytorch_to_caffe.py中定义了Rp类,用来实现Pytorch操作到Caffe操作的变换:

class Rp(object):def __init__(self, raw, replace, **kwargs):self.obj = replaceself.raw = raw
​def __call__(self, *args, **kwargs):if not NET_INITTED:return self.raw(*args, **kwargs)for stack in traceback.walk_stack(None):if 'self' in stack[0].f_locals:layer = stack[0].f_locals['self']if layer in layer_names:log.pytorch_layer_name = layer_names[layer]print('984', layer_names[layer])breakout = self.obj(self.raw, *args, **kwargs)return out

在添加操作时,要使用Rp类替换操作:

torch.argmax = Rp(torch.argmax, torch_argmax)

接下来,要具体实现该操作:

def torch_argmax(raw, input, dim=1):x = raw(input, dim=dim)layer_name = log.add_layer(name='argmax')top_blobs = log.add_blobs([x], name='argmax_blob'.format(type))layer = caffe_net.Layer_param(name=layer_name, type='ArgMax',bottom=[log.blobs(input)], top=top_blobs)layer.argmax_param(dim=dim)log.cnet.add_layer(layer)return x

即实现了argmax操作Pytorch到Caffe的转换。

Caffe中无直接对应操作

如果要转换的操作在Caffe中无直接对应的层实现,解决思路主要有两个:

1)在Pytorch中将不支持的操作分解为支持的操作:

如nn.InstanceNorm2d,实例归一化在转换时是用BatchNorm做的,不支持 affine=True 或者track_running_stats=True,默认use_global_stats:false,但om转换时use_global_stats必须为true,所以可以转到Caffe,但再转om不友好。

InstanceNorm是在featuremap的每个Channel上进行归一化操作,因此,可以实现nn.InstanceNorm2d为:

class InstanceNormalization(nn.Module):def __init__(self, dim, eps=1e-5):super(InstanceNormalization, self).__init__()self.gamma = nn.Parameter(torch.FloatTensor(dim))self.beta = nn.Parameter(torch.FloatTensor(dim))self.eps = epsself._reset_parameters()
​def _reset_parameters(self):self.gamma.data.uniform_()self.beta.data.zero_()
​def __call__(self, x):n = x.size(2) * x.size(3)t = x.view(x.size(0), x.size(1), n)mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)gamma_broadcast = self.gamma.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)beta_broadcast = self.beta.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)out = (x - mean) / torch.sqrt(var + self.eps)out = out * gamma_broadcast + beta_broadcastreturn out

但在验证HiLens Caffe算子边界中发现,om模型转换不支持Channle维度之外的求和或求均值操作,为了规避这个操作,我们可以通过支持的算子重新实现nn.InstanceNorm2d:

class InstanceNormalization(nn.Module):def __init__(self, dim, eps=1e-5):super(InstanceNormalization, self).__init__()self.gamma = torch.FloatTensor(dim)self.beta = torch.FloatTensor(dim)self.eps = epsself.adavg = nn.AdaptiveAvgPool2d(1)
​def forward(self, x):n, c, h, w = x.shapemean = nn.Upsample(scale_factor=h)(self.adavg(x))var = nn.Upsample(scale_factor=h)(self.adavg((x - mean).pow(2)))gamma_broadcast = self.gamma.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)beta_broadcast = self.beta.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x)out = (x - mean) / torch.sqrt(var + self.eps)out = out * gamma_broadcast + beta_broadcastreturn out

经过验证,与原操作等价,可以转为Caffe模型

2)在Caffe中通过利用现有操作实现:

在Pytorch转Caffe的过程中发现,如果存在featuremap + 6这种涉及到常数的操作,转换过程中会出现找不到blob的问题。我们首先查看pytorch_to_caffe.py中add操作的具体转换方法:

def _add(input, *args):x = raw__add__(input, *args)if not NET_INITTED:return xlayer_name = log.add_layer(name='add')top_blobs = log.add_blobs([x], name='add_blob')if log.blobs(args[0]) == None:log.add_blobs([args[0]], name='extra_blob')else:layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs)layer.param.eltwise_param.operation = 1 # sum is 1log.cnet.add_layer(layer)return x

可以看到对于blob不存在的情况进行了判断,我们只需要在log.blobs(args[0]) == None条件下进行修改,一个自然的想法是利用Scale层实现add操作:

def _add(input, *args):x = raw__add__(input, *args)if not NET_INITTED:return xlayer_name = log.add_layer(name='add')top_blobs = log.add_blobs([x], name='add_blob')if log.blobs(args[0]) == None:layer = caffe_net.Layer_param(name=layer_name, type='Scale',bottom=[log.blobs(input)], top=top_blobs)layer.param.scale_param.bias_term = Trueweight = torch.ones((input.shape[1]))bias = torch.tensor(args[0]).squeeze().expand_as(weight)layer.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())log.cnet.add_layer(layer)else:layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)layer.param.eltwise_param.operation = 1  # sum is 1log.cnet.add_layer(layer)return x

类似的,featuremap * 6这种简单乘法也可以通过同样的方法实现。

踩过的坑

  • Pooling:Pytorch默认 ceil_mode=false,Caffe默认 ceil_mode=true,可能会导致维度变化,如果出现尺寸不匹配的问题可以检查一下Pooling参数是否正确。另外,虽然文档上没有看到,但是 kernel_size > 32 后模型虽然可以转换,但推理会报错,这时可以分两层进行Pooling操作。
  • Upsample :om边界算子中的Upsample 层scale_factor参数必须是int,不能是size。如果已有模型参数为size也会正常跑完Pytorch转Caffe的流程,但此时Upsample参数是空的。参数为size的情况可以考虑转为scale_factor或用Deconvolution来实现。
  • Transpose2d:Pytorch中 output_padding 参数会加在输出的大小上,但Caffe不会,输出特征图相对会变小,此时反卷积之后的featuremap会变大一点,可以通过Crop层进行裁剪,使其大小与Pytorch对应层一致。另外,om中反卷积推理速度较慢,最好是不要使用,可以用Upsample+Convolution替代。
  • Pad:Pytorch中Pad操作很多样,但Caffe中只能进行H与W维度上的对称pad,如果Pytorch网络中有h = F.pad(x, (1, 2, 1, 2), "constant", 0)这种不对称的pad操作,解决思路为:
  1. 如果不对称pad的层不存在后续的维度不匹配的问题,可以先判断一下pad对结果的影响,一些任务受pad的影响很小,那么就不需要修改。
  2. 如果存在维度不匹配的问题,可以考虑按照较大的参数充分pad之后进行Crop,或是将前后两个(0, 0, 1, 1)与(1, 1, 0, 0)的pad合为一个(1, 1, 1, 1),这要看具体的网络结构确定。
  3. 如果是Channel维度上的pad如F.pad(x, (0, 0, 0, 0, 0, channel_pad), "constant", 0),可以考虑零卷积后cat到featuremap上:
zero = nn.Conv2d(in_channels, self.channel_pad, kernel_size=3, padding=1, bias=False)
nn.init.constant(self.zero.weight, 0)
pad_tensor = zero(x)
x = torch.cat([x, pad_tensor], dim=1)
  • 一些操作可以转到Caffe,但om并不支持标准Caffe的所有操作,如果要再转到om要对照文档确认好边界算子。

本文分享自华为云社区《Pytorch->Caffe模型转换》,原文作者:杜甫盖房子 。

点击关注,第一时间了解华为云新鲜技术~

一文带你熟悉Pytorch->Caffe->om模型转换流程相关推荐

  1. 一文带你深入理解JVM内存模型

    一文带你深入理解JVM内存模型 一.JAVA的并发模型 共享内存模型 在共享内存的并发模型里面,线程之间共享程序的公共状态,线程之间通过读写内存中公共状态来进行隐式通信 该内存指的是主内存,实际上是物 ...

  2. yolov5 pt->onnx->om yolov5模型转onnx转om模型转换

    yolov5 pt->onnx->om yolov5-6.1版本 models/yolo.py Detect函数修改 class Detect(nn.Module):def forward ...

  3. 独家 | 一文带你熟悉贝叶斯统计

    作者:Matthew Ward 翻译:陈之炎 校对:陈丹 本文约5000字,建议阅读10+分钟 本文为你带来贝叶斯统计的基础示例及全面解释. 标签:贝叶斯统计 图:Unsplash,Chris Liv ...

  4. 用dos复制文件_一文带你熟悉DOS命令操作,CMD从此不再是路人!

    DOS常用命令: 1. 什么是DOS命令,如何打开dos窗口? A:首先:DOS命令是在DOS窗口输入的一系列命令,通过执行这些命令我们可以完成文件的新建.编辑.保存等操作. 打开DOS命令有两种方式 ...

  5. 一文带你熟悉android的smali语法一

    1.smali必须掌握的关键字 .locals 表示方法内使用的v开口的寄存器个数..prologue 表示方法中代码的开始处..line 表示对应java中的行数..annotation/.end ...

  6. 一文带你熟悉简单实用的Python科学计算库NumPy

    Python科学计算库NumPy 安装 数组的创建 array创建 **arange** 创建 **随机数创建** 方法numpy.random.random(size=None) 方法numpy.r ...

  7. 文带你深入了解 Redis 内存模型

    前言 Redis是目前最火爆的内存数据库之一,通过在内存中读写数据,大大提高了读写速度,可以说Redis是实现网站高并发不可或缺的一部分. 我们使用Redis时,会接触Redis的5种对象类型(字符串 ...

  8. jvm 参数_一文带你深入了解JVM内存模型与JVM参数详细配置

    JVM基本是BAT面试必考的内容,今天我们先从JVM内存模型开启详解整个JVM系列,希望看完整个系列后,可以轻松通过BAT关于JVM的考核. 一.JVM内存结构 由上图可以清楚的看到JVM的内存空间分 ...

  9. 一文带你学会数据库测试核心内容和流程

    的软件应用程序,已经离不开数据库的支持. 无论是在Web应用.桌面应用.客户端服务器.企业和个人业务,都需要自己的数据库在后端操作. 随着现在应用的复杂程度增加,应用需要更强大和安全系数高的数据库才可 ...

最新文章

  1. 为什么使用HashMap需要重写hashcode和equals方法_java常见面试题敲黑板了,HashMap最全的整理,大厂必考...
  2. Spring 的 Controller 是单例还是多例?怎么保证并发的安全
  3. python类装饰器详解-python 中的装饰器详解
  4. 洛谷P1280 caioj 1085 动态规划入门(非常规DP9:尼克的任务)
  5. dos下常用网络相关命令解释
  6. fscokopen php,详解PHP fsockopen的使用方法
  7. sql中的indexof,函数介绍
  8. 学计算机的用surface,11个高效利用Surface处理工作学习任务的方法 - Surface 使用教程...
  9. java8 内存设置_Java 8内存分析
  10. 验证集准确率上不去_Python机器学习之“模型验证”
  11. 一个有趣的Java编译问题
  12. 模式识别的发展及应用
  13. 在线作图p图|图片生成|做图HTML源码
  14. 谈谈两个互联网大佬的「认知革命」
  15. 测试服务器带宽的几种常用方法
  16. 计数oracle,SQL数据透视表子组计数
  17. InfluxDB中Line Protocol理解
  18. 旅游网站毕业设计,旅游网站网页设计设计源码,旅游网站设计毕业论文
  19. get、put、post、delete四大请求的含义与区别个人理解和解释
  20. 解决序列长期依赖的法宝——注意力机制

热门文章

  1. 视觉SLAM十四讲学习笔记-第一讲
  2. 代码编辑框控件_某游戏控件遍历
  3. 信息技术课与计算机课有关系吗,信息技术教学与计算机教学的区别与联系
  4. 华为手机媒体音量自动静音_华为手机还能自动清理垃圾,怪不得手机越用越流畅,学到了...
  5. Python小技巧:使用*解包和itertools.product()求笛卡尔积(转)
  6. EntboostChat 0.9(越狱版)公布,iOS免费企业IM
  7. C++中字符数组和字符串string
  8. Linux下性能测量和调试诊断工具Systemtap
  9. IOHelper(自制常用的输入输出的帮助类)
  10. 【异常】INFO: TopologyManager: EndpointListener changed ...