From 899d8bc8cf6cce4ed314da8115f7be04cae902fe Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 13 Mar 2023 21:39:14 +0000 Subject: [PATCH] Add dtype functions for `aten.ge.Tensor` and `aten.le.Tensor` --- .../python_deploy/build_linux_packages.sh | 2 +- e2e_testing/xfail_sets.py | 8 +- .../TorchToLinalg/TensorConstructors.cpp | 3 +- .../Transforms/AbstractInterpLibrary.cpp | 2794 ++++++++++++----- .../Torch/Transforms/DecomposeComplexOps.cpp | 14 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 601 +--- .../ReifyAbstractInterpCalculationsUtils.cpp | 22 +- ...implifyAbstractInterpCalculationsUtils.cpp | 252 ++ .../SimplifyAbstractInterpCalculationsUtils.h | 7 + .../Transforms/SimplifyDtypeCalculations.cpp | 5 + .../Transforms/SimplifyShapeCalculations.cpp | 242 +- .../build_tools/abstract_interp_lib_gen.py | 2206 ++++++++++--- .../jit_ir/build_tools/library_generator.py | 47 +- .../jit_ir/build_tools/testing_framework.py | 20 +- .../test_suite/__init__.py | 3 - python/torch_mlir_e2e_test/test_suite/rng.py | 24 + test/Dialect/Torch/refine-types-ops.mlir | 314 -- test/Dialect/Torch/refine-types.mlir | 124 - .../Torch/reify-shape-calculations.mlir | 11 + 19 files changed, 4139 insertions(+), 2560 deletions(-) delete mode 100644 test/Dialect/Torch/refine-types-ops.mlir diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index a970062b67f3..cfb4dbfe5aed 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -288,7 +288,7 @@ function test_in_tree() { python -m e2e_testing.main --config=lazy_tensor_core -v echo ":::: Run TorchDynamo e2e integration tests" - python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic + python -m e2e_testing.main --config=torchdynamo -v } function setup_venv() { diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 5be3b6c116ad..bf21e8bae35f 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -93,11 +93,6 @@ "ElementwiseAddScalar_NumToTensorFloat_Module_basic", # ERROR: assert isinstance(e, FakeTensor) "RsubInt0d_NumToTensor_Module_basic", - - # Dtype function transition failures - "MobilenetV3Module_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", } STABLEHLO_PASS_SET = { @@ -706,7 +701,8 @@ "FullLikeModuleInt2DStatic_basic", "FullModuleInt3D_basic", "FullModuleFloat2D_basic", - "RepeatModule_basic" + "RepeatModule_basic", + "ResNet18StaticModule_basic", } LTC_XFAIL_SET = { diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index e861a1877e99..724430401ab1 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -226,7 +226,8 @@ class ConvertAtenEmptyMemoryFormatOp typeConverter->convertType(op.getType()).cast(); Type resultElementType; if (op.getDtype().getType().isa()) { - resultElementType = resultType.getElementType(); + resultElementType = getDefaultDtypeForTorchScalar( + Torch::FloatType::get(op->getContext())); } else { int64_t dtypeInt; if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index fe5a20e8059b..aed70815431f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -5935,6 +5935,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.testing_framework._convert_dtype_to_int(%arg0: !torch.int) -> !torch.int {\n" +" return %arg0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.triu\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7380,20 +7383,46 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %7 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple, list, list> {\n" +" %none = torch.constant.none\n" +" %int2 = torch.constant.int 2\n" " %int1 = torch.constant.int 1\n" " %int0 = torch.constant.int 0\n" " %0 = torch.prim.If %arg5 -> (!torch.tuple, list, list>) {\n" -" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %2 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list\n" -" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.prim.ListConstruct %3 : (!torch.int) -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.ge.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.tuple, list, list>) {\n" +" %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" +" %6 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list\n" +" %8 = torch.prim.TupleConstruct %arg0, %5, %7 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" torch.prim.If.yield %8 : !torch.tuple, list, list>\n" +" } else {\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.prim.ListConstruct : () -> !torch.list\n" +" %6 = torch.prim.TupleConstruct %arg0, %4, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" torch.prim.If.yield %6 : !torch.tuple, list, list>\n" +" }\n" +" torch.prim.If.yield %3 : !torch.tuple, list, list>\n" +" } else {\n" +" %1 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" %3 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" " %5 = torch.prim.TupleConstruct %arg0, %2, %4 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" " torch.prim.If.yield %5 : !torch.tuple, list, list>\n" -" } else {\n" -" %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %3 = torch.prim.TupleConstruct %arg0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" -" torch.prim.If.yield %3 : !torch.tuple, list, list>\n" " }\n" " return %0 : !torch.tuple, list, list>\n" " }\n" @@ -7635,19 +7664,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n" " %true = torch.constant.bool true\n" @@ -7674,104 +7693,53 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %3 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" " %int7 = torch.constant.int 7\n" " %int6 = torch.constant.int 6\n" " %int15 = torch.constant.int 15\n" " %int5 = torch.constant.int 5\n" " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" " %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" " return %1 : !torch.bool\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" " %int10 = torch.constant.int 10\n" " %int9 = torch.constant.int 9\n" " %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n" -" %int5 = torch.constant.int 5\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %4 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -7779,134 +7747,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %2 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %5 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %4 : !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" -" %false = torch.constant.bool false\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %7 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %6 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %6 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" %3 = func.call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" " }\n" -" %5 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %5 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" " %int4 = torch.constant.int 4\n" " %int3 = torch.constant.int 3\n" " %int2 = torch.constant.int 2\n" @@ -7914,8 +7800,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int0 = torch.constant.int 0\n" " %int11 = torch.constant.int 11\n" " %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" @@ -7957,171 +7842,238 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %6 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" %3 = func.call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" " }\n" -" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n" -" %int5 = torch.constant.int 5\n" -" %str_0 = torch.constant.str \"AssertionError: `self` cannot have integer dtype\"\n" -" %str_1 = torch.constant.str \"AssertionError: `self` cannot have complex dtype\"\n" -" %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: `grad_output` and `self` must have the same dtype\"\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" +" func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %3:2 = torch.prim.If %2 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" %5 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %6:2 = torch.prim.If %5 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %6#0, %6#1 : !torch.bool, !torch.int\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" +" %4 = torch.prim.If %3#0 -> (!torch.int) {\n" +" torch.prim.If.yield %3#1 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %1#1 : !torch.int\n" " }\n" -" %7 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" return %1#1 : !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n" -" %int5 = torch.constant.int 5\n" -" %str_0 = torch.constant.str \"AssertionError: `self` cannot have integer dtype\"\n" -" %str_1 = torch.constant.str \"AssertionError: `self` cannot have complex dtype\"\n" -" %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: `grad_output` and `self` must have the same dtype\"\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli\"(%arg0: !torch.tuple, %arg1: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_not\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_to\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ceil\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " }\n" -" return %1#1 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" -" %int0 = torch.constant.int 0\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" +" torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" -" torch.prim.If.yield %int11 : !torch.int\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" -" %int0 = torch.constant.int 0\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " %2 = torch.prim.If %1 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" +" torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" -" torch.prim.If.yield %int11 : !torch.int\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clone\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.copy\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cpu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumsum\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" " }\n" -" return %int11 : !torch.int\n" +" return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flatten.using_ints\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gather\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" +" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu\"(%arg0: !torch.tuple, %arg1: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardsigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardswish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " }\n" -" return %int11 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %int0, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list, !torch.int -> !torch.bool\n" " %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" " torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" @@ -8129,22 +8081,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" return %int11 : !torch.int\n" +" return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put.hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor_hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.list>>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.layer_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -8152,61 +8119,379 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %int11 : !torch.int\n" +" return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_select\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.narrow\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %int11 : !torch.int\n" +" return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.numpy_T\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pad\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.permute\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.prelu\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.repeat\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reshape\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.resize_\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.roll\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.round\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.square\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze.dim\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.t\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.triu\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unsqueeze\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero_\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" +" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" +" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" +" return %1#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" @@ -8222,15 +8507,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" -" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" -" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" -" return %0 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" @@ -8238,26 +8514,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int7 = torch.constant.int 7\n" " %int9 = torch.constant.int 9\n" " %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" " %0 = torch.prim.Uninitialized : !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" " %3 = torch.prim.If %2 -> (!torch.int) {\n" " torch.prim.If.yield %1#1 : !torch.int\n" " } else {\n" -" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" " %5 = torch.prim.If %4 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" +" torch.prim.If.yield %int8 : !torch.int\n" " } else {\n" -" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" " %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int10 : !torch.int\n" +" torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" -" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" " %9 = torch.prim.If %8 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" +" torch.prim.If.yield %int10 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.int\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %11 : !torch.int\n" " }\n" " torch.prim.If.yield %9 : !torch.int\n" " }\n" @@ -8269,46 +8553,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot have bool dtype\"\n" -" %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" -" %4 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %5 : !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8319,115 +8577,133 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" -" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" -" %int11 = torch.constant.int 11\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %1#1 : !torch.int\n" " }\n" -" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" -" torch.prim.If.yield\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%arg0: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Cannot determine priority of dtype\"\n" +" %int15 = torch.constant.int 15\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %int6 = torch.constant.int 6\n" +" %int7 = torch.constant.int 7\n" +" %int8 = torch.constant.int 8\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1 = torch.aten.eq.int %arg0, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" %3 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" %5 = torch.aten.eq.int %arg0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int2 : !torch.int\n" +" } else {\n" +" %7 = torch.aten.eq.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" } else {\n" +" %9 = torch.aten.eq.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %11 = torch.aten.eq.int %arg0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %13 = torch.aten.eq.int %arg0, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %15 = torch.aten.eq.int %arg0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" %17 = torch.aten.eq.int %arg0, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %19 = torch.aten.eq.int %arg0, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %21 = torch.aten.eq.int %arg0, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %23 = torch.aten.eq.int %arg0, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %24 = torch.prim.If %23 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" torch.prim.If.yield %22 : !torch.int\n" +" }\n" +" torch.prim.If.yield %20 : !torch.int\n" +" }\n" +" torch.prim.If.yield %18 : !torch.int\n" +" }\n" +" torch.prim.If.yield %16 : !torch.int\n" +" }\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %12 : !torch.int\n" +" }\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" torch.prim.If.yield %8 : !torch.int\n" +" }\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" " }\n" -" return %1#1 : !torch.int\n" +" return %2 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %false = torch.constant.bool false\n" @@ -8524,173 +8800,81 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %8 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `self` and `other` must have the same dtype\"\n" -" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `other` to not be float16 or bool\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" -" %int11 = torch.constant.int 11\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" -" torch.prim.If.yield\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %1#1 : !torch.int\n" " }\n" -" return %1#1 : !torch.int\n" +" return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %8 : !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %8 : !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" -" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" -" %int11 = torch.constant.int 11\n" +" %false = torch.constant.bool false\n" " %int5 = torch.constant.int 5\n" +" %int15 = torch.constant.int 15\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = torch.prim.ListConstruct %int15, %int5 : (!torch.int, !torch.int) -> !torch.list\n" " %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %7 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" " } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %7 = torch.aten.ne.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %8 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" -" torch.prim.If.yield\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" %7 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %8 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %9 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" " }\n" -" return %1#1 : !torch.int\n" +" return %6 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: Expected promoted dtype to be float but not `bfloat16`\"\n" -" %false = torch.constant.bool false\n" -" %int15 = torch.constant.int 15\n" -" %str_0 = torch.constant.str \"AssertionError: `target` cannot be complex\"\n" " %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %str = torch.constant.str \"AssertionError: \"\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" %9 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%8) : (!torch.int) -> !torch.bool\n" -" %10 = torch.prim.If %9 -> (!torch.bool) {\n" -" %11 = torch.aten.ne.int %8, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %11 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %10 -> () {\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" return %8 : !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" @@ -8701,108 +8885,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `self` and `vec` must have the same dtype\"\n" -" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `vec` to not be float16 or bool\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" -" %int11 = torch.constant.int 11\n" -" %int5 = torch.constant.int 5\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %11 : !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: `other` cannot be of bool dtype\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: `self` cannot be of bool dtype\"\n" -" %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %5 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %6 : !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" -" %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n" -" %int5 = torch.constant.int 5\n" -" %int11 = torch.constant.int 11\n" -" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: `grad_output` cannot be complex\"\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" %9 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = torch.aten.__contains__.int_list %9, %8 : !torch.list, !torch.int -> !torch.bool\n" -" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" -" torch.prim.If %11 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" return %8 : !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" " %none = torch.constant.none\n" @@ -8895,138 +8999,1202 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %10 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.bool, %arg8: !torch.list, %arg9: !torch.int, %arg10: !torch.list) -> !torch.tuple {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.list) -> !torch.tuple {\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" -" %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.eq.int %2#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %9 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %7:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.prim.TupleConstruct %7#1, %7#1, %7#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %8 : !torch.tuple\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bincount\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int4 = torch.constant.int 4\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int15 = torch.constant.int 15\n" -" %int5 = torch.constant.int 5\n" -" %true = torch.constant.bool true\n" -" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %0#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %5 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" -" %11 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %12 : !torch.bool\n" +" torch.prim.If.yield %false : !torch.bool\n" " }\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" " } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" " }\n" +" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" +" %5 = torch.aten.ne.int %2#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" " } else {\n" -" %11 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %12 = torch.aten.__not__ %11 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %12 : !torch.bool\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" " }\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" " } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" +" torch.prim.If.yield %5 : !torch.int\n" " }\n" -" torch.prim.If %7 -> () {\n" +" return %7 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %false = torch.constant.bool false\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" " } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" " }\n" +" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n" " torch.prim.If %4 -> () {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %5 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int4 = torch.constant.int 4\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %10 : !torch.int\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1, %6 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %7 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum.dim_IntList\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.argmax\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.amax\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = call @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0, %0, %false, %arg1) : (!torch.tuple, !torch.optional>, !torch.bool, !torch.optional) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple\n" +" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.int\"(%arg0: !torch.int, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.bool\"(%arg0: !torch.bool, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty.memory_format\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_ones\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty_strided\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._to_copy\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.nvprims.convert_element_type\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype_layout\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.device\"(%arg0: !torch.tuple, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.int {\n" +" return %arg2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.other\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.type_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.tuple) {\n" +" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %5 : !torch.tuple\n" +" } else {\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.tuple) {\n" +" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" } else {\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" }\n" +" torch.prim.If.yield %6 : !torch.tuple\n" +" }\n" +" return %4 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.tuple) {\n" +" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %5 : !torch.tuple\n" +" } else {\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.tuple) {\n" +" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" } else {\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" }\n" +" torch.prim.If.yield %6 : !torch.tuple\n" +" }\n" +" return %4 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ScalarImplicit\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.If %arg2 -> (!torch.int) {\n" +" %2 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.If %arg2 -> (!torch.int) {\n" +" %2 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log_softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.embedding\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._embedding_bag\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4, %int4, %int4 : !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.embedding_bag.padding_idx\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4, %int4, %int4 : !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bucketize.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.prim.If %arg2 -> (!torch.int) {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" return %0 : !torch.int\n" " }\n" "}\n" ""; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f54198e5056a..c218395cd1f2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3716,8 +3716,14 @@ class DecomposeAtenRandnGeneratorOp LogicalResult matchAndRewrite(AtenRandnGeneratorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type resultType = op.getType(); + auto resultType = op.getType().cast(); + + if (!resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have a dtype"); + } + Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); Value none = rewriter.create(loc); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)0.0)); @@ -3729,11 +3735,13 @@ class DecomposeAtenRandnGeneratorOp loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159))); Value emptyTensorA = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), + loc, resultType, op.getSize(), /*dtype=*/dtype, + /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); Value emptyTensorB = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), + loc, resultType, op.getSize(), /*dtype=*/dtype, + /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 550f0de5aa08..8df95def9ee4 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -85,16 +85,6 @@ static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) { return failed(result) ? Type() : *result; } -static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, - Type defaultDtype) { - int64_t dtypeInt; - if (matchPattern(optionalDtype, m_TorchConstantInt(&dtypeInt))) - return getTypeForDTypeInteger(context, dtypeInt); - else if (optionalDtype.getType().isa()) - return defaultDtype; - return Type(); -} - // Get the kind enum for `ValueKnowledge.kind`. static torch_upstream::TypeKind getTypeKind(Type type) { if (type.isa()) @@ -429,12 +419,6 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< void visitAtenLinearOp(AtenLinearOp op, ArrayRef operands); - void visitAtenArangeStartStepOp(AtenArangeStartStepOp op); - void visitAtenArangeStartOp(AtenArangeStartOp op); - void visitAtenArangeOp(AtenArangeOp op); - void visitAtenArangeLikeOpHelper(Operation *op, std::optional start, - Value end, std::optional step, - Value dtype); void visitReductionAlongAllDimsOp(Operation *op, Type dtype, ArrayRef operands); void visitReductionAlongDimIntListOp(Operation *op, Value dim, Value keepdim, @@ -536,37 +520,6 @@ static Type getPromotedResultScalarType(ArrayRef scalarTypes) { return *result; } -// Returns most generic type Type() if the tensor dtype is unknown. -static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) { - if (!tensor->dtype) - return Type(); - torch_upstream::ResultTypeState state = {}; - // No need to check if rank is zero for tensor because scalar uses - // wrappedResult which is a lower priority than both dimResult and zeroResult. - state = updateResultTypeState(tensor, /*rankIsNonZero=*/std::nullopt, state, - /*skipRankCheck=*/true); - state = - updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); - FailureOr result = - getTypeForScalarType(scalarType.getContext(), result_type(state)); - return failed(result) ? Type() : *result; -} - -static SmallVector> -getRankIsNonZeroArray(ValueRange values) { - SmallVector> rankIsNonZero; - for (Value v : values) { - if (auto tensorType = v.getType().dyn_cast()) { - if (tensorType.hasSizes()) { - rankIsNonZero.push_back(tensorType.getSizes().size() != 0); - } else { - rankIsNonZero.push_back(std::nullopt); - } - } - } - return rankIsNonZero; -} - // Normally, tensor dimensions need to be known at compile time to do type // promotion. `skipRankCheck`, when equal to true, can be used to indicate // special cases that tensor operands are guaranteed to be not zero dimension @@ -622,514 +575,6 @@ void TypeAnalysis::fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, void TypeAnalysis::visitOperation(Operation *op, ArrayRef operands, ArrayRef results) { - - // These ops have results that are dynamically the same as their operands. - if (isa(op)) { - incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - return; - } - - // Take dtype from first operand. - if (isa(op)) { - return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - } - - // Dtype is always si64. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Dtype is always float32, except for bfloat16, float64 and nullptr after - // promotion and assuming possible-zero rank. - if (isa(op)) { - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type promotedDtype = getPromotedResultType( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, - getRankIsNonZeroArray(op->getOperands())); - if (promotedDtype) { - knowledge.dtype = Float32Type::get(op->getContext()); - if (promotedDtype.isa()) - knowledge.dtype = promotedDtype; - } - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote three dtypes. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(), - &operands[2]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - if (auto linear = llvm::dyn_cast(op)) { - visitAtenLinearOp(linear, operands); - return; - } - - // Promote LHS with scalar RHS. - if (isa(op)) { - auto lhs = operands[0]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( - op->getContext(), {&operands[1]->getValue(), &operands[2]->getValue()}, - getRankIsNonZeroArray(op->getOperands().slice(1, 2))); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - Value lhsScalar = op->getOperand(1); - Value rhsScalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getDefaultDtypeForTorchScalar(getPromotedResultScalarType( - {lhsScalar.getType(), rhsScalar.getType()})); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto lhs = operands[1]->getValue(); - Value scalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto rhs = operands[2]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&rhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // 2 results take dtype from first operand. - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - return; - } - - // 3 results take dtype from first operand. - if (isa( - op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = self.dtype; - auto result2Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result2Knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - incorporateKnowledge(op->getResult(2), result1Knowledge); - return; - } - - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - return; - } - - if (auto arange = dyn_cast(op)) { - visitAtenArangeOp(arange); - return; - } - if (auto arangeStart = dyn_cast(op)) { - visitAtenArangeStartOp(arangeStart); - return; - } - if (auto arangeStartStep = dyn_cast(op)) { - visitAtenArangeStartStepOp(arangeStartStep); - return; - } - - if (auto sum = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sum.getContext(), sum.getDtype(), defaultDtype); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = dtype; - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - if (auto sumDimIntList = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - if (!defaultDtype) { - incorporateKnowledge( - sumDimIntList.getResult(), - ValueKnowledge::getTensorPessimisticValueState(op->getContext())); - return; - } - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sumDimIntList.getContext(), - sumDimIntList.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.getDim(), - sumDimIntList.getKeepdim(), dtype, operands); - return; - } - if (auto meanDim = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(meanDim.getContext(), meanDim.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(meanDim, meanDim.getDim(), meanDim.getKeepdim(), - dtype, operands); - return; - } - if (auto argmax = dyn_cast(op)) { - Value dim = argmax.getDim(); - Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - if (dim.getType().isa()) { - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } - if (dim.getType().isa()) { - visitReductionAlongDimIntOp(argmax, argmax.getDim(), argmax.getKeepdim(), dtype, - operands); - return; - } - } - if (auto anyDim = dyn_cast(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongDimIntOp(anyDim, anyDim.getDim(), anyDim.getKeepdim(), dtype, - operands); - return; - } - if (auto maxDim = dyn_cast(op)) { - Type firstResDtype = operands[0]->getValue().dtype; - Type secondResDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - firstResDtype, operands); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - secondResDtype, operands, /*resNum=*/1); - return; - } - if (auto mean = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(mean.getContext(), mean.getDtype(), defaultDtype); - visitReductionAlongAllDimsOp(mean, dtype, operands); - return; - } else if (isa(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } else if (isa(op)) { - auto input = operands[0]->getValue(); - visitReductionAlongAllDimsOp(op, input.dtype, operands); - return; - } - - if (auto tensorFloat = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorFloat); - return; - } else if (auto tensorInt = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorInt); - return; - } else if (auto tensorBool = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorBool); - return; - } - - if (auto tensor = dyn_cast(op)) { - visitAtenTensorOp(tensor); - return; - } - - if (auto zeros = dyn_cast(op)) { - visitConstantTensorAllocOp(zeros, /*dataType=*/{}); - return; - } else if (auto ones = dyn_cast(op)) { - visitConstantTensorAllocOp(ones, /*dataType=*/{}); - return; - } else if (auto emptyMemoryFormat = dyn_cast(op)) { - visitConstantTensorAllocOp(emptyMemoryFormat, - /*dataType=*/{}); - return; - } else if (auto full = dyn_cast(op)) { - visitConstantTensorAllocOp( - full, /*dataType=*/full.getFillValue().getType()); - return; - } else if (auto zerosLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(zerosLike, operands); - return; - } else if (auto onesLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(onesLike, operands); - return; - } else if (auto emptyLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(emptyLike, operands); - return; - } else if (auto fullLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(fullLike, operands); - return; - } else if (auto newZeros = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newZeros, operands); - return; - } else if (auto newOnes = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newOnes, operands); - return; - } else if (auto newEmpty = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmpty, operands); - return; - } else if (auto newEmptyStrided = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmptyStrided, - operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto toCopy = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(toCopy, operands); - return; - } - - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto primsConvertElementType = dyn_cast(op)) { - visitAtenToDtypeLikeOp(primsConvertElementType, - operands); - return; - } - - if (auto toDtypeLayout = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtypeLayout, operands); - return; - } - - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto toOther = dyn_cast(op)) { - visitTypeConversionOp(toOther, operands); - return; - } else if (auto typeAs = dyn_cast(op)) { - visitTypeConversionOp(typeAs, operands); - return; - } - - if (auto cat = dyn_cast(op)) { - visitAtenCatLikeOp(cat, operands); - return; - } else if (auto stack = dyn_cast(op)) { - visitAtenCatLikeOp(stack, operands); - return; - } - - if (auto shapeAsTensor = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(shapeAsTensor.getResult(), knowledge); - return; - } - - if (auto embedding = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = operands[0]->getValue().dtype; - incorporateKnowledge(embedding.getResult(), knowledge); - return; - } - - if (isa(op)) { - visitAtenEmbeddingBagOp(op); - return; - } - - if (auto softmaxIntOp = dyn_cast(op)) { - visitAtenSoftmaxLikeOp(softmaxIntOp, operands); - return; - } - if (auto _softmaxOp = dyn_cast(op)) { - visitAten_SoftmaxLikeOp(_softmaxOp, operands); - return; - } else if (auto _logSoftmaxOp = dyn_cast(op)) { - visitAten_SoftmaxLikeOp(_logSoftmaxOp, operands); - return; - } else if (auto logSoftmaxIntOp = dyn_cast(op)) { - visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands); - return; - } - - if (auto numToTensorOp = dyn_cast(op)) { - visitNumToTensorOp(numToTensorOp); - return; - } - - if (isa(op)) { - visitBinaryScalarOp(op, operands); - return; - } - - if (auto scalarImplicit = dyn_cast(op)) { - visitAtenScalarImplicitOp(scalarImplicit, operands); - return; - } - - if (auto vectorNorm = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = getDtypeOrDefault(vectorNorm.getContext(), vectorNorm.getDtype(), - defaultDtype); - visitReductionAlongDimIntListOp(vectorNorm, vectorNorm.getDim(), - vectorNorm.getKeepdim(), dtype, operands); - return; - } - - if (auto randIntLow = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randIntLow.getDtype(), defaultDtype); - incorporateKnowledge(randIntLow.getResult(), knowledge); - return; - } - - if (isa(op)) { - auto input = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = input.dtype; - incorporateKnowledge(op->getResult(0), knowledge); - incorporateKnowledge(op->getResult(1), knowledge); - return; - } - - if (auto randn = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = Float32Type::get(op->getContext()); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randn.getDtype(), defaultDtype); - incorporateKnowledge(randn.getResult(), knowledge); - return; - } - - if (auto randnGenerator = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = Float32Type::get(op->getContext()); - knowledge.dtype = getDtypeOrDefault(op->getContext(), - randnGenerator.getDtype(), defaultDtype); - incorporateKnowledge(randnGenerator.getResult(), knowledge); - return; - } - - if (auto bucketize = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - bool outInt32; - if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) && - outInt32) { - knowledge.dtype = - IntegerType::get(op->getContext(), 32, IntegerType::Signed); - } else { - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(bucketize.getResult(), knowledge); - return; - } - // Otherwise, this is an unknown operation, so reset the state. setAllToEntryStates(results); return; @@ -1183,49 +628,6 @@ void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) { return; } -// Arange like ops returns a 1-D tensor of size ceil(end - start). -void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op, - std::optional start, - Value end, - std::optional step, - Value dtype) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - int64_t dtypeInt; - if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) { - knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt); - } else if (dtype.getType().isa()) { - // From torch/_torch_docs.py: - // If `dtype` is not given, infer the data type from the other input - // arguments. If any of `start`, `end`, or `step` are floating-point, the - // `dtype` is inferred to be the default dtype, see - // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to - // be `torch.int64` - if ((start.has_value() && (*start).getType().isa()) || - end.getType().isa() || - (step.has_value() && (*step).getType().isa())) { - // TODO: Should get the dtype from torch.get_default_dtype(). - // For now, use float32 which is the initial default dtype. - knowledge.dtype = Float32Type::get(op->getContext()); - } else - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitAtenArangeStartStepOp(AtenArangeStartStepOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), op.getStep(), op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeStartOp(AtenArangeStartOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), {}, op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeOp(AtenArangeOp op) { - visitAtenArangeLikeOpHelper(op, {}, op.getEnd(), {}, op.getDtype()); -} - void TypeAnalysis::visitReductionAlongAllDimsOp( Operation *op, Type dtype, ArrayRef operands) { auto knowledge = @@ -1494,8 +896,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { // the right thing forthose ops. // static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) { - return op->hasTrait() || - isa(op); + return op->hasTrait(); } // Some operations have extra verification logic regarding the relationship diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 04bbe0220e71..7e3697302c5b 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -167,12 +167,14 @@ FailureOr Torch::adjustFunctionArg( return b.create(loc, desiredType, operand).getResult(); } - // !torch.union is the type used for `Scalar` inputs. At - // compile time, such inputs will usually be resolved to an `int` or a `float` - // so we need to derefine to match the library function signature. + // !torch.union or !torch.union is the type used + // for (optional) `Scalar` inputs. At compile time, such inputs will usually + // be resolved to an `int` or a `float` so we need to derefine to match the + // library function signature. if (auto unionType = desiredType.dyn_cast()) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { - return containedType.isa(); + return containedType + .isa(); })) return b.create(loc, desiredType, operand).getResult(); } @@ -180,11 +182,21 @@ FailureOr Torch::adjustFunctionArg( // If the operand is NoneType, then we just need to derefine it to the // optional type in the function signature. if (operandType.isa()) { - assert(desiredType.isa() && + assert(!desiredType.isa() && "Don't expect library functions to have NoneType parameters"); return b.create(loc, desiredType, operand).getResult(); } + // To keep things simple in shape functions, `Scalar` inputs are considered + // `float`s. This is safe since output shape of torch ops never depends on the + // dtype of input scalars. However, this also means we sometimes have to + // manually turn `Scalar`s into `float`s when inserting the shape functions + // into the IR. + if (operandType.isa() && + desiredType.isa()) { + return b.create(loc, desiredType, operand).getResult(); + } + // If the operand type is statically !torch.optional, then we need to do // different things for the None and non-None cases. // For the None case, we just need to derefine it to the desired type. diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index fd58ead00367..1a2d3d545cbe 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -8,11 +8,248 @@ //===----------------------------------------------------------------------===// #include "SimplifyAbstractInterpCalculationsUtils.h" +#include "mlir/IR/IRMapping.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace { +class FoldPrimUncheckedCastOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimUncheckedCastOp op, + PatternRewriter &rewriter) const override { + if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) { + return rewriter.notifyMatchFailure( + op, "input tensor type is not a valid subtype of result type"); + } + rewriter.replaceOp(op, op.getX()); + return success(); + } +}; +} // namespace + +namespace { +// TODO: Only unroll inside the shape calculation region. +// Maybe do this by only applying patterns and folding greedily on the ops +// inside the region + the shape.calculate op itself? +class FullyUnrollPrimLoopOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimLoopOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + if (!op.isForLike()) + return rewriter.notifyMatchFailure(op, "Loop is not for-like"); + int64_t maxTripCount; + if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount))) + return rewriter.notifyMatchFailure( + op, "Expected `maxTripCount` to be a constant int"); + ; + SmallVector indices; + for (int64_t i = 0; i < maxTripCount; i++) { + // TODO: Add convenience builder. + indices.push_back(rewriter.create( + loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); + } + Block *beforeBlock = op->getBlock(); + Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); + + SmallVector blocksToMerge; + IRMapping bvm; + // TODO: Helper for region().front() + auto condition = + cast(op.getRegion().front().getTerminator()); + for (int64_t i = 0; i < maxTripCount; i++) { + SmallVector iterArgs; + if (i == 0) { + llvm::append_range(iterArgs, op.getIterArgsInit()); + } else { + llvm::append_range( + iterArgs, llvm::map_range(condition.getIterArgs(), + [&](Value v) { return bvm.lookup(v); })); + } + bvm.clear(); + bvm.map(op.getRegion().front().getArgument(0), indices[i]); + bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs); + + op.getRegion().cloneInto(afterBlock->getParent(), + afterBlock->getIterator(), bvm); + Block *clonedBlock = bvm.lookup(&op.getRegion().front()); + rewriter.eraseOp(clonedBlock->getTerminator()); + blocksToMerge.push_back(clonedBlock); + } + + blocksToMerge.push_back(afterBlock); + for (Block *block : blocksToMerge) + rewriter.mergeBlocks(block, beforeBlock); + if (maxTripCount == 0) { + rewriter.replaceOp(op, op.getIterArgsInit()); + } else { + rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range( + condition.getIterArgs(), + [&](Value v) { return bvm.lookup(v); }))); + } + return success(); + } +}; +} // namespace + +namespace { +class AbstractlyInterpretListOpsWithinABlock + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListConstructOp op, + PatternRewriter &rewriter) const override { + Block *block = op->getBlock(); + auto allUsers = llvm::to_vector<6>(op->getUsers()); + + // Sort the users into program order. + auto getParentInBlock = [&](Operation *op) { + while (op->getBlock() != block) + op = op->getParentOp(); + return op; + }; + // Use a stable sort for deterministic results when users are nested in two + // regions of the same parent op. + llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) { + return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs)); + }); + + // We cannot interpret all ops. So first do a check to see up until which + // point we can interpret. + int numUsersToInterpret = 0; + for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) { + Operation *user = allUsers[i]; + // If a user potentially mutates the list, then we require it to be in the + // same block for our simple abstract interpretation to work (we can't, + // for example, handle an "append" operation in a loop or other region). + // However, if the op is read-only, then from the purpose of our abstract + // interpretation, we can handle it effectively as though it was at the + // same position as the corresponding parent op in the block under + // consideration. + if (potentiallyMutatesListOperands(user)) { + if (user->getBlock() != block) + break; + } + } + + // Truncate the list of users to the number of users we're going to + // interpret. + allUsers.resize(numUsersToInterpret); + auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); + + // For each mutating op (which must be in the same block), we save the + // current state of the list as a vector of Value's. These will then + // be converted to PrimListConstructOp's at the correct program points. + SmallVector> listLiterals; + SmallVector runningList; + llvm::append_range(runningList, op->getOperands()); + bool generatedNewLiteral = false; + for (Operation *user : usersToInterpret) { + if (auto append = dyn_cast(user)) { + if (!append.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `AtenAppendTOp` to not have users"); + if (append.getSelf() == op) { + runningList.push_back(append.getEl()); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + if (auto insert = dyn_cast(user)) { + if (!insert.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `AtenInsertTOp` to not have users"); + int64_t index; + if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `AtenInsertTOp` to be a constant int"); + // The index might be statically out of bounds. + if (index < 0 || index > static_cast(runningList.size())) + return rewriter.notifyMatchFailure( + op, "Index in `AtenInsertTOp` is out of bounds"); + if (insert.getSelf() == op) { + runningList.insert(runningList.begin() + index, insert.getEl()); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + if (auto setItem = dyn_cast(user)) { + if (!setItem.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `Aten_SetItemTOp` to not have users"); + std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( + setItem.getIdx(), runningList.size()); + // The index might be statically out of bounds. + if (!indexOpt) + return rewriter.notifyMatchFailure( + op, "Index in `Aten_SetItemTOp` is out of bounds"); + if (setItem.getL() == op) { + runningList[*indexOpt] = setItem.getEl(); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + // If this user potentially mutates the list and isn't handled above, then + // we can't abstractly interpret any further. + if (potentiallyMutatesListOperands(user)) + break; + } + + if (!generatedNewLiteral) + return rewriter.notifyMatchFailure(op, "No new literal created"); + + // Rewrite all users to use the appropriate list literals. + Value latestLiteral = rewriter.create( + op->getLoc(), op.getType(), op->getOperands()); + int nextLiteral = 0; + for (Operation *user : usersToInterpret) { + if (auto append = dyn_cast(user)) { + rewriter.setInsertionPoint(append); + latestLiteral = rewriter.create( + append->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (append.getSelf() == op) + rewriter.eraseOp(append); + continue; + } + if (auto insert = dyn_cast(user)) { + rewriter.setInsertionPoint(insert); + latestLiteral = rewriter.create( + insert->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (insert.getSelf() == op) + rewriter.eraseOp(insert); + continue; + } + if (auto setItem = dyn_cast(user)) { + rewriter.setInsertionPoint(setItem); + latestLiteral = rewriter.create( + setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (setItem.getL() == op) + rewriter.eraseOp(setItem); + continue; + } + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get() == op.getResult()) { + opOperand.set(latestLiteral); + } + } + } + + // Any remaining uses should use the updated value of the latest literal. + rewriter.replaceOp(op, latestLiteral); + return success(); + } +}; +} // namespace + LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, int resultNum, Type newResultType, @@ -97,3 +334,18 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, return success(); } + +void mlir::torch::Torch::populateFoldPrimUncheckedCastOpPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} + +void mlir::torch::Torch::populateFullyUnrollPrimLoopOpPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} + +void mlir::torch::Torch::populateAbstractlyInterpretListOpsWithinABlockPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h index 9c618d4a27a8..172d27c00df8 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h @@ -23,6 +23,13 @@ LogicalResult updateCalculateOpResultTypes(Operation *calculateOp, int resultNum, Type newResultType, PatternRewriter &rewriter); +void populateFoldPrimUncheckedCastOpPattern(RewritePatternSet &patterns, + MLIRContext *context); +void populateFullyUnrollPrimLoopOpPattern(RewritePatternSet &patterns, + MLIRContext *context); +void populateAbstractlyInterpretListOpsWithinABlockPattern( + RewritePatternSet &patterns, MLIRContext *context); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 7b74e2c506d0..43f2b22a3d66 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -191,11 +191,16 @@ class SimplifyDtypeCalculationsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + populateFullyUnrollPrimLoopOpPattern(patterns, context); + populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context); + populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); patterns.insert(context); PrimIfOp::getCanonicalizationPatterns(patterns, context); + Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); + PrimTupleUnpackOp::getCanonicalizationPatterns(patterns, context); // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index f8d3651d9a5c..1669be7c4e62 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -10,7 +10,6 @@ #include "PassDetail.h" #include "SimplifyAbstractInterpCalculationsUtils.h" -#include "mlir/IR/IRMapping.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -19,225 +18,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -namespace { -// TODO: Only unroll inside the shape calculation region. -// Maybe do this by only applying patterns and folding greedily on the ops -// inside the region + the shape.calculate op itself? -class FullyUnrollPrimLoopOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimLoopOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - if (!op.isForLike()) - return rewriter.notifyMatchFailure(op, "Loop is not for-like"); - int64_t maxTripCount; - if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount))) - return rewriter.notifyMatchFailure( - op, "Expected `maxTripCount` to be a constant int"); - ; - SmallVector indices; - for (int64_t i = 0; i < maxTripCount; i++) { - // TODO: Add convenience builder. - indices.push_back(rewriter.create( - loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); - } - Block *beforeBlock = op->getBlock(); - Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); - - SmallVector blocksToMerge; - IRMapping bvm; - // TODO: Helper for region().front() - auto condition = - cast(op.getRegion().front().getTerminator()); - for (int64_t i = 0; i < maxTripCount; i++) { - SmallVector iterArgs; - if (i == 0) { - llvm::append_range(iterArgs, op.getIterArgsInit()); - } else { - llvm::append_range( - iterArgs, llvm::map_range(condition.getIterArgs(), - [&](Value v) { return bvm.lookup(v); })); - } - bvm.clear(); - bvm.map(op.getRegion().front().getArgument(0), indices[i]); - bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs); - - op.getRegion().cloneInto(afterBlock->getParent(), afterBlock->getIterator(), - bvm); - Block *clonedBlock = bvm.lookup(&op.getRegion().front()); - rewriter.eraseOp(clonedBlock->getTerminator()); - blocksToMerge.push_back(clonedBlock); - } - - blocksToMerge.push_back(afterBlock); - for (Block *block : blocksToMerge) - rewriter.mergeBlocks(block, beforeBlock); - if (maxTripCount == 0) { - rewriter.replaceOp(op, op.getIterArgsInit()); - } else { - rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range( - condition.getIterArgs(), - [&](Value v) { return bvm.lookup(v); }))); - } - return success(); - } -}; -} // namespace - -namespace { -class AbstractlyInterpretListOpsWithinABlock - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimListConstructOp op, - PatternRewriter &rewriter) const override { - Block *block = op->getBlock(); - auto allUsers = llvm::to_vector<6>(op->getUsers()); - - // Sort the users into program order. - auto getParentInBlock = [&](Operation *op) { - while (op->getBlock() != block) - op = op->getParentOp(); - return op; - }; - // Use a stable sort for deterministic results when users are nested in two - // regions of the same parent op. - llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) { - return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs)); - }); - - // We cannot interpret all ops. So first do a check to see up until which - // point we can interpret. - int numUsersToInterpret = 0; - for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) { - Operation *user = allUsers[i]; - // If a user potentially mutates the list, then we require it to be in the - // same block for our simple abstract interpretation to work (we can't, - // for example, handle an "append" operation in a loop or other region). - // However, if the op is read-only, then from the purpose of our abstract - // interpretation, we can handle it effectively as though it was at the - // same position as the corresponding parent op in the block under - // consideration. - if (potentiallyMutatesListOperands(user)) { - if (user->getBlock() != block) - break; - } - } - - // Truncate the list of users to the number of users we're going to - // interpret. - allUsers.resize(numUsersToInterpret); - auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); - - // For each mutating op (which must be in the same block), we save the - // current state of the list as a vector of Value's. These will then - // be converted to PrimListConstructOp's at the correct program points. - SmallVector> listLiterals; - SmallVector runningList; - llvm::append_range(runningList, op->getOperands()); - bool generatedNewLiteral = false; - for (Operation *user : usersToInterpret) { - if (auto append = dyn_cast(user)) { - if (!append.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `AtenAppendTOp` to not have users"); - if (append.getSelf() == op) { - runningList.push_back(append.getEl()); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - if (auto insert = dyn_cast(user)) { - if (!insert.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `AtenInsertTOp` to not have users"); - int64_t index; - if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index))) - return rewriter.notifyMatchFailure( - op, "Expected `idx` of `AtenInsertTOp` to be a constant int"); - // The index might be statically out of bounds. - if (index < 0 || index > static_cast(runningList.size())) - return rewriter.notifyMatchFailure( - op, "Index in `AtenInsertTOp` is out of bounds"); - if (insert.getSelf() == op) { - runningList.insert(runningList.begin() + index, insert.getEl()); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - if (auto setItem = dyn_cast(user)) { - if (!setItem.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `Aten_SetItemTOp` to not have users"); - std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( - setItem.getIdx(), runningList.size()); - // The index might be statically out of bounds. - if (!indexOpt) - return rewriter.notifyMatchFailure( - op, "Index in `Aten_SetItemTOp` is out of bounds"); - if (setItem.getL() == op) { - runningList[*indexOpt] = setItem.getEl(); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - // If this user potentially mutates the list and isn't handled above, then - // we can't abstractly interpret any further. - if (potentiallyMutatesListOperands(user)) - break; - } - - if (!generatedNewLiteral) - return rewriter.notifyMatchFailure(op, "No new literal created"); - - // Rewrite all users to use the appropriate list literals. - Value latestLiteral = rewriter.create( - op->getLoc(), op.getType(), op->getOperands()); - int nextLiteral = 0; - for (Operation *user : usersToInterpret) { - if (auto append = dyn_cast(user)) { - rewriter.setInsertionPoint(append); - latestLiteral = rewriter.create( - append->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (append.getSelf() == op) - rewriter.eraseOp(append); - continue; - } - if (auto insert = dyn_cast(user)) { - rewriter.setInsertionPoint(insert); - latestLiteral = rewriter.create( - insert->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (insert.getSelf() == op) - rewriter.eraseOp(insert); - continue; - } - if (auto setItem = dyn_cast(user)) { - rewriter.setInsertionPoint(setItem); - latestLiteral = rewriter.create( - setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (setItem.getL() == op) - rewriter.eraseOp(setItem); - continue; - } - for (OpOperand &opOperand : user->getOpOperands()) { - if (opOperand.get() == op.getResult()) { - opOperand.set(latestLiteral); - } - } - } - - // Any remaining uses should use the updated value of the latest literal. - rewriter.replaceOp(op, latestLiteral); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -266,22 +46,6 @@ class DecomposeAtenSizeOp : public OpRewritePattern { }; } // namespace -namespace { -class FoldPrimUncheckedCastOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimUncheckedCastOp op, - PatternRewriter &rewriter) const override { - if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) { - return rewriter.notifyMatchFailure( - op, "input tensor type is not a valid subtype of result type"); - } - rewriter.replaceOp(op, op.getX()); - return success(); - } -}; -} // namespace - static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, PatternRewriter &rewriter) { @@ -367,11 +131,11 @@ class SimplifyShapeCalculationsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.insert(context); - patterns.insert(context); + populateFullyUnrollPrimLoopOpPattern(patterns, context); + populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context); + populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); - patterns.insert(context); PrimIfOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index fd1a0ebe558f..2a175f5a031c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -12,7 +12,7 @@ import torch.jit._shape_functions as upstream_shape_functions from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function -from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype +from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype, get_priority_of_dtype, all_integer_dtypes, all_float_dtypes, all_complex_dtypes # ============================================================================== # Shape Functions @@ -883,12 +883,17 @@ def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[in Invocation(TensorOfShape(2, 3), None, None, TensorOfShape(3), TensorOfShape(3), False, 1e-4, 1e-6), # Inference basic case. Invocation(TensorOfShape(2, 3, 4, 5, 6), None, None, None, None, True, 1e-4, 1e-6), # Training high-D case. Invocation(TensorOfShape(2, 3, 4, 5, 6), None, None, TensorOfShape(3), TensorOfShape(3), False, 1e-4, 1e-6), # Inference high-D case. - ErrorInvocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # Dimensionality too low. + Invocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # 1D input. ]) def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float) -> Tuple[List[int], List[int], List[int]]: if training: - return input, [input[1]], [input[1]] - return input, [0], [0] + if len(input) >= 2: + return input, [input[1]], [input[1]] + else: + return input, [], [] + running_mean_list: List[int] = [0] if running_mean is None else running_mean + running_var_list: List[int] = [0] if running_var is None else running_var + return input, running_mean_list, running_var_list # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. @@ -1083,6 +1088,7 @@ def _check_tensors_with_the_same_dtype( return invocations def _check_two_tensor_op( + tensor_shapes: Optional[list[tuple[int]]] = None, input_error_types: Optional[set[int]] = None, output_error_types: Optional[set[int]] = None, **kwargs): """Generate invocations for basic two-tensor dtype functions. @@ -1096,6 +1102,10 @@ def _check_two_tensor_op( must be specified in `input_error_types` and `output_error_types` to ensure the invocations are error invocations. """ + if tensor_shapes is None: + tensor_shapes = [(1,), (1,)] + shape_1, shape_2 = tensor_shapes + if input_error_types is not None and output_error_types is not None: assert len(input_error_types.intersection(output_error_types)) == 0, \ "An invalid input type implies an invalid output type, " \ @@ -1133,18 +1143,16 @@ def check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs): invocations = [] for type_ in _SORTED_TORCH_TYPES[constant_type_index:]: - tensor_1 = NonZeroDTensorWithDtype(type_) - tensor_2 = NonZeroDTensorWithDtype(constant_type) if input_error_types is not None and type_ in input_error_types: - invocations += [ErrorInvocation(tensor_1, tensor_2, **kwargs), - ErrorInvocation(tensor_2, tensor_1, **kwargs)] + invocation_type = ErrorInvocation else: - invocations += [Invocation(tensor_1, tensor_2, **kwargs), - Invocation(tensor_2, tensor_1, **kwargs)] + invocation_type = Invocation + invocations += [invocation_type(TensorOfShape(*shape_1, dtype=type_), TensorOfShape(*shape_2, dtype=constant_type), **kwargs), + invocation_type(TensorOfShape(*shape_1, dtype=constant_type), TensorOfShape(*shape_2, dtype=type_), **kwargs)] return invocations same_dtype_invocations = _check_tensors_with_the_same_dtype( - num_of_tensors=2, error_types=all_error_types, **kwargs) + tensor_shapes=tensor_shapes, error_types=all_error_types, **kwargs) varying_dtype_invocations = \ check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs) @@ -1156,42 +1164,34 @@ def _get_dtype_of_floating_point_op(input_dtype: int) -> int: return input_dtype return torch.float32 -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert self_dtype != torch.float16, "`self` cannot have float16 dtype" return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @@ -1199,49 +1199,41 @@ def aten〇reciprocal〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇log〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype) and self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype( - num_of_tensors=1, - error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, - torch.int64, torch.float16, torch.complex64, torch.complex128})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype) and not is_integer_dtype(self_dtype) and self_dtype != torch.float16 + if is_integer_dtype(self_dtype): + return self_dtype return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype( @@ -1256,525 +1248,1851 @@ def aten〇frobenius_norm〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: L return torch.float32 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.float16 + if is_integer_dtype(self_dtype): + return self_dtype return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype( - None, [(3,), (3, 4)], - {torch.complex128, torch.complex64, torch.float16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool}, - TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)) + - [ErrorInvocation(TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, 4, dtype=torch.float64), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)), - ErrorInvocation(TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32))]) -def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int, total_weight_rank_dtype: Tuple[int, int]) -> int: - grad_output_rank, grad_output_dtype = grad_output_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex128: + return torch.float64 + elif self_dtype == torch.complex64: + return torch.float32 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) +def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype - assert grad_output_dtype == self_dtype, "`grad_output` and `self` must have the same dtype" - assert not is_complex_dtype(self_dtype), "`self` cannot have complex dtype" - assert not is_integer_dtype(self_dtype), "`self` cannot have integer dtype" - assert self_dtype != torch.float16, "`self` cannot have float16 dtype" return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype( - None, [(2, 4, 7, 6), (2, 4, 6, 5)], - {torch.complex128, torch.complex64, torch.float16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool}, - [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) + - [ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float32), TensorOfShape(2, 4, 6, 5, dtype=torch.float64), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)), - ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float64), TensorOfShape(2, 4, 6, 5, dtype=torch.float32), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64))]) -def aten〇max_pool2d_with_indices_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices_rank_dtype: Tuple[int, int]) -> int: - grad_output_rank, grad_output_dtype = grad_output_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype - assert grad_output_dtype == self_dtype, "`grad_output` and `self` must have the same dtype" - assert not is_complex_dtype(self_dtype), "`self` cannot have complex dtype" - assert not is_integer_dtype(self_dtype), "`self` cannot have integer dtype" - assert self_dtype != torch.float16, "`self` cannot have float16 dtype" return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype( + tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) +def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int: +def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype - return torch.uint8 if self_dtype == torch.uint8 else torch.bool + return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: +def aten〇bernoulli〡dtype(self_rank_dtype: Tuple[int, int], generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype - return torch.uint8 if self_dtype == torch.uint8 else torch.bool + return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + - _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: - return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇bernoulli〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], p_rank_dtype: Tuple[int, int], generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype -@check_dtype_function(_check_two_tensor_op()) -def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bitwise_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - return torch.bool + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - return torch.bool + return self_dtype -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇gt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) +def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert not is_complex_dtype(other_dtype), "`self` cannot be complex" - return torch.bool + if self_dtype == torch.bool: + return torch.int64 + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0)) +def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - return torch.bool + if self_dtype == torch.bool: + return torch.int64 + return self_dtype -@check_dtype_function(_check_two_tensor_op()) -def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1)) +def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: - return torch.bool +def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype -@check_dtype_function(_check_two_tensor_op()) -def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype @check_dtype_function(_check_two_tensor_op()) -def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - return torch.bool +def aten〇copy〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], non_blocking: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) -def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: +# TODO: Cannot call .cpu() on meta tensor due to: Cannot copy out of meta tensor; no data! +def aten〇cpu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - return torch.bool + return self_dtype @check_dtype_function( - _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇lt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert not is_complex_dtype(other_dtype), "`self` cannot be complex" - return torch.bool + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + - _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) -def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: - return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype -@check_dtype_function([ - Invocation(0.0, 0.0), # float, float - Invocation(0.0, 0), # float, int - Invocation(0, 0.0), # int, float - Invocation(0, 0), # int, int -]) -def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: - ranks: List[Optional[int]] = [None, None] - dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] - return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) +def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.float16, torch.bfloat16})) -def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: +@check_dtype_function(_check_two_tensor_op()) +def aten〇expand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - if is_complex_dtype(self_dtype): - return self_dtype - elif self_dtype == torch.float: - return torch.complex64 - elif self_dtype == torch.double: - return torch.complex128 - elif is_integer_dtype(self_dtype): - return torch.complex64 - else: - assert False, "Unsupported dtype" + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.bool}, other=0.0) + - _check_tensors_with_the_same_dtype( - num_of_tensors=1, error_types={torch.bool}, other=0)) -def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇expand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], implicit: bool = False) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.bool, "`self` cannot have bool dtype" - return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) + return self_dtype -@check_dtype_function(_check_two_tensor_op( - input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0)) +def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: self_rank, self_dtype = self_rank_dtype - assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" - assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - return promote_dtypes(ranks, dtypes) + return self_dtype -@check_dtype_function(_check_two_tensor_op()) -def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1,), ()])) +def aten〇fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - return promote_dtypes(ranks, dtypes) + return self_dtype -@check_dtype_function(_check_two_tensor_op( - input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_dim: int = 0, end_dim: int = -1) -> int: self_rank, self_dtype = self_rank_dtype - assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" - assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - return promote_dtypes(ranks, dtypes) + return self_dtype -@check_dtype_function(_check_two_tensor_op( - input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: self_rank, self_dtype = self_rank_dtype - assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" - assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - return promote_dtypes(ranks, dtypes) + return self_dtype -@check_dtype_function(_check_two_tensor_op( - input_error_types={torch.float16, torch.bfloat16, torch.float32, - torch.float64, torch.complex64, torch.complex128})) -def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇floor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" - assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - return promote_dtypes(ranks, dtypes) + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - tensor_shapes=[(2, 3, 4), (2, 4, 3)], error_types={torch.float16, torch.bool}) + - # Different width - [ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float64), - TensorOfShape(2, 4, 3, dtype=torch.float32)), - # Different type - ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), - TensorOfShape(2, 4, 3, dtype=torch.int32))]) -def aten〇bmm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: - mat2_rank, mat2_dtype = mat2_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇gather〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], sparse_grad: bool = False) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `self` to not be float16 or bool" - assert mat2_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `mat2` to not be float16 or bool" - assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype" return self_dtype @check_dtype_function(_check_two_tensor_op()) -def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +def aten〇gelu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype self_rank, self_dtype = self_rank_dtype - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] promoted_dtype = promote_dtypes(ranks, dtypes) - if is_complex_dtype(promoted_dtype) or \ - (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): - return promoted_dtype - else: - return torch.float32 + return promoted_dtype -@check_dtype_function(_check_two_tensor_op(rounding_mode=None)) -def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rounding_mode: Optional[str]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇gelu〡dtype(self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: self_rank, self_dtype = self_rank_dtype - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - promoted_dtype = promote_dtypes(ranks, dtypes) - if is_complex_dtype(promoted_dtype) or \ - (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): - return promoted_dtype - else: - return torch.float32 + return self_dtype -@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool})) -def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardsigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert not is_complex_dtype(other_dtype), "`other` cannot be complex" - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - promoted_dtype = promote_dtypes(ranks, dtypes) - assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool" - return promoted_dtype + return self_dtype -@check_dtype_function( - _check_tensors_with_the_same_dtype( - tensor_shapes=[(2, 3, 4), (2, 4, 3)], error_types={torch.float16, torch.bool}) + - # Different width - [ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float64), - TensorOfShape(2, 4, 3, dtype=torch.float32)), - # Different type - ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), - TensorOfShape(2, 4, 3, dtype=torch.int32))]) -def aten〇matmul〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `self` to not be float16 or bool" - assert other_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `other` to not be float16 or bool" - assert self_dtype == other_dtype, "`self` and `other` must have the same dtype" return self_dtype -@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇maximum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5)) +def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + if is_integer_dtype(grad_output_dtype): + return torch.float32 + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool})) +def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert self_dtype not in [torch.uint8, torch.bool] + return self_dtype + +_index_put_invocations = [ + # same dtype + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES +] + [ + # different dtypes + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] + [ + # index dtype + Invocation(TensorOfShape(3, dtype=torch.float32), [TensorOfShape(3, dtype=dtype)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] +@check_dtype_function(_index_put_invocations) +def aten〇index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor_hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1])) +def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> int: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype + +@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False)) +def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_two_tensor_op(dim=0, input_dtype=torch.float32) + + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) +def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), TensorOfShape(dtype=torch.float32))) +def aten〇masked_fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +# TODO: This op cannot be run on the Meta backend. We should have the test suite default on CPU when the Meta backend does not work. +def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, start=0, length=1)) +def aten〇narrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int, length: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇neg〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇numpy_T〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇pad〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇permute〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇pow〇Tensor_Tensor〡dtype(self_rank_dtype: Tuple[int, int], exponent_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + exponent_rank, exponent_dtype = exponent_rank_dtype + ranks: List[Optional[int]] = [self_rank, exponent_rank] + dtypes = [self_dtype, exponent_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if promoted_dtype == torch.bool: + return torch.int64 + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2) + + [ErrorInvocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64)), + ErrorInvocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32))]) +def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert self_dtype == weight_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇relu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=[1])) +def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1])) +def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shape=[1])) +def aten〇reshape〡dtype(self_rank_dtype: Tuple[int, int], shape: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇resize_〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shifts=[0], dims=[0])) +def aten〇roll〡dtype(self_rank_dtype: Tuple[int, int], shifts: List[int], dims: List[int] = ()) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_reduce〇two〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, index=0)) +def aten〇select〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(tensor_shapes=[(1, 1), (1,)], dim=0, index=0)) +def aten〇select_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(dim=0)) +def aten〇slice_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) + + [Invocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64), dim=0, input_dtype=torch.float32), + Invocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32), dim=0, input_dtype=torch.float32)]) +def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇square〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇squeeze〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇squeeze〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇tanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + output_rank, output_dtype = output_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, output_rank] + dtypes = [grad_output_dtype, output_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0)) +def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇t〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, device=torch.device("meta"))) +def aten〇to〇prim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1)) +def aten〇transpose〇int〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)])) +def aten〇triu〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇_unsafe_view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇unsqueeze〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 4, 8)], output_size=[4, 8], input_size=[1, 1, 2, 3])) +def aten〇upsample_nearest2d_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13])) +def aten〇upsample_nearest2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function([Invocation(-1), Invocation(-1.0)]) +def prim〇abs〇Scalar〡dtype(a: Union[int, float]) -> int: + return get_dtype_of_scalar(a) + +@check_dtype_function(_check_tensors_with_the_same_dtype( + None, [(3,), (3, 4)], None, + TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)) + + [Invocation(TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, 4, dtype=torch.float64), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)), + Invocation(TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32))]) +def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int, total_weight_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, grad_output_rank] + dtypes = [self_dtype, grad_output_dtype] + result = promote_dtypes(ranks, dtypes) + if result == torch.bool: + return torch.int64 + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype( + None, [(2, 4, 7, 6), (2, 4, 6, 5)], None, + [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) + + [ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float32), TensorOfShape(2, 4, 6, 5, dtype=torch.float64), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)), + ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float64), TensorOfShape(2, 4, 6, 5, dtype=torch.float32), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64))]) +def aten〇max_pool2d_with_indices_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert grad_output_dtype == self_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.uint8 if self_dtype == torch.uint8 else torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.uint8 if self_dtype == torch.uint8 else torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇gt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇ge〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇lt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇le〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function([ + Invocation(0.0, 0.0), # float, float + Invocation(0.0, 0), # float, int + Invocation(0, 0.0), # int, float + Invocation(0, 0), # int, int +]) +def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: + ranks: List[Optional[int]] = [None, None] + dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16})) +def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype): + return self_dtype + elif self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + + # Different width + [Invocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float16), + TensorOfShape(2, 4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇bmm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype + mat2_priority = get_priority_of_dtype(mat2_dtype) + self_priority = get_priority_of_dtype(self_dtype) + return mat2_dtype if mat2_priority < self_priority else self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(rounding_mode=None)) +def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rounding_mode: Optional[str]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool})) +def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool" + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + + # Different width + [Invocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float16), + TensorOfShape(2, 4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇matmul〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + other_priority = get_priority_of_dtype(other_dtype) + self_priority = get_priority_of_dtype(self_dtype) + return other_dtype if other_priority < self_priority else self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇maximum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4, 3)]) + + # Different width + [Invocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(3, 4, dtype=torch.float16), + TensorOfShape(4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32))]) +def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype + + float16_types = [torch.bfloat16, torch.float16] + if self_dtype in float16_types and mat2_dtype in float16_types and self_dtype != mat2_dtype: + return torch.float16 + + ranks: List[Optional[int]] = [self_rank, mat2_rank] + dtypes = [self_dtype, mat2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op( + output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) +def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert not is_integer_dtype(promoted_dtype) + return promoted_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4,)]) + + # Different width + [Invocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(3, 4, dtype=torch.float16), + TensorOfShape(4, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, dtype=torch.int32))]) +def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + vec_rank, vec_dtype = vec_rank_dtype + ranks: List[Optional[int]] = [self_rank, vec_rank] + dtypes = [self_dtype, vec_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +# TODO: This op has incosistent behavior when run on CPU vs META +# If the following check is made to pass, e2e tests fail +# @check_dtype_function(_check_two_tensor_op(threshold=0)) +def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + ranks: List[Optional[int]] = [self_rank, grad_output_rank] + dtypes = [self_dtype, grad_output_dtype] + return promote_dtypes(ranks, dtypes) + +# TODO: This op fails when using meta backend with error: +# Op raised error 'convolution_overrideable not implemented. +# You are likely triggering this with tensor backend other than +# CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL +# to override this function ' but dtype function did not raise any error. +# +# This is similar to https://github.com/pytorch/pytorch/issues/97481 +def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +# TODO: This op fails when using meta backend with error: +# Op raised error 'convolution_overrideable not implemented. +# You are likely triggering this with tensor backend other than +# CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL +# to override this function ' but dtype function did not raise any error. +# +# This is similar to https://github.com/pytorch/pytorch/issues/97481 +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], **convolution_kwargs) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs) +]) +def aten〇convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +convolution_backward_kwargs = { + "bias_sizes" : [1], "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1, "output_mask" : [True, True, True]} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1)], + **convolution_backward_kwargs) + + # dtype of first three tensors must be float + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.int32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float64), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float64), **convolution_backward_kwargs), +]) +def aten〇convolution_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[int, int, int]: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + return grad_output_dtype, grad_output_dtype, grad_output_dtype + +# TODO: Currently unable to test because the op `torch.ops.aten.convolution_backward_overrideable` +# fails to run on the CPU backend. A bug has been filed upstream: https://github.com/pytorch/pytorch/issues/97481 +# The dtype function for this op is unlikely to be different from `torch.ops.aten.convolution_backward` +# which is tested. +def aten〇convolution_backward_overrideable〡dtype(grad_output_rank_dtype: Tuple[int, int], input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[int, int, int]: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert grad_output_dtype == input_dtype + assert weight_dtype == input_dtype + assert is_float_dtype(input_dtype) and input_dtype != torch.float16 + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + return grad_output_dtype, grad_output_dtype, grad_output_dtype + +# TODO: This op cannot be run on the Meta backend. We should have the test suite default on CPU when the Meta backend does not work. +def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype: Optional[Tuple[int, int]] = None, minlength: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_integer_dtype(self_dtype) and self_dtype != torch.bool + if weights_rank_dtype is None: + return torch.int64 + return torch.float64 + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + mat1_rank, mat1_dtype = mat1_rank_dtype + mat2_rank, mat2_dtype = mat2_rank_dtype + + ranks: List[Optional[int]] = [self_rank, mat1_rank, mat2_rank] + dtypes = [self_dtype, mat1_dtype, mat2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + Invocation(TensorOfShape(4, 3, dtype=torch.int32), + TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + end_rank, end_dtype = end_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + + ranks: List[Optional[int]] = [self_rank, end_rank, weight_rank] + dtypes = [self_dtype, end_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float64), + TensorOfShape(3, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32))]) +def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert not is_complex_dtype(other_dtype), "`other` cannot be complex" - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype + + assert self_dtype != torch.bool + assert tensor1_dtype != torch.bool + assert tensor2_dtype != torch.bool + + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype( - tensor_shapes=[(3, 4), (4, 3)], error_types={torch.float16, torch.bool}) + + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + # Different width - [ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float64), - TensorOfShape(4, 3, dtype=torch.float32)), + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float64), + TensorOfShape(3, 3, dtype=torch.float32)), # Different type - ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.int32))]) -def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: - mat2_rank, mat2_dtype = mat2_rank_dtype + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32))]) +def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `self` to not be float16 or bool" - assert mat2_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `mat2` to not be float16 or bool" - assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype" - return self_dtype + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype -@check_dtype_function(_check_two_tensor_op( - input_error_types={torch.complex64, torch.complex128}, - output_error_types={torch.bool, torch.uint8, torch.int8, torch.int16, - torch.int32, torch.int64, torch.bfloat16})) -def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] + result = promote_dtypes(ranks, dtypes) + if is_integer_dtype(result): + return torch.float32 + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: self_rank, self_dtype = self_rank_dtype - target_rank, target_dtype = target_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert not is_complex_dtype(target_dtype), "`target` cannot be complex" - ranks: List[Optional[int]] = [self_rank, target_rank] - dtypes = [self_dtype, target_dtype] + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] promoted_dtype = promote_dtypes(ranks, dtypes) - assert is_float_dtype(promoted_dtype) and promoted_dtype != torch.bfloat16, \ - "Expected promoted dtype to be float but not `bfloat16`" + if is_integer_dtype(promoted_dtype): + return torch.float32 return promoted_dtype -@check_dtype_function(_check_two_tensor_op()) -def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: self_rank, self_dtype = self_rank_dtype - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype( - tensor_shapes=[(3, 4), (4,)], error_types={torch.float16, torch.bool}) + - # Different width - [ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float64), - TensorOfShape(4, dtype=torch.float32)), - # Different type - ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), - TensorOfShape(4, dtype=torch.int32))]) -def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[int, int]) -> int: + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) +def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: self_rank, self_dtype = self_rank_dtype - vec_rank, vec_dtype = vec_rank_dtype - assert self_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `self` to not be float16 or bool" - assert vec_dtype not in [torch.float16, torch.bool], \ - "Expected dtype of `vec` to not be float16 or bool" - assert self_dtype == vec_dtype, "`self` and `vec` must have the same dtype" - ranks: List[Optional[int]] = [self_rank, vec_rank] - dtypes = [self_dtype, vec_dtype] + assert not is_complex_dtype(self_dtype) + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) -@check_dtype_function(_check_two_tensor_op(input_error_types={torch.bool})) -def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: - other_rank, other_dtype = other_rank_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) +def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(exponent)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, negative_slope=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, negative_slope=1.0)) +def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + ranks: List[Optional[int]] = [self_rank, None] + negative_slope_dtype = get_dtype_of_scalar(negative_slope) + if is_float_dtype(negative_slope_dtype): + assert not is_integer_dtype(self_dtype) + dtypes = [self_dtype, negative_slope_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: self_rank, self_dtype = self_rank_dtype - assert self_dtype != torch.bool, "`self` cannot be of bool dtype" - assert other_dtype != torch.bool, "`other` cannot be of bool dtype" + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)]) + + [Invocation(TensorOfShape( + 1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.int16), TensorOfShape(1, 1, 1, dtype=torch.int32)), + Invocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.int64), TensorOfShape(1, 1, 1, dtype=torch.float16)), + Invocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, dtype=torch.int64)), + Invocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, dtype=torch.float16))]) +def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + batch2_rank, batch2_dtype = batch2_rank_dtype + return batch2_dtype + +@check_dtype_function([ + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), NonZeroDTensorWithDtype(torch.int32)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.bfloat16), NonZeroDTensorWithDtype(torch.float16))]) +def aten〇where〇self〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool, torch.float16}, threshold=0)) -def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)]) +def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int: + if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)): + return torch.int64 + return torch.float32 + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)]) +def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: self_rank, self_dtype = self_rank_dtype - grad_output_rank, grad_output_dtype = grad_output_rank_dtype - assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex" - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - ranks: List[Optional[int]] = [grad_output_rank, self_rank] - dtypes = [grad_output_dtype, self_dtype] - promoted_dtype = promote_dtypes(ranks, dtypes) - assert promoted_dtype not in [torch.bool, torch.float16], \ - "Result dtype for aten.threshold_backward cannot be bool or float16" - return promoted_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.int16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))]) +def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [None, other_rank] + dtypes = [get_dtype_of_scalar(self), other_dtype] + return promote_dtypes(ranks, dtypes) -_convolution_kwargs = { - "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], - "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False, "allow_tf32" : False} @check_dtype_function( - _check_tensors_with_the_same_dtype( - tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) + - [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), - TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), - TensorOfShape(1, dtype=torch.float32), **_convolution_kwargs) -]) -def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: + [Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64), + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int32), # target must be int64 + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.float64), # target must be int64 + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.float64), TensorOfShape(2, dtype=torch.int64), # self and weight must have same dtype + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.int32), TensorOfShape(2, dtype=torch.int64), # self and weight must be float + TensorOfShape(3, dtype=torch.int32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex64), TensorOfShape(2, dtype=torch.int64), # self and weight must be float + TensorOfShape(3, dtype=torch.complex64), reduction=0, ignore_index=0)]) +def aten〇nll_loss_forward〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + assert target_dtype == torch.int64 + return self_dtype, self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float64), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), eps=0.0), + # Input must be float or complex + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.int32), [3], TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex64), [3], TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex128), [3], TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), eps=0.0), + ]) +def aten〇native_layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], eps: float) -> Tuple[int, int, int]: input_rank, input_dtype = input_rank_dtype - weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] - ranks: List[Optional[int]] = [input_rank, weight_rank] - dtypes = [input_dtype, weight_dtype] - return promote_dtypes(ranks, dtypes) + assert not is_integer_dtype(input_dtype) + result_dtype = input_dtype + if input_dtype == torch.complex64: + result_dtype = torch.float32 + if input_dtype == torch.complex128: + result_dtype = torch.float64 + return input_dtype, input_dtype, result_dtype -_convolution_deprecated_kwargs = { - "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], - "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} @check_dtype_function( - _check_tensors_with_the_same_dtype( - tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + - [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) -]) -def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + # Tensors with different dtype + Invocation(TensorOfShape(3, 3, dtype=torch.float64), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), training=False, momentum=0.0, eps=0.0), + # Non-float tensors + Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.complex64), TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), training=False, momentum=0.0, eps=0.0), + ]) +def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float) -> Tuple[int, int, int]: input_rank, input_dtype = input_rank_dtype - weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] - ranks: List[Optional[int]] = [input_rank, weight_rank] - dtypes = [input_dtype, weight_dtype] - return promote_dtypes(ranks, dtypes) + result_dtype = input_dtype + if is_integer_dtype(input_dtype): + result_dtype = torch.float32 + return input_dtype, input_dtype, result_dtype + +@check_dtype_function([Invocation(end=0, dtype=None), # No floats + Invocation(end=0.0, dtype=None), # One float + ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified + Invocation(end=0, dtype=torch.float16), # Dtype specified + Invocation(end=0, dtype=torch.int16)]) # Dtype specified +def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, dtype=None), # No floats + Invocation(start=0.0, end=10, dtype=None), # One float + Invocation(start=0, end=10.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, step=1, dtype=None), # No floats + Invocation(start=0.0, end=10, step=1, dtype=None), # One float + Invocation(start=0, end=10.0, step=1, dtype=None), # One float + Invocation(start=0, end=10, step=1.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)) or \ + is_float_dtype(get_dtype_of_scalar(step)): + return torch.float32 + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇sum〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64)) +def aten〇sum〇dim_IntList〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + return aten〇sum〡dtype(self_rank_dtype, dtype) @check_dtype_function( _check_tensors_with_the_same_dtype( - tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)], error_types={torch.bool, torch.float16}) + - [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) -]) -def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int: - input_rank, input_dtype = input_rank_dtype - weight_rank, weight_dtype = weight_rank_dtype - assert input_dtype not in [torch.bool, torch.float16] - assert weight_dtype not in [torch.bool, torch.float16] - ranks: List[Optional[int]] = [input_rank, weight_rank] - dtypes = [input_dtype, weight_dtype] - return promote_dtypes(ranks, dtypes) + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dim=None, dtype=torch.int32)]) +def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + result = aten〇sum〡dtype(self_rank_dtype, dtype) + assert not is_integer_dtype(result) + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.uint8: + return self_dtype + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int: + return aten〇max〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: + return aten〇max〡dtype(self_rank_dtype), torch.int64 @check_dtype_function( _check_tensors_with_the_same_dtype( - tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)], - error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16}) + - [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) -]) -def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇mean〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + return aten〇mean〇dim〡dtype(self_rank_dtype, dim=None, dtype=dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex64: + return torch.float32 + if self_dtype == torch.complex128: + return torch.float64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇std〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇var〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) +def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> int: + return aten〇std〡dtype(inp_rank_dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if dtype is not None: + assert not is_integer_dtype(dtype) + if is_complex_dtype(self_dtype): + assert is_complex_dtype(dtype) + return aten〇std〡dtype((self_rank, dtype)) + assert not is_complex_dtype(dtype) + return dtype + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function([Invocation(0.0), + Invocation(0.0, dtype=torch.int32), + Invocation(0.0, dtype=torch.float16), + Invocation(0.0, dtype=torch.complex64)]) +def aten〇tensor〇float〡dtype(t: float, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.float32 + return dtype + +@check_dtype_function([Invocation(0), + Invocation(0, dtype=torch.int32), + Invocation(0, dtype=torch.float16), + Invocation(0, dtype=torch.complex64)]) +def aten〇tensor〇int〡dtype(t: int, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.int64 + return dtype + +@check_dtype_function([Invocation(True), + Invocation(True, dtype=torch.int32), + Invocation(True, dtype=torch.float16), + Invocation(True, dtype=torch.complex64)]) +def aten〇tensor〇bool〡dtype(t: bool, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.bool + return dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇zeros〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇ones〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇empty〇memory_format〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1], 0.0), + Invocation([1], 0), + Invocation([1], 0.0, dtype=torch.int32), + Invocation([1], 0.0, dtype=torch.float16), + Invocation([1], 0.0, dtype=torch.complex64)]) +def aten〇full〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + return dtype + fill_value_dtype = get_dtype_of_scalar(fill_value) + if is_float_dtype(fill_value_dtype): + return torch.float32 + return fill_value_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇zeros_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇ones_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64)) +def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_zeros〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_ones〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_empty〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.complex64)) +def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇rand_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇randn_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇_to_copy〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype〡dtype(self_rank_dtype: Tuple[int, int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def nvprims〇convert_element_type〡dtype(a_rank_dtype: Tuple[int, int], dtype: int) -> int: + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype_layout〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.complex64)) +def aten〇to〇device〡dtype(self_rank_dtype: Tuple[int, int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇to〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + +@check_dtype_function([Invocation(low=0, high=10, size=[1]), + Invocation(low=0, high=10, size=[1], dtype=torch.float32), + Invocation(low=0, high=10, size=[1], dtype=torch.int32), + ErrorInvocation(low=0, high=10, size=[1], dtype=torch.complex64)]) +def aten〇randint〇low〡dtype(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.int64 + assert not is_complex_dtype(dtype) + return dtype + +@check_dtype_function([Invocation(size=[1]), + Invocation(size=[1], dtype=torch.float32), + ErrorInvocation(size=[1], dtype=torch.int32), + Invocation(size=[1], dtype=torch.complex64)]) +def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + assert not is_integer_dtype(dtype) + return dtype + +@check_dtype_function([Invocation(size=[1], generator=None), + Invocation(size=[1], generator=None, dtype=torch.float32), + ErrorInvocation(size=[1], generator=None, dtype=torch.int32), + Invocation(size=[1], generator=None, dtype=torch.complex64)]) +def aten〇randn〇generator〡dtype(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + assert not is_integer_dtype(dtype) + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) +def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex64: + return torch.float32, self_dtype + if self_dtype == torch.complex128: + return torch.float64, self_dtype + return self_dtype, self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) +def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex64: + return torch.float32, self_dtype + if self_dtype == torch.complex128: + return torch.float64, self_dtype + return self_dtype, self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_integer_dtype(promoted_dtype): + return torch.float32 + return promoted_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert (input_dtype == torch.int64 or not is_integer_dtype(input_dtype)) and input_dtype not in [torch.float16, torch.bfloat16] - assert (weight_dtype == torch.int64 or not is_integer_dtype(weight_dtype)) and weight_dtype not in [torch.float16, torch.bfloat16] - ranks: List[Optional[int]] = [input_rank, weight_rank] - dtypes = [input_dtype, weight_dtype] + return input_dtype + +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) return promote_dtypes(ranks, dtypes) -convolution_kwargs = { - "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.int64 + +# Does not work on meta backend +#@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[()])) +def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int: + a_rank, a_dtype = a_rank_dtype + return a_dtype + +@check_dtype_function([Invocation(0), Invocation(0.0)]) +def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float]) -> int: + return get_dtype_of_scalar(a) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) +def aten〇softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if dtype is None: + return self_dtype + return dtype + @check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + _check_tensors_with_the_same_dtype( - tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], - error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **convolution_kwargs) + - [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), - TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), - TensorOfShape(1, dtype=torch.float32), **convolution_kwargs) -]) -def aten〇convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> int: - input_rank, input_dtype = input_rank_dtype + num_of_tensors=1, + error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), + dim=0, half_to_float=True)) +def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int: + self_rank, self_dtype = self_rank_dtype + if half_to_float: + assert self_dtype == torch.float16 + return torch.float32 + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), + dim=0, half_to_float=True)) +def aten〇_log_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int: + self_rank, self_dtype = self_rank_dtype + if half_to_float: + assert self_dtype == torch.float16 + return torch.float32 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) +def aten〇log_softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if dtype is None: + return self_dtype + return dtype + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇embedding〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> int: weight_rank, weight_dtype = weight_rank_dtype - assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] - assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] - ranks: List[Optional[int]] = [input_rank, weight_rank] - dtypes = [input_dtype, weight_dtype] - return promote_dtypes(ranks, dtypes) + return weight_dtype + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇_embedding_bag〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], offsets_rank_dtype: Tuple[int, int], scale_grad_by_freq: bool = False, mode: int = 0, sparse: bool = False, per_sample_weights_rank_dtype: Optional[Tuple[int, int]] = None, include_last_offset: bool = False, padding_idx: int = -1) -> Tuple[int, int, int, int]: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype, torch.int64, torch.int64, torch.int64 + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇embedding_bag〇padding_idx〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], offsets_rank_dtype: Tuple[int, int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights_rank_dtype: Optional[Tuple[int, int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[int, int, int, int]: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype, torch.int64, torch.int64, torch.int64 + +@check_dtype_function(_check_two_tensor_op(out_int32=True) + _check_two_tensor_op(out_int32=False)) +def aten〇bucketize〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], boundaries_rank_dtype: Tuple[int, int], out_int32: bool = False, right: bool = False) -> int: + if out_int32: + return torch.int32 + return torch.int64 # ============================================================================== # Main diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index 3820e311a24f..717923cb17a2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -14,15 +14,54 @@ from .registry import Registry +def all_integer_dtypes() -> List[int]: + return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + def is_integer_dtype(dtype: int) -> bool: - return dtype in [torch.bool, torch.uint8, torch.int8, - torch.int16, torch.int32, torch.int64] + return dtype in all_integer_dtypes() + +def all_complex_dtypes() -> List[int]: + return [torch.complex64, torch.complex128] def is_complex_dtype(dtype: int) -> bool: - return dtype in [torch.complex64, torch.complex128] + return dtype in all_complex_dtypes() + +def all_float_dtypes() -> List[int]: + return [torch.float16, torch.bfloat16, torch.float32, torch.float64] def is_float_dtype(dtype: int) -> bool: - return dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64] + return dtype in all_float_dtypes() + +def get_priority_of_dtype(dtype: int) -> int: + # If a loop is used to iterate over a list of sorted dtypes, TorchScript + # produces a loop with INT64_MAX max trip count, which causes problems + # during the loop unrolling that takes place when simplifying the dtype + # functions. Therefore, here we result to `if`s. + if dtype == torch.bool: + return 0 + if dtype == torch.uint8: + return 1 + if dtype == torch.int8: + return 2 + if dtype == torch.int16: + return 3 + if dtype == torch.int32: + return 4 + if dtype == torch.int64: + return 5 + if dtype == torch.bfloat16: + return 6 + if dtype == torch.float16: + return 7 + if dtype == torch.float32: + return 8 + if dtype == torch.float64: + return 9 + if dtype == torch.complex64: + return 10 + if dtype == torch.complex128: + return 11 + assert False, "Cannot determine priority of dtype" def get_dtype_of_scalar(scalar: Union[int, float]) -> int: # This is hacky. `NumToTensor` is the only PyTorch op for scalars diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index efd270b78a7f..c9387cc4c8f7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -86,7 +86,7 @@ def _recursively_transform_tensor_args( o: Any, tensor_transformer: Callable[[TensorOfShape], Any]) -> Any: """Replace `TensorOfShape` with the result of `tensor_transformer`""" - if o is None or isinstance(o, (float, int)): + if o is None or isinstance(o, (float, int, str)): return o if isinstance(o, TensorOfShape): return tensor_transformer(o) @@ -146,7 +146,7 @@ def to_dtype_function_args(self): def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" - tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype) + tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype).to("meta") return _recursively_transform_tensor_args(self.args, tensor_transformer) def __repr__(self) -> str: @@ -258,6 +258,15 @@ def decorator(f): return f return decorator +@torch.jit.script +def _convert_dtype_to_int(dtype: torch.dtype) -> int: + """Convert a PyTorch `dtype` into its underlying `int` representation. + + This works because in TorchScript there is no special type for `dtypes`; + they are simply `int`s. + """ + return dtype + def check_dtype_function(invocations: List[Invocation]): """Decorator that automatically tests a dtype function. @@ -281,7 +290,12 @@ def decorator(f): golden_dtype = torch.tensor([]).to(type(golden_result)).dtype else: raise ValueError(f"Unhandled return type {type(golden_result)}") - if result_dtype != golden_dtype: + # Some dtype funtions have default `dtype` parameters, which are + # represented as `int` values in the registry. In order to + # support returning the default `int` value, the comparisons of + # the result and golden dtypes are done using their underlying + # `int` representation. + if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype): _report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}") return f return decorator diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index b0ea4dd8b770..2d584489842b 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -8,9 +8,6 @@ # to the backend contract. COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", - "MobilenetV3Module_basic", "ReduceMaxAlongDimUnsignedInt_basic", } diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index 89fc81b8ba93..14fd9d2dba92 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -400,6 +400,30 @@ def RandnGeneratorModule_basic(module, tu: TestUtils): # ============================================================================== + +class RandnGeneratorF64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64) + std = torch.std(a) + return std + + +@register_test_case(module_factory=lambda: RandnGeneratorF64Module()) +def RandnGeneratorF64Module_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + class RandnLikeModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir deleted file mode 100644 index 6fc29daaba08..000000000000 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ /dev/null @@ -1,314 +0,0 @@ -// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s - -// This file is for tests for individual ops that require a new transfer -// function (i.e. new code called from visitOperation). - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$int64_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.int, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$float32_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.float, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$specified_dtype( -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange -// CHECK-SAME: %[[END]], %[[CST6]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor { - %int6 = torch.constant.int 6 - %none = torch.constant.none - %ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.linear( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>, -// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor { -// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32> -// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor { - %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.sum.dim_IntList( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 -// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]] -// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]] -// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor { - %false = torch.constant.bool false - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int-1 = torch.constant.int -1 - %dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.zeros( -// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor { - %none = torch.constant.none - %int2 = torch.constant.int 2 - %sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.type_as( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>, -// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor { - %ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor - return %ret: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.cat( -// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor { - %int1 = torch.constant.int 1 - %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list - %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.cat$promote_type( -// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor { - %int1 = torch.constant.int 1 - %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list - %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten._shape_as_tensor( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor { - %ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor { - %ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.embedding( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>, -// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[PADDING_IDX:.*]] = torch.constant.int 1 -// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor { - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor - return %ret: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %int11 = torch.constant.int 11 - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %int11, %none, %false : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.softmax.int( -// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, -// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[DTYPE:.*]] = torch.constant.none -// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[RET]] : !torch.tensor -func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { - %none = torch.constant.none - %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.softmax.int$specified_dtype( -// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, -// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[DTYPE:.*]] = torch.constant.int 4 -// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[RET]] : !torch.tensor -func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { - %int4 = torch.constant.int 4 - %ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype -// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : -// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.tensor<*,si64> -// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK-NEXT: return %[[RES]] : !torch.tensor -func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ - %none = torch.constant.none - %false = torch.constant.bool false - %int4 = torch.constant.int 4 - %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar( -// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor { - %0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor - return %0: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor( -// CHECK-SAME: %[[DATA:.*]]: !torch.list>) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]] -// CHECK-SAME: : !torch.list>, !torch.none, !torch.none, !torch.bool -// CHECK-SAME: -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor(%t: !torch.list>) -> !torch.tensor { - %none = torch.constant.none - %false = torch.constant.bool false - %ret = torch.aten.tensor %t, %none, %none, %false : !torch.list>, !torch.none, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype( -// CHECK-SAME: %[[DATA:.*]]: !torch.list>) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT4:.*]] = torch.constant.int 4 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor$specified_dtype(%t: !torch.list>) -> !torch.tensor { - %none = torch.constant.none - %int4 = torch.constant.int 4 - %false = torch.constant.bool false - %ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list>, !torch.int, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index ae02a0339d8e..df2722037496 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -8,109 +8,6 @@ // ----- -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static( -// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor - %static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32> - return %result : !torch.vtensor<[2],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { -// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32> -// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - %dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor - torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor - %dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32> - return %result : !torch.vtensor<[?],f32> -} - -// ----- -// CHECK-LABEL: func.func @bf16_result_type( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { -// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16> -// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16> -func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { - %1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16> - return %1 : !torch.vtensor<[2],bf16> -} - -// ----- -// CHECK-LABEL: func.func @propagate_scalar_type( -// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { -// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number -// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int -// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number -// CHECK: return %[[RET]] : !torch.number -func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number { - %num = torch.derefine %arg0 : !torch.int to !torch.number - %1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number - return %1 : !torch.number -} - -// ----- -// CHECK-LABEL: func.func @prim.dtype( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor { - -// CHECK: %[[zero:.*]] = torch.constant.int 0 -// CHECK: %[[false:.*]] = torch.constant.bool false - -// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16> -// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor -// CHECK: return %[[cast]] : !torch.vtensor -// CHECK: } - -func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> { - %zero = torch.constant.int 0 - %false = torch.constant.bool false - - // Op that requires type refinement - %neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk> - - // Op whose processing requires type refinement on its source argument. - %dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int - %device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device - - // Another op that requires type refinement - %result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - // Repeat the above three ops a second time to ensure that the type refinement - // code works regardless of the number of alternating refinement+prim.dtype - // sequences. - %dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int - %device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device - %result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - return %result2 : !torch.vtensor<*,unk> -} - -// ----- - // Check that we don't crash on this input. // CHECK-LABEL: func.func @forward @@ -143,27 +40,6 @@ func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) { // ----- -// CHECK-LABEL: func.func @torch.aten.zeros_like( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[CPU:.*]] = torch.constant.device "cpu" -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32> -// CHECK: return -func.func @torch.aten.zeros_like(%arg: !torch.vtensor) { - %int6 = torch.constant.int 6 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %cpu = torch.constant.device "cpu" - %2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor - return -} - -// ----- - // The data-flow analysis does not always propagate information to the entire graph. // This results in some lattice elements being uninitialized, which must be properly // handled when using the lattice elements to rewrite the graph. diff --git a/test/Dialect/Torch/reify-shape-calculations.mlir b/test/Dialect/Torch/reify-shape-calculations.mlir index 245def5bade7..894763e12aec 100644 --- a/test/Dialect/Torch/reify-shape-calculations.mlir +++ b/test/Dialect/Torch/reify-shape-calculations.mlir @@ -231,3 +231,14 @@ func.func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.v %1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list -> !torch.vtensor return %1 : !torch.vtensor } + +// ----- + +// CHECK-LABEL: func.func @adjust_shape_function_arg$number( +// CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar {{.*}} : !torch.number -> !torch.float +// CHECK: %[[VAL_9:.*]] = func.call @__torch_mlir_shape_fn.aten.arange(%[[FLOAT]], {{.*}}) : (!torch.float, {{.*}} +func.func @adjust_shape_function_arg$number(%arg0: !torch.number) -> !torch.vtensor { + %none = torch.constant.none + %1 = torch.aten.arange %arg0, %none, %none, %none, %none : !torch.number, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor + return %1 : !torch.vtensor +}