Skip to content

Commit

Permalink
[CORE][Relay] Swap and remove compile_engine with te_compiler followu…
Browse files Browse the repository at this point in the history
…p of apache#8775 (apache#9282)

* Remove compile_engine.h for real

* Fix format

* RM compile_engine.cc

* Swap compile engine with TECompiler

* Cleanup on compile engine py leftovers

* [WIP] Exposing legacy compile engine capabilities through TE Compiler

* Swap usages for depreciated compile engine with TE compiler

* Track and replace usages of compile engine refactor them to TE compiler

* [Docs] Log helper mod

* Remove depreciated function for lookup compile engine cachce

* Fix typos

* Debug misc cleanups

* Register global pass for using te compiler for auto scheduler

* Fix tests using the legacy compile engine

* Fix broken autotuner tests and minor cleanups

* Swap compile engine with te_compiler in rst config

* PR nits

* Fix failed test

Co-authored-by: Jared Roesch <[email protected]>
  • Loading branch information
2 people authored and ylc committed Jan 7, 2022
1 parent dd6922a commit 9fc9db1
Show file tree
Hide file tree
Showing 37 changed files with 151 additions and 629 deletions.
8 changes: 4 additions & 4 deletions docs/arch/relay_op_strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,14 @@ will then be chosen. Implementations with same priority level in this case leads
to an undefined behavior, and any of them might be selected.

The selection policy for ops with symbolic input shapes is still work in
progess. Currently, if any input tensor has a symbolic shape, only the
progress. Currently, if any input tensor has a symbolic shape, only the
implementation with highest priority level will be used for this operator. This
will be updated after the implemention finishes.
will be updated after the implementation finishes.

For debug purpose, you can add the following lines before you compile the Relay
model to learn which implementation is used for each operator.

.. code:: python
logging.getLogger("compile_engine").setLevel(logging.INFO)
logging.getLogger("compile_engine").addHandler(logging.StreamHandler(sys.stdout))
logging.getLogger("te_compiler").setLevel(logging.INFO)
logging.getLogger("te_compiler").addHandler(logging.StreamHandler(sys.stdout))
2 changes: 1 addition & 1 deletion docs/reference/api/python/relay/backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ tvm.relay.backend
.. automodule:: tvm.relay.backend.interpreter
:members:

.. automodule:: tvm.relay.backend.compile_engine
.. automodule:: tvm.relay.backend.te_compiler
:members:

.. automodule:: tvm.relay.backend.graph_executor_codegen
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def call_all_topi_funcs(mod, params, target, opt_level=3):
opt_level=opt_level,
config={
"relay.backend.use_auto_scheduler": True,
"relay.backend.disable_compile_engine_cache": True,
},
disabled_pass={"AutoSchedulerLayoutRewrite"},
):
Expand Down Expand Up @@ -165,7 +164,8 @@ class TracingMode:
"""Two modes for tracing"""

EXTRACT_TASK = 0 # trace all topi calls to extract tasks
EXTRACT_COMPLEX_TASK_ONLY = 1 # same as EXTRACT_TASK but ignore the task without complex ops
# same as EXTRACT_TASK but ignore the task without complex ops
EXTRACT_COMPLEX_TASK_ONLY = 1
PREPARE_LAYOUT_REWRITE = 2 # trace topi calls to prepare layout rewrite


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _traverse_expr(node):
params.append(free_var)
call = relay.Call(node.op, params, node.attrs)
mod = tvm.IRModule.from_expr(relay.Function(params, call))
relay.backend.compile_engine.get().clear()
relay.backend.te_compiler.get().clear()
tracing_target = _replace_device_with_tracing(tvm_target)
build_thread = threading.Thread(
target=relay.build, args=(mod, tracing_target, None, None)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
assert isinstance(
mod, tvm.IRModule
), "only support relay Module or Function to be tuned"
relay.backend.compile_engine.get().clear()
relay.backend.te_compiler.get().clear()
# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=_lower, args=(mod, target, param))
build_thread.start()
build_thread.join()
relay.backend.compile_engine.get().clear()
relay.backend.te_compiler.get().clear()
# Clear the warning message cache in FallbackContext
if isinstance(DispatchContext.current, FallbackContext):
DispatchContext.current.memory = {}
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# specific language governing permissions and limitations
# under the License.
"""Backend codegen modules for relay."""
from . import compile_engine
from . import te_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=len-as-condition,no-else-return,invalid-name
"""Backend code generation engine."""
"""TE compiler engine (replacing legacy compile_engine)."""
from __future__ import absolute_import

import logging
import numpy as np
import tvm
from tvm import te, autotvm
from tvm.ir.transform import PassContext
Expand All @@ -31,7 +30,7 @@
from .. import ty as _ty
from . import _backend

logger = logging.getLogger("compile_engine")
logger = logging.getLogger("te_compiler")
autotvm_logger = logging.getLogger("autotvm")

_first_warning = True
Expand All @@ -47,7 +46,7 @@ def __init__(self, outputs, implement):

@tvm._ffi.register_object("relay.CCacheKey")
class CCacheKey(Object):
"""Key in the CompileEngine.
"""Key in the TE Compiler.
Parameters
----------
Expand All @@ -64,7 +63,7 @@ def __init__(self, source_func, target):

@tvm._ffi.register_object("relay.CCacheValue")
class CCacheValue(Object):
"""Value in the CompileEngine, including usage statistics."""
"""Value in the TE Compiler, including usage statistics."""


def _get_cache_key(source_func, target):
Expand All @@ -79,24 +78,6 @@ def _get_cache_key(source_func, target):
return source_func


def get_shape(shape):
"""Convert the shape to correct dtype and vars."""
ret = []
for dim in shape:
if isinstance(dim, tvm.tir.IntImm):
if libinfo()["INDEX_DEFAULT_I64"] == "ON":
ret.append(dim)
else:
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.tir.IntImm("int32", val))
elif isinstance(dim, tvm.tir.Any):
ret.append(te.var("any_dim", "int32"))
else:
ret.append(dim)
return ret


