理论

预备知识

  • TIR Let Binding

    • Let (var, value, body) 将value求值赋给var,然后返回body的求值结果。let将表达式 Expr 绑定到局部作用域的不可变 var 中。
  • scope:作用域/代码块。 TVM AST将作用域组建为树形结构。外侧作用域包含内侧作用域的关系被表示为父节点和子节点。父节点的变量对子节点可见但反之不然。判断变量是否在某个作用域内是CSE算法的一个重要的部分。

是什么

TIR代码来源于relay和其他Pass。生成TIR代码的过程是自动的,因此会有很多重复。Common Subexpression Elimination (CSE,公共子表达式消除) 是TIR的Pass之一,旨在定位并替换重复的计算。

  • 创建一个新变量并替换所有的表达式。
  • 支持完整的表达式替换。
  • 支持子表达式替换。比如(w+x)+(y+z); (w+x)+u; => new_var = (w+x); new_var+(y+z); new_var+u;(w+x)就是子表达式。

原理

前提

TIR的SSA(Static Single Assignment)性质:变量的值不变(immutable)。如果没有这个前提,替换就会出问题。比如拿y=a+b替换所有的a+b,但是a的值在某处被修改了,那么之后的y=a+b就变成了y=0+b,如果不重新计算a+b的值,就会有错。

定位重复的表达式

筛选出候选的表达式

  • 表达式不是常量或变量(已经是变量了就没必要创一个新变量去替换了)
  • 表达式不是function call或者memory load
    • 函数不一定是pure的。对于有副作用(side effects)的函数,即使函数名和参数一样,返回的结果可能不同。
    • 同理,两次memeory load返回的结果可能不同。如果两个相同的表达式中间出现了一次memory load,则这个表达式不能作为候选表达式。
  • 表达式也不包含(子)function call或者memory load。
    • 替换sum(f,f)也是不安全的,因为f可能有side effects。

判断候选表达式是否应该继续处理
表达式所使用的变量必须在当前的scope下频繁出现。

对于不满足以上条件的表达式,递归地考虑它的子表达式。
比如(w+x)*(y+z)包含的(w+x)(y+z)

数据结构

  • Context:上下文,vector<pair<Var,MaybeValue>>

    • 知道哪个变量在当前的scope下
  • table of computations:表达式计数表,unordered_map。
    • key是表达式PrimExpr,例如在Stmtbuffer[i1] = ((x + y) + z)中,((x + y) + z)(x+y)都是PrimExpr,后者可以由Visit而其子表达式
    • value是其出现的次数。

创建新变量
考虑到新变量的表达式之间可能会有包含关系,需要将表示长表达式的新变量放在在let作用域内部。短的表达式的变量则在外部。

如图,外侧smallComp可以是y=a+b, z=d+e,内侧bigComp则可以是p=y+z。

缺点与改进

  • 未来可以支持丰富的语义结构。比如(x+y)+z <=> z+(x+y)
  • 区分出side effects的函数,以便进行更深的优化。

参考资料

TVM Conference 2021 Qualcomm
TVM 拆包(一):Runtime basics

代码实现

先来看tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse:

@main = primfn(i1: int32, i2: int32, z3: int32) -> () {let z1: int32 = 1let z2: int32 = 2{buffer: Pointer(int32)[i1] = (z1 + z2)let x: int32 = 1let y: int32 = 1let a: int32 = ((x + y) + (z1 + z2))let b: int32 = ((x + y) + z3)buffer[i2] = (a + b)}
}[15:01:55] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, z3: int32) -> () {let z1: int32 = 1let z2: int32 = 2let cse_var_1: int32 = (z1 + z2){buffer: Pointer(int32)[i1] = cse_var_1let x: int32 = 1let y: int32 = 1let cse_var_2: int32 = (x + y)let a: int32 = (cse_var_2 + cse_var_1)let b: int32 = (cse_var_2 + z3)buffer[i2] = (a + b)}
}

观察到,CSE生成了两个变量cse_var_1=z1+z2, cse_var_2=x+y,代替了相应的表达式。替换规则:z1+z2,x+y两次出现在同一scope下,且没有关于变量的load操作。

test_tir_transform_common_subexpr_elim.py::test_cse_cascade:

