diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py index d1c8906fda..11ff881d25 100644 --- a/python/tvm/relax/testing/relay_translator.py +++ b/python/tvm/relax/testing/relay_translator.py @@ -122,17 +122,28 @@ def visit_func(node): for arg in args: if arg in var_map: arg_expr = var_map[arg] - if not isinstance(arg_expr.shape, relax.Tuple): + if isinstance(arg_expr.checked_type, relax.DynTensorType): new_args.append(arg_expr) te_inputs.append(tvm.relax.expr.te_tensor(arg_expr)) - else: - n_tensor = len(arg_expr.checked_type.fields) - assert isinstance(arg_expr.checked_type, relax.TupleType) - assert len(arg_expr.shape.fields) == n_tensor + elif isinstance(arg_expr.checked_type, relax.TupleType): + assert isinstance(arg_expr.shape, relax.Tuple) + assert len(arg_expr.shape.fields) == len(arg_expr.checked_type.fields) + n_tensor = len(arg_expr.shape.fields) + bound_tuple = bb.lookup_binding(arg_expr) + if isinstance(bound_tuple, relax.Tuple): + assert len(bound_tuple) == n_tensor for i in range(n_tensor): - item = bb.emit(relax.TupleGetItem(arg_expr, i)) + if isinstance(bound_tuple, relax.Tuple): + item = bb.emit(bound_tuple[i]) + else: + item = bb.emit(relax.TupleGetItem(arg_expr, i)) new_args.append(item) te_inputs.append(tvm.relax.expr.te_tensor(item)) + else: + raise TypeError( + f"CallTIR argument type being {type(arg_expr.checked_type)} is not " + "supported." + ) op_name = node.op.name attrs = node.attrs @@ -217,6 +228,4 @@ def visit_func(node): bb._begin_dataflow_block() relay.analysis.post_order_visit(mod["main"], visit_func) - relax_mod = bb.get() - relax_mod["main"] = relax_mod["main"].with_attr("global_symbol", "main") return bb.get() diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py index c39bc45def..06f5623356 100644 --- a/tests/python/relax/test_relay_translator.py +++ b/tests/python/relax/test_relay_translator.py @@ -277,9 +277,8 @@ def tir_matmul( def test_translate_tuple_arg(): x = relay.var("x", shape=(10, 16)) y = relay.var("y", shape=(10, 16)) - t = relay.Tuple((x, y)) - relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate(t, axis=-1))) - relay_vm, relax_vm, relax_mod = translate_and_build_vms(relay_mod) + relay_mod = tvm.IRModule.from_expr(relay.Function([x, y], relay.concatenate((x, y), axis=-1))) + relax_mod = relay_translator.from_relay(relay_mod["main"], target="llvm") # Construct the expected module bb = relax.BlockBuilder() @@ -295,9 +294,9 @@ def test_translate_tuple_arg(): ) with bb.function("main", [x_relax, y_relax]): with bb.dataflow(): - lv = bb.emit(relax.Tuple((x_relax, y_relax))) - lv1 = bb.emit(relax.TupleGetItem(lv, 0)) - lv2 = bb.emit(relax.TupleGetItem(lv, 1)) + _ = bb.emit(relax.Tuple((x_relax, y_relax))) + lv1 = bb.emit(x_relax) + lv2 = bb.emit(y_relax) lv3 = bb.emit_te(topi.x86.concatenate, (lv1, lv2), axis=-1) gv = bb.emit_output(lv3) bb.emit_func_output(gv)