def get_valid_implementations(op, attrs, inputs, out_type, target):
"""Get all valid implementations from the op strategy.
Expand Down Expand Up @@ -275,6 +256,24 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
return best_plevel_impl, outputs[best_plevel_impl]


def get_shape(shape):
"""Convert the shape to correct dtype and vars."""
ret = []
for dim in shape:
if isinstance(dim, tvm.tir.IntImm):
if libinfo()["INDEX_DEFAULT_I64"] == "ON":
ret.append(dim)
else:
val = int(dim)
assert val <= np.iinfo(np.int32).max
ret.append(tvm.tir.IntImm("int32", val))
elif isinstance(dim, tvm.tir.Any):
ret.append(te.var("any_dim", "int32"))
else:
ret.append(dim)
return ret


@tvm._ffi.register_func("relay.backend.lower_call")
def lower_call(call, inputs, target):
"""Lower the call expression to op implementation and tensor outputs."""
Expand Down Expand Up @@ -322,12 +321,12 @@ def lower_call(call, inputs, target):
return LoweredOutput(outputs, best_impl)


@tvm._ffi.register_object("relay.CompileEngine")
class CompileEngine(Object):
"""CompileEngine to get lowered code."""
@tvm._ffi.register_object("relay.TECompiler")
class TECompiler(Object):
"""TECompiler to get lowered code."""

def __init__(self):
raise RuntimeError("Cannot construct a CompileEngine")
raise RuntimeError("Cannot construct a TECompiler")

def lower(self, source_func, target=None, mod_name="default"):
"""Lower a source_func to a CachedFunc.
Expand All @@ -349,7 +348,7 @@ def lower(self, source_func, target=None, mod_name="default"):
try:
mod_name = mangle_module_name(mod_name)
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLower(self, key, mod_name)
return _backend._TECompilerLower(self, key, mod_name)
except Exception:
import traceback

Expand All @@ -360,10 +359,6 @@ def lower(self, source_func, target=None, mod_name="default"):
msg += "--------------------------\n"
raise RuntimeError(msg)

def lower_shape_func(self, source_func, target=None):
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLowerShapeFunc(self, key)