yuan@yuan:~/Coding/compiler/repos/tvm$ python -m pytest /home/yuan/Coding/compiler/repos/tvm/tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse_cascade -s
enabled targets: llvm; cuda; nvptx
pytest marker:
====================================================================== test session starts ======================================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/yuan/Coding/compiler/repos/tvm
collected 1 item                                                                                                                                                tests/python/unittest/test_tir_transform_common_subexpr_elim.py @main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {buffer: Pointer(int32)[i1] = ((x + y) + z)buffer[i2] = ((x + y) + z)buffer[i3] = (x + y)
}[15:17:37] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {let cse_var_2: int32 = (x + y)let cse_var_1: int32 = (cse_var_2 + z){buffer: Pointer(int32)[i1] = cse_var_1buffer[i2] = cse_var_1buffer[i3] = cse_var_2}
}

替换规则对应了之前讲到的内容:长表达式在内,短表达式在外。

看完了样例,我们对照第二个样例分析源码的执行过程。在递归入口,函数Input和Output,计数哈希表处插入日志来观察调用逻辑:

yuan@yuan:~/Coding/compiler/repos/tvm$ python -m pytest /home/yuan/Coding/compiler/repos/tvm/tests/python/unittest/test_tir_transform_common_subexpr_elim.py::test_cse_cascade -s
enabled targets: llvm; cuda; nvptx
pytest marker:
============================================================= test session starts ==============================================================
platform linux -- Python 3.8.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /home/yuan/Coding/compiler/repos/tvm
collected 1 item                                                                                                                               tests/python/unittest/test_tir_transform_common_subexpr_elim.py @main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {buffer: Pointer(int32)[i1] = ((x + y) + z)buffer[i2] = ((x + y) + z)buffer[i3] = (x + y)
}[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/runtime/logging.cc:239: TVM_LOG_DEBUG enables VLOG statements in 'tir/transforms/common_subexpr_elim.cc' up to level 1
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 2)
((x + y), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:493: variables_created true
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
let cse_var_1 = ((x + y) + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = (x + y)[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 1)
((x + y), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:493: variables_created true
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
((x + y), 1)
((cse_var_2 + z), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
((cse_var_2 + z), 1)
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = cse_var_1[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i1] = cse_var_1[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i2] = cse_var_1[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i2] = cse_var_1[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
}
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:499: variables_created false
[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:502: Output: result=let cse_var_2 = (x + y)
let cse_var_1 = (cse_var_2 + z)
buffer[i1] = cse_var_1
buffer[i2] = cse_var_1
buffer[i3] = cse_var_2[17:08:32] /home/yuan/Coding/compiler/repos/tvm/src/ir/transform.cc:566: PrintIR():
#[version = "0.0.5"]
@main = primfn(i1: int32, i2: int32, i3: int32, x: int32, y: int32, z: int32) -> () {let cse_var_2: int32 = (x + y)let cse_var_1: int32 = (cse_var_2 + z){buffer: Pointer(int32)[i1] = cse_var_1buffer[i2] = cse_var_1buffer[i3] = cse_var_2}
}


按照Log的输出,模拟了一个执行流程。方框内包含了原始的计算图。按照从下往上的顺序,依次加入了cse_var_1, cse_var_2。而在遍历子节点时,又从上往下使用DFS。

Context更新/调用入口

PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {// At this point, we have already done the generic treatment of introducing (via let-in) what// was doable at the toplevel of the given let-in.// Save the context at the entry of the functionContext context_at_entry = context_;// Recurse on the `value` field for potentially rewriting itPrimExpr value_new = VisitExpr(op->value);// Augment the context with the association (`var`, `value`) for preparing the next recursion// on the `body`context_.push_back({op->var, MaybeValue(op->value)});// Recurse on the `body` (with this extended context)// The recursive call will have potentially done new simplifications, because in this recursive// call `var` will be a part of the context.// (see in VisitExpr() that no introduction were performed when a computation was using an// undefined variable, as that would lead to ill-formed code)PrimExpr body_new = VisitExpr(op->body);// Restaure the context to its content at the entrance to not carry out of scope declarations// as the variable introduced by the let-in is not in scope outside of its bodycontext_ = context_at_entry;// Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might// have been done.// If the `value` and the `body` of the let-in have been rewritten to the same thingif (value_new.same_as(op->value) && body_new.same_as(op->body)) {// then return a reference to the same nodereturn GetRef<PrimExpr>(op);} else {// Otherwise return a let-in built with the new `value_new` and the new `body_new` that// have just been obtainedreturn Let(op->var, value_new, body_new, op->span);}
}

CommonSubexpressionEliminator::VisitExpr_(const LetNode* op)作为调用VisitExpr的函数入口,做了两件事:

  • 先递归遍历Var
  • 将let表达式包含的变量加入上下文,准备下一步递归遍历body
  • 递归遍历body,此时let包含的变量存储在context中,因此在body中可见
  • 还原context

候选表达式规则

bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {return (// In order to be eligible, the given expression should not be a constant(expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&(expr.as<StringImmNode>() == nullptr)// and it should not be a variable&& (expr.as<VarNode>() == nullptr)// and it should not be a forbidden computation (function calls and loads)&& (!ForbiddenComputation(expr))// and it should not even contain a forbidden computation (function calls and loads)// the reason is that we don't want to register expressions like (x + f(y)) or// (x + Mem[i]) as introducing them into variables could change the semantics&& (!CheckContains::ExprContains(expr, ForbiddenComputation))// and it should not be a ramp node or a broadcast node due to some internals TVM// constraints (which check for these node explicitely without performing any// evaluation first, so if they have been put into variables it fails)&& (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
}

关于ComputationTable

  // Obtain the (syntactic) eligible computations done by the input statement, and keep it as// a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the// number of time this exact syntactic computation is being computed.ComputationTable table_syntactic_comp_done_by_stmt = ComputationsDoneBy::GetComputationsDoneBy(stmt, IsEligibleComputation, CanContainEligibleComputations);  ...std::unordered_map<Stmt, ComputationTable, ObjectPtrHash, ObjectPtrEqual>cache_stmt_table_computations_;...void ComputationsDoneBy::VisitStmt(const Stmt& stmt) {// See if we have already computed the (table of) computations done by `stmt`auto it_table_stmt = cache_.cache_stmt_table_computations_.find(stmt);if (it_table_stmt != cache_.cache_stmt_table_computations_.end()) {// We need to do the union with `table_of_computations_` instead of just writing into it,// because some other childs might have added things into it too. The reason for that is// that `table_of_computations_` is shared between the child nodes of a given statement.UnionOfComputationTables(&table_of_computations_, it_table_stmt->second);return;}// If we reach this point, it means that we have never computed before the computations done// by `stmt` and will do so now.// The computations done by a Stmt node are just the ones done by its childrenComputationTable temp =ComputationsDoneByChildrenOf(stmt, is_eligible_computation_, can_contain_computations_);// We need to do the union with `table_of_computations_` instead of just writing into it,// because some other childs might have added things into it too. The reason for that is// that `table_of_computations_` is shared between the child nodes of a given expression.UnionOfComputationTables(&table_of_computations_, temp);}

对于每个Stmt,程序保存一个unordered_map<Stmt, ComputationTable>。使用时,拿着Stmt查找对应的ComputationTable。

ComputationTable ComputationsDoneBy::ComputationsDoneByChildrenOf(const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation,std::function<bool(const PrimExpr&)> can_contain_computations) {// We will be using an instance of the class ComputationsDoneBy for the child nodes// (ie, they will share the "result" that `table_of_computations_` is)ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations);// Calls the *dispatcher* (not the overriden method)computations_done_by.StmtExprVisitor::VisitStmt(stmt);// So now we can copy table_of_computations_ into the cache for the future queries// Note : in the table, the computations done by `stmt` is set to the computations done by its// children, because that's exactly what we mean by "the computations of a statement".cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_;return computations_done_by.table_of_computations_;
}

ComputationsDoneByChildrenOf和ComputationsDoneBy::VisitStmt实际上是互相递归调用的,因为

  // Calls the *dispatcher* (not the overriden method)computations_done_by.StmtExprVisitor::VisitStmt(stmt);


StmtExprVisitor::VisitStmt的作用是递归地访问Stmts和它的表达式。
可以模拟执行顺序:
ComputationsDoneBy::VisitStmt 传入Stmt -> ComputationsDoneByChildrenOf -> computations_done_by.StmtExprVisitor::VisitStmt(stmt) 传入下一个Stmt -> ComputationsDoneBy::VisitStmt -> …

为了方便理解,在ComputationsDoneByChildrenOf里打印日志:

[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:369: Input Stmt :
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i1] = ((x + y) + z)
buffer[i2] = ((x + y) + z)
buffer[i3] = (x + y)[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i1] = ((x + y) + z)[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i2] = ((x + y) + z)[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim_tools.cc:555: Recursively calling child node:
buffer[i3] = (x + y)[22:09:45] /home/yuan/Coding/compiler/repos/tvm/src/tir/transforms/common_subexpr_elim.cc:379: ComputationTable :
{
(((x + y) + z), 2)
((x + y), 1)
}

扩展:语法转语义
在之前的“缺点与改进“中提到。SyntacticToSemanticComputations和EquivalentTerms等等是单独的一个模组。

目的是支持:

  • 交换律 (x+y <=> y+x)
  • 结合律 (x+y)+z <=> x+(y+z)
  • 分配律 x*(y+z) <=> x*y+x*z
// Transform the hashtable of *syntactic* eligible computations into a vector of pairs
// containing *semantic* entities, i.e. where equivalent computations are merged.
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt);
...
/*!* \brief Decides if two terms are equivalent semantically*/
bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {// For now, we just check the syntactic equality, but that could later become a semantic test,// for instance identifying computations modulo commutativity (like x+y and y+x), or modulo// associativity (like (x+y)+z and x+(y+z)), etc.arith::Analyzer analyser;PrimExpr a_simplified = analyser.Simplify(a);PrimExpr b_simplified = analyser.Simplify(b);return EqualTerms(a_simplified, b_simplified);
}

上面是我在EquivalentTerms上做的改动,tvm/arith下支持一部分语义分析。我也和开发者讨论过一次。结果是虽然不能完全覆盖所有情况但是聊胜于无。

创建新变量
按照从长到短的规则,对数据结构按照其长度(Complexity)降序排序。排序后遍历,并用Let语句包含。长的表达式会在内侧,然后被后来添加的短表达式所使用的Let语句块覆盖。

  std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {return (CalculateExprComplexity(a.first) > CalculateExprComplexity(b.first));});for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];...Var new_var = GenerateNewVar(computation_and_nb.first.dtype());...result = Let(new_var, computation_and_nb.first, result);}

用新变量替换所有Occurence

// Replace in the current `result` everything that is selected by the selector with
// the new variable, without diving into expressions in which we don't have the
// right to dive.
result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(result, predicate_selector, new_var,CanContainEligibleComputations);
...
class ReplaceSelectedExpr : public StmtExprMutator
...
PrimExpr ReplaceSelectedExpr::VisitExpr(const PrimExpr& expr) {// If the current expression is selected by the predicateif (predicate_selector_(expr)) {// Then simply return the new expressionreturn new_expr_;} else {// If replacing inside the current expression is allowedif (can_replace_inside_(expr)) {// then we continue the exploration recursivelyreturn StmtExprMutator::VisitExpr(expr);} else {// otherwise we simply return the current expressionreturn expr;}}
}

分清况讨论:

  • 如果当前表达式被predicate_selector_选中,则返回new_expr_。这里完成了代替过程。
  • 如果当前表达式不被predicate_selector_选中,StmtExprMutator::VisitExpr(expr);将对其子表达式递归调用ReplaceSelectedExpr::VisitExpr。ComputationTable中也有类似的调用逻辑。
  • 否则不做改动,返回expr自己。

扩展:pure属性的函数
最早的Commit禁止了函数的优化。原因在之前已经说过了。但是更深的优化可以利用函数是否纯,来决定是否可以进行替换。原作者FrankQC认为,决定优化的规则有两个,一是函数是否对同一组参数有相同的输出,二是函数是否修改了外部状态。在笔者写这篇文章的时候,函数的性质可以注册为:

enum class CallEffectKind : int {/*! \brief Function corresponds to an annotation(e.g. likely) and can translate to identity. */kExprAnnotation = 0,/*!* \brief Pure function that do not interacts*        with any external state.*/kPure = 1,/*!* \brief Function's that may read from states(e.g. RAM)*/kReadState = 2,/*!* \brief Function that may read/write from states(e.g. RAM).*/kUpdateState = 3,/*!* \brief Opaque function, cannot make any assumption*/kOpaque = kUpdateState,/*!* \brief Special intrinsic to annotate call arguments info*        only valid as a direct argument to a call.*/kSpecialCallArg = 4,/*!* \brief Embed opaque information in the Expr, cannot be codegen.*/kEmbedInfo = 5,/*!* \brief Function that changes control flow*/kControlJump = 6,
};

被注册为kPure=1的函数,其语义为“do not interacts with any external state"。故不能草率的认为系统内所有kPure=1的函数均为纯。如果要引入新的性质,比如kDeterminant=1以表示第一条规则,则至少需要所有后端开发达成共识,因为后端开发者注册算子时需要标注函数性质,比如src/tir/op/builtin.cc:

TIR_DEFINE_BUILTIN_FUNC(shift_left).set_num_inputs(2).set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)).set_attr<TVectorizable>("TVectorizable", true);

参见此处的讨论。

后记

CSE是高通公司的contribute,代码注释非常详细且论坛上有大量讨论的内容,值得仔细地研究。

TVM TIR Pass - CSE (Common Subexpression Elimination) 优化原理和代码解析 PR#9482相关推荐

  1. [LK光流法,disflow using Dense Inverse Search, VariationalRefinement变分优化 原理和代码]

    文章目录 1.Fast Optical Flow using Dense Inverse Search 1.1 W的含义: 1.2 LK光流模型 1.3 LK光流模型求解(不含迭代) 1.4 LK光流 ...

  2. 「日常训练」Common Subexpression Elimination(UVa-12219)

    今天做的题目就是抱佛脚2333 懂的都懂. 这条题目干了好几天,最后还是参考别人的代码敲出来了,但是自己独立思考了两天多,还是有收获的. 思路分析 做这条题我是先按照之前的那条题目(The SetSt ...

  3. VINS紧耦合优化公式及代码解析

    1.首先确定待优化的状态变量 对应代码,优化参数为: Vector3d Ps[(WINDOW_SIZE + 1)];(平移向量) Vector3d Vs[(WINDOW_SIZE + 1)];(速度) ...

  4. TVM Relay Pass探究

    引言 Relay 是 TVM 中十分重要的基础组件之一,用于对接不同格式的深度学习模型以及进行模型的 transform.深度学习编译器的核心功能就是进行各种各样的 transform 变换,这个变换 ...

  5. TVM darknet yolov3算子优化与量化代码的配置方法

    TVM darknet yolov3算子优化与量化代码的配置方法 使用以下接口函数  tvm.relay.optimize  quantize.quantize 实际代码: convert nnv ...

  6. TVM TIR 各种Node类型举例

    TVM TIR中有很多类型的Node.这里收录了全部可打印的类型: StmtNode AnyNode PrimExprNode TypeNode PrimFuncNode IRModuleNode A ...

  7. TensorFlow XLA优化原理与示例

    TensorFlow XLA优化原理与示例 XLA概述 XLA(加速线性代数)是用于优化TensorFlow计算的线性代数的域特定编译器.结果是在服务器和移动平台上的速度,内存使用率和可移植性得到了改 ...

  8. deeplearning算法优化原理

    deeplearning算法优化原理 目录 • 量化原理介绍 • 剪裁原理介绍 • 蒸馏原理介绍 • 轻量级模型结构搜索原理介绍 Quantization Aware Training量化介绍 1.1 ...

  9. MATLAB机器学习系列-11:粒子群优化原理及其matlab实现

    粒子群优化原理 粒子群算法思想来源于实际生活中鸟捕食的过程.假设在一个n维的空间中,有一群鸟(m只)在捕食,食物位于n维空间的某个点上,对于第i只鸟某一时刻来说,有两个向量描述,一个是鸟的位置向量,第 ...

最新文章

  1. NHibernate: Session.Save 采用版本控制时无必要地自动Update版本字段的问题
  2. 015_SpringBoot视图层技术thymeleaf-URL表达式
  3. C#使用读写锁三行代码简单解决多线程并发写入文件时线程同步的问题
  4. c语言程序的入口是哪部分,C语言入口函数和LD_PRELOAD环境变量
  5. P3723 [AH2017/HNOI2017]礼物 FFT + 式子化简
  6. 初创公司5大Java服务困局,阿里工程师如何打破?
  7. Kali在Vmware中通过Bridge联网
  8. mac XCode 快捷键
  9. ASP.Net Mvc 发布网站 (样式+图片问题)
  10. 复习-java运行的整个流程
  11. 疑似SSD掉盘:自动重启;进BIOS看不到SSD;断电重启才能看到
  12. LINUX的awk和sed的常用用法 正则表达式 grep egrep用法
  13. Android音频系统之二音频框架
  14. 学fpga(先自顶而下设计,再自下而上集成)
  15. SQP 序列二次规划法
  16. 基于R语言结构方程模型
  17. TX2 4.6.1 全部软件环境刷机要点
  18. 有关JIT你需要知道的
  19. python控制ppt定时_python自动化怎么操作ppt?
  20. 单片机是什么?在大学里学习单片机,对以后的就业会有帮助吗

热门文章

  1. 字节跳动攻略腾讯战略腹地,打算从腾讯的碗里“抢饭吃”?
  2. linux上杀死进程命令:
  3. JDK版本的区别及选择(如7u71和7u72)
  4. ArcGIS空间统计—Moran‘s莫兰指数上
  5. 2019年vivo提前批开发岗笔试
  6. 有了链路日志增强之后再也不怕领导在群里艾特自己某某功能报错了
  7. 如果。我好愛你。我好想你。
  8. 语音识别的一些开源项目整理
  9. Go1.18 新特性:多模块(Multi-Module)工作区模式
  10. window系统环境下使用adplus抓取dump方法