将编译器pass添加到Relay
编译器pass是扩展Relay功能集和对Relay程序执行优化的主要接口。通过编写编译器pass,可以修改AST或收集有关AST的信息,具体取决于目标。事实上,Relay的一些最重要的内置功能(如autodiff和类型推断),只不过是“标准”编译器pass。
在高层次上,写pass有两个关键组成部分:
创建一个或多个遍历程序的C++类
将遍历实现及元数据包装在pass manager API中,以便可以与pass基础结构完整交互。
首先,将概述编写编译器pass的关键机制。然后,将介绍一个Relay中常量折叠pass的具体示例。
AST遍历器
用于遍历Relay程序的基类是ExprFunctor。提供的公共接口是一个VisitExpr方法,接受一个表达式和零个或多个参数,返回某种类型的实例。扩展此类时,可以通过为每种类型的表达式重写VisitExpr_ f的实现,定义AST遍历模式。
VisitExpr和VisitExpr_间的关系与调度有关。每个VisitExpr_定义都针对特定类型的表达式,但不总是知道要访问的节点类型。为了解决这个问题,ExprFunctor提供了一个VisitExpr函数,该函数从给定的表达式路由到处理VisitExpr_案例。尽管C++已经提供了动态调度,ExpPrimor还是定义了VisteExPR使用的VTe表。通过定义vtable,可以更好地控制调度。例如,如果想定义一个PrintVisitor遍历器,在每次访问前打印“Here”,可以覆盖VisitExpr:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << “Here” << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor本身是一个非常通用的类,这就是为什么经常会扩展ExprVisitor或ExprMutator。这些类扩展了ExprFunctor,提供了VisitExpr_的默认实现,该实现获取每种表达式类型的公共遍历模式。拥有这些默认实现,不同行为的表达式类型提供覆盖实现。在下面的介绍中,将单独描述每个子类。
ExprVisitor
ExprVisitor用于不修改程序,执行程序分析和收集信息的过程。在这个类中,VisitExpr和私有对应项不返回任何内容。此类提供的VisitExpr_实现,只需访问表达式的所有字段即可。IfNode的默认实现如下所示。
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
在这里调用的是VisitExpr,不是VisitExpr,可以使用vtable in ExprFunctor进行路由。
现在,如果想编写一个类调用检查器,检查程序中是否出现任何函数调用,只需要扩展ExprVisitor,定义以下VisitExpr_方法:
void VisitExpr_(const CallNode* n) final {
result_ = true;
}
其中result_是一个字段。在这种情况下,不需要在CallNode的字段上进一步递归,因为result_已经为true,原始表达式包含一个调用。为了使visitor可用,将提供以下公共方法:
bool Check(const Expr& expr) final {
result_ = false;
VisitExpr(expr);
return result_;
}
这就是所需要的。在调用顶级递归前,定义一个公共接口,执行一些bookkeeping记录是非常常见的。当然,可以通过创建一个独立的pass,进一步包装API,该pass创建一个CallChecker实例调用Check,只花了很少的努力就实现了目标。
Expression Mutators
ExprMutator用于以某种方式转换程序的pass。使用该类,VisitExpr及私有对应项返回Expr。此类提供的默认VisitExpr_,实现访问表达式的所有字段,这些字段都是表达式,将这些字段设置为访问结果。TupleGetItemNode的默认实现如下所示。
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef(g);
} else {
return TupleGetItem(t, g->index);
}
}
这里有几件事需要注意。首先,Mutate是ExprMutator中VisitExpr的别名。其次,如果Mutate调用修改了tuple字段,只返回一个新节点。这种更新方法称为功能更新,这样做可以避免不必要的分配。
ExprMutator的一个特性是ExprVisitor没有的,一个用于缓存结果的内置备注字段。ExprMutator有一个memoizer是有道理的,知道正在缓存哪些类型的结果(即Expr),ExprVisitor的访问方法不返回任何内容。通常,当想要将结果缓存在ExprVisitor的子类中时,需要定义缓存。
现在,如果想编写一个类IfCollapser,用真正分支替换每个if语句,将覆盖IfNode的VisitExpr_:
Expr ExprMutator::VisitExpr_(const IfNode* op) {
return this->Mutate(op->true_branch);
}
返回的表达式不一定是IfNode,因为返回类型是Expr。现在,创建公共接口:
Expr CollapseIfs(const Expr& expr) final {
return this->Mutate(expr);
}
有了这个mutator,不需要做任何记录,但仍然希望遵循使用描述性方法,作为接口的惯例。
示例:常量折叠
为了更好地理解编写pass,将以常量折叠pass(见src/relay/transforms/fold_constant.cc)为指导,因为是一个相对简单的过程,包含了两种类型的遍历。
常量折叠涉及计算程序中,只涉及常量值的表达式,然后用计算结果替换这些表达式。此pass的目标是预先加载所有可以进行的计算。为了实现这一点,常量折叠pass使用访客(ConstantChecker)和变异子(ConstantFolder)。
ConstantChecker Visitor
此访问者用于检查表达式是否为常量。在Relay中,如果表达式是常量节点或只有常量字段的元组节点,将定义为常量。
使用一个memo_字段,从节点映射是否为常量,缓存这些结果。以下是ConstantChecker中的VisitExpr_定义。
void VisitExpr_(const ConstantNode* n) final {
memo_[GetRef(n)] = true;
}

