1. tvm relay op python调用C++

在tvm前端涉及relay算子(比如说外部框架算子转vm relay ir)的时候,会调用到tvm/relay/op下对应算子的接口,进而调用_make.xxx()接口。这个接口最终会调用到C++端对应的算子处理接口。这里我们将探寻Python调用C++的实现。

我们以2D卷积算子为例,对应的接口在python/tvm/relay/op/nn/nn.py中:

def conv2d(...
):...return _make.conv2d(...)

_make的导入:

from . import _make 

也就是python/tvm/relay/op/nn/_make.py:

import tvm._ffitvm._ffi._init_api("relay.op.nn._make", __name__)

这里__name__是一个python内置变量,表示当前模块的文件名(不包括.py),即tvm/relay/op/nn/_make。

tvm._ffi模块位于python/tvm/_ffi。函数_init_api的定义在python/tvm/_ffi/registry.py中

def _init_api(namespace, target_module_name=None):"""Initialize api for a given module namenamespace : strThe namespace of the source registrytarget_module_name : strThe target module name if different from namespace"""target_module_name = target_module_name if target_module_name else namespaceif namespace.startswith("tvm."):_init_api_prefix(target_module_name, namespace[4:])else:_init_api_prefix(target_module_name, namespace)

这里传入的第一个参数namespace为relay.op.nn._make, target_module_name参数为tvm/relay/op/nn/_make。这样传入_init_api_prefix的两个参数将是 tvm.relay.op.nn._make和relay.op.nn._make。

def _init_api_prefix(module_name, prefix):module = sys.modules[module_name]for name in list_global_func_names():if not name.startswith(prefix):continuefname = name[len(prefix) + 1 :]target_module = moduleif fname.find(".") != -1:continuef = get_global_func(name)ff = _get_api(f)ff.__name__ = fnameff.__doc__ = "TVM PackedFunc %s. " % fnamesetattr(target_module, ff.__name__, ff)

module = sys.modules[module_name]获取的是tvm.relay.op.nn._make模块的句柄。list_global_func_names()定义在python/tvm/_ffi/registry.py中:

def list_global_func_names():"""Get list of global functions registered.Returns-------names : listList of global functions names."""plist = ctypes.POINTER(ctypes.c_char_p)()size = ctypes.c_uint()check_call(_LIB.TVMFuncListGlobalNames(ctypes.byref(size), ctypes.byref(plist)))fnames = []for i in range(size.value):fnames.append(py_str(plist[i]))return fnames

接口种通过ctypes方式,调用C++库的TVMFuncListGlobalNames接口,得到的结果字符串数组plist,该数组为所有全局接口的函数名集合。TVMFuncListGlobalNames接口定义在src/runtime/registry.cc中

int TVMFuncListGlobalNames(int* out_size, const char*** out_array) {API_BEGIN();TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get();ret->ret_vec_str = tvm::runtime::Registry::ListNames();ret->ret_vec_charp.clear();for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());}*out_array = dmlc::BeginPtr(ret->ret_vec_charp);*out_size = static_cast<int>(ret->ret_vec_str.size());API_END();
}

函数中调用tvm::runtime::Registry::ListNames()得到函数名表:

std::vector<std::string> Registry::ListNames() {Manager* m = Manager::Global();std::lock_guard<std::mutex> lock(m->mutex);std::vector<std::string> keys;keys.reserve(m->fmap.size());for (const auto& kv : m->fmap) {keys.push_back(kv.first);}return keys;
}

可以看到,函数名都是从Manager类实例的fmap表的第一个元素。而且Manager还是各单实例类。而fmap的定义:

struct Registry::Manager {// map storing the functions.// We deliberately used raw pointer.// This is because PackedFunc can contain callbacks into the host language (Python) and the// resource can become invalid because of indeterministic order of destruction and forking.// The resources will only be recycled during program exit.std::unordered_map<std::string, Registry*> fmap;// mutexstd::mutex mutex;Manager() {}static Manager* Global() {// We deliberately leak the Manager instance, to avoid leak sanitizers// complaining about the entries in Manager::fmap being leaked at program// exit.static Manager* inst = new Manager();return inst;}
};

从注释看,这个fmap是一个存储函数的map表。表单元的第一个元素是string类型。

再看_init_api_prefix中的get_global_func:

