在https://blog.csdn.net/zx_ros/article/details/125897256中有调用_get_convert_map获取onnx算子到tvm relay ir的转换接口,以及调用_convert_operator将onnx node转换为tvm relay ir。我们将详细分析其中的转换过程。

1 算子映射表

获取映射表的_get_convert_map接口定义如下:

# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
def _get_convert_map(opset):return {# defs/experimental"Identity": Renamer("copy"),"Affine": Affine.get_converter(opset),"BitShift": BitShift.get_converter(opset),"ThresholdedRelu": ThresholdedRelu.get_converter(opset),"ScaledTanh": ScaledTanh.get_converter(opset),"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),"Constant": Constant.get_converter(opset),"ConstantOfShape": ConstantOfShape.get_converter(opset),...}

从注释看,当前支持两种映射:

1. onnx算子到tvm算子一对一映射。这种情况是双方算子仅仅名字不同,其他都一致。算子映射接口为Renamer,返回对应的tvm算子表示;再使用AttrCvt将onnx属性转换tvm属性即可;

2.onnx算子在tvm中需要多个算子组合来表示,此时需要实现特定的转换函数。

代码中get_converter即第二种情况。

2 算子转换

在处理onnx节点时,调用_convert_operator将onnx node转换为tvm relay ir。函数实现:

    def _convert_operator(self, op_name, inputs, attrs, opset):"""Convert ONNX operator into a Relay operator.The converter must specify conversions explicitly for incompatible name, andapply handlers to operator attributes.Parameters----------op_name : strOperator name, such as Convolution, FullyConnectedinputs : list of tvm.relay.function.FunctionList of inputs.attrs : dictDict of operator attributesopset : intOpset versionReturns-------sym : tvm.relay.function.FunctionConverted relay function"""convert_map = _get_convert_map(opset)if op_name in _identity_list:sym = get_relay_op(op_name)(*inputs, **attrs)elif op_name in convert_map:sym = convert_map[op_name](inputs, attrs, self._params)else:raise NotImplementedError("Operator {} not implemented.".format(op_name))return sym

这里:

1. 首先获取算子映射表;

2. 如果算子在_identity_list表中,调用get_relay_op得到转换后的算子表达;

3. 否则,如果在算子转换映射表中,调用映射接口转换算子;

4. 否则认为转换异常;

5. 返回转换后的表达式。

2.1 _identity_list表和get_relay_op

在python/tvm/relay/frontend/onnx.py中,_identity_list表为空,所以_convert_operator中这个分支是走不到的。所有支持的框架里面,只有mxnet里面该表不为空:

# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = ["abs","log","exp","erf","sqrt","floor","ceil","round","trunc","sign","sigmoid","negative","reshape_like","zeros_like","ones_like","cos","cosh","sin","sinh","tan","tanh","where",
]

从注释看是因为这些算子的属性转换限制,才单列了这些算子。get_relay_op函数:

def get_relay_op(op_name):"""Get the callable function from Relay based on operator name.Parameters----------op_name : strThe Relay operator name."""if "." in op_name:# explicit hierarchical modulesop = _optry:for opn in op_name.split("."):op = getattr(op, opn)except AttributeError:op = Noneelse:# try search op in various modulesfor candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):op = getattr(candidate, op_name, None)if op is not None:breakif not op:raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))return op

从注释看是基于算子名称获取一个可调用的函数。

getattr(objectname[, default])

Return the value of the named attribute of objectname must be a string. If the string is the name of one of the object’s attributes, the result is the value of that attribute. For example, getattr(x, 'foobar') is equivalent to x.foobar. If the named attribute does not exist, default is returned if provided, otherwise AttributeError is raised.

因为_op:

from .. import op as _op

所以_op是python/tvm/relay/op模块。这个下面有所有的relay算子,并且做了归类,例如nn,image, vision,contrib。

get_relay_op的if分支检查下传入的op_name是不是用点号形式给出的,比如relay.op.abs;else分支是到nn,image, vision,contrib目录下去找是否有名为op_name的算子。两个分支下,任一找到,都会返回算子的定义接口。所以返回的是跟传入的op_name同名的函数地址。例如op_name为abs时,对应的函数定义(python/tvm/relay/op/tensor.py):

