From ae869680e3990d38fc4401d3c74ae8a78c9c197c Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Fri, 3 Feb 2023 18:38:53 +0000 Subject: [PATCH] Change dtype functions interface to take ints tuple for each tensor The original design for the dtype functions outlined in https://github.com/llvm/torch-mlir/issues/1462 was unable to properly handle ops that take optional tensors as an input when the optional tensor has a value of None. By the time the op gets imported into torch-mlir, if an optional value is None, all information about the original type is lost from the op type signature, preventing torch-mlir from knowing if a value of None was from an optional tensor or not, which was crucial in the original design since each tensor argument must be turned into two separate arguments for the dtype function. This commit changes the interface to dtype functions such that each tensor turns into a tuple of two ints, the first representing the rank of the tensor and the second the dtype of the tensor. Since now there is a one-to-one correspondence between the operands of an op and the operands of its dtype function, there is no ambiguity about which operand of the op corresponds with which operand of the dtype function. To test the implementation, this commit defines dtype functions for the convolution ops, all of which take one optional tensor as an argument. --- ...dding_abstract_interpretation_functions.md | 2 +- e2e_testing/xfail_sets.py | 1 - .../Transforms/AbstractInterpLibrary.cpp | 1019 +++++++++++------ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 12 - .../Transforms/ReifyDtypeCalculations.cpp | 77 +- .../build_tools/abstract_interp_lib_gen.py | 314 +++-- .../importer/jit_ir/build_tools/registry.py | 7 +- .../jit_ir/build_tools/testing_framework.py | 38 +- .../test_suite/__init__.py | 2 + .../Torch/reify-dtype-calculations.mlir | 21 +- 10 files changed, 942 insertions(+), 551 deletions(-) diff --git a/docs/adding_abstract_interpretation_functions.md b/docs/adding_abstract_interpretation_functions.md index 68e1b98b78acb..b5e427e1adfd0 100644 --- a/docs/adding_abstract_interpretation_functions.md +++ b/docs/adding_abstract_interpretation_functions.md @@ -21,7 +21,7 @@ We will use the example of adding support for the `torch.aten.tanh` op. function signatures are: - `def aten〇tanh〡shape(self: List[int]) -> List[int]:` - - `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:` + - `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:` Note the use of `〇` as a separator since `.` or `::` aren't legal in a Python identifier. diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 836ed7a08c980..df51514dfb5cb 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -523,7 +523,6 @@ "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", "MaxPool2dStaticModule_basic", - "ResNet18StaticModule_basic", "ReduceAmaxKeepDim_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 682048b93654e..0f220b778d1e6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7350,19 +7350,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" " func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n" " %true = torch.constant.bool true\n" @@ -7404,209 +7405,222 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n" " %int5 = torch.constant.int 5\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %3 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %3 : !torch.int\n" +" %4 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" " %int5 = torch.constant.int 5\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %4 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %5 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %2 -> () {\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %3 : !torch.int\n" +" %4 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" " %false = torch.constant.bool false\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %6 : !torch.bool\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %5 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %6 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %6 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %3 -> () {\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %4 : !torch.int\n" +" %5 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %5 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" " %int4 = torch.constant.int 4\n" @@ -7619,7 +7633,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int9 = torch.constant.int 9\n" " %int7 = torch.constant.int 7\n" @@ -7629,85 +7643,90 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %false = torch.constant.bool false\n" " %true = torch.constant.bool true\n" " %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" %4:2 = torch.prim.If %3 -> (!torch.bool, !torch.int) {\n" +" %4 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" " torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" " } else {\n" -" %6 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" %7:2 = torch.prim.If %6 -> (!torch.bool, !torch.int) {\n" +" %7 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.int) {\n" " torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" " } else {\n" " torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" " }\n" -" torch.prim.If.yield %7#0, %7#1 : !torch.bool, !torch.int\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.int\n" " }\n" -" %5 = torch.prim.If %4#0 -> (!torch.int) {\n" -" torch.prim.If.yield %4#1 : !torch.int\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" " } else {\n" -" %6 = func.call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" torch.prim.If.yield %6 : !torch.int\n" +" %7 = func.call @__torch__._get_dtype_of_floating_point_op(%1#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" " }\n" -" return %5 : !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %int0 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int11 : !torch.int\n" " }\n" -" return %1 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %int0 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int11 : !torch.int\n" " }\n" -" return %1 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7715,13 +7734,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7729,21 +7749,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7751,13 +7773,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7765,29 +7788,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7795,21 +7819,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7817,7 +7843,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -7839,7 +7865,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int10 = torch.constant.int 10\n" @@ -7847,532 +7873,789 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int9 = torch.constant.int 9\n" " %int6 = torch.constant.int 6\n" " %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" " } else {\n" -" %3 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" -" %5 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" " torch.prim.If.yield %int10 : !torch.int\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield %0 : !torch.int\n" " }\n" -" torch.prim.If.yield %8 : !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" " }\n" -" torch.prim.If.yield %6 : !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" " }\n" -" torch.prim.If.yield %4 : !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" -" return %2 : !torch.int\n" +" return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot have bool dtype\"\n" " %int11 = torch.constant.int 11\n" -" %0 = torch.aten.ne.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" -" %3 = torch.prim.ListConstruct %arg1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %2 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %arg1 : !torch.int\n" +" return %1#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %8 = torch.aten.ne.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %8 : !torch.bool\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If.yield %7 : !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" " }\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" -" torch.prim.If.yield %2 : !torch.int\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" return %5 : !torch.int\n" +" return %7 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %8 = torch.aten.ne.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %8 : !torch.bool\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If.yield %7 : !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" " }\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" -" torch.prim.If.yield %2 : !torch.int\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" return %5 : !torch.int\n" +" return %7 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Result dtype for aten.floor_divide bool\"\n" " %int11 = torch.constant.int 11\n" " %str_0 = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" %7 = torch.aten.ne.int %6, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.aten.ne.int %8, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %6 : !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `other` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `other` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %arg1 : !torch.int\n" +" return %1#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %arg1 : !torch.int\n" +" return %1#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected promoted dtype to be float but not `bfloat16`\"\n" " %false = torch.constant.bool false\n" " %int15 = torch.constant.int 15\n" " %str_0 = torch.constant.str \"AssertionError: `target` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.bool) {\n" -" %9 = torch.aten.ne.int %6, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %9 : !torch.bool\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%8) : (!torch.int) -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %8, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %8 -> () {\n" +" torch.prim.If %10 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %6 : !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %2 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `vec` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `vec` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %7 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %8 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %9 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %9 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `other` cannot be of bool dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be of bool dtype\"\n" " %int11 = torch.constant.int 11\n" -" %0 = torch.aten.ne.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.aten.ne.int %arg3, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n" " %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: `grad_output` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" %7 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = torch.aten.__contains__.int_list %7, %6 : !torch.list, !torch.int -> !torch.bool\n" -" %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n" -" torch.prim.If %9 -> () {\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.aten.__contains__.int_list %9, %8 : !torch.list, !torch.int -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %6 : !torch.int\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %true = torch.constant.bool true\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" " }\n" "}\n" ""; diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 34acd2c80744d..32208cb682fe2 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -692,18 +692,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - // Promote the two dtypes assuming non-zero rank. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - // Dtype is always float32, except for bfloat16, float64 and nullptr after // promotion and assuming possible-zero rank. if (isa(op)) { diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index e66cc76fe0826..9eac538743b40 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -19,55 +19,25 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -static bool isTensorTypeOrWrappedTensorType(Type type) { - // Allowing tuples as arguments to dtype calculation functions can cause - // issues. For example, if an argument is a tuple of tensors and ints, there - // would be no way of differentiating the original ints from the ints created - // to represent the dtype and rank of the tensors. Therefore, to avoid this - // and keep things simple, the tuple type is not allowed. This works well in - // practice, since PyTorch op signatures don't seem to take tuples as inputs. - assert(!type.isa() && - "dtype calculation functions are expected to not have tuples of " - "tensors as arguments"); - - if (type.isa()) - return true; - - if (auto optionalType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(optionalType.getContainedType()); - } else if (auto listType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(listType.getContainedType()); - } else { - return false; - } -} - // Massage the op operands to match the dtype function signature. // The dtype function generally takes the same operands as the op, with a few -// systematic modifications, such as replacing tensors with a rank and dtype -// argument. +// systematic modifications, such as replacing each tensor with a tuple of +// its rank and dtype. static FailureOr> dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, ValueRange originalOperands, func::FuncOp dtypeFunc) { - // Turns a tensor operand into an operand representing the rank of the tensor - auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand, - Type desiredType) -> Value { - if (desiredType.isa() && - operand.getType().isa()) { - auto sizeListType = - Torch::ListType::get(Torch::IntType::get(b.getContext())); - Value size = b.create(loc, sizeListType, operand); - return b.create(loc, desiredType, size); - } - return operand; - }; - - // Turns a tensor operand into an operand representing the dtype of the tensor + // Turn every tensor into a tuple of (tensor_rank, tensor_dtype) auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, Type desiredType) -> Value { - if (desiredType.isa() && + if (desiredType.isa() && operand.getType().isa()) { - return b.create(loc, desiredType, operand); + Type intType = Torch::IntType::get(b.getContext()); + Type sizeListType = Torch::ListType::get(intType); + Value size = b.create(loc, sizeListType, operand); + Value rank = b.create(loc, intType, size); + Value dtype = b.create(loc, intType, operand); + return b.create(loc, desiredType, + ArrayRef{rank, dtype}); } return operand; }; @@ -79,26 +49,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, "`dtypeFunc` should have at least one argument for each argument in " "`originalOperands`"); Type desiredType = desiredTypes.front(); - if (isTensorTypeOrWrappedTensorType(operand.getType())) { - assert(desiredTypes.size() >= 2 && - "`dtypeFunc` should have two arguments for each tensor argument " - "in `originalOperands`"); - FailureOr rankArg, dtypeArg; - if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType, - rankArgAdjuster))) - return failure(); - desiredTypes = desiredTypes.drop_front(); - desiredType = desiredTypes.front(); - if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType, - dtypeArgAdjuster))) - return failure(); - dtypeFuncArgs.append({*rankArg, *dtypeArg}); - } else { - FailureOr otherArg; - if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType))) - return failure(); - dtypeFuncArgs.push_back(*otherArg); - } + FailureOr otherArg; + if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType, + dtypeArgAdjuster))) + return failure(); + dtypeFuncArgs.push_back(*otherArg); desiredTypes = desiredTypes.drop_front(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 65282ea940961..3f187bc0b6261 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1133,81 +1133,98 @@ def _get_dtype_of_floating_point_op(input_dtype: int) -> int: return torch.float32 @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) -def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert self_dtype != torch.float16, "`self` cannot have float16 dtype" return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇reciprocal〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇log〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) -def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype) and self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, - torch.int64, torch.float16, torch.complex64, torch.complex128})) -def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, + torch.int64, torch.float16, torch.complex64, torch.complex128})) +def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype) and not is_integer_dtype(self_dtype) and self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0])) -def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int: + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0])) +def aten〇frobenius_norm〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int], keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_integer_dtype(self_dtype) if self_dtype == torch.complex128: return torch.float64 @@ -1216,25 +1233,28 @@ def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: Li return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: +def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype return torch.uint8 if self_dtype == torch.uint8 else torch.bool @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇any〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype return torch.uint8 if self_dtype == torch.uint8 else torch.bool @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇eq〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇eq〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function( @@ -1242,7 +1262,8 @@ def aten〇eq〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇ge〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @@ -1251,13 +1272,16 @@ def aten〇ge〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇gt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function( _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇gt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇gt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool @@ -1267,24 +1291,25 @@ def aten〇gt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇le〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇logical_and〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇logical_not〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇logical_or〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇logical_xor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function( @@ -1292,13 +1317,16 @@ def aten〇logical_xor〡dtype(self_rank: int, self_dtype: int, other_rank: int, num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇lt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function( _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇lt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇lt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool @@ -1306,7 +1334,7 @@ def aten〇lt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇ne〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: return torch.bool @check_dtype_function([ @@ -1323,7 +1351,8 @@ def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.float16, torch.bfloat16})) -def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: +def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype if is_complex_dtype(self_dtype): return self_dtype elif self_dtype == torch.float: @@ -1340,14 +1369,17 @@ def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = N num_of_tensors=1, error_types={torch.bool}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.bool}, other=0)) -def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.bool, "`self` cannot have bool dtype" return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇__and__〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1355,35 +1387,43 @@ def aten〇__and__〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op()) -def aten〇add〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int: +def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_and〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_or〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_xor〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1399,7 +1439,9 @@ def aten〇bitwise_xor〇Tensor〡dtype(self_rank: int, self_dtype: int, other_r # Different type ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), TensorOfShape(2, 4, 3, dtype=torch.int32))]) -def aten〇bmm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int: +def aten〇bmm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert mat2_dtype not in [torch.float16, torch.bool], \ @@ -1408,7 +1450,9 @@ def aten〇bmm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dty return self_dtype @check_dtype_function(_check_two_tensor_op()) -def aten〇div〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] promoted_dtype = promote_dtypes(ranks, dtypes) @@ -1419,7 +1463,9 @@ def aten〇div〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int return torch.float32 @check_dtype_function(_check_two_tensor_op(rounding_mode=None)) -def aten〇div〇Tensor_mode〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, rounding_mode: Optional[str]) -> int: +def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rounding_mode: Optional[str]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] promoted_dtype = promote_dtypes(ranks, dtypes) @@ -1430,7 +1476,9 @@ def aten〇div〇Tensor_mode〡dtype(self_rank: int, self_dtype: int, other_rank return torch.float32 @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool})) -def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1448,7 +1496,9 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int # Different type ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), TensorOfShape(2, 4, 3, dtype=torch.int32))]) -def aten〇matmul〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇matmul〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert other_dtype not in [torch.float16, torch.bool], \ @@ -1457,7 +1507,9 @@ def aten〇matmul〡dtype(self_rank: int, self_dtype: int, other_rank: int, othe return self_dtype @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇maximum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇maximum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1465,7 +1517,9 @@ def aten〇maximum〡dtype(self_rank: int, self_dtype: int, other_rank: int, oth return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇minimum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1481,7 +1535,9 @@ def aten〇minimum〡dtype(self_rank: int, self_dtype: int, other_rank: int, oth # Different type ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(4, 3, dtype=torch.int32))]) -def aten〇mm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int: +def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert mat2_dtype not in [torch.float16, torch.bool], \ @@ -1489,10 +1545,13 @@ def aten〇mm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtyp assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype" return self_dtype -@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, - output_error_types={torch.bool, torch.uint8, torch.int8, torch.int16, - torch.int32, torch.int64, torch.bfloat16})) -def aten〇mse_loss〡dtype(self_rank: int, self_dtype: int, target_rank: int, target_dtype: int, reduction: int = 1) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.complex64, torch.complex128}, + output_error_types={torch.bool, torch.uint8, torch.int8, torch.int16, + torch.int32, torch.int64, torch.bfloat16})) +def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(target_dtype), "`target` cannot be complex" ranks: List[Optional[int]] = [self_rank, target_rank] @@ -1503,7 +1562,9 @@ def aten〇mse_loss〡dtype(self_rank: int, self_dtype: int, target_rank: int, t return promoted_dtype @check_dtype_function(_check_two_tensor_op()) -def aten〇mul〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) @@ -1517,7 +1578,9 @@ def aten〇mul〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int # Different type ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(4, dtype=torch.int32))]) -def aten〇mv〡dtype(self_rank: int, self_dtype: int, vec_rank: int, vec_dtype: int) -> int: +def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + vec_rank, vec_dtype = vec_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert vec_dtype not in [torch.float16, torch.bool], \ @@ -1528,7 +1591,9 @@ def aten〇mv〡dtype(self_rank: int, self_dtype: int, vec_rank: int, vec_dtype: return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op(input_error_types={torch.bool})) -def aten〇sub〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int: +def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.bool, "`self` cannot be of bool dtype" assert other_dtype != torch.bool, "`other` cannot be of bool dtype" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1536,7 +1601,9 @@ def aten〇sub〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool, torch.float16}, threshold=0)) -def aten〇threshold_backward〡dtype(grad_output_rank: int, grad_output_dtype: int, self_rank: int, self_dtype: int, threshold: Union[int, float]) -> int: +def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + grad_output_rank, grad_output_dtype = grad_output_rank_dtype assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex" assert not is_complex_dtype(self_dtype), "`self` cannot be complex" ranks: List[Optional[int]] = [grad_output_rank, self_rank] @@ -1546,6 +1613,115 @@ def aten〇threshold_backward〡dtype(grad_output_rank: int, grad_output_dtype: "Result dtype for aten.threshold_backward cannot be bool or float16" return promoted_dtype +_convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False, "allow_tf32" : False} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs) +]) +def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +_convolution_deprecated_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) +]) +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)], error_types={torch.bool, torch.float16}) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert input_dtype not in [torch.bool, torch.float16] + assert weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)], + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.float16}) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert (input_dtype == torch.int64 or not is_integer_dtype(input_dtype)) and input_dtype != torch.float16 + assert (weight_dtype == torch.int64 or not is_integer_dtype(weight_dtype)) and weight_dtype != torch.float16 + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **convolution_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs) +]) +def aten〇convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + # ============================================================================== # Main # ============================================================================== diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 81c669e1abe3c..550b47802e76c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -86,7 +86,7 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str: """ # Dtype functions only care about the rank and dtype of tensors. if "Tensor" in pytype: - return pytype.replace("Tensor", "int") + return pytype.replace("Tensor", "Tuple[int, int]") return _pytype_to_fn_pytype_common(pytype) def _pytype_to_decomposition_fn_pytype(pytype: str) -> str: @@ -232,8 +232,7 @@ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: default = _get_default_value(arg) parameter_name = _rename_python_keyword_parameter_name(arg["name"]) if "Tensor" in arg["pytype"]: - return ", ".join([f"{parameter_name}_rank: {pytype}{default}", - f"{parameter_name}_dtype: {pytype}{default}"]) + return f"{parameter_name}_rank_dtype: {pytype}{default}" return f"{parameter_name}: {pytype}{default}" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: @@ -241,7 +240,7 @@ def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: # results of type `number`. Here we handle this case because # `_pytype_to_dtype_fn_pytype` will replace `number` with # `Union[int, float]`. - if arg["pytype"] == "number": + if arg["pytype"] in ["number", "Tensor"]: return "int" return _pytype_to_dtype_fn_pytype(arg["pytype"]) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index 1e82a59706b7c..efd270b78a7f9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -96,36 +96,6 @@ def _recursively_transform_tensor_args( return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o) raise Exception(f"Unhandled type {type(o)}") -def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]: - """Converts an Invocation argument to a dtype function argument. - - TensorOfShape is replaced with two ints representing the rank - and dtype of the tensor, respectively. - """ - def contains_tensor(o: Any) -> bool: - if o is None or isinstance(o, (float, int)): - return False - if isinstance(o, TensorOfShape): - return True - if isinstance(o, (list, tuple)): - for elem in o: - if contains_tensor(elem): - return True - return False - raise Exception(f"Unhandled type {type(o)}") - - result = [] - for arg in arguments: - if contains_tensor(arg): - rank_arg = _recursively_transform_tensor_args( - arg, lambda x: len(x.shape)) - dtype_arg = _recursively_transform_tensor_args( - arg, lambda x: x.dtype) - result += [rank_arg, dtype_arg] - else: - result.append(arg) - return result - class Invocation: """Representation of a single op invocation (i.e. list of args to the op). @@ -135,8 +105,8 @@ class Invocation: Specifically, this class has special knowledge of `TensorOfShape` and translates it appropriately to either a tensor (for the real op), a - `List[int]` for the shape function, and two `int`s representing - the tensor rank and dtype in the case of a dtype function. + `List[int]` for the shape function, and a tuple with two `int`s + representing the tensor rank and dtype in the case of a dtype function. This class also tracks whether the invocation is expected to raise an exception for greater precision when interpreting errors raised during @@ -170,7 +140,9 @@ def to_shape_function_args(self): def to_dtype_function_args(self): """Gets positional arguments appropriate for a dtype function.""" - return _convert_to_dtype_function_args(self.args) + tensor_transformer = lambda o: (len(o.shape), o.dtype) + return _recursively_transform_tensor_args( + self.args, tensor_transformer) def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index d3795307fbf18..9dd80b0d22674 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -9,6 +9,8 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", "NormalizeModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", "MobilenetV3Module_basic", } diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 8ed19c68a159d..455dfbbfd07de 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -12,7 +12,8 @@ // CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list -> !torch.int // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int -// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE:.*]] = torch.prim.TupleConstruct %[[RANK]], %[[DTYPE]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK_DTYPE]]) : (!torch.tuple) -> !torch.int // CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int // CHECK: } : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor @@ -38,6 +39,20 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // ----- +// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.convolution( + +// CHECK-LABEL: func.func @op_with_optional_tensor_arg$none( +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[OPTIONAL_TUPLE:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional> +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.convolution({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional>, {{.*}}) -> !torch.int +func.func @op_with_optional_tensor_arg$none(%input: !torch.vtensor, %weight: !torch.vtensor, %stride: !torch.list, %padding: !torch.list, %dilation: !torch.list, %transposed: !torch.bool, %output_padding: !torch.list, %groups: !torch.int) -> !torch.vtensor { + %bias_none = torch.constant.none + %0 = torch.aten.convolution %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor + return %0 : !torch.vtensor +} + +// ----- + // CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide( // CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args( @@ -46,10 +61,12 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list -> !torch.int // CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int +// CHECK: %[[RANK_DTYPE0:.*]] = torch.prim.TupleConstruct %[[RANK0]], %[[DTYPE0]] : !torch.int, !torch.int -> !torch.tuple // CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list -> !torch.int // CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int -// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE1:.*]] = torch.prim.TupleConstruct %[[RANK1]], %[[DTYPE1]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK_DTYPE0]], %[[RANK_DTYPE1]]) : (!torch.tuple, !torch.tuple) -> !torch.int func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return %0 : !torch.vtensor