Skip to content

Commit

Permalink
[Unity] Relax Recursive function (#14092)
Browse files Browse the repository at this point in the history
This PR adds TVMScript local recursive function support. It also update lambda lifting pass. Removed CalledGlobalVars, it was not used anymore. It also updates well-form pass to allow un-defined vars for recursive call
  • Loading branch information
yongwww authored and tqchen committed Feb 24, 2023
1 parent fc5981b commit 3f4835c
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 73 deletions.
9 changes: 0 additions & 9 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,6 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
*/
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);

/*!
* \brief Get all global variables used in calls in expression expr.
*
* \param expr the expression.
*
* \return List of all global variables called in expr.
*/
TVM_DLL tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr);

/*!
* \brief Get all global variables from expression expr.
*
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ TVM_DLL tvm::relax::Var Emit(
TVM_DLL tvm::relax::Var EmitMatchCast(const tvm::relax::Expr& value,
const tvm::relax::StructInfo& struct_info);

/*!
* \brief Emit a binding to the last binding block frame.
* \param binding The binding to be emitted.
* \return The left side var of the emitted binding.
*/
TVM_DLL tvm::relax::Var EmitVarBinding(const tvm::relax::VarBinding& binding);

///////////////////////////// If Then Else /////////////////////////////

/*!
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import tvm
from tvm import DataType, relax
from tvm.ir import PrimExpr
from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const
from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, VarBinding, const

############################### Operators ###############################
from tvm.relax.op import (
Expand Down Expand Up @@ -342,6 +342,20 @@ def emit_match_cast(value: Expr, struct_info: StructInfo) -> Var:
return _ffi_api.EmitMatchCast(value, struct_info) # type: ignore


def emit_var_binding(value: VarBinding) -> Var:
"""Emit a binding to the last binding block frame.
Parameters
----------
value: VarBinding
The binding to be emitted.
Returns
-------
var: Var
The left side var of the emitted binding.
"""
return _ffi_api.EmitVarBinding(value) # type: ignore


############################# If Then Else #############################


Expand Down Expand Up @@ -497,6 +511,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"divide",
"dtype",
"emit",
"emit_var_binding",
"emit_match_cast",
"equal",
"ewise_fma",
Expand Down
62 changes: 57 additions & 5 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
annotation = annotation()
if isinstance(annotation, StructInfoProxy):
return annotation
else:
raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.")
raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.")
except Exception as err:
self.report_error(node, str(err))
raise err
Expand All @@ -112,6 +111,38 @@ def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> St
raise err


def is_called(node: Any, func_name: str) -> bool:
# Check if it calls into a func
if isinstance(node, doc.Call):
# Recursive call was found
if isinstance(node.func, doc.Name) and node.func.id == func_name:
return True
elif isinstance(node, (list, tuple)):
for stmt in node:
if is_called(stmt, func_name):
return True
elif isinstance(node, (doc.AnnAssign, doc.Assign, doc.Return, doc.Expr)):
return is_called(node.value, func_name)
elif isinstance(node, doc.With):
return is_called(node.body, func_name)
elif isinstance(node, doc.If):
smts = []
if node.body is not None:
smts = smts + list(node.body)
if node.orelse is not None:
smts = smts + list(node.orelse)
return is_called(smts, func_name)
return False


def is_recursive(node: doc.FunctionDef) -> bool:
# Check if it is a recursive function
for stmt in node.body:
if is_called(stmt, node.name):
return True
return False


def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None:
# Collect symbolic vars from parameters
symbolic_vars = set()
Expand All @@ -128,6 +159,24 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non

@dispatch.register(token="relax", type_name="FunctionDef")
def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
# reserve a var for local function
func_val = self.var_table.get().get(node.name)
if not func_val and is_recursive(node):
collect_symbolic_var_from_params(self, node)
if node.returns is None:
ret_sinfo = relax.TupleStructInfo([])
else:
ret_sinfo = eval_struct_info(self, node.returns, eval_str=True)
params_sinfo = []
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function parameters.")
param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
params_sinfo.append(param_sinfo)
# created a var for the local function, the same var could be used for recursive call
local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo))
self.var_table.add(node.name, local_func_var)

with self.var_table.with_frame():
with self.with_dispatch_token("relax"):
with R.function():
Expand Down Expand Up @@ -164,12 +213,10 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None:
else:
ret_sinfo = eval_struct_info(self, node.returns, eval_str=True)
params = []
params_sinfo = []
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function parameters.")
param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True)
params_sinfo.append(param_sinfo)
params.append(relax.Var(arg.arg, param_sinfo))

func_signature = relax.Function.create_empty(params, ret_sinfo)
Expand All @@ -188,7 +235,12 @@ def post_token_switch(self: Parser, node: doc.Expr) -> None:
ir_builder = IRBuilder.current()
result = ir_builder.get()
ir_builder.__exit__(None, None, None)
var = R.emit(result)
# reuse var if it is reserved
reserved_var = self.var_table.get().get(node.name)
if reserved_var:
var = R.emit_var_binding(relax.VarBinding(reserved_var, result))
else:
var = R.emit(result)
IRBuilder.name(node.name, var)
self.var_table.add(node.name, var, allow_shadowing=False)

Expand Down
20 changes: 0 additions & 20 deletions src/relax/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,6 @@ class VarVisitor : protected ExprVisitor {
return ret;
}

Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
this->VisitExpr(expr);
Array<GlobalVar> ret;
for (const auto& v : called_global_vars_.data) {
ret.push_back(v);
}
return ret;
}

void MarkBounded(const Var& v) {
bound_vars_.Insert(v);
vars_.Insert(v);
Expand Down Expand Up @@ -123,10 +114,6 @@ class VarVisitor : protected ExprVisitor {
for (Expr arg : call_node->args) {
VisitExpr(arg);
}

if (const GlobalVarNode* global_var_node = call_node->op.as<GlobalVarNode>()) {
called_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
}
}

void VisitBinding_(const VarBindingNode* binding) final {
Expand All @@ -144,7 +131,6 @@ class VarVisitor : protected ExprVisitor {
InsertionSet<Var> vars_;
InsertionSet<Var> bound_vars_;
InsertionSet<GlobalVar> global_vars_;
InsertionSet<GlobalVar> called_global_vars_;
};

tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
Expand All @@ -155,10 +141,6 @@ tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }

tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); }

tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
return VarVisitor().CalledGlobalVars(expr);
}

TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);

TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);
Expand All @@ -167,7 +149,5 @@ TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);

TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars);

} // namespace relax
} // namespace tvm
11 changes: 10 additions & 1 deletion src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class WellFormedChecker : public relax::ExprVisitor,

void VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
if (var_set_.count(var) == 0) {
if (var_set_.count(var) == 0 && recur_vars_.count(var) == 0) {
Malformed(Diagnostic::Error(var) << "Var " << op->name_hint() << " is not defined.");
}
CheckStructInfo(op);
Expand Down Expand Up @@ -316,12 +316,20 @@ class WellFormedChecker : public relax::ExprVisitor,
}

void VisitBinding_(const VarBindingNode* binding) final {
bool is_lambda = false;
if (binding->value->IsInstance<FunctionNode>()) {
is_lambda = true;
recur_vars_.insert(binding->var);
}
if (binding->value->IsInstance<tir::PrimFuncNode>()) {
Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is disallowed in Relax IR.");
} else {
this->VisitExpr(binding->value);
}
this->VisitVarDef(binding->var);
if (is_lambda) {
recur_vars_.erase(binding->var);
}
}

void VisitBinding_(const MatchCastNode* binding) final {
Expand Down Expand Up @@ -451,6 +459,7 @@ class WellFormedChecker : public relax::ExprVisitor,
VisitMode mode_ = VisitMode::kDefault;
// set of context variables.
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> recur_vars_;
std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> dataflow_var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> symbolic_var_set_;
std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_;
Expand Down
Loading

0 comments on commit 3f4835c

Please sign in to comment.