diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 5e8c2b92de..7b8f2c0274 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -53,6 +54,60 @@ TVM_DLL bool WellFormed(const IRModule& m, */ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); +/*! + * \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array BoundVars(const Expr& expr); + +/*! + * \brief Get free type parameters from expression expr. + * + * Free variables are variables that are not bound by a + * varbinding or a function parameter in the context. + * + * \param expr the expression. + * + * \return List of free vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array FreeVars(const Expr& expr); + +/*! + * \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of all vars, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllVars(const Expr& expr); + +/*! + * \brief Get all glabal variables for recursive call from expression expr. + * + * \param expr the expression. + * + * \return List of all global variables for recursive call. + */ +TVM_DLL tvm::Array RecGlobalVars(const Expr& expr); + +/*! + * \brief Get all glabal variables from expression expr. + * + * AllVars is a superset of BoundVars and FreeVars. + * The union of BoundVars and FreeVars is Allvars. + * + * \param expr the expression. + * + * \return List of all global variables, in the PostDFS order in the expression. + */ +TVM_DLL tvm::Array AllGlobalVars(const Expr& expr); + } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index baf51e09f1..0f229ee2bd 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -81,6 +81,13 @@ TVM_DLL Pass FailTestRewrite(); */ TVM_DLL Pass FMARewrite(); +/*! + * \brief Perform lambda lifting to lift functions from nested into global. + * + * \return The Pass. + */ +TVM_DLL Pass LambdaLift(); + /*! * \brief Transform all dataflow structure to non-dataflow version. * diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 682e1494ae..9c6beb2697 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -57,6 +57,22 @@ class NameTable { std::unordered_map alloc_map_; }; +/*! + * \brief Bind the variables to a Relax expression. This is a helper + * function usually called by other pass functions to help optimizations. + * If any free variables are introduced into a function, those are added + * to the function parameters. + * Additionally this may change the order of parameters if you map a variable + * to a variable. + * + * \param expr The input expression. + * \param binds The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index d6145a61ff..686e6ac07b 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -204,6 +204,16 @@ def create_unchecked( """Construct a relax.Function but without type checking.""" return _ffi_api.Function_CreateUnchecked(params, body, ret_type, attrs, span) + def __call__(self, *args): + """Invoke the global function. + + Parameters + ---------- + args: List[relax.Expr] + Arguments. + """ + return Call(self, args, None, None) + @tvm._ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc): diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index f15a6df60e..ce9bc5723f 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -70,6 +70,17 @@ def FuseFMA() -> tvm.ir.transform.Pass: return _ffi_api.FuseFMA() +def LambdaLift(): + """ + Lift local functions into global. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.LambdaLift() + + def ToNonDataflow() -> tvm.ir.transform.Pass: """Transform all dataflow structure to non-dataflow version. diff --git a/python/tvm/script/relax/parser.py b/python/tvm/script/relax/parser.py index 2f6fd6840c..f3579c5b30 100644 --- a/python/tvm/script/relax/parser.py +++ b/python/tvm/script/relax/parser.py @@ -920,8 +920,7 @@ def transform_stmt( elif isinstance(stmt, ast.Function): func = self.transform_function(stmt) - func_var = self.decl_var(stmt.name, None, None, stmt.span) - return relax.VarBinding(func_var, func, self.to_tvm_span(stmt.span)) + return func else: self.report_error( @@ -1559,8 +1558,15 @@ def transform_block(self, block: ast.Block) -> relax.SeqExpr: blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(stmt.span))) current_block = [] blocks.append(parsed_stmt) + elif isinstance(parsed_stmt, (relax.Function, tir.PrimFunc)): + func_var = self.decl_var(stmt.name, None, None, stmt.span) + current_block.append( + relax.VarBinding(func_var, parsed_stmt, self.to_tvm_span(stmt.span)) + ) else: - assert isinstance(parsed_stmt, relax.Binding) + assert isinstance( + parsed_stmt, relax.Binding + ), "Expected relax.Binding, but got " + str(type(parsed_stmt)) current_block.append(parsed_stmt) if len(current_block) > 0: blocks.append(relax.BindingBlock(current_block, self.to_tvm_span(block.stmts[-1].span))) @@ -1573,6 +1579,19 @@ def transform_block(self, block: ast.Block) -> relax.SeqExpr: ) ret_expr = self.transform_stmt(ret_stmt) + # only a call node in the function body + if isinstance(ret_expr, relax.Call) and len(blocks) == 0: + return ret_expr + + # return a defined inner function + if ( + len(blocks) > 0 + and isinstance(blocks[-1].bindings[-1].value, relax.Function) + and hasattr(ret_expr, "name_hint") + and ret_expr.name_hint == blocks[-1].bindings[-1].var.name_hint + ): + return blocks[-1].bindings[-1].value + return relax.SeqExpr(blocks, ret_expr, self.to_tvm_span(block.span)) diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc new file mode 100644 index 0000000000..a9401129e7 --- /dev/null +++ b/src/relax/analysis/analysis.cc @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file analysis.cc + * + * \brief Analysis functions for Relax. + */ + +#include +#include + +namespace tvm { +namespace relax { + +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class VarVisitor : protected ExprVisitor { + public: + Array Free(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array Collect() { + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + return Collect(); + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array AllGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array RecGlobalVars(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : rec_global_vars_.data) { + ret.push_back(v); + } + return ret; + } + + void MarkBounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } + + void VisitExpr_(const FunctionNode* op) final { + for (const auto& param : op->params) { + MarkBounded(param); + } + VisitExpr(op->body); + } + void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef(op)); } + + void VisitExpr_(const CallNode* call_node) final { + VisitSpan(call_node->span); + VisitExpr(call_node->op); + + for (Type ty_arg : call_node->type_args) { + VisitType(ty_arg); + } + + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + + if (call_node->shape_) { + VisitExpr(Downcast(call_node->shape_.value())); + } + + if (const GlobalVarNode* global_var_node = call_node->op.as()) { + rec_global_vars_.Insert(GetRef(global_var_node)); + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + MarkBounded(binding->var); + VisitExpr(binding->value); + VisitVarDef(binding->var); + } + + private: + InsertionSet vars_; + InsertionSet bound_vars_; + InsertionSet global_vars_; + InsertionSet rec_global_vars_; +}; + +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } + +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } + +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } + +tvm::Array AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); } + +tvm::Array RecGlobalVars(const Expr& expr) { return VarVisitor().RecGlobalVars(expr); } + +TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars); + +TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars); + +TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars); + +TVM_REGISTER_GLOBAL("relax.analysis.rec_global_vars").set_body_typed(RecGlobalVars); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 2e5ced07bc..a1ea34804b 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -30,8 +30,8 @@ using namespace tir; class PatternKindAnalyzer : public StmtExprVisitor { public: - explicit PatternKindAnalyzer(const PrimFunc& func) { - for (const Var& param : func->params) { + explicit PatternKindAnalyzer(const tir::PrimFunc& func) { + for (const tir::Var& param : func->params) { param_buffers_.insert(func->buffer_map.Get(param).value()); } } @@ -111,7 +111,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { // Step 4. Checking if the block contains reduce axis by looking into block iterators. bool has_reduction = false; - Array reduce_vars; + Array reduce_vars; for (const IterVar& it : op->iter_vars) { if (it->iter_type == kCommReduce) { has_reduction = true; @@ -203,9 +203,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { * A[i, j] = B[i - j] is injective since the load indice vars are only i, j */ static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) { - std::unordered_set vars; + std::unordered_set vars; for (const PrimExpr& store_index : store->indices) { - if (const auto* v = store_index.as()) { + if (const auto* v = store_index.as()) { vars.insert(v); } else { return false; @@ -213,7 +213,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { } for (const PrimExpr& load_index : load->indices) { // return false if there are vars used in load indices but not in store indices. - if (tir::UsesVar(load_index, [&vars](const VarNode* var) { return !vars.count(var); })) { + if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) { return false; } } @@ -227,9 +227,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { * Store = A[i, j] and Load = B[i, j + k] allow data reuse. */ static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad& load) { - std::unordered_set vars; + std::unordered_set vars; for (const PrimExpr& index : store->indices) { - if (const auto* v = index.as()) { + if (const auto* v = index.as()) { vars.insert(v); } else { return false; @@ -237,7 +237,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { } for (const PrimExpr& index : load->indices) { PreOrderVisit(index, [&](const ObjectRef& node) { - if (const auto* v = node.as()) { + if (const auto* v = node.as()) { if (vars.count(v)) { vars.erase(v); } @@ -276,10 +276,10 @@ class PatternKindAnalyzer : public StmtExprVisitor { * A[i] = sum(B[i, j + k]) is not pure reduce * pooling is not pure reduce */ - static bool IsPureReducePattern(Array reduce_loops, Array indices) { + static bool IsPureReducePattern(Array reduce_loops, Array indices) { for (const PrimExpr& e : indices) { int id = -1; - if (UsesVar(e, [&](const VarNode* var) { + if (UsesVar(e, [&](const tir::VarNode* var) { for (size_t i = 0; i < reduce_loops.size(); ++i) { if (reduce_loops[i].get() == var) { id = i; diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 38a0882efe..d17b27e208 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -71,11 +71,7 @@ class WellFormedChecker : public relax::ExprVisitor { void Malformed(Diagnostic diag) { well_formed = false; - if (diag_ctx) { - diag_ctx.value().Emit(diag); - } else { - LOG(WARNING) << "This IR is not well formed: " << diag->message; - } + LOG(WARNING) << "This IR is not well formed: " << diag->message; } void RegisterGlobalVar(GlobalVar var) { global_var_set_.insert(var); } @@ -126,6 +122,8 @@ class WellFormedChecker : public relax::ExprVisitor { } void VisitExpr_(const FunctionNode* op) { + // save the var_set_ for local function + std::unordered_set previous_var_set_ = var_set_; for (Var param : op->params) { // register symbolic var defined in the shape annotation of function params if (param->shape_) { @@ -146,7 +144,7 @@ class WellFormedChecker : public relax::ExprVisitor { this->VisitVarDef(param); } this->VisitBody(op->body); - var_set_.clear(); + var_set_ = previous_var_set_; prim_expr_visitor_.symbolic_var_set_.clear(); } diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index e0a023c65d..6afcb7f8a6 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -418,7 +418,7 @@ class CodeGenVM : public ExprFunctor { ICHECK(call_node->args[1]->IsInstance()); std::vector args; - // VMState is utilized to help get the Function in builtin packedfunc + // VM is utilized to help get the Function in builtin packedfunc args.push_back(Instruction::Arg(Instruction::kVMRegister)); auto lv = Downcast(call_node->args[0]); @@ -429,9 +429,9 @@ class CodeGenVM : public ExprFunctor { args.push_back(Instruction::Arg(Instruction::kRegister, registers_num_)); } - // free_vars of VMClosure - auto closure_args = Downcast(call_node->args[1]); - for (Expr arg : closure_args->fields) { + // args for the invoke_closure + auto invoke_closure_args = Downcast(call_node->args[1]); + for (Expr arg : invoke_closure_args->fields) { args.push_back(ConvertArg(arg)); } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index ce97b7bcfc..3481378735 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -398,6 +398,23 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { // LOG(FATAL) << "ValueError: Cannot find function " << gv->name_hint // << " in the context IRModule."; // } + } else if (const auto* var = call->op.as()) { + if (var->shape_) { + return Downcast(var->shape_.value()); + } + Optional val = builder_->LookupBinding(GetRef(var)); + if (const auto* func_node = val.value().as()) { + Function func = GetRef(func_node); + if (func->ret_type.as()) { + Expr func_shape = Downcast(func_node->body->shape_); + if (IsConstantShapes(func_shape)) { + return func_shape; + } else { + // TODO(@yuchen, @yongwww): add deducer for other cases + return RuntimeDepShape(Span()); + } + } + } } else { LOG(FATAL) << "ValueError: Failed to do shape inference for " << call->op->GetTypeKey(); } @@ -441,6 +458,15 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { // << " in the context IRModule."; // } } + } else if (auto* var = call->op.as()) { + // TODO(@yongwww, yuchen): handle the infer with more specific cases + Optional val = builder_->LookupBinding(GetRef(var)); + if (const auto* func_node = val.value().as()) { + return func_node->ret_type; + } + if (auto* ft_node = var->checked_type_.as()) { + return ft_node->ret_type; + } } else { // TODO(@yuchen): call to local var/function support LOG(FATAL) << "ValueError: Failed to do type inference for " << call->op->GetTypeKey(); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index b97373be61..b89b176470 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -25,9 +25,9 @@ * the cost of using functional updates. */ #include +#include #include #include -#include #include namespace tvm { @@ -658,6 +658,5 @@ Var ExprMutator::WithShapeAndType(Var var, Optional shape, Type type) return var; } - } // namespace relax } // namespace tvm diff --git a/src/relax/ir/type.cc b/src/relax/ir/type.cc index 3207c20153..cd9062f490 100644 --- a/src/relax/ir/type.cc +++ b/src/relax/ir/type.cc @@ -102,7 +102,7 @@ bool IsBaseOf(const Type& base, const Type& derived) { return false; } - for (size_t i = 0; i < base_tuple->fields.size(); i++) { + for (size_t i = 0; i < base_tuple->fields.size(); ++i) { if (!IsBaseOf(base_tuple->fields[i], derived_tuple->fields[i])) { return false; } diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 2a366098ed..550ba43237 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -30,50 +30,6 @@ namespace tvm { namespace relax { -/*! \brief Helper to implement bind params.*/ -class ExprBinder : public ExprMutator { - public: - explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} - - Expr VisitExpr_(const VarNode* op) final { - auto id = GetRef(op); - auto it = args_map_.find(id); - if (it != args_map_.end()) { - return (*it).second; - } else { - return ExprMutator::VisitExpr_(op); - } - } - - private: - const tvm::Map& args_map_; -}; - -/*! - * \brief Bind params on expr - * \param expr The expr where to bind params - * \param args_map The map from param var to the expr it binds to - * \return The result expr after bind params - */ -Expr Bind(const Expr& expr, const tvm::Map& args_map) { - if (const FunctionNode* func = expr.as()) { - Expr new_body = ExprBinder(args_map).VisitExpr(func->body); - Array new_params; - for (size_t i = 0; i < func->params.size(); ++i) { - if (!args_map.count(func->params[i])) { - new_params.push_back(func->params[i]); - } - } - if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { - return expr; - } - // The checked_type_ of the new function is deduced from the function body - return Function(new_params, new_body, Type(), func->attrs); - } else { - return ExprBinder(args_map).VisitExpr(expr); - } -} - /*! * \brief Bind params to function by using name * \param func Relax function diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc new file mode 100644 index 0000000000..70bc0b3b80 --- /dev/null +++ b/src/relax/transform/lambda_lift.cc @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/lambda_lift.cc + * \brief Lift local functions into global functions. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ +class LambdaLifter : public ExprMutator { + public: + explicit LambdaLifter(const IRModule& module) : ExprMutator(module) { mod_ = module; } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + if (auto const* var = call_node->op.as()) { + bool has_closure = HasClosure(GetRef(var)); + auto val = builder_->LookupBinding(GetRef(var)); + // Call "relax.invoke_closure" to invoke closure + if (has_closure && val.as()) { + Var clo_arg = GetRef(var); + if (this->var_remap_.find(var->vid) != this->var_remap_.end()) { + clo_arg = this->var_remap_.at(var->vid); + } + return Call(invoke_closure_op_, {clo_arg, Tuple(call_node->args)}, {}, + {call_node->checked_type_}); + } + } + if (auto global_var_node = call_node->op.as()) { + String rec_name = global_var_node->name_hint; + auto global_var = GetRef(global_var_node); + auto it = lambda_map_.find(global_var); + if (it != lambda_map_.end()) { + // flatten nested call, e.g. call(y)(x) -> call(x, y)) + Array new_args; + for (const auto arg : call->args) { + new_args.push_back(arg); + } + if (const auto* nest_call = it->second.as()) { + for (const auto arg : nest_call->args) { + new_args.push_back(arg); + } + return Call(nest_call->op, new_args, call_node->attrs, call_node->type_args); + } + return Call(it->second, call->args, call_node->attrs, call_node->type_args); + } + } + return std::move(call); + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + + // TODO(@yongwww): consider appending inner func name into the lifted func name + String lift_func_name = "lifted_func_" + std::to_string(lift_func_num_++); + auto global = GlobalVar(lift_func_name); + Array captured_vars = FreeVars(func); + recur_vars_ = RecGlobalVars(func); + auto all_global_vars = AllGlobalVars(func); + + Array typed_captured_vars; + Map rebinding_map; + for (auto free_var : captured_vars) { + Var var = Var(free_var->name_hint(), NullOpt, free_var->checked_type_, free_var->span); + var->shape_ = free_var->shape_; + typed_captured_vars.push_back(var); + rebinding_map.Set(free_var, var); + } + + // recursive call + if (!recur_vars_.empty()) { + if (!captured_vars.empty()) { + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + lambda_map_.emplace(recur_vars_.back(), Call(global, fvs)); + } else { + if (recur_vars_.size() > 0) { + lambda_map_.emplace(recur_vars_.back(), global); + } + } + } + + tvm::Array params; + bool all_params_unchanged = true; + for (Var param : func_node->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + all_params_unchanged &= param.same_as(new_param); + } + + Expr body = this->VisitWithNewScope(func_node->body); + Expr visited_func; + + if (all_params_unchanged && body.same_as(func_node->body)) { + visited_func = GetRef(func_node); + } else if (body->checked_type_.as()) { + // make_closure was introduced + visited_func = Function(params, body, body->checked_type_, func_node->attrs); + } else { + visited_func = Function(params, body, func_node->ret_type, func_node->attrs); + } + auto new_func = Downcast(visited_func); + + Function lifted_func; + bool is_closure = IsClosure(captured_vars); + if (!is_closure) { + lifted_func = Function( + /*params=*/new_func->params, + /*body=*/new_func->body, + /*ret_type=*/new_func->ret_type, + /*attrs=*/new_func->attrs, + /*span=*/new_func->span); + } else { + // Flatten the Closure + std::vector closure_params; + closure_params.reserve(func->params.size() + typed_captured_vars.size()); + for (size_t i = 0; i < func->params.size(); ++i) { + closure_params.emplace_back(func->params[i]); + } + for (size_t i = 0; i < typed_captured_vars.size(); ++i) { + closure_params.emplace_back(typed_captured_vars[i]); + } + + lifted_func = Function(/*params=*/closure_params, + /*body=*/Bind(new_func->body, rebinding_map), + /*ret_type=*/new_func->ret_type, + /*attrs=*/new_func->attrs, + /*span=*/func->span); + + Array param_types; + for (Var param : closure_params) { + CHECK(param->checked_type_.defined()) + << "relax.Function requires params to contain checked_type_"; + param_types.push_back(param->checked_type_); + } + } + lifted_func = WithAttr(std::move(lifted_func), tvm::attr::kGlobalSymbol, lift_func_name); + + ICHECK(lifted_func.defined()); + + // Add the lifted function to the module. + builder_->UpdateFunction(global, lifted_func); + + if (!is_closure) { + return std::move(global); + } else { + // If we need to allocate a closure, + // we pass the variables in its environment here. + Array fvs; + for (auto fv : captured_vars) { + fvs.push_back(fv); + } + // Call make_closure intrinsic + return Call(make_closure_op_, {global, Tuple(fvs)}, {}, {}); + } + } + + bool HasClosure(const Var& var) { + auto val = builder_->LookupBinding(var); + if (const auto* value = val.as()) { + IRModule ctx_mod = builder_->GetContextIRModule(); + ICHECK(ctx_mod->functions.size() > 0); + BaseFunc func = ctx_mod->Lookup(GetRef(value)); + if (const auto* func_node = func.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } + } + } else if (const auto* func_node = val.as()) { + if (const auto* call_node = func_node->body.as()) { + if (call_node->op == make_closure_op_) { + return true; + } + } + } else if (const auto* call_node = val.as()) { + // recursive call + auto op = call_node->op; + if (make_closure_op_ == op) { + return true; + } + if (const auto* lv = op.as()) { + return HasClosure(GetRef(lv)); + } + } + return false; + } + + bool IsClosure(const Array& captured_vars) { return captured_vars.size() > 0; } + + IRModule Lift() { + auto glob_funcs = mod_->functions; + for (auto pair : glob_funcs) { + if (auto* n = pair.second.as()) { + auto func = GetRef(n); + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->attrs); + builder_->UpdateFunction(pair.first, func); + } + } + return builder_->GetContextIRModule(); + } + + private: + std::unordered_map lambda_map_; + Array recur_vars_; + IRModule mod_; + size_t lift_func_num_ = 0; + /*! \brief Cache ops that would be used later to reduce lookup overhead. */ + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); +}; + +namespace transform { + +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::LambdaLifter(m).Lift(); }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.LambdaLift").set_body_typed(LambdaLift); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/resolve_globals.cc b/src/relax/transform/resolve_globals.cc index 2851a97d5b..31ca020baa 100644 --- a/src/relax/transform/resolve_globals.cc +++ b/src/relax/transform/resolve_globals.cc @@ -32,8 +32,6 @@ class GlobalVarResolver : public ExprMutator { Expr VisitExpr_(const GlobalVarNode* gvar) { if (!mod_->ContainGlobalVar(gvar->name_hint)) { - diag_ctx_.Emit(Diagnostic::Error(gvar->span) - << "undefined variable/global \"" << gvar->name_hint << "\""); return GetRef(gvar); } return mod_->GetGlobalVar(gvar->name_hint); diff --git a/src/relax/utils.cc b/src/relax/utils.cc new file mode 100644 index 0000000000..2c0d5a5850 --- /dev/null +++ b/src/relax/utils.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Helper to implement bind params.*/ +class ExprBinder : public ExprMutator { + public: + explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} + + Expr VisitExpr_(const VarNode* op) final { + auto id = GetRef(op); + auto it = args_map_.find(id); + if (it != args_map_.end()) { + return (*it).second; + } else { + return ExprMutator::VisitExpr_(op); + } + } + + private: + const tvm::Map& args_map_; +}; + +/*! + * \brief Bind params on expr + * \param expr The expr where to bind params + * \param args_map The map from param var to the expr it binds to + * \return The result expr after bind params + */ +Expr Bind(const Expr& expr, const tvm::Map& args_map) { + if (const FunctionNode* func = expr.as()) { + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); + Array new_params; + for (size_t i = 0; i < func->params.size(); ++i) { + if (!args_map.count(func->params[i])) { + new_params.push_back(func->params[i]); + } + } + if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { + return expr; + } + // The checked_type_ of the new function is deduced from the function body + return Function(new_params, new_body, Type(), func->attrs); + } else { + return ExprBinder(args_map).VisitExpr(expr); + } +} + +} // namespace relax +} // namespace tvm diff --git a/src/relay/printer/relax_script_printer.cc b/src/relay/printer/relax_script_printer.cc index d589ab1ffa..205ae62b17 100644 --- a/src/relay/printer/relax_script_printer.cc +++ b/src/relay/printer/relax_script_printer.cc @@ -558,6 +558,18 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& if (const relax::SeqExprNode* body = func->body.as()) { doc << Doc::Indent(4, Print(func->body)); doc << Doc::Indent(4, Doc::Text("return ") << Print(body->body)) << Doc::NewLine(); + } else if (const relax::FunctionNode* body = func->body.as()) { + // nested function + String func_name; + Optional gsymbol = body->GetAttr(tvm::attr::kGlobalSymbol); + if (gsymbol.defined()) { + func_name = gsymbol.value(); + } else { + func_name = "local_func_" + std::to_string(local_func_counter_++); + } + Doc nested_func = PrintFunctionDef(Doc::Text(func_name), GetRef(body)); + doc << Doc::Indent(4, nested_func); + doc << Doc::Indent(4, Doc::Text("return ") << func_name) << Doc::NewLine(); } else { doc << Doc::Indent(4, Doc::Text("return ") << Print(func->body)) << Doc::NewLine(); } diff --git a/src/relay/printer/text_printer.h b/src/relay/printer/text_printer.h index 57855edf9f..58ee3ab09f 100644 --- a/src/relay/printer/text_printer.h +++ b/src/relay/printer/text_printer.h @@ -250,6 +250,8 @@ class RelaxScriptPrinter : public relax::IRFunctor, NameTable name_table_; /*! \brief Whether to print meta data. */ bool show_meta_data_; + /*! \brief A counter for naming local functions. */ + size_t local_func_counter_ = 0; /*! \brief meta data context */ TextMetaDataContext* meta_; std::unordered_map var_id_map_; @@ -584,9 +586,10 @@ class TextPrinter { (node->IsInstance() || node->IsInstance() || node->IsInstance())) { doc << tir_text_printer_.Print(node); - } else if (node.defined() && (node->IsInstance() || - node->IsInstance() || - node->IsInstance())) { + } else if (node.defined() && + (node->IsInstance() || node->IsInstance() || + node->IsInstance() || + node->IsInstance())) { doc << relax_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 702548862a..7e75be8808 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -781,8 +781,8 @@ def k(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tensor: j = my_module[var_j] k = my_module[var_k] - assert f.body.body.op == var_g - assert g.body.body.args[0] == var_my_matmul + assert f.body.op == var_g + assert g.body.args[0] == var_my_matmul gv_bind = j.body.blocks[0].bindings[0] assert gv_bind.value.checked_type.ndim == 2 diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index c098ddb499..2dd2a91846 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -311,11 +311,14 @@ def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @R.function def f(x: Tensor((n, n), _)) -> Tensor: - return g(x) + # todo(@yongwww): Update the check_type_ function's body is a call_node + r = g(x) + return r @R.function def g(y: Tensor((n, n), _)) -> Tensor: - return relax.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + r = relax.call_tir(my_matmul, (y, y), (n, n), dtype="float32") + return r @R.function def h(x: Tensor((n, n), _), y: Tensor((n, n), _), z: Tensor((n, n), _)) -> Tensor: diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py new file mode 100644 index 0000000000..996a25eea1 --- /dev/null +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations +import pytest +import tvm +from tvm import relax +from tvm.runtime.object import Object +import tvm.script +from tvm.script import relax as R, tir as T +from tvm.relax import transform +from tvm.ir.base import assert_structural_equal + + +def _check_equal(x, y): + tvm.ir.assert_structural_equal(x, y) + tvm.ir.assert_structural_equal(y, x) + + xhash = tvm.ir.structural_hash(x) + yhash = tvm.ir.structural_hash(y) + + assert xhash == yhash + + +def _check_save_roundtrip(x): + y = tvm.ir.load_json(tvm.ir.save_json(x)) + _check_equal(x, y) + + +def test_basic(): + # the target IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0(x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32")) -> Tensor: + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + @R.function + def main( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + inner = lifted_func_0 + gv1 = inner(x1, y1) + return gv1 + + @tvm.script.ir_module + class Before: + @R.function + def main( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + @R.function + def inner( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + gv1: Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lambda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_closure(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def main(x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32")): + outer_func = lifted_func_0 + in_call = outer_func(x) + res = relax.invoke_closure(in_call, (y,), type_args=(Tensor(ndim=2, dtype="float32"))) + return res + + @R.function + def lifted_func_1(x1: Tensor((2, 3), "float32"), c1: Tensor((2, 3), "float32")): + r_1: Tensor((2, 3), "float32") = relax.add(x1, c1) + return r_1 + + @R.function + def lifted_func_0(y: Tensor((2, 3), "float32")): + return relax.make_closure(lifted_func_1, (y,)) + + # IRModule to perform Lambda Lifting + @tvm.script.ir_module + class Before: + @R.function + def main( + x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32") + ) -> Tensor((2, 3), "float32"): + @R.function + def outer_func(c1: Tensor((2, 3), "float32")): + @R.function + def inner_func(x1: Tensor((2, 3), "float32")): + s: Tensor((2, 3), "float32") = relax.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + before = Before + after = transform.LambdaLift()(before) + expected = Expected + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_recursive(): + # the expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def lifted_func_0( + i: Tensor((), "int32"), s: Tensor((2, 3), "float32"), x: Tensor((2, 3), "float32") + ) -> Tensor((2, 3), "float32"): + cond: Tensor((), "bool") = relax.call_packed( + "test.vm.less", i, relax.const(10), type_args=(Tensor(ndim=0, dtype="bool")) + ) + c: Tensor((), "int32") = relax.const(1, dtype="int32") + if cond: + new_i: Tensor((), "int32") = relax.add(i, c) + new_s: Tensor((2, 3), "float32") = relax.add(s, x) + r = lifted_func_0(new_i, new_s, x) + else: + r = s + return r + + @R.function + def main(x: Tensor((2, 3), "float32")) -> Tensor: + while_loop = relax.make_closure(lifted_func_0, (x,)) + gv = relax.invoke_closure( + while_loop, (relax.const(0), x), type_args=(Tensor(ndim=2, dtype="float32")) + ) + return gv + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def main(x: Tensor((2, 3), "float32")) -> Tensor: + @R.function + def while_loop( + i: Tensor((), "int32"), s: Tensor((2, 3), "float32") + ) -> Tensor((2, 3), "float32"): + cond: Tensor((), "bool") = relax.call_packed( + "test.vm.less", i, relax.const(10), type_args=(Tensor(ndim=0, dtype="bool")) + ) + c: Tensor((), "int32") = relax.const(1, dtype="int32") + if cond: + new_i: Tensor((), "int32") = relax.add(i, c) + new_s: Tensor((2, 3), "float32") = relax.add(s, x) + r: Tensor((2, 3), "float32") = while_loop(new_i, new_s) + else: + r: Tensor((2, 3), "float32") = s + return r + + gv: Tensor((2, 3), "float32") = while_loop(relax.const(0), x) + return gv + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 2 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_multi_func(): + # expected IRModule + @tvm.script.ir_module + class Expected: + @R.function + def glob_func_1( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + inner = lifted_func_1 + gv1 = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x11: Tensor((10, 5), "float32"), y11: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + inner1 = lifted_func_0 + gv11 = inner1(x11, y11) + return gv11 + + @R.function + def lifted_func_0( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + @R.function + def lifted_func_1( + x21: Tensor((10, 5), "float32"), y21: Tensor((10, 5), "float32") + ) -> Tensor(None, "float32", ndim=2): + s1: Tensor((10, 5), "float32") = relax.add(x21, y21) + return s1 + + # the IRModule to apply lambda lifting + @tvm.script.ir_module + class Before: + @R.function + def glob_func_1( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + @R.function + def inner( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + gv1: Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + @R.function + def glob_func_2( + x1: Tensor((10, 5), "float32"), y1: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + @R.function + def inner( + x2: Tensor((10, 5), "float32"), y2: Tensor((10, 5), "float32") + ) -> Tensor((10, 5), "float32"): + s: Tensor((10, 5), "float32") = relax.add(x2, y2) + return s + + gv1: Tensor((10, 5), "float32") = inner(x1, y1) + return gv1 + + before = Before + expected = Expected + # Perform Lamda Lifting + after = transform.LambdaLift()(before) + assert len(after.functions) == 4 + assert_structural_equal(after, expected, map_free_vars=True) + _check_save_roundtrip(after) + + +def test_no_local_func(): + @tvm.script.ir_module + class Before: + @T.prim_func + def sub( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], + ) -> None: + for i, j in T.grid(16, 16): + with T.block("sub"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A[vi, vj] - B[vi, vj] + + @R.function + def before(c0: Tensor((16, 16), "float32"), x: Tensor((_, _), "float32")): + s = relax.call_tir(sub, (c0, x), (16, 16), dtype="float32") + return s + + before = Before + # Perform lambda lifting + after = transform.LambdaLift()(before) + # No local functions are lifted + assert_structural_equal(after, before, map_free_vars=True) + _check_save_roundtrip(after) + + +if __name__ == "__main__": + pytest.main((__file__))