Skip to content

Commit

Permalink
[Relay][TIR] Add utility to lower Relay func to TIR prim func (#13606)
Browse files Browse the repository at this point in the history
* introduce LowerToPrimFunc to lower Relay func to TIR prim func

* add doc

* expose to python

* adding test

* another minor doc update

* Verify that the input is a primitive function
  • Loading branch information
masahi authored Dec 14, 2022
1 parent c547bbb commit c6652bc
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 8 deletions.
23 changes: 23 additions & 0 deletions python/tvm/relay/backend/te_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,26 @@ def get():
The TE Compiler.
"""
return _backend._TECompilerGlobal()


def lower_to_primfunc(relay_func, target):
"""Lower Relay Function to TIR PrimFunc.
Parameters
----------
relay_func: relay.Function
The source primitive function, created by FuseOps.
target : Target
The compilation target.
Returns
-------
prim_func : tir.PrimFunc
The created prim func.
"""
f = tvm._ffi.get_global_func("relay.backend.LowerToPrimFunc")
assert f is not None, "relay.backend.LowerToPrimFunc does not exist. "

with target:
return f(relay_func, target)
6 changes: 2 additions & 4 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,
using meta_schedule::ExtractedTask;
using meta_schedule::ModuleEqual;
using meta_schedule::ModuleHash;
backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
backend::BindParamsInModule(mod, params);
// is_vm=true for backward compatibility
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
Expand All @@ -84,10 +83,9 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,
if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
return;
}
auto [inputs_outputs, constants, fused_name] =
tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true);

if (Optional<tir::PrimFunc> f = tir_converter(inputs_outputs, constants)) {
auto [f, fused_name] = tec::LowerToPrimFunc(relay_func, target, constant_name_supply);
if (f) {
IRModule tir_mod = PrimFuncToIRModule(f.value());
lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod));
}
Expand Down
27 changes: 27 additions & 0 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,33 @@ std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompu
return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_);
}

std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func,
Target target,
NameSupply constant_name_supply) {
ICHECK(relay_func->HasNonzeroAttr(attr::kPrimitive))
<< "The input must be a Relay primitive function.";

auto [inputs_outputs, constants, fused_name] =
tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true);
auto tir_converter = backend::GetTIRConverter();
return std::make_pair(tir_converter(inputs_outputs, constants), fused_name);
}

tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) {
auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply(""));
(void)_; // to suppress -Werror=unused-variable warning
if (f_opt) {
return f_opt.value();
}
LOG(FATAL) << "Failed to convert the Relay function: " << AsText(relay_func, false);
return PrimFunc();
}

TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc")
.set_body_typed([](Function relay_func, Target target) {
return LowerToPrimFunc(relay_func, target);
});

TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
auto tgt = tvm::Target("ext_dev");
LowerToTECompute lower_te_compute(tgt, NameSupply(""));
Expand Down
20 changes: 16 additions & 4 deletions src/relay/backend/te_compiler_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ class CCacheValue : public ObjectRef {
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape);

/*!
* \brief Lowers Relay primitive Function to TE Compute
* \brief Lower Relay primitive Function to TE Compute
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \param constant_name_supply A name supplier for constants.
* \param target The compilation target.
* \param constant_name_supply A name supplier for constants
* across different invocations of this function.
* \param return_inputs If true, prepend input tensors to the output array of tensors.
* \return Tuple of the lowered TE compute, constant raw data, and fused function name.
Expand All @@ -224,10 +224,22 @@ std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompu
const Function& source_func, Target target, NameSupply constant_name_supply,
bool return_inputs = true);

/*!
* \brief Lower Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc.
* \param relay_func The primitive function to be lowered.
* \param target The compilation target.
* \param constant_name_supply A name supplier for constants
* across different invocations of this function.
* \return A pair of the created prim func and the name of the fused function.
*/
std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func,
Target target,
NameSupply constant_name_supply);

/*!
* \brief Create schedule for target.
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \param target The compilation target.
* \param global_var_supply A name supplier for global variables.
* \param constant_name_supply A name supplier for constants.
* \return Pair of schedule and cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
import tvm.testing
from tvm import te
from tvm.script import tir as T
from tvm import relay, tir
from tvm.relay.backend.te_compiler import lower_to_primfunc
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN


def _check(original, transformed):
Expand Down Expand Up @@ -360,5 +365,56 @@ def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
_check(before, after)


def test_allocate_const_after_tensorize():
i_size, o_size, h_size, w_size = 64, 64, 56, 56
k_height_size = k_width_size = 3
w_shape = (o_size, i_size, k_height_size, k_width_size)

data = relay.var("data", shape=(1, i_size, h_size, w_size), dtype="uint8")
weight = relay.var("weight", shape=w_shape, dtype="uint8")
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(k_height_size, k_width_size),
channels=o_size,
padding=(0, 0),
strides=(1, 1),
out_dtype="int32",
)
mod = tvm.IRModule.from_expr(conv2d)

executor = relay.backend.Executor("graph", {"link-params": True})
mod = mod.with_attr("executor", executor)

weight_np = np.random.uniform(1, 10, size=w_shape).astype("uint8")

target = tvm.target.Target("hexagon")

with tvm.transform.PassContext(opt_level=3):
opt_mod, _ = relay.optimize(mod, params={"weight": weight_np}, target=target)

conv2d_func = opt_mod["main"].body.args[0].op
prim_func = lower_to_primfunc(conv2d_func, target)

sch = tir.Schedule(prim_func)
block = sch.get_block("conv2d_NCHWc_int8")
loops = sch.get_loops(block)

sch.reorder(loops[8], loops[4], loops[-1])
sch.decompose_reduction(block, loops[1])
sch.tensorize(loops[4], VRMPY_u8u8i32_INTRIN)

seq = tvm.transform.Sequential(
[
tvm.tir.transform.LowerInitBlock(),
tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
]
)

# The following error is emitted if AllocateConst nodes are not correctly handled:
# Check failed: (buffer_data_to_buffer_.count(source_var)) is false:
_ = seq(sch.mod)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c6652bc

Please sign in to comment.