def get_global_func(name, allow_missing=False):return _get_global_func(name, allow_missing)def _get_global_func(name, allow_missing=False):handle = PackedFuncHandle()check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))if handle.value:return _make_packed_func(handle, False)if allow_missing:return Noneraise ValueError("Cannot find global function %s" % name)

_get_global_func中使用ctypes方式调用C++库中的TVMFuncGetGlobal函数:

int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {API_BEGIN();const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name);if (fp != nullptr) {*out = new tvm::runtime::PackedFunc(*fp);  // NOLINT(*)} else {*out = nullptr;}API_END();
}const PackedFunc* Registry::Get(const std::string& name) {Manager* m = Manager::Global();std::lock_guard<std::mutex> lock(m->mutex);auto it = m->fmap.find(name);if (it == m->fmap.end()) return nullptr;return &(it->second->func_);
}

TVMFuncGetGlobal调用了Registry::Get,从Manager的fmap表中,找到第一个元素为python传入的函数名的单元,从该单元的第二个元素中获取了函数指针。也就是根据函数名获取函数句柄。

搜索下谁在往fmap成员中写数据,可以看到是Registry::Register接口:

Registry& Registry::Register(const std::string& name, bool can_override) {  // NOLINT(*)Manager* m = Manager::Global();std::lock_guard<std::mutex> lock(m->mutex);if (m->fmap.count(name)) {ICHECK(can_override) << "Global PackedFunc " << name << " is already registered";}Registry* r = new Registry();r->name_ = name;m->fmap[name] = r;return *r;
}

可以看到调用Registry::Register接口接口时,如果name在fmap中不存在,就会创建一个Registry实例,加入Manager的fmap表,并返回新建的Registry实例。搜索Registry::Register接口的调用,在include/tvm/runtime/registry.h中有定义

/*!* \brief Register a function globally.* \code*   TVM_REGISTER_GLOBAL("MyPrint")*   .set_body([](TVMArgs args, TVMRetValue* rv) {*   });* \endcode*/
#define TVM_REGISTER_GLOBAL(OpName) \TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName)

这里调用Registry::Register接口,传入的是一个函数名。在代码中搜索TVM_REGISTER_GLOBAL宏的使用会有很多。这里我们继续关注relay.op.nn._make.conv2d的,搜索到src/relay/op/nn/convolution.cc中代码:

TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d").set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,Array<IndexExpr> dilation, int groups, IndexExpr channels,Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,String out_layout, DataType out_dtype) {return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,kernel_size, data_layout, kernel_layout, out_layout, out_dtype,"nn.conv2d");});

这里set_body_typed的参数为一个lamad表达式,函数体部分调用了MakeConv。所以这里是向Manager的fmap注册了一个函数,名字为relay.op.nn._make.conv2d, 函数体部分是调用MakeConv。

简单的讲,TVM relay op的python到C++调用,就是在C++里,创建一个函数管理表(Manager::fmap),各算子向这个表注册接口,并给每个接口一个标记符;python部分在对应目录下放一个_make.py文件,在这个文件中设置注册的函数标记符和对应的C++函数句柄;当算子转换以标记符调用接口时,就会调用到C++里面的对应的函数体。

我们看下MakeConv的实现:

template <typename T>
inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,Array<IndexExpr> dilation, int groups, IndexExpr channels,Array<IndexExpr> kernel_size, std::string data_layout,std::string kernel_layout, std::string out_layout, DataType out_dtype,std::string op_name) {auto attrs = make_object<T>();attrs->strides = std::move(strides);attrs->padding = std::move(padding);attrs->dilation = std::move(dilation);attrs->groups = groups;attrs->channels = std::move(channels);attrs->kernel_size = std::move(kernel_size);attrs->data_layout = std::move(data_layout);attrs->kernel_layout = std::move(kernel_layout);attrs->out_layout = std::move(out_layout);attrs->out_dtype = std::move(out_dtype);const Op& op = Op::Get(op_name);return Call(op, {data, weight}, Attrs(attrs), {});
}

这里将卷积参数打包,生成一个Op实例,然后生成一个Call实例返回。

python/tvm/relay/op/tensor.py中的数学算符可能会多绕一步,根据算符的类型,先在src/relay/op/tensor目录下的文件中TVM_REGISTER_NODE_TYPE、RELAY_REGISTER_BINARY_OP或者RELAY_REGISTER_UNARY_OP定义,这些宏中再调用TVM_REGISTER_GLOBAL宏,参见src/relay/op/op_common.h

2. tvm relay function python调用C++