def abs(data):"""Compute element-wise absolute of data.Parameters----------data : relay.ExprThe input dataReturns-------result : relay.ExprThe computed result."""return _make.abs(data)

这里get_relay_op返回了abs()接口(地址)。

所以_convert_operator中get_relay_op(op_name)(*inputs, **attrs)就是调用了_make.abs(*inputs, **attrs),_make.abs()执行的是src/relay/op/op_common.h中lambda函数体

#define RELAY_REGISTER_UNARY_OP(OpName)                                        \TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \static const Op& op = Op::Get(OpName);                                     \return Call(op, {data}, Attrs(), {});                                      \});                                                                          \

详细的调用机制可以参考https://blog.csdn.net/zx_ros/article/details/122931616

回到前面,因为onnx的_identity_list表为空,所以算子转换不会走到get_relay_op。

2.2 get_converter

get_converter是类OnnxOpConverter的方法。而其他各种算子在tvm/relay/frontend/onnx.py中,定义自己的算子转换类时都是继承了OnnxOpConverter。例如:

class Conv(OnnxOpConverter):"""Operator converter for Conv."""@classmethoddef _impl_v1(cls, inputs, attr, params):# Use shape of input to determine convolution type.data = inputs[0]kernel = inputs[1]...

调用的get_converter方法也就是OnnxOpConverter的。我们看下OnnxOpConverter.get_converter的实现:

    @classmethoddef get_converter(cls, opset):"""Get converter matches given opset.Parameters----------opset: intopset from model.Returns-------converter, which should be `_impl_vx`. Number x is the biggestnumber smaller than or equal to opset belongs to all support versions."""# 当在继承自OnnxOpConverter的各算子转换类调用get_convertver的时候,这里的cls就是子类本身了。# dir(cls)是获取子类的属性,# for d in dir(cls) if "_impl_v" in d 就是遍历子类的属性,查找名称包含字符串_impl_v的属性和方法.#int(d.replace("_impl_v", ""))是将找到的属性或者方法名中_impl_v部分去掉,并将剩余的部分转换为int类型versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d]# version是一个list,将当前传入的版本号opset加入到version表中,并从小到大排序versions = sorted(versions + [opset])# 遍历versions表,i为表单元序号,v为对应的单元值.找到所有版本号为opset的单元的下标.# 因为表中至少有一个opset, 所以减1就得到的是和opset相等或者仅比opset小的那个版本号的下标.# 所以这里就是找到和opset相等或者比opset小但是最接近opset的版本号version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1]# 返回该版本的_impl_v方法if hasattr(cls, "_impl_v{}".format(version)):return getattr(cls, "_impl_v{}".format(version))raise NotImplementedError("opset version {} of {} not implemented".format(version, cls.__name__))

因为各算子的转换类定义了多个版本的转换函数,这些函数的函数名都是"_impl_v" + "版本号"的形式。这里get_converter是找到一个最接近但是不高于opset的版本的_impl_v方法,返回该方法的地址,也就是返回一个函数。

2.3 算子转换接口_impl_vxx

每个需要转换的算子都有一个或者多个版本的转换接口。我们以卷积算子为例,Conv类支持的_impl_vx方法:

class Conv(OnnxOpConverter):"""Operator converter for Conv."""@classmethoddef _impl_v1(cls, inputs, attr, params):# Use shape of input to determine convolution type.# 从传入的inputs参数中获取输入和卷积核数据,并推导各自的形状data = inputs[0]kernel = inputs[1]input_shape = infer_shape(data)ndim = len(input_shape)kernel_type = infer_type(inputs[1])kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)]# 如果onnx卷积属性中没有给出卷积核的形状,就使用inputs里面推导出来的形状if "kernel_shape" not in attr:attr["kernel_shape"] = kernel_shapes[0][2:]# 如果onnx卷积算子设置了auto_pad属性if "auto_pad" in attr:# 对用的tvm卷积算子也使用onnx设置的auto_pad属性值attr["auto_pad"] = attr["auto_pad"].decode("utf-8")# 根据auto_pad属性值对数据进行填充处理if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):# Warning: Convolution does not yet support dynamic shapes,# one will need to run dynamic_to_static on this model after import# 对输入数据进行填充,得到填充后的数据data = autopad(data,attr.get("strides", [1] * (ndim - 2)),attr["kernel_shape"],attr.get("dilations", [1] * (ndim - 2)),mode=attr["auto_pad"],)elif attr["auto_pad"] == "VALID":attr["pads"] = [0 for i in range(ndim - 2)]elif attr["auto_pad"] == "NOTSET":passelse:msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))attr.pop("auto_pad")attr["channels"] = kernel_shapes[0][0]out = AttrCvt(# 返回的op_name是一个函数,返回当前算子对应的tvm算子名称.在AttrCvt.__call__方法中调用该函数,根据当前attr中kernel_shape# 属性得到对应的TVM conv1d/conv2d/conv3d算子接口;然后算子接收([data, kernel], attr, params)# 参数, 返回转换后的TVM表示outop_name=dimension_picker("conv"),# 参数转换表transforms={# 当前属性名 : 转换后的属性名"kernel_shape": "kernel_size",# 当前属性名 : (转换后的属性名, 转换后的默认值)"dilations": ("dilation", 1),# 当前属性名 : (转换后的属性名, 转换后的默认值)"pads": ("padding", 0),# 当前属性名 : (转换后的属性名, 转换后的默认值)"group": ("groups", 1),},custom_check=dimension_constraint(),)([data, kernel], attr, params)use_bias = len(inputs) == 3# 如果输入中有偏置参数,则在表达式中添加偏置运算if use_bias:out = _op.nn.bias_add(out, inputs[2])return out

在_impl_v1中对卷积的输入数据,卷积核参数,以及填充做了初步的处理,然后创建一个AttrCvt实例。传入的参数op_name是一个函数,在AttrCvt.__call__方法中会调用该方法,参数为当前卷积的attr。根据attr中的kernel_shape参数,判断当前是1d/2d/3d卷积,得到对应的tvm算子名称conv1d/conv2d/conv3d;传入的transforms参数,用作AttrCvt.__call__中对当前attr和权重参数转换,会转换为tvm的卷积需要的参数形式;custom_check参数用于检查参数,这里对于卷积来说,是检查当前卷积维度是否合法(1d/2d/3d)。

2.4 算子属性转换AttrCvt

AttrCvt.__call__方法大致流程是对参数进行检查,转换,然后调用get_relay_op得到算子对应的tvm接口函数,将当前算子的输入和变换后的参数输入接口,得到onnx node对应的tvm relay ir。

AttrCv的实现:

class AttrCvt(object):def __init__(self,op_name,transforms=None,excludes=None,disables=None,ignores=None,extras=None,custom_check=None,):# 算子的新名字,op_name可以是一个字符串,也可以是一个返回字符串的函数self._op_name = op_name# 属性转换表,表项为属性转换字典,形式为"attr_name : new_attr_name", # 或者"attr_name : (new_name, default_value, transform function)"self._transforms = transforms if transforms else {}# 不允许出现的属性集合,如果出现会抛出异常self._excludes = excludes if excludes else []# 转换后会被disable的属性集合self._disables = disables if disables else []# 转换过程中会被忽略的属性集合self._ignores = ignores if ignores else []# 转换后会被额外返回的属性self._extras = extras if extras else {}# 转换执行的检测函数,返回False会抛出异常self._custom_check = custom_checkdef __call__(self, inputs, attrs, *args):# 忽略待转换算子的这些属性self._ignores.append("_output_shapes")self._ignores.append("_input_shapes")self._ignores.append("T")self._ignores.append("use_cudnn_on_gpu")self._ignores.append("_node_name")self._ignores.append("is_training")self._ignores.append("_target_layout")# apply custom check# 如果算子转换传入了检测函数,则执行该检测函数if self._custom_check:func, msg = self._custom_checkif not func(attrs):raise RuntimeError("Check failed: {}".format(msg))# get new op_name# 得到算子转换后的名字if isinstance(self._op_name, str):op_name = self._op_nameelse:assert callable(self._op_name), "op_name can either be string or callable"op_name = self._op_name(attrs)# ignore 'tvm_custom' always# 忽略tvm_custom属性self._ignores.append("tvm_custom")# convert attributesnew_attrs = {}# 遍历传入的待转换算子的属性for k in attrs.keys():# 如果属性在排除表中, 抛出异常if k in self._excludes:raise NotImplementedError("Attribute %s in operator %s is not" + " supported.", k, op_name)# 如果属性是要求disable的,打印debug日志if k in self._disables:logger.debug("Attribute %s is disabled in relay.sym.%s", k, op_name)# 如果属性是要求忽略的,打印debug日志elif k in self._ignores:if k != "tvm_custom":logger.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)# 如果属性在转换表中elif k in self._transforms:# 从转换表中该属性对应的转换dict,得到属性的新名字,新默认值和转换操作函数# 如果转换表中没有给出转换函数,则将转换函数设置为lambda x: x,也就是直接返回参数new_name, defaults, transform = self._parse_default(self._transforms[k])# 如果没有给出默认值if defaults is None:# 那么必须是"attr_name:new_attr_name"形式,获取新属性名new_attr = self._required_attr(attrs, k)else:# 从原始的属性表中查找该属性的值,如果没找到,则为新属性为Nonenew_attr = attrs.get(k, None)if new_attr is None:# 如果新属性为None,在新的属性表中添加该属性,值为转换表中得到的默认值new_attrs[new_name] = defaultselse:# 在新的属性表中添加该属性,调用转换函数得到新的属性值new_attrs[new_name] = transform(new_attr)else:# copy# 如果属性不在转换表中,直接原封不动的加入新属性表new_attrs[k] = attrs[k]# add extras# 更新额外的属性new_attrs.update(self._extras)# 将输入和新属性表传入算子转换接口,返回转换后tvm relay irreturn get_relay_op(op_name)(*inputs, **new_attrs)

仍然以conv2d为例,这里get_relay_op(conv2d)将返回nn.conv2d。nn.conv2d代码如下:

def conv2d(data,weight,strides=(1, 1),padding=(0, 0),dilation=(1, 1),groups=1,channels=None,kernel_size=None,data_layout="NCHW",kernel_layout="OIHW",out_layout="",out_dtype="",
):if isinstance(kernel_size, int):kernel_size = (kernel_size, kernel_size)if isinstance(strides, int):strides = (strides, strides)if isinstance(dilation, int):dilation = (dilation, dilation)# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged# convert 2-way padding to 4-way paddingpadding = get_pad_tuple2d(padding)return _make.conv2d(data,weight,strides,padding,dilation,groups,channels,kernel_size,data_layout,kernel_layout,out_layout,out_dtype,)

_make.conv2d会调用到C++代码src/relay/op/nn/convolution_make.h中实现的MakeConv接口:

template <typename T>
inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,Array<IndexExpr> dilation, int groups, IndexExpr channels,Array<IndexExpr> kernel_size, std::string data_layout,std::string kernel_layout, std::string out_layout, DataType out_dtype,std::string op_name) {auto attrs = make_object<T>();attrs->strides = std::move(strides);attrs->padding = std::move(padding);attrs->dilation = std::move(dilation);attrs->groups = groups;attrs->channels = std::move(channels);attrs->kernel_size = std::move(kernel_size);attrs->data_layout = std::move(data_layout);attrs->kernel_layout = std::move(kernel_layout);attrs->out_layout = std::move(out_layout);attrs->out_dtype = std::move(out_dtype);const Op& op = Op::Get(op_name);return Call(op, {data, weight}, Attrs(attrs), {});
}

所以最终是MakeConv返回了卷积算子的tvm relay ir。

_make.conv2d是如何调用到C++的MakeConv可以参考https://blog.csdn.net/zx_ros/article/details/122931616。