void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef(n)] = result;
}
用于协调这些定义的记录是一个检查方法,返回给定表达式是否被视为常量。
bool Check(const Expr& expr) {
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr];
}
不会为遇到的每个节点修改memo_;相反,只在遇到的节点可能是常量时修改memo_。然后,当memo_不包含expr时,依赖于默认值为false。
ConstantFolder Mutator常量折叠变异体
该mutator变异器执行大部分常量折叠pass,在内部使用ConstantChecker。在Relay中,常量折叠涉及三种节点类型:LetNode、TupleItemGetNode和CallNode。在下面的段落中,将解释pass中每个角色的作用。
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef(op);
} else {
return Let(var, value, body);
}
}
}
在LetNode的情况下,首先尝试对表达式中绑定的值进行常量折叠。填充memo_,返回访问主体的结果,将绑定值传播到主体中的使用站点。如果不能将绑定值常量化,将模拟默认实现。
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as();
if (const auto* tuple = op->tuple.as()) {
return tuple->fields[op->index];
} else {
return res;
}
}
在TupleItemGetNode的情况下,检查op->tuple字段是否是TupleNode。用op->index指向的元组字段替换元组get。需要检查的原因是op->tuple可能计算为一个tuple,本身不是tuple。
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttrMap(“TOpIsStateful”);
Expr res = ExprMutator::VisitExpr_(call);
call = res.as();
// We don’t constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
在CallNode的情况下,首先使用ExprMutator的VisitExpr_访问调用,将调用的所有字段折叠起来。使用ExprMutator::VisitExpr_uu而不是VisitExpr,因为希望绕过vtable(避免无限循环),使用ExprMutator提供的默认实现。然后,仅在所有参数都是常量时(使用ConstantChecker)计算调用。对调用求值会产生一个值,因此使用help方法ValueToExpr,将求值表达式放回AST中。
现在,为常量文件夹构造一个更方便的接口FoldConstant。FoldConstant是ConstantFolder类外的一个独立函数,接受一个表达式,在内部创建和使用ConstantFolder实例(完整定义可在src/relay/transforms/fold_constant.cc中找到)。
向pass管理器注册pass
参阅:ref:pass infra上的文档,了解有关此主题的更多详细信息。
编写AST遍历器后,可以使用以下代码,将pass注册为TVM API端点:
namespace transform {

Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, “FoldConstant”, {});
}

} // namespace transform
如果将上述代码生成的Pass对象,提供给Pass基础设施,将确保将AST遍历应用于给定Relay模块中的每个函数,这是常量折叠Pass的预期行为(它应尽可能折叠所有常数)。
函数CreateFunctionPass允许注册pass的优化级别(在本例中为2),该级别可用于根据pass的通用工具、pass名称以及pass的任何依赖项,将pas分组。pass的依赖项以任何pass的列表的形式给出,这些pass的结果是运行当前pass所必需的。FoldConstant没有任何依赖项,但许多Relay pass确实依赖于类型信息,因此InferType是一个常见的依赖项;另一些可能依赖于程序,通过ToANormalForm pass处于A-normal形式。
注意,PassContext对象包含pass用于错误报告和配置选项的信息;FoldConstant不需要此信息,但其它pass可能会引用PassContext对象。
现在可以通过pass基础设施调用pass,不过最好为pass添加一个Python绑定,如下面的代码片段所示:
TVM_REGISTER_GLOBAL(“relay._transform.FoldConstant”)
.set_body_typed(FoldConstant);
一旦以上述方式定义了Pass对象,就可以使用Pass基础设施的顺序构造调用,该构造获取一个Pass列表,按顺序应用于Relay模块,从而获得转换后的模块。例如,下面的代码将FoldConstant和ToANormalForm pass(一个接一个),应用于mod中的每个函数,获得一个新模块。
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
有关注册的更多详细信息,可以在TVM Runtime系统中找到,有关pass manager接口的更多信息可以在pass基础设施中找到。Relay的标准pass在include/tvm/Relay/transform.h中列出,在src/Relay/transforms/中实现。

