【TVM源码学习笔记】2.1 onnx算子转换
在https://blog.csdn.net/zx_ros/article/details/125897256中有调用_get_convert_map获取onnx算子到tvm relay ir的转换接口,以及调用_convert_operator将onnx node转换为tvm relay ir。我们将详细分析其中的转换过程。
1 算子映射表
获取映射表的_get_convert_map接口定义如下:
# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)
def _get_convert_map(opset):return {# defs/experimental"Identity": Renamer("copy"),"Affine": Affine.get_converter(opset),"BitShift": BitShift.get_converter(opset),"ThresholdedRelu": ThresholdedRelu.get_converter(opset),"ScaledTanh": ScaledTanh.get_converter(opset),"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),"Constant": Constant.get_converter(opset),"ConstantOfShape": ConstantOfShape.get_converter(opset),...}
从注释看,当前支持两种映射:
1. onnx算子到tvm算子一对一映射。这种情况是双方算子仅仅名字不同,其他都一致。算子映射接口为Renamer,返回对应的tvm算子表示;再使用AttrCvt将onnx属性转换tvm属性即可;
2.onnx算子在tvm中需要多个算子组合来表示,此时需要实现特定的转换函数。
代码中get_converter即第二种情况。
2 算子转换
在处理onnx节点时,调用_convert_operator将onnx node转换为tvm relay ir。函数实现:
def _convert_operator(self, op_name, inputs, attrs, opset):"""Convert ONNX operator into a Relay operator.The converter must specify conversions explicitly for incompatible name, andapply handlers to operator attributes.Parameters----------op_name : strOperator name, such as Convolution, FullyConnectedinputs : list of tvm.relay.function.FunctionList of inputs.attrs : dictDict of operator attributesopset : intOpset versionReturns-------sym : tvm.relay.function.FunctionConverted relay function"""convert_map = _get_convert_map(opset)if op_name in _identity_list:sym = get_relay_op(op_name)(*inputs, **attrs)elif op_name in convert_map:sym = convert_map[op_name](inputs, attrs, self._params)else:raise NotImplementedError("Operator {} not implemented.".format(op_name))return sym
这里:
1. 首先获取算子映射表;
2. 如果算子在_identity_list表中,调用get_relay_op得到转换后的算子表达;
3. 否则,如果在算子转换映射表中,调用映射接口转换算子;
4. 否则认为转换异常;
5. 返回转换后的表达式。
2.1 _identity_list表和get_relay_op
在python/tvm/relay/frontend/onnx.py中,_identity_list表为空,所以_convert_operator中这个分支是走不到的。所有支持的框架里面,只有mxnet里面该表不为空:
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = ["abs","log","exp","erf","sqrt","floor","ceil","round","trunc","sign","sigmoid","negative","reshape_like","zeros_like","ones_like","cos","cosh","sin","sinh","tan","tanh","where",
]
从注释看是因为这些算子的属性转换限制,才单列了这些算子。get_relay_op函数:
def get_relay_op(op_name):"""Get the callable function from Relay based on operator name.Parameters----------op_name : strThe Relay operator name."""if "." in op_name:# explicit hierarchical modulesop = _optry:for opn in op_name.split("."):op = getattr(op, opn)except AttributeError:op = Noneelse:# try search op in various modulesfor candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):op = getattr(candidate, op_name, None)if op is not None:breakif not op:raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))return op
从注释看是基于算子名称获取一个可调用的函数。
getattr
(object, name[, default])Return the value of the named attribute of object. name must be a string. If the string is the name of one of the object’s attributes, the result is the value of that attribute. For example,
getattr(x, 'foobar')
is equivalent tox.foobar
. If the named attribute does not exist, default is returned if provided, otherwise AttributeError is raised.
因为_op:
from .. import op as _op
所以_op是python/tvm/relay/op模块。这个下面有所有的relay算子,并且做了归类,例如nn,image, vision,contrib。
get_relay_op的if分支检查下传入的op_name是不是用点号形式给出的,比如relay.op.abs;else分支是到nn,image, vision,contrib目录下去找是否有名为op_name的算子。两个分支下,任一找到,都会返回算子的定义接口。所以返回的是跟传入的op_name同名的函数地址。例如op_name为abs时,对应的函数定义(python/tvm/relay/op/tensor.py):
def abs(data):"""Compute element-wise absolute of data.Parameters----------data : relay.ExprThe input dataReturns-------result : relay.ExprThe computed result."""return _make.abs(data)
这里get_relay_op返回了abs()接口(地址)。
所以_convert_operator中get_relay_op(op_name)(*inputs, **attrs)就是调用了_make.abs(*inputs, **attrs),_make.abs()执行的是src/relay/op/op_common.h中lambda函数体
#define RELAY_REGISTER_UNARY_OP(OpName) \TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \static const Op& op = Op::Get(OpName); \return Call(op, {data}, Attrs(), {}); \}); \
详细的调用机制可以参考https://blog.csdn.net/zx_ros/article/details/122931616
回到前面,因为onnx的_identity_list表为空,所以算子转换不会走到get_relay_op。
2.2 get_converter
get_converter是类OnnxOpConverter的方法。而其他各种算子在tvm/relay/frontend/onnx.py中,定义自己的算子转换类时都是继承了OnnxOpConverter。例如:
class Conv(OnnxOpConverter):"""Operator converter for Conv."""@classmethoddef _impl_v1(cls, inputs, attr, params):# Use shape of input to determine convolution type.data = inputs[0]kernel = inputs[1]...
调用的get_converter方法也就是OnnxOpConverter的。我们看下OnnxOpConverter.get_converter的实现:
@classmethoddef get_converter(cls, opset):"""Get converter matches given opset.Parameters----------opset: intopset from model.Returns-------converter, which should be `_impl_vx`. Number x is the biggestnumber smaller than or equal to opset belongs to all support versions."""# 当在继承自OnnxOpConverter的各算子转换类调用get_convertver的时候,这里的cls就是子类本身了。# dir(cls)是获取子类的属性,# for d in dir(cls) if "_impl_v" in d 就是遍历子类的属性,查找名称包含字符串_impl_v的属性和方法.#int(d.replace("_impl_v", ""))是将找到的属性或者方法名中_impl_v部分去掉,并将剩余的部分转换为int类型versions = [int(d.replace("_impl_v", "")) for d in dir(cls) if "_impl_v" in d]# version是一个list,将当前传入的版本号opset加入到version表中,并从小到大排序versions = sorted(versions + [opset])# 遍历versions表,i为表单元序号,v为对应的单元值.找到所有版本号为opset的单元的下标.# 因为表中至少有一个opset, 所以减1就得到的是和opset相等或者仅比opset小的那个版本号的下标.# 所以这里就是找到和opset相等或者比opset小但是最接近opset的版本号version = versions[max([i for i, v in enumerate(versions) if v == opset]) - 1]# 返回该版本的_impl_v方法if hasattr(cls, "_impl_v{}".format(version)):return getattr(cls, "_impl_v{}".format(version))raise NotImplementedError("opset version {} of {} not implemented".format(version, cls.__name__))
因为各算子的转换类定义了多个版本的转换函数,这些函数的函数名都是"_impl_v" + "版本号"的形式。这里get_converter是找到一个最接近但是不高于opset的版本的_impl_v方法,返回该方法的地址,也就是返回一个函数。
2.3 算子转换接口_impl_vxx
每个需要转换的算子都有一个或者多个版本的转换接口。我们以卷积算子为例,Conv类支持的_impl_vx方法:
class Conv(OnnxOpConverter):"""Operator converter for Conv."""@classmethoddef _impl_v1(cls, inputs, attr, params):# Use shape of input to determine convolution type.# 从传入的inputs参数中获取输入和卷积核数据,并推导各自的形状data = inputs[0]kernel = inputs[1]input_shape = infer_shape(data)ndim = len(input_shape)kernel_type = infer_type(inputs[1])kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)]# 如果onnx卷积属性中没有给出卷积核的形状,就使用inputs里面推导出来的形状if "kernel_shape" not in attr:attr["kernel_shape"] = kernel_shapes[0][2:]# 如果onnx卷积算子设置了auto_pad属性if "auto_pad" in attr:# 对用的tvm卷积算子也使用onnx设置的auto_pad属性值attr["auto_pad"] = attr["auto_pad"].decode("utf-8")# 根据auto_pad属性值对数据进行填充处理if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):# Warning: Convolution does not yet support dynamic shapes,# one will need to run dynamic_to_static on this model after import# 对输入数据进行填充,得到填充后的数据data = autopad(data,attr.get("strides", [1] * (ndim - 2)),attr["kernel_shape"],attr.get("dilations", [1] * (ndim - 2)),mode=attr["auto_pad"],)elif attr["auto_pad"] == "VALID":attr["pads"] = [0 for i in range(ndim - 2)]elif attr["auto_pad"] == "NOTSET":passelse:msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))attr.pop("auto_pad")attr["channels"] = kernel_shapes[0][0]out = AttrCvt(# 返回的op_name是一个函数,返回当前算子对应的tvm算子名称.在AttrCvt.__call__方法中调用该函数,根据当前attr中kernel_shape# 属性得到对应的TVM conv1d/conv2d/conv3d算子接口;然后算子接收([data, kernel], attr, params)# 参数, 返回转换后的TVM表示outop_name=dimension_picker("conv"),# 参数转换表transforms={# 当前属性名 : 转换后的属性名"kernel_shape": "kernel_size",# 当前属性名 : (转换后的属性名, 转换后的默认值)"dilations": ("dilation", 1),# 当前属性名 : (转换后的属性名, 转换后的默认值)"pads": ("padding", 0),# 当前属性名 : (转换后的属性名, 转换后的默认值)"group": ("groups", 1),},custom_check=dimension_constraint(),)([data, kernel], attr, params)use_bias = len(inputs) == 3# 如果输入中有偏置参数,则在表达式中添加偏置运算if use_bias:out = _op.nn.bias_add(out, inputs[2])return out
在_impl_v1中对卷积的输入数据,卷积核参数,以及填充做了初步的处理,然后创建一个AttrCvt实例。传入的参数op_name是一个函数,在AttrCvt.__call__方法中会调用该方法,参数为当前卷积的attr。根据attr中的kernel_shape参数,判断当前是1d/2d/3d卷积,得到对应的tvm算子名称conv1d/conv2d/conv3d;传入的transforms参数,用作AttrCvt.__call__中对当前attr和权重参数转换,会转换为tvm的卷积需要的参数形式;custom_check参数用于检查参数,这里对于卷积来说,是检查当前卷积维度是否合法(1d/2d/3d)。
2.4 算子属性转换AttrCvt
AttrCvt.__call__方法大致流程是对参数进行检查,转换,然后调用get_relay_op得到算子对应的tvm接口函数,将当前算子的输入和变换后的参数输入接口,得到onnx node对应的tvm relay ir。
AttrCv的实现:
class AttrCvt(object):def __init__(self,op_name,transforms=None,excludes=None,disables=None,ignores=None,extras=None,custom_check=None,):# 算子的新名字,op_name可以是一个字符串,也可以是一个返回字符串的函数self._op_name = op_name# 属性转换表,表项为属性转换字典,形式为"attr_name : new_attr_name", # 或者"attr_name : (new_name, default_value, transform function)"self._transforms = transforms if transforms else {}# 不允许出现的属性集合,如果出现会抛出异常self._excludes = excludes if excludes else []# 转换后会被disable的属性集合self._disables = disables if disables else []# 转换过程中会被忽略的属性集合self._ignores = ignores if ignores else []# 转换后会被额外返回的属性self._extras = extras if extras else {}# 转换执行的检测函数,返回False会抛出异常self._custom_check = custom_checkdef __call__(self, inputs, attrs, *args):# 忽略待转换算子的这些属性self._ignores.append("_output_shapes")self._ignores.append("_input_shapes")self._ignores.append("T")self._ignores.append("use_cudnn_on_gpu")self._ignores.append("_node_name")self._ignores.append("is_training")self._ignores.append("_target_layout")# apply custom check# 如果算子转换传入了检测函数,则执行该检测函数if self._custom_check:func, msg = self._custom_checkif not func(attrs):raise RuntimeError("Check failed: {}".format(msg))# get new op_name# 得到算子转换后的名字if isinstance(self._op_name, str):op_name = self._op_nameelse:assert callable(self._op_name), "op_name can either be string or callable"op_name = self._op_name(attrs)# ignore 'tvm_custom' always# 忽略tvm_custom属性self._ignores.append("tvm_custom")# convert attributesnew_attrs = {}# 遍历传入的待转换算子的属性for k in attrs.keys():# 如果属性在排除表中, 抛出异常if k in self._excludes:raise NotImplementedError("Attribute %s in operator %s is not" + " supported.", k, op_name)# 如果属性是要求disable的,打印debug日志if k in self._disables:logger.debug("Attribute %s is disabled in relay.sym.%s", k, op_name)# 如果属性是要求忽略的,打印debug日志elif k in self._ignores:if k != "tvm_custom":logger.debug("Attribute %s is ignored in relay.sym.%s", k, op_name)# 如果属性在转换表中elif k in self._transforms:# 从转换表中该属性对应的转换dict,得到属性的新名字,新默认值和转换操作函数# 如果转换表中没有给出转换函数,则将转换函数设置为lambda x: x,也就是直接返回参数new_name, defaults, transform = self._parse_default(self._transforms[k])# 如果没有给出默认值if defaults is None:# 那么必须是"attr_name:new_attr_name"形式,获取新属性名new_attr = self._required_attr(attrs, k)else:# 从原始的属性表中查找该属性的值,如果没找到,则为新属性为Nonenew_attr = attrs.get(k, None)if new_attr is None:# 如果新属性为None,在新的属性表中添加该属性,值为转换表中得到的默认值new_attrs[new_name] = defaultselse:# 在新的属性表中添加该属性,调用转换函数得到新的属性值new_attrs[new_name] = transform(new_attr)else:# copy# 如果属性不在转换表中,直接原封不动的加入新属性表new_attrs[k] = attrs[k]# add extras# 更新额外的属性new_attrs.update(self._extras)# 将输入和新属性表传入算子转换接口,返回转换后tvm relay irreturn get_relay_op(op_name)(*inputs, **new_attrs)
仍然以conv2d为例,这里get_relay_op(conv2d)将返回nn.conv2d。nn.conv2d代码如下:
def conv2d(data,weight,strides=(1, 1),padding=(0, 0),dilation=(1, 1),groups=1,channels=None,kernel_size=None,data_layout="NCHW",kernel_layout="OIHW",out_layout="",out_dtype="",
):if isinstance(kernel_size, int):kernel_size = (kernel_size, kernel_size)if isinstance(strides, int):strides = (strides, strides)if isinstance(dilation, int):dilation = (dilation, dilation)# TODO enforce 4-way padding in topi/nn/conv2d after #4644 merged# convert 2-way padding to 4-way paddingpadding = get_pad_tuple2d(padding)return _make.conv2d(data,weight,strides,padding,dilation,groups,channels,kernel_size,data_layout,kernel_layout,out_layout,out_dtype,)
_make.conv2d会调用到C++代码src/relay/op/nn/convolution_make.h中实现的MakeConv接口:
template <typename T>
inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,Array<IndexExpr> dilation, int groups, IndexExpr channels,Array<IndexExpr> kernel_size, std::string data_layout,std::string kernel_layout, std::string out_layout, DataType out_dtype,std::string op_name) {auto attrs = make_object<T>();attrs->strides = std::move(strides);attrs->padding = std::move(padding);attrs->dilation = std::move(dilation);attrs->groups = groups;attrs->channels = std::move(channels);attrs->kernel_size = std::move(kernel_size);attrs->data_layout = std::move(data_layout);attrs->kernel_layout = std::move(kernel_layout);attrs->out_layout = std::move(out_layout);attrs->out_dtype = std::move(out_dtype);const Op& op = Op::Get(op_name);return Call(op, {data, weight}, Attrs(attrs), {});
}
所以最终是MakeConv返回了卷积算子的tvm relay ir。
_make.conv2d是如何调用到C++的MakeConv可以参考https://blog.csdn.net/zx_ros/article/details/122931616。
【TVM源码学习笔记】2.1 onnx算子转换相关推荐
- 【TVM源码学习笔记】附录1 TVM python调用C++的机制
1. tvm relay op python调用C++ 在tvm前端涉及relay算子(比如说外部框架算子转vm relay ir)的时候,会调用到tvm/relay/op下对应算子的接口,进而调用_ ...
- Java多线程之JUC包:Semaphore源码学习笔记
若有不正之处请多多谅解,并欢迎批评指正. 请尊重作者劳动成果,转载请标明原文链接: http://www.cnblogs.com/go2sea/p/5625536.html Semaphore是JUC ...
- RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的?
RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的? 文章目录 RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的? 前言 项目 ...
- Vuex 4源码学习笔记 - 通过Vuex源码学习E2E测试(十一)
在上一篇笔记中:Vuex 4源码学习笔记 - 做好changelog更新日志很重要(十) 我们学到了通过conventional-changelog来生成项目的Changelog更新日志,通过更新日志 ...
- Vuex 4源码学习笔记 - Vuex是怎么与Vue结合?(三)
在上一篇笔记中:Vuex源码学习笔记 - Vuex开发运行流程(二) 我们通过运行npm run dev命令来启动webpack,来开发Vuex,并在Vuex的createStore函数中添加了第一个 ...
- jquery源码学习笔记三:jQuery工厂剖析
jquery源码学习笔记二:jQuery工厂 jquery源码学习笔记一:总体结构 上两篇说过,query的核心是一个jQuery工厂.其代码如下 function( window, noGlobal ...
- 雷神FFMpeg源码学习笔记
雷神FFMpeg源码学习笔记 文章目录 雷神FFMpeg源码学习笔记 读取编码并依据编码初始化内容结构 每一帧的视频解码处理 读取编码并依据编码初始化内容结构 在开始编解码视频的时候首先第一步需要注册 ...
- Apache log4j-1.2.17源码学习笔记
(1)Apache log4j-1.2.17源码学习笔记 http://blog.csdn.net/zilong_zilong/article/details/78715500 (2)Apache l ...
- PHP Yac cache 源码学习笔记
YAC 源码学习笔记 本文地址 http://blog.csdn.net/fanhengguang_php/article/details/54863955 config.m4 检测系统共享内存支持情 ...
最新文章
- echarts - 条形图grid设置距离绘图区域的距离
- 理解 Neutron FWaaS - 每天5分钟玩转 OpenStack(117)
- Android 系统(75)---Android常用的网路框架
- data fastboot 擦除_fastboot是什么?如何解锁fastboot?
- BigDecimal类型加减乘除运算(Java必备知识)
- excel自动求和_excel工作表的行或列怎么自动求和
- 实战:第十五章:摸爬滚打这些年的心路历程
- Opencv中rect的功能应用
- Fedora 17: 安装 perl-Tk
- SpringCloud学习笔记(一)【Euraka集群搭建】
- linux输入特殊符号密码,linux 输入特殊符号
- arcgis农田图例_ArcGIS在高标准农田建设项目图件制作中的应用
- python获取当前日期_python获取当前的日期和时间
- windox连接电子秤通过COM口获取数据(java)
- 浅谈量子量化股票交易的基本原理
- 医疗软件实施入门02
- zeebe入门课程10-bpmn元素的支持7(exclusive gateway )
- mapreduce多目录输出(MultipleOutputFormat和MultipleOutputs)
- angular中组件changeDetection为ChangeDetectionStrategy.OnPush时的学习
- ASEMI整流模块MDA110-16参数,MDA110-16规格
热门文章
- Bean、BeanDefinition、BeanFactory、FactoryBean
- 计算机应用基础说课方案,广东省“XX杯”说课大赛计算机应用基础类一等奖作品:PPT写字动画的制作教学设计方案.doc...
- npm install 报错:found XXX vulnerabilities (XXX low, X moderate),run `npm audit fix` to fix them, or `
- leetcode-数据结构-566. 重塑矩阵
- 香帅的北大金融学课笔记7 -- 基金业绩
- 华为智慧屏SE55通过FTP远程文件管理-实简FTP v1.6.30
- Smarty核心内容:Smarty基本安装与调试
- Excel 如何让一列中的很多数 同时加上一个数
- 建议收藏:GitHub 上值得收藏的100个精选前端项目!
- 【Kubernetes快速实战】