【TVM源码学习笔记】2.1 onnx算子转换相关推荐

  1. 【TVM源码学习笔记】附录1 TVM python调用C++的机制

    1. tvm relay op python调用C++ 在tvm前端涉及relay算子(比如说外部框架算子转vm relay ir)的时候,会调用到tvm/relay/op下对应算子的接口,进而调用_ ...

  2. Java多线程之JUC包:Semaphore源码学习笔记

    若有不正之处请多多谅解,并欢迎批评指正. 请尊重作者劳动成果,转载请标明原文链接: http://www.cnblogs.com/go2sea/p/5625536.html Semaphore是JUC ...

  3. RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的?

    RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的? 文章目录 RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的? 前言 项目 ...

  4. Vuex 4源码学习笔记 - 通过Vuex源码学习E2E测试(十一)

    在上一篇笔记中:Vuex 4源码学习笔记 - 做好changelog更新日志很重要(十) 我们学到了通过conventional-changelog来生成项目的Changelog更新日志,通过更新日志 ...

  5. Vuex 4源码学习笔记 - Vuex是怎么与Vue结合?(三)

    在上一篇笔记中:Vuex源码学习笔记 - Vuex开发运行流程(二) 我们通过运行npm run dev命令来启动webpack,来开发Vuex,并在Vuex的createStore函数中添加了第一个 ...

  6. jquery源码学习笔记三:jQuery工厂剖析

    jquery源码学习笔记二:jQuery工厂 jquery源码学习笔记一:总体结构 上两篇说过,query的核心是一个jQuery工厂.其代码如下 function( window, noGlobal ...

  7. 雷神FFMpeg源码学习笔记

    雷神FFMpeg源码学习笔记 文章目录 雷神FFMpeg源码学习笔记 读取编码并依据编码初始化内容结构 每一帧的视频解码处理 读取编码并依据编码初始化内容结构 在开始编解码视频的时候首先第一步需要注册 ...

  8. Apache log4j-1.2.17源码学习笔记

    (1)Apache log4j-1.2.17源码学习笔记 http://blog.csdn.net/zilong_zilong/article/details/78715500 (2)Apache l ...

  9. PHP Yac cache 源码学习笔记

    YAC 源码学习笔记 本文地址 http://blog.csdn.net/fanhengguang_php/article/details/54863955 config.m4 检测系统共享内存支持情 ...

最新文章

  1. echarts - 条形图grid设置距离绘图区域的距离
  2. 理解 Neutron FWaaS - 每天5分钟玩转 OpenStack(117)
  3. Android 系统(75)---Android常用的网路框架
  4. data fastboot 擦除_fastboot是什么?如何解锁fastboot?
  5. BigDecimal类型加减乘除运算(Java必备知识)
  6. excel自动求和_excel工作表的行或列怎么自动求和
  7. 实战:第十五章:摸爬滚打这些年的心路历程
  8. Opencv中rect的功能应用
  9. Fedora 17: 安装 perl-Tk
  10. SpringCloud学习笔记(一)【Euraka集群搭建】
  11. linux输入特殊符号密码,linux 输入特殊符号
  12. arcgis农田图例_ArcGIS在高标准农田建设项目图件制作中的应用
  13. python获取当前日期_python获取当前的日期和时间
  14. windox连接电子秤通过COM口获取数据(java)
  15. 浅谈量子量化股票交易的基本原理
  16. 医疗软件实施入门02
  17. zeebe入门课程10-bpmn元素的支持7(exclusive gateway )
  18. mapreduce多目录输出(MultipleOutputFormat和MultipleOutputs)
  19. angular中组件changeDetection为ChangeDetectionStrategy.OnPush时的学习
  20. ASEMI整流模块MDA110-16参数,MDA110-16规格

热门文章

  1. Bean、BeanDefinition、BeanFactory、FactoryBean
  2. 计算机应用基础说课方案,广东省“XX杯”说课大赛计算机应用基础类一等奖作品:PPT写字动画的制作教学设计方案.doc...
  3. npm install 报错:found XXX vulnerabilities (XXX low, X moderate),run `npm audit fix` to fix them, or `
  4. leetcode-数据结构-566. 重塑矩阵
  5. 香帅的北大金融学课笔记7 -- 基金业绩
  6. 华为智慧屏SE55通过FTP远程文件管理-实简FTP v1.6.30
  7. Smarty核心内容:Smarty基本安装与调试
  8. Excel 如何让一列中的很多数 同时加上一个数
  9. 建议收藏:GitHub 上值得收藏的100个精选前端项目!
  10. 【Kubernetes快速实战】