在Graphproto.from_onnx的最后,使用网络的输入输出和权重参数打包成一个Function实例,然后生成一个IRModule实例:

# 由模型输入, 输出表达式依赖的权重和输出表达式生成一个function
func = _function.Function([v for k, v in self._inputs.items()], outputs)
# 返回表达式和所有权重
return IRModule.from_expr(func), self._params

这两步也都是会调用到C++代码。先看_function.Function的流程:

@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):def __init__(self, params, body, ret_type=None, type_params=None, attrs=None):if type_params is None:type_params = convert([])self.__init_handle_by_constructor__(_ffi_api.Function, params, body, ret_type, type_params, attrs)def __call__(self, *args):return Call(self, args, None, None)

__init__函数第二个参数body是函数体,而前面在调用_function.Function时传入的时outputs。这是因为outputs并不是网络或者函数的输出张量,而是输出的计算表达式,而且这个表达式描述的是从输入开始,一步一步的到输出的计算过程,也就是函数实现的所有计算过程了。所以这个outputs就是函数体。

__init__中调用了self.__init_handle_by_constructor__,参数_ffi_api.Function这种形式在前面算子调用流程中我们已经分析过,_ffi_api引入的是模块,Function是具体的函数,所以我们看下当前目录下的_ffi_api是什么模块, 见python/tvm/relay/_ffi_api.py:

import tvm._ffitvm._ffi._init_api("relay.ir", __name__)

模块为relay.ir,所以_ffi_api.Function就是relay.ir.Function。

搜索该标记符的注册TVM_REGISTER_GLOBAL("relay.ir.Function")可以看到:

TVM_REGISTER_GLOBAL("relay.ir.Function").set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {return Function(params, body, ret_type, ty_params, attrs);});

也就是调用_ffi_api.Function会在C++端实例化一个Function。

在python的Function类中, _ffi_api.Function是作为参数传给self.__init_handle_by_constructor__,这个方法定义在python/tvm/_ffi/_ctypes/object.py中的基类ObjectBase中,而ObjectBase.__init_handle_by_constructor__调用的是

def __init_handle_by_constructor__(fconstructor, args):"""Initialize handle by constructor"""temp_args = []values, tcodes, num_args = _make_tvm_args(args, temp_args)ret_val = TVMValue()ret_tcode = ctypes.c_int()if (_LIB.TVMFuncCall(fconstructor.handle,values,tcodes,ctypes.c_int(num_args),ctypes.byref(ret_val),ctypes.byref(ret_tcode),)!= 0):raise get_last_ffi_error()_ = temp_args_ = argsassert ret_tcode.value == ArgTypeCode.OBJECT_HANDLEhandle = ret_val.v_handlereturn handle

我们看下TVMFuncCall的调用链

src/runtime/c_runtime_api.cc:int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args,TVMValue* ret_val, int* ret_type_code) {API_BEGIN();TVMRetValue rv;(*static_cast<const PackedFunc*>(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);...}include/tvm/runtime/packed_func.h:inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); }

这里最后调用到的body_就是TVM_REGISTER_GLOBAL("relay.ir.Function").set_typed_body设置的lamabd函数体。

这里比较绕,我们理下:

1. 首先将注册的relay.ir.Function作为参数传给了__init_handle_by_constructor__;

2. __init_handle_by_constructor__调用了_LIB.TVMFuncCall;

3. _LIB.TVMFuncCall相当于一个函数执行器,它执行了relay.ir.Function;

4. relay.ir.Function的函数体被执行时,返回一个C++端的Function对象句柄。

3. tvm relay op IRModule python调用C++

onnx.py中GraphProto.from_onnx最后return IRModule.from_expr(func), self._params,这个from_expr代码在python/tvm/ir/module.py中:

    def from_expr(expr, functions=None, type_defs=None):funcs = functions if functions is not None else {}defs = type_defs if type_defs is not None else {}return _ffi_api.Module_FromExpr(expr, funcs, defs)

这里直接调用_ffi_api.Module_FromExpr,python/tvm/ir/目录定义的模块名为ir(见python/tvm/ir/_ffi_api.py), 搜索对应的函数注册TVM_REGISTER_GLOBAL("ir.Module_FromExpr"),注册函数执行IRModule::FromExpr,FromExpr调用IRModule::FromExprInContext,生成一个C++端的IRModule实例

