Skip to content

Commit

Permalink
[Translator] Remove global symbol and follow-up fix for tlc-pack#262 (t…
Browse files Browse the repository at this point in the history
…lc-pack#316)

This PR removes the `global_symbol` linkage added by Relay Translator. It also fixes unaddressed comments of tlc-pack#262.

All tests can pass locally and I believe it is safe to merge this PR directly.
  • Loading branch information
MasterJH5574 authored and junrushao committed Jan 25, 2023
1 parent 92a1cb2 commit d88200d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
25 changes: 17 additions & 8 deletions python/tvm/relax/testing/relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,28 @@ def visit_func(node):
for arg in args:
if arg in var_map:
arg_expr = var_map[arg]
if not isinstance(arg_expr.shape, relax.Tuple):
if isinstance(arg_expr.checked_type, relax.DynTensorType):
new_args.append(arg_expr)
te_inputs.append(tvm.relax.expr.te_tensor(arg_expr))
else:
n_tensor = len(arg_expr.checked_type.fields)
assert isinstance(arg_expr.checked_type, relax.TupleType)
assert len(arg_expr.shape.fields) == n_tensor
elif isinstance(arg_expr.checked_type, relax.TupleType):
assert isinstance(arg_expr.shape, relax.Tuple)
assert len(arg_expr.shape.fields) == len(arg_expr.checked_type.fields)
n_tensor = len(arg_expr.shape.fields)
bound_tuple = bb.lookup_binding(arg_expr)
if isinstance(bound_tuple, relax.Tuple):
assert len(bound_tuple) == n_tensor
for i in range(n_tensor):
item = bb.emit(relax.TupleGetItem(arg_expr, i))
if isinstance(bound_tuple, relax.Tuple):
item = bb.emit(bound_tuple[i])
else:
item = bb.emit(relax.TupleGetItem(arg_expr, i))
new_args.append(item)
te_inputs.append(tvm.relax.expr.te_tensor(item))
else:
raise TypeError(
f"CallTIR argument type being {type(arg_expr.checked_type)} is not "
"supported."
)

op_name = node.op.name
attrs = node.attrs
Expand Down Expand Up @@ -217,6 +228,4 @@ def visit_func(node):
bb._begin_dataflow_block()
relay.analysis.post_order_visit(mod["main"], visit_func)

relax_mod = bb.get()
relax_mod["main"] = relax_mod["main"].with_attr("global_symbol", "main")
return bb.get()
11 changes: 5 additions & 6 deletions tests/python/relax/test_relay_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,8 @@ def tir_matmul(
def test_translate_tuple_arg():
x = relay.var("x", shape=(10, 16))
y = relay.var("y", shape=(10, 16))
t = relay.Tuple((x, y))
relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate(t, axis=-1)))
relay_vm, relax_vm, relax_mod = translate_and_build_vms(relay_mod)
relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate((x, y), axis=-1)))
relax_mod = relay_translator.from_relay(relay_mod["main"], target="llvm")

# Construct the expected module
bb = relax.BlockBuilder()
Expand All @@ -295,9 +294,9 @@ def test_translate_tuple_arg():
)
with bb.function("main", [x_relax, y_relax]):
with bb.dataflow():
lv = bb.emit(relax.Tuple((x_relax, y_relax)))
lv1 = bb.emit(relax.TupleGetItem(lv, 0))
lv2 = bb.emit(relax.TupleGetItem(lv, 1))
_ = bb.emit(relax.Tuple((x_relax, y_relax)))
lv1 = bb.emit(x_relax)
lv2 = bb.emit(y_relax)
lv3 = bb.emit_te(topi.x86.concatenate, (lv1, lv2), axis=-1)
gv = bb.emit_output(lv3)
bb.emit_func_output(gv)
Expand Down

0 comments on commit d88200d

Please sign in to comment.