Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Pass] Lambda Lifting (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww authored and junrushao committed Feb 8, 2023
1 parent b9e2caf commit add66b2
Show file tree
Hide file tree
Showing 22 changed files with 987 additions and 80 deletions.
55 changes: 55 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/diagnostic.h>
#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/function.h>

Expand Down Expand Up @@ -53,6 +54,60 @@ TVM_DLL bool WellFormed(const IRModule& m,
*/
TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);

/*!
* \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);

/*!
* \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
* varbinding or a function parameter in the context.
*
* \param expr the expression.
*
* \return List of free vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);

/*!
* \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);

/*!
* \brief Get all glabal variables for recursive call from expression expr.
*
* \param expr the expression.
*
* \return List of all global variables for recursive call.
*/
TVM_DLL tvm::Array<GlobalVar> RecGlobalVars(const Expr& expr);

/*!
* \brief Get all glabal variables from expression expr.
*
* AllVars is a superset of BoundVars and FreeVars.
* The union of BoundVars and FreeVars is Allvars.
*
* \param expr the expression.
*
* \return List of all global variables, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr);

} // namespace relax
} // namespace tvm

Expand Down
7 changes: 7 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ TVM_DLL Pass FailTestRewrite();
*/
TVM_DLL Pass FMARewrite();

/*!
* \brief Perform lambda lifting to lift functions from nested into global.
*
* \return The Pass.
*/
TVM_DLL Pass LambdaLift();

/*!
* \brief Transform all dataflow structure to non-dataflow version.
*
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ class NameTable {
std::unordered_map<std::string, uint32_t> alloc_map_;
};

/*!
* \brief Bind the variables to a Relax expression. This is a helper
* function usually called by other pass functions to help optimizations.
* If any free variables are introduced into a function, those are added
* to the function parameters.
* Additionally this may change the order of parameters if you map a variable
* to a variable.
*
* \param expr The input expression.
* \param binds The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

} // namespace relax
} // namespace tvm

Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ def create_unchecked(
"""Construct a relax.Function but without type checking."""
return _ffi_api.Function_CreateUnchecked(params, body, ret_type, attrs, span)

def __call__(self, *args):
"""Invoke the global function.
Parameters
----------
args: List[relax.Expr]
Arguments.
"""
return Call(self, args, None, None)


@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def FuseFMA() -> tvm.ir.transform.Pass:
return _ffi_api.FuseFMA()


def LambdaLift():
"""
Lift local functions into global.
Returns
-------
ret : tvm.ir.transform.Pass
"""
return _ffi_api.LambdaLift()


def ToNonDataflow() -> tvm.ir.transform.Pass:
"""Transform all dataflow structure to non-dataflow version.
Expand Down
25 changes: 22 additions & 3 deletions python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,7 @@ def transform_stmt(

elif isinstance(stmt, ast.Function):
func = self.transform_function(stmt)
func_var = self.decl_var(stmt.name, None, None, stmt.span)
return relax.VarBinding(func_var, func, self.to_tvm_span(stmt.span))
return func

else:
self.report_error(
Expand Down Expand Up @@ -1559,8 +1558,15 @@ def transform_block(self, block: ast.Block) -> relax.SeqExpr:
blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(stmt.span)))
current_block = []
blocks.append(parsed_stmt)
elif isinstance(parsed_stmt, (relax.Function, tir.PrimFunc)):
func_var = self.decl_var(stmt.name, None, None, stmt.span)
current_block.append(
relax.VarBinding(func_var, parsed_stmt, self.to_tvm_span(stmt.span))
)
else:
assert isinstance(parsed_stmt, relax.Binding)
assert isinstance(
parsed_stmt, relax.Binding
), "Expected relax.Binding, but got " + str(type(parsed_stmt))
current_block.append(parsed_stmt)
if len(current_block) > 0:
blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(block.stmts[-1].span)))
Expand All @@ -1573,6 +1579,19 @@ def transform_block(self, block: ast.Block) -> relax.SeqExpr:
)
ret_expr = self.transform_stmt(ret_stmt)

# only a call node in the function body
if isinstance(ret_expr, relax.Call) and len(blocks) == 0:
return ret_expr

# return a defined inner function
if (
len(blocks) > 0
and isinstance(blocks[-1].bindings[-1].value, relax.Function)
and hasattr(ret_expr, "name_hint")
and ret_expr.name_hint == blocks[-1].bindings[-1].var.name_hint
):
return blocks[-1].bindings[-1].value

return relax.SeqExpr(blocks, ret_expr, self.to_tvm_span(block.span))


Expand Down
168 changes: 168 additions & 0 deletions src/relax/analysis/analysis.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
*
* \file analysis.cc
*
* \brief Analysis functions for Relax.
*/

#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>

namespace tvm {
namespace relax {

template <typename T>
struct InsertionSet {
std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
std::vector<T> data;
void Insert(const T& t) {
if (set.count(t) == 0) {
set.insert(t);
data.push_back(t);
}
}
};

class VarVisitor : protected ExprVisitor {
public:
Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
for (const auto& v : vars_.data) {
if (bound_vars_.set.count(v) == 0) {
ret.push_back(v);
}
}
return ret;
}

Array<Var> Collect() {
Array<Var> ret;
for (const auto& v : bound_vars_.data) {
ret.push_back(v);
}
return ret;
}

Array<Var> Bound(const Expr& expr) {
this->VisitExpr(expr);
return Collect();
}

Array<Var> All(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
for (const auto& v : vars_.data) {
ret.push_back(v);
}
return ret;
}

Array<GlobalVar> AllGlobalVars(const Expr& expr) {
this->VisitExpr(expr);
Array<GlobalVar> ret;
for (const auto& v : global_vars_.data) {
ret.push_back(v);
}
return ret;
}

Array<GlobalVar> RecGlobalVars(const Expr& expr) {
this->VisitExpr(expr);
Array<GlobalVar> ret;
for (const auto& v : rec_global_vars_.data) {
ret.push_back(v);
}
return ret;
}

void MarkBounded(const Var& v) {
bound_vars_.Insert(v);
vars_.Insert(v);
}

void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }

void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) {
MarkBounded(param);
}
VisitExpr(op->body);
}
void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef<GlobalVar>(op)); }

void VisitExpr_(const CallNode* call_node) final {
VisitSpan(call_node->span);
VisitExpr(call_node->op);

for (Type ty_arg : call_node->type_args) {
VisitType(ty_arg);
}

for (Expr arg : call_node->args) {
VisitExpr(arg);
}

if (call_node->shape_) {
VisitExpr(Downcast<Expr>(call_node->shape_.value()));
}

if (const GlobalVarNode* global_var_node = call_node->op.as<GlobalVarNode>()) {
rec_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
}
}

void VisitBinding_(const VarBindingNode* binding) final {
MarkBounded(binding->var);
VisitExpr(binding->value);
VisitVarDef(binding->var);
}

private:
InsertionSet<Var> vars_;
InsertionSet<Var> bound_vars_;
InsertionSet<GlobalVar> global_vars_;
InsertionSet<GlobalVar> rec_global_vars_;
};

tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }

tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); }

tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }

tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); }

tvm::Array<GlobalVar> RecGlobalVars(const Expr& expr) { return VarVisitor().RecGlobalVars(expr); }

TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);

TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);

TVM_REGISTER_GLOBAL("relax.analysis.rec_global_vars").set_body_typed(RecGlobalVars);

} // namespace relax
} // namespace tvm
Loading

0 comments on commit add66b2

Please sign in to comment.