【TVM源码学习笔记】附录1 TVM python调用C++的机制相关推荐

  1. 【TVM源码学习笔记】2.1 onnx算子转换

    在https://blog.csdn.net/zx_ros/article/details/125897256中有调用_get_convert_map获取onnx算子到tvm relay ir的转换接 ...

  2. Java多线程之JUC包:Semaphore源码学习笔记

    若有不正之处请多多谅解,并欢迎批评指正. 请尊重作者劳动成果,转载请标明原文链接: http://www.cnblogs.com/go2sea/p/5625536.html Semaphore是JUC ...

  3. RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的?

    RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的? 文章目录 RocketMQ 源码学习笔记 Producer 是怎么将消息发送至 Broker 的? 前言 项目 ...

  4. Vuex 4源码学习笔记 - 通过Vuex源码学习E2E测试(十一)

    在上一篇笔记中:Vuex 4源码学习笔记 - 做好changelog更新日志很重要(十) 我们学到了通过conventional-changelog来生成项目的Changelog更新日志,通过更新日志 ...

  5. Vuex 4源码学习笔记 - Vuex是怎么与Vue结合?(三)

    在上一篇笔记中:Vuex源码学习笔记 - Vuex开发运行流程(二) 我们通过运行npm run dev命令来启动webpack,来开发Vuex,并在Vuex的createStore函数中添加了第一个 ...

  6. jquery源码学习笔记三:jQuery工厂剖析

    jquery源码学习笔记二:jQuery工厂 jquery源码学习笔记一:总体结构 上两篇说过,query的核心是一个jQuery工厂.其代码如下 function( window, noGlobal ...

  7. 雷神FFMpeg源码学习笔记

    雷神FFMpeg源码学习笔记 文章目录 雷神FFMpeg源码学习笔记 读取编码并依据编码初始化内容结构 每一帧的视频解码处理 读取编码并依据编码初始化内容结构 在开始编解码视频的时候首先第一步需要注册 ...

  8. Apache log4j-1.2.17源码学习笔记

    (1)Apache log4j-1.2.17源码学习笔记 http://blog.csdn.net/zilong_zilong/article/details/78715500 (2)Apache l ...

  9. PHP Yac cache 源码学习笔记

    YAC 源码学习笔记 本文地址 http://blog.csdn.net/fanhengguang_php/article/details/54863955 config.m4 检测系统共享内存支持情 ...

最新文章

  1. socket编程缓冲区大小对send()的影响
  2. office2007的界面
  3. Kotlin极简教程:第4章 基本数据类型与类型系统
  4. Cookie的特点和作用|| 案例:记住上一次访问时间
  5. 1.14 字符串查找(3种方法)indexOf(), lastlndexOf(), charAt()
  6. SQL server 中SQL语句实战操作
  7. C编程,随机数,排序
  8. OpenCV学习笔记十:hough变换
  9. 每天一道Leetcod或者Codeforce算法系列
  10. Spring 事务相关及@Transactional的使用建议
  11. python读单行文本求平均值_利用Python读取json数据并求数据平均值
  12. 2019无盘游戏服务器128g内存,云更新无盘客户端 v2019.8.15.12486官方版
  13. 浅谈mtk平台手机通过gprs网络连接pc
  14. 关于OpenFOAM的一些学习资料
  15. 查找所有的两个字姓名,中间加个空格(强迫症的福音)
  16. 原生js实现的日期选择插件
  17. 计算机一级中的高级筛选怎么做,详解Excel的高级筛选
  18. 计算机视觉论文-2021-09-06
  19. ​寒武纪思元370系列与飞桨完成II级兼容性测试,联合赋能AI落地实践
  20. 手机里android文件里哪些文件可删除,手机文件夹哪些可以删除?这4个删除它妥妥的...

热门文章

  1. c语言编辑 显示atd数在led上,2012 - 2013 学年第1学期《单片机原理及应用》课程答题纸 1_5.doc...
  2. JAVA提高篇(24)--CharArrayReader、CharArrayWriter简介
  3. AWS 容器三大新品:K8s 发行版,免费镜像库和 “Game Changer”AWS Proton
  4. 计算机专业考MBA有优势吗,工作后考mba有什么好处
  5. 一文详解TVS管应用的正确姿势,不懂的来看看
  6. 洛谷P2713 罗马游戏
  7. VS2019/MFC编程入门——文档、视图和框架:分割窗口
  8. 小程序scroll-view实现左右联动
  9. 如何使用 Lightly 邀请朋友在线协作?
  10. 程序设计c语言高速公路收费标准,C语言 高速公路超速处罚