def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.runtime.PackedFunc.
Expand All @@ -381,87 +376,30 @@ def jit(self, source_func, target=None):
The result of jited function.
"""
key = _get_cache_key(source_func, target)
return _backend._CompileEngineJIT(self, key)
return _backend._TECompilerJIT(self, key)

def clear(self):
"""clear the existing cached functions"""
_backend._CompileEngineClear(self)
_backend._TECompilerClear(self)

def items(self):
"""List items in the cache.
Returns
-------
item_list : List[Tuple[CCacheKey, CCacheValue]]
The list of items.
"""
res = _backend._CompileEngineListItems(self)
assert len(res) % 2 == 0
return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)]

def shape_func_items(self):
"""List items in the shape_func_cache.
Returns
-------
item_list : List[Tuple[CCacheKey, CCacheValue]]
The list of shape_func_items.
"""
res = _backend._CompileEngineListShapeFuncItems(self)
res = _backend._TECompilerListItems(self)
assert len(res) % 2 == 0
return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)]

def get_current_ccache_key(self):
return _backend._CompileEngineGetCurrentCCacheKey(self)

def dump(self):
"""Return a string representation of engine dump.
Returns
-------
dump : str
The dumped string representation
"""
items = self.items()
res = "====================================\n"
res += "CompilerEngine dump, %d items cached\n" % len(items)
for k, v in items:
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
res += "inputs={}\n".format(v.cached_func.inputs)
res += "outputs={}\n".format(v.cached_func.outputs)
res += "function: \n"
res += v.cached_func.funcs.astext() + "\n"
res += "===================================\n"
shape_func_items = self.shape_func_items()
res += "%d shape_func_items cached\n" % len(shape_func_items)
for k, v in shape_func_items:
res += "------------------------------------\n"
res += "target={}\n".format(k.target)
res += "use_count={}\n".format(v.use_count)
res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint)
res += "----relay function----\n"
res += k.source_func.astext() + "\n"
res += "----tir function----- \n"
res += "inputs={}\n".format(v.cached_func.inputs)
res += "outputs={}\n".format(v.cached_func.outputs)
res += "function: \n"
res += v.cached_func.funcs.astext() + "\n"
res += "===================================\n"
return res


def get():
"""Get the global compile engine.
"""Get the global TE Compiler.
Returns
-------
engine : tvm.relay.backend.CompileEngine
The compile engine.
engine : tvm.relay.backend.TECompiler
The TE Compiler.
"""
return _backend._CompileEngineGlobal()
return _backend._TECompilerGlobal()
13 changes: 8 additions & 5 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import tvm
from tvm import relay
from tvm.relay.adt import Pattern
from tvm.relay.backend import compile_engine
from tvm.relay.backend import te_compiler
from tvm.relay.expr import Expr, GlobalVar, Var
from tvm.relay.function import Function
from tvm.relay.expr_functor import ExprFunctor
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, mod, target) -> None:
super().__init__()
self.mod = mod
self.tgt = target
self.engine = compile_engine.get()
self.tec = te_compiler.get()
self.fun_no = 0
self.var_no = 0
self.var_map = {}
Expand Down Expand Up @@ -153,7 +153,10 @@ def parse_name(self, name: str):
def parse_numpy_array(self, arr):
"""Given a Numpy array, produces an appropriate Python array
or numerical literal representing its contents."""
parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i)

def parse_single(i):
return NameConstant(i) if isinstance(i, bool) else Num(i)

if arr.ndim == 0:
return parse_single(arr.item())
if arr.ndim == 1:
Expand Down Expand Up @@ -240,11 +243,11 @@ def create_op_call(self, op: Function, relay_args, py_args):
the generated Python code."""

# compile the function and register globally
cc_key = compile_engine.CCacheKey(op, self.tgt)
cc_key = te_compiler.CCacheKey(op, self.tgt)
func_hash = tvm.ir.structural_hash(op)
op_name = "_lowered_op_{}".format(func_hash)
if not tvm.get_global_func(op_name, allow_missing=True):
jitted = self.engine.jit(cc_key, self.tgt)
jitted = self.tec.jit(cc_key, self.tgt)
tvm.register_func(op_name, jitted)

def convert_input(py_input, arg_type):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
dispatch_ctx = autotvm.task.DispatchContext.current

_, outs = relay.backend.compile_engine.select_implementation(
_, outs = relay.backend.te_compiler.select_implementation(
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/bifrost/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
dispatch_ctx = autotvm.task.DispatchContext.current

_, outs = relay.backend.compile_engine.select_implementation(
_, outs = relay.backend.te_compiler.select_implementation(
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
data, kernel = tinfos
out_dtype = out_type.dtype

impl, outs = relay.backend.compile_engine.select_implementation(
impl, outs = relay.backend.te_compiler.select_implementation(
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/conv3d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
dispatch_ctx = autotvm.task.DispatchContext.current

_, outs = relay.backend.compile_engine.select_implementation(
_, outs = relay.backend.te_compiler.select_implementation(
relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/intel_graphics/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
cfg = dispatch_ctx.query(target, None)
workload = cfg.workload
else:
_, outs = relay.backend.compile_engine.select_implementation(
_, outs = relay.backend.te_compiler.select_implementation(
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/mali/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
data, kernel = tinfos
out_dtype = out_type.dtype

impl, outs = relay.backend.compile_engine.select_implementation(
impl, outs = relay.backend.te_compiler.select_implementation(
relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
)
workload = autotvm.task.get_workload(outs)
Expand Down
Loading

0 comments on commit 9fc9db1

Please sign in to comment.