diff --git a/docs/adding_abstract_interpretation_functions.md b/docs/adding_abstract_interpretation_functions.md index 68e1b98b78acb..b5e427e1adfd0 100644 --- a/docs/adding_abstract_interpretation_functions.md +++ b/docs/adding_abstract_interpretation_functions.md @@ -21,7 +21,7 @@ We will use the example of adding support for the `torch.aten.tanh` op. function signatures are: - `def aten〇tanh〡shape(self: List[int]) -> List[int]:` - - `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:` + - `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:` Note the use of `〇` as a separator since `.` or `::` aren't legal in a Python identifier. diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 836ed7a08c980..df51514dfb5cb 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -523,7 +523,6 @@ "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", "MaxPool2dStaticModule_basic", - "ResNet18StaticModule_basic", "ReduceAmaxKeepDim_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 682048b93654e..0f220b778d1e6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7350,19 +7350,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" " func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n" " %true = torch.constant.bool true\n" @@ -7404,209 +7405,222 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n" " %int5 = torch.constant.int 5\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %3 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %3 : !torch.int\n" +" %4 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" " %int5 = torch.constant.int 5\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %4 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %5 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %2 -> () {\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %3 : !torch.int\n" +" %4 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" " %false = torch.constant.bool false\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %6 : !torch.bool\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %5 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %6 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %6 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %3 -> () {\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %4 : !torch.int\n" +" %5 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %5 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" " %int4 = torch.constant.int 4\n" @@ -7619,7 +7633,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int9 = torch.constant.int 9\n" " %int7 = torch.constant.int 7\n" @@ -7629,85 +7643,90 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %false = torch.constant.bool false\n" " %true = torch.constant.bool true\n" " %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" %4:2 = torch.prim.If %3 -> (!torch.bool, !torch.int) {\n" +" %4 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" " torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" " } else {\n" -" %6 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" %7:2 = torch.prim.If %6 -> (!torch.bool, !torch.int) {\n" +" %7 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.int) {\n" " torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" " } else {\n" " torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" " }\n" -" torch.prim.If.yield %7#0, %7#1 : !torch.bool, !torch.int\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.int\n" " }\n" -" %5 = torch.prim.If %4#0 -> (!torch.int) {\n" -" torch.prim.If.yield %4#1 : !torch.int\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" " } else {\n" -" %6 = func.call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" torch.prim.If.yield %6 : !torch.int\n" +" %7 = func.call @__torch__._get_dtype_of_floating_point_op(%1#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" " }\n" -" return %5 : !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" +" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %int0 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int11 : !torch.int\n" " }\n" -" return %1 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %int0 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int11 : !torch.int\n" " }\n" -" return %1 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7715,13 +7734,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7729,21 +7749,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7751,13 +7773,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7765,29 +7788,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7795,21 +7819,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" @@ -7817,7 +7843,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" @@ -7839,7 +7865,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int10 = torch.constant.int 10\n" @@ -7847,532 +7873,789 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int9 = torch.constant.int 9\n" " %int6 = torch.constant.int 6\n" " %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" " } else {\n" -" %3 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" -" %5 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" " torch.prim.If.yield %int10 : !torch.int\n" " } else {\n" -" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield %0 : !torch.int\n" " }\n" -" torch.prim.If.yield %8 : !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" " }\n" -" torch.prim.If.yield %6 : !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" " }\n" -" torch.prim.If.yield %4 : !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" -" return %2 : !torch.int\n" +" return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: `self` cannot have bool dtype\"\n" " %int11 = torch.constant.int 11\n" -" %0 = torch.aten.ne.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" -" %3 = torch.prim.ListConstruct %arg1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %2 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %arg1 : !torch.int\n" +" return %1#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %8 = torch.aten.ne.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %8 : !torch.bool\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If.yield %7 : !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" " }\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" -" torch.prim.If.yield %2 : !torch.int\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" return %5 : !torch.int\n" +" return %7 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int6 = torch.constant.int 6\n" " %true = torch.constant.bool true\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %8 = torch.aten.ne.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %8 : !torch.bool\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If.yield %7 : !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" " }\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" -" torch.prim.If.yield %2 : !torch.int\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" return %5 : !torch.int\n" +" return %7 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Result dtype for aten.floor_divide bool\"\n" " %int11 = torch.constant.int 11\n" " %str_0 = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" %7 = torch.aten.ne.int %6, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.aten.ne.int %8, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %6 : !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `other` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `other` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %arg1 : !torch.int\n" +" return %1#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %arg1 : !torch.int\n" +" return %1#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Expected promoted dtype to be float but not `bfloat16`\"\n" " %false = torch.constant.bool false\n" " %int15 = torch.constant.int 15\n" " %str_0 = torch.constant.str \"AssertionError: `target` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.bool) {\n" -" %9 = torch.aten.ne.int %6, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %9 : !torch.bool\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%8) : (!torch.int) -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %8, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %8 -> () {\n" +" torch.prim.If %10 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %6 : !torch.int\n" +" return %8 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %2 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `self` and `vec` must have the same dtype\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected dtype of `vec` to not be float16 or bool\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" " %int11 = torch.constant.int 11\n" " %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" +" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %7 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %8 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %9 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %9 : !torch.int\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: `other` cannot be of bool dtype\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be of bool dtype\"\n" " %int11 = torch.constant.int 11\n" -" %0 = torch.aten.ne.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %0 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.aten.ne.int %arg3, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n" " %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" " %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" " %none = torch.constant.none\n" " %str_1 = torch.constant.str \"AssertionError: `grad_output` cannot be complex\"\n" -" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" -" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" %7 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = torch.aten.__contains__.int_list %7, %6 : !torch.list, !torch.int -> !torch.bool\n" -" %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n" -" torch.prim.If %9 -> () {\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.aten.__contains__.int_list %9, %8 : !torch.list, !torch.int -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %6 : !torch.int\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %true = torch.constant.bool true\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" " }\n" "}\n" ""; diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 34acd2c80744d..32208cb682fe2 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -692,18 +692,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - // Promote the two dtypes assuming non-zero rank. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - // Dtype is always float32, except for bfloat16, float64 and nullptr after // promotion and assuming possible-zero rank. if (isa(op)) { diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index e66cc76fe0826..9eac538743b40 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -19,55 +19,25 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -static bool isTensorTypeOrWrappedTensorType(Type type) { - // Allowing tuples as arguments to dtype calculation functions can cause - // issues. For example, if an argument is a tuple of tensors and ints, there - // would be no way of differentiating the original ints from the ints created - // to represent the dtype and rank of the tensors. Therefore, to avoid this - // and keep things simple, the tuple type is not allowed. This works well in - // practice, since PyTorch op signatures don't seem to take tuples as inputs. - assert(!type.isa() && - "dtype calculation functions are expected to not have tuples of " - "tensors as arguments"); - - if (type.isa()) - return true; - - if (auto optionalType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(optionalType.getContainedType()); - } else if (auto listType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(listType.getContainedType()); - } else { - return false; - } -} - // Massage the op operands to match the dtype function signature. // The dtype function generally takes the same operands as the op, with a few -// systematic modifications, such as replacing tensors with a rank and dtype -// argument. +// systematic modifications, such as replacing each tensor with a tuple of +// its rank and dtype. static FailureOr> dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, ValueRange originalOperands, func::FuncOp dtypeFunc) { - // Turns a tensor operand into an operand representing the rank of the tensor - auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand, - Type desiredType) -> Value { - if (desiredType.isa() && - operand.getType().isa()) { - auto sizeListType = - Torch::ListType::get(Torch::IntType::get(b.getContext())); - Value size = b.create(loc, sizeListType, operand); - return b.create(loc, desiredType, size); - } - return operand; - }; - - // Turns a tensor operand into an operand representing the dtype of the tensor + // Turn every tensor into a tuple of (tensor_rank, tensor_dtype) auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, Type desiredType) -> Value { - if (desiredType.isa() && + if (desiredType.isa() && operand.getType().isa()) { - return b.create(loc, desiredType, operand); + Type intType = Torch::IntType::get(b.getContext()); + Type sizeListType = Torch::ListType::get(intType); + Value size = b.create(loc, sizeListType, operand); + Value rank = b.create(loc, intType, size); + Value dtype = b.create(loc, intType, operand); + return b.create(loc, desiredType, + ArrayRef{rank, dtype}); } return operand; }; @@ -79,26 +49,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, "`dtypeFunc` should have at least one argument for each argument in " "`originalOperands`"); Type desiredType = desiredTypes.front(); - if (isTensorTypeOrWrappedTensorType(operand.getType())) { - assert(desiredTypes.size() >= 2 && - "`dtypeFunc` should have two arguments for each tensor argument " - "in `originalOperands`"); - FailureOr rankArg, dtypeArg; - if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType, - rankArgAdjuster))) - return failure(); - desiredTypes = desiredTypes.drop_front(); - desiredType = desiredTypes.front(); - if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType, - dtypeArgAdjuster))) - return failure(); - dtypeFuncArgs.append({*rankArg, *dtypeArg}); - } else { - FailureOr otherArg; - if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType))) - return failure(); - dtypeFuncArgs.push_back(*otherArg); - } + FailureOr otherArg; + if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType, + dtypeArgAdjuster))) + return failure(); + dtypeFuncArgs.push_back(*otherArg); desiredTypes = desiredTypes.drop_front(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 65282ea940961..3f187bc0b6261 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1133,81 +1133,98 @@ def _get_dtype_of_floating_point_op(input_dtype: int) -> int: return torch.float32 @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) -def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert self_dtype != torch.float16, "`self` cannot have float16 dtype" return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇reciprocal〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇log〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) -def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype) and self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, - torch.int64, torch.float16, torch.complex64, torch.complex128})) -def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, + torch.int64, torch.float16, torch.complex64, torch.complex128})) +def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype) and not is_integer_dtype(self_dtype) and self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0])) -def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int: + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0])) +def aten〇frobenius_norm〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int], keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_integer_dtype(self_dtype) if self_dtype == torch.complex128: return torch.float64 @@ -1216,25 +1233,28 @@ def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: Li return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) -def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: +def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype return torch.uint8 if self_dtype == torch.uint8 else torch.bool @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇any〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype return torch.uint8 if self_dtype == torch.uint8 else torch.bool @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇eq〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇eq〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function( @@ -1242,7 +1262,8 @@ def aten〇eq〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇ge〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @@ -1251,13 +1272,16 @@ def aten〇ge〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇gt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function( _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇gt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇gt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool @@ -1267,24 +1291,25 @@ def aten〇gt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇le〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇logical_and〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇logical_not〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇logical_or〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function(_check_two_tensor_op()) -def aten〇logical_xor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @check_dtype_function( @@ -1292,13 +1317,16 @@ def aten〇logical_xor〡dtype(self_rank: int, self_dtype: int, other_rank: int, num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇lt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function( _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇lt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇lt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool @@ -1306,7 +1334,7 @@ def aten〇lt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇ne〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: +def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: return torch.bool @check_dtype_function([ @@ -1323,7 +1351,8 @@ def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.float16, torch.bfloat16})) -def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: +def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype if is_complex_dtype(self_dtype): return self_dtype elif self_dtype == torch.float: @@ -1340,14 +1369,17 @@ def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = N num_of_tensors=1, error_types={torch.bool}, other=0.0) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.bool}, other=0)) -def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.bool, "`self` cannot have bool dtype" return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇__and__〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1355,35 +1387,43 @@ def aten〇__and__〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op()) -def aten〇add〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int: +def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_and〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_or〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_xor〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1399,7 +1439,9 @@ def aten〇bitwise_xor〇Tensor〡dtype(self_rank: int, self_dtype: int, other_r # Different type ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), TensorOfShape(2, 4, 3, dtype=torch.int32))]) -def aten〇bmm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int: +def aten〇bmm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert mat2_dtype not in [torch.float16, torch.bool], \ @@ -1408,7 +1450,9 @@ def aten〇bmm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dty return self_dtype @check_dtype_function(_check_two_tensor_op()) -def aten〇div〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] promoted_dtype = promote_dtypes(ranks, dtypes) @@ -1419,7 +1463,9 @@ def aten〇div〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int return torch.float32 @check_dtype_function(_check_two_tensor_op(rounding_mode=None)) -def aten〇div〇Tensor_mode〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, rounding_mode: Optional[str]) -> int: +def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rounding_mode: Optional[str]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] promoted_dtype = promote_dtypes(ranks, dtypes) @@ -1430,7 +1476,9 @@ def aten〇div〇Tensor_mode〡dtype(self_rank: int, self_dtype: int, other_rank return torch.float32 @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool})) -def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1448,7 +1496,9 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int # Different type ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), TensorOfShape(2, 4, 3, dtype=torch.int32))]) -def aten〇matmul〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇matmul〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert other_dtype not in [torch.float16, torch.bool], \ @@ -1457,7 +1507,9 @@ def aten〇matmul〡dtype(self_rank: int, self_dtype: int, other_rank: int, othe return self_dtype @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇maximum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇maximum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1465,7 +1517,9 @@ def aten〇maximum〡dtype(self_rank: int, self_dtype: int, other_rank: int, oth return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇minimum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1481,7 +1535,9 @@ def aten〇minimum〡dtype(self_rank: int, self_dtype: int, other_rank: int, oth # Different type ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(4, 3, dtype=torch.int32))]) -def aten〇mm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int: +def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert mat2_dtype not in [torch.float16, torch.bool], \ @@ -1489,10 +1545,13 @@ def aten〇mm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtyp assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype" return self_dtype -@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, - output_error_types={torch.bool, torch.uint8, torch.int8, torch.int16, - torch.int32, torch.int64, torch.bfloat16})) -def aten〇mse_loss〡dtype(self_rank: int, self_dtype: int, target_rank: int, target_dtype: int, reduction: int = 1) -> int: +@check_dtype_function(_check_two_tensor_op( + input_error_types={torch.complex64, torch.complex128}, + output_error_types={torch.bool, torch.uint8, torch.int8, torch.int16, + torch.int32, torch.int64, torch.bfloat16})) +def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(target_dtype), "`target` cannot be complex" ranks: List[Optional[int]] = [self_rank, target_rank] @@ -1503,7 +1562,9 @@ def aten〇mse_loss〡dtype(self_rank: int, self_dtype: int, target_rank: int, t return promoted_dtype @check_dtype_function(_check_two_tensor_op()) -def aten〇mul〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) @@ -1517,7 +1578,9 @@ def aten〇mul〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int # Different type ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(4, dtype=torch.int32))]) -def aten〇mv〡dtype(self_rank: int, self_dtype: int, vec_rank: int, vec_dtype: int) -> int: +def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + vec_rank, vec_dtype = vec_rank_dtype assert self_dtype not in [torch.float16, torch.bool], \ "Expected dtype of `self` to not be float16 or bool" assert vec_dtype not in [torch.float16, torch.bool], \ @@ -1528,7 +1591,9 @@ def aten〇mv〡dtype(self_rank: int, self_dtype: int, vec_rank: int, vec_dtype: return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op(input_error_types={torch.bool})) -def aten〇sub〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int: +def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype assert self_dtype != torch.bool, "`self` cannot be of bool dtype" assert other_dtype != torch.bool, "`other` cannot be of bool dtype" ranks: List[Optional[int]] = [self_rank, other_rank] @@ -1536,7 +1601,9 @@ def aten〇sub〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int return promote_dtypes(ranks, dtypes) @check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool, torch.float16}, threshold=0)) -def aten〇threshold_backward〡dtype(grad_output_rank: int, grad_output_dtype: int, self_rank: int, self_dtype: int, threshold: Union[int, float]) -> int: +def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + grad_output_rank, grad_output_dtype = grad_output_rank_dtype assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex" assert not is_complex_dtype(self_dtype), "`self` cannot be complex" ranks: List[Optional[int]] = [grad_output_rank, self_rank] @@ -1546,6 +1613,115 @@ def aten〇threshold_backward〡dtype(grad_output_rank: int, grad_output_dtype: "Result dtype for aten.threshold_backward cannot be bool or float16" return promoted_dtype +_convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False, "allow_tf32" : False} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs) +]) +def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +_convolution_deprecated_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) +]) +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)], error_types={torch.bool, torch.float16}) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert input_dtype not in [torch.bool, torch.float16] + assert weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)], + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.float16}) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert (input_dtype == torch.int64 or not is_integer_dtype(input_dtype)) and input_dtype != torch.float16 + assert (weight_dtype == torch.int64 or not is_integer_dtype(weight_dtype)) and weight_dtype != torch.float16 + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **convolution_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs) +]) +def aten〇convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + # ============================================================================== # Main # ============================================================================== diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 81c669e1abe3c..550b47802e76c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -86,7 +86,7 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str: """ # Dtype functions only care about the rank and dtype of tensors. if "Tensor" in pytype: - return pytype.replace("Tensor", "int") + return pytype.replace("Tensor", "Tuple[int, int]") return _pytype_to_fn_pytype_common(pytype) def _pytype_to_decomposition_fn_pytype(pytype: str) -> str: @@ -232,8 +232,7 @@ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: default = _get_default_value(arg) parameter_name = _rename_python_keyword_parameter_name(arg["name"]) if "Tensor" in arg["pytype"]: - return ", ".join([f"{parameter_name}_rank: {pytype}{default}", - f"{parameter_name}_dtype: {pytype}{default}"]) + return f"{parameter_name}_rank_dtype: {pytype}{default}" return f"{parameter_name}: {pytype}{default}" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: @@ -241,7 +240,7 @@ def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: # results of type `number`. Here we handle this case because # `_pytype_to_dtype_fn_pytype` will replace `number` with # `Union[int, float]`. - if arg["pytype"] == "number": + if arg["pytype"] in ["number", "Tensor"]: return "int" return _pytype_to_dtype_fn_pytype(arg["pytype"]) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index 1e82a59706b7c..efd270b78a7f9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -96,36 +96,6 @@ def _recursively_transform_tensor_args( return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o) raise Exception(f"Unhandled type {type(o)}") -def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]: - """Converts an Invocation argument to a dtype function argument. - - TensorOfShape is replaced with two ints representing the rank - and dtype of the tensor, respectively. - """ - def contains_tensor(o: Any) -> bool: - if o is None or isinstance(o, (float, int)): - return False - if isinstance(o, TensorOfShape): - return True - if isinstance(o, (list, tuple)): - for elem in o: - if contains_tensor(elem): - return True - return False - raise Exception(f"Unhandled type {type(o)}") - - result = [] - for arg in arguments: - if contains_tensor(arg): - rank_arg = _recursively_transform_tensor_args( - arg, lambda x: len(x.shape)) - dtype_arg = _recursively_transform_tensor_args( - arg, lambda x: x.dtype) - result += [rank_arg, dtype_arg] - else: - result.append(arg) - return result - class Invocation: """Representation of a single op invocation (i.e. list of args to the op). @@ -135,8 +105,8 @@ class Invocation: Specifically, this class has special knowledge of `TensorOfShape` and translates it appropriately to either a tensor (for the real op), a - `List[int]` for the shape function, and two `int`s representing - the tensor rank and dtype in the case of a dtype function. + `List[int]` for the shape function, and a tuple with two `int`s + representing the tensor rank and dtype in the case of a dtype function. This class also tracks whether the invocation is expected to raise an exception for greater precision when interpreting errors raised during @@ -170,7 +140,9 @@ def to_shape_function_args(self): def to_dtype_function_args(self): """Gets positional arguments appropriate for a dtype function.""" - return _convert_to_dtype_function_args(self.args) + tensor_transformer = lambda o: (len(o.shape), o.dtype) + return _recursively_transform_tensor_args( + self.args, tensor_transformer) def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index d3795307fbf18..9dd80b0d22674 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -9,6 +9,8 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", "NormalizeModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", "MobilenetV3Module_basic", } diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 8ed19c68a159d..455dfbbfd07de 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -12,7 +12,8 @@ // CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list -> !torch.int // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int -// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE:.*]] = torch.prim.TupleConstruct %[[RANK]], %[[DTYPE]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK_DTYPE]]) : (!torch.tuple) -> !torch.int // CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int // CHECK: } : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor @@ -38,6 +39,20 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // ----- +// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.convolution( + +// CHECK-LABEL: func.func @op_with_optional_tensor_arg$none( +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[OPTIONAL_TUPLE:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional> +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.convolution({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional>, {{.*}}) -> !torch.int +func.func @op_with_optional_tensor_arg$none(%input: !torch.vtensor, %weight: !torch.vtensor, %stride: !torch.list, %padding: !torch.list, %dilation: !torch.list, %transposed: !torch.bool, %output_padding: !torch.list, %groups: !torch.int) -> !torch.vtensor { + %bias_none = torch.constant.none + %0 = torch.aten.convolution %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor + return %0 : !torch.vtensor +} + +// ----- + // CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide( // CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args( @@ -46,10 +61,12 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list -> !torch.int // CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int +// CHECK: %[[RANK_DTYPE0:.*]] = torch.prim.TupleConstruct %[[RANK0]], %[[DTYPE0]] : !torch.int, !torch.int -> !torch.tuple // CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list -> !torch.int // CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int -// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE1:.*]] = torch.prim.TupleConstruct %[[RANK1]], %[[DTYPE1]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK_DTYPE0]], %[[RANK_DTYPE1]]) : (!torch.tuple, !torch.tuple) -> !torch.int func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return %0 : !torch.vtensor