diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py index 5594e36cb855..814e79329019 100644 --- a/python/tvm/relay/backend/te_compiler.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -412,3 +412,26 @@ def get(): The TE Compiler. """ return _backend._TECompilerGlobal() + + +def lower_to_primfunc(relay_func, target): + """Lower Relay Function to TIR PrimFunc. + + Parameters + ---------- + relay_func: relay.Function + The source primitive function, created by FuseOps. + + target : Target + The compilation target. + + Returns + ------- + prim_func : tir.PrimFunc + The created prim func. + """ + f = tvm._ffi.get_global_func("relay.backend.LowerToPrimFunc") + assert f is not None, "relay.backend.LowerToPrimFunc does not exist. " + + with target: + return f(relay_func, target) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index e7e677938e1a..fc45311e085d 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -59,7 +59,6 @@ Array ExtractTask(IRModule mod, Target target, using meta_schedule::ExtractedTask; using meta_schedule::ModuleEqual; using meta_schedule::ModuleHash; - backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter(); backend::BindParamsInModule(mod, params); // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); @@ -84,10 +83,9 @@ Array ExtractTask(IRModule mod, Target target, if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { return; } - auto [inputs_outputs, constants, fused_name] = - tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true); - if (Optional f = tir_converter(inputs_outputs, constants)) { + auto [f, fused_name] = tec::LowerToPrimFunc(relay_func, target, constant_name_supply); + if (f) { IRModule tir_mod = PrimFuncToIRModule(f.value()); lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod)); } diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 511f0a901d11..d71cbcfc667d 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -1088,6 +1088,33 @@ std::tuple, Array, std::string> LowerTECompu return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_); } +std::pair, std::string> LowerToPrimFunc(const Function& relay_func, + Target target, + NameSupply constant_name_supply) { + ICHECK(relay_func->HasNonzeroAttr(attr::kPrimitive)) + << "The input must be a Relay primitive function."; + + auto [inputs_outputs, constants, fused_name] = + tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true); + auto tir_converter = backend::GetTIRConverter(); + return std::make_pair(tir_converter(inputs_outputs, constants), fused_name); +} + +tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) { + auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply("")); + (void)_; // to suppress -Werror=unused-variable warning + if (f_opt) { + return f_opt.value(); + } + LOG(FATAL) << "Failed to convert the Relay function: " << AsText(relay_func, false); + return PrimFunc(); +} + +TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc") + .set_body_typed([](Function relay_func, Target target) { + return LowerToPrimFunc(relay_func, target); + }); + TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { auto tgt = tvm::Target("ext_dev"); LowerToTECompute lower_te_compute(tgt, NameSupply("")); diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index fcbf10477fdf..76939a923cdf 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -212,10 +212,10 @@ class CCacheValue : public ObjectRef { Array GetShape(const Array& shape); /*! - * \brief Lowers Relay primitive Function to TE Compute + * \brief Lower Relay primitive Function to TE Compute * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. - * \param constant_name_supply A name supplier for constants. + * \param target The compilation target. + * \param constant_name_supply A name supplier for constants * across different invocations of this function. * \param return_inputs If true, prepend input tensors to the output array of tensors. * \return Tuple of the lowered TE compute, constant raw data, and fused function name. @@ -224,10 +224,22 @@ std::tuple, Array, std::string> LowerTECompu const Function& source_func, Target target, NameSupply constant_name_supply, bool return_inputs = true); +/*! + * \brief Lower Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc. + * \param relay_func The primitive function to be lowered. + * \param target The compilation target. + * \param constant_name_supply A name supplier for constants + * across different invocations of this function. + * \return A pair of the created prim func and the name of the fused function. + */ +std::pair, std::string> LowerToPrimFunc(const Function& relay_func, + Target target, + NameSupply constant_name_supply); + /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. - * \param target The target we want to create schedule for. + * \param target The compilation target. * \param global_var_supply A name supplier for global variables. * \param constant_name_supply A name supplier for constants. * \return Pair of schedule and cache. diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 92e3cbd66e2f..0a8a0dd59fbf 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -14,10 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np + import tvm import tvm.testing from tvm import te from tvm.script import tir as T +from tvm import relay, tir +from tvm.relay.backend.te_compiler import lower_to_primfunc +from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN def _check(original, transformed): @@ -360,5 +365,56 @@ def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]): _check(before, after) +def test_allocate_const_after_tensorize(): + i_size, o_size, h_size, w_size = 64, 64, 56, 56 + k_height_size = k_width_size = 3 + w_shape = (o_size, i_size, k_height_size, k_width_size) + + data = relay.var("data", shape=(1, i_size, h_size, w_size), dtype="uint8") + weight = relay.var("weight", shape=w_shape, dtype="uint8") + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(k_height_size, k_width_size), + channels=o_size, + padding=(0, 0), + strides=(1, 1), + out_dtype="int32", + ) + mod = tvm.IRModule.from_expr(conv2d) + + executor = relay.backend.Executor("graph", {"link-params": True}) + mod = mod.with_attr("executor", executor) + + weight_np = np.random.uniform(1, 10, size=w_shape).astype("uint8") + + target = tvm.target.Target("hexagon") + + with tvm.transform.PassContext(opt_level=3): + opt_mod, _ = relay.optimize(mod, params={"weight": weight_np}, target=target) + + conv2d_func = opt_mod["main"].body.args[0].op + prim_func = lower_to_primfunc(conv2d_func, target) + + sch = tir.Schedule(prim_func) + block = sch.get_block("conv2d_NCHWc_int8") + loops = sch.get_loops(block) + + sch.reorder(loops[8], loops[4], loops[-1]) + sch.decompose_reduction(block, loops[1]) + sch.tensorize(loops[4], VRMPY_u8u8i32_INTRIN) + + seq = tvm.transform.Sequential( + [ + tvm.tir.transform.LowerInitBlock(), + tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(), + ] + ) + + # The following error is emitted if AllocateConst nodes are not correctly handled: + # Check failed: (buffer_data_to_buffer_.count(source_var)) is false: + _ = seq(sch.mod) + + if __name__ == "__main__": tvm.testing.main()