参考链接:
https://tvm.apache.org/docs/dev/how_to/relay_add_pass.html

将编译器pass添加到Relay相关推荐

  1. 如何将算子添加到Relay

    如何将算子添加到Relay 本文将介绍在Relay中注册新TVM算子所需的步骤.将PR添加累积产品运算示例.PR本身建立在另一个PR的基础上,该PR添加了一个累积和运算. 注册新算子需要几个步骤: 添 ...

  2. 为Sublime Text 3的C++编译器(g++)添加C++11标准的方法

    写在前面 最近熟悉了使用Sublime写C++单文件并编译的方法, 但是美中不足的就是不能使用C++11的新特性, 网上有的方法是修改默认的编译命令, 这个方法需要修改安装目录下的一个文件, 还要解压 ...

  3. pycharm 使用anaconda python编译器时添加available packages 显示nothing to show的解决办法

    点击绿色的按钮刷新下就好了 参考文章:pycharm的project interpretr 安装包时nothing to show

  4. 【从零开始学深度学习编译器】十三,如何在MLIR里面写Pass?

    [GiantPandaCV导语]这篇文章是学习了比较久然后按照自己的理解步骤重新总结了下来,主要是MLIR Toy Tutorials第3,4篇文章的内容.这里主要讲解了如何在MLIR中自定义Pass ...

  5. TVM Relay Pass探究

    引言 Relay 是 TVM 中十分重要的基础组件之一,用于对接不同格式的深度学习模型以及进行模型的 transform.深度学习编译器的核心功能就是进行各种各样的 transform 变换,这个变换 ...

  6. TVM,Relay,Pass

    TVM,Relay,Pass Relay介绍 主要结合TVM的文档(https://tvm.apache.org/docs/dev/relay_intro.html),介绍一下NNVM的第二代Rela ...

  7. TVM Relay与Pass

    TVM Relay与Pass 本文介绍TVM的Relay,如何基于Relay构建一个Conv+BN+ReLU的小网络, TVM中的Pass的工作机制,并较为详细的介绍了RemoveUnusedFunc ...

  8. Qt Creator添加编译器

    Qt Creator添加编译器 添加编译器 重新检测编译器 指定编译器设置 添加Nim编译器 添加自定义编译器 添加编译器 Qt在各种32位和64位平台上受支持,通常可以在每个平台上使用GCC,供应商 ...

  9. android中c文件怎么加logo,c – 如何在CMake中添加“-l”(ell)编译器标志

    在Ubuntu 16上工作 我使用g main.cpp -lpq命令编译我的小项目.现在我使用Clion,并想做同样的事情.但我不能在cmake文件中添加编译器标志并得到编译错误. cmake_min ...

最新文章

  1. 一文看懂Python(五)-----文件篇
  2. Log4net使用详细说明
  3. 东大OJ-1430-PrimeNumbers
  4. tomcat8开启远程debug
  5. (转)贝莱德,从零到五万亿
  6. 计算机安装win10配置,安装Win10系统配置的最低要求
  7. L1-049 天梯赛座位分配 (20分) (C++)
  8. 华为应用市场APP上架流程
  9. PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快
  10. java获取文件后缀_Java获取文件后缀的两种方式
  11. 淘宝客商品推广图片合成(包含二维码、图片、价格)
  12. loachost 1.php,云豹短视频系统搭建部署文档
  13. python/sympy计算施密特正交化向量
  14. 为什么HashMap的key允许空值,而Hashtable却不允许
  15. mysql 根据身份证号码修改出生日期
  16. 服务器显示蜘蛛,解决因服务器而导致的蜘蛛抓取失败
  17. 【玩转嵌入式屏幕显示】(四)TFT-LCD屏幕显示英文字符(ASCII)和字符串
  18. 梦幻西游手游:工坊进阶考试题目攻略—考古、乐艺篇
  19. 影院活动管理系统 项目测试与部署
  20. 全面认识OpenStack架构

热门文章

  1. 2022-2028年中国轻型输送带行业市场发展规模及市场分析预测报告
  2. jieba词性说明字典
  3. /etc/profile ,/etc/bashrc ,~/.bash_profile,~/ .bashrc 区别与联系
  4. tornado压力测试
  5. LeetCode简单题之距离顺序排列矩阵单元格
  6. AI芯片加速图像识别
  7. 自动编码器的评级预测
  8. 人脸标记检测:ICCV2019论文解析
  9. TypeError: Total() missing 1 required positional argument: ‘self‘
  10. python 正则表达质 re.sub() 的使用