From 89484b26d4dc7614094abf603b9bdebff2ecbe0b Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 9 May 2023 17:04:06 +0000 Subject: [PATCH] Squashed commit of the following: commit bafe33906904be26ae295f8333d2d6d716393ef7 Author: Ramiro Leal-Cavazos Date: Mon May 8 21:26:56 2023 +0000 Add dtype functions for aten.atan and prims.squeeze commit bebf695383162c1874c0bee4d37dfedfb338e94e Author: Ramiro Leal-Cavazos Date: Mon May 8 21:26:10 2023 +0000 Remove duplicate code from merge with main commit 0d11895696e7ebaea78d0796d9403fa515bc826f Author: Ramiro Leal-Cavazos Date: Fri May 5 21:39:02 2023 +0000 Update LLVM tag commit 73d5c0773e879202494f4079bcff026483a508f5 Merge: 899d8bc8 eaaaeb6f Author: Ramiro Leal-Cavazos Date: Fri May 5 21:30:09 2023 +0000 Merge remote-tracking branch 'upstream/main' into merge-main commit 899d8bc8cf6cce4ed314da8115f7be04cae902fe Author: Ramiro Leal-Cavazos Date: Mon Mar 13 21:39:14 2023 +0000 Add dtype functions for `aten.ge.Tensor` and `aten.le.Tensor` commit f58f9c2e6cb1588fcbba8eb4db4bcc1feb097d2b Merge: ce7abf49 4912c393 Author: Ramiro Leal-Cavazos Date: Mon Mar 13 21:32:00 2023 +0000 Merge branch 'main' into merge-main commit ce7abf4911f7defd5ad56a644c4be7242677eea3 Author: Jiahao Li Date: Wed Feb 22 06:54:41 2023 +0800 Add dtype functions for ops that take dtype from 2nd operand (#1891) commit 63945a2fd411e768186dcc82a54c56a2a9470cbb Author: Ramiro Leal-Cavazos Date: Mon Feb 13 17:56:09 2023 -0800 Change dtype functions interface to take ints tuple for each tensor (#1865) The original design for the dtype functions outlined in https://github.com/llvm/torch-mlir/issues/1462 was unable to properly handle ops that take optional tensors as an input when the optional tensor has a value of None. By the time the op gets imported into torch-mlir, if an optional value is None, all information about the original type is lost from the op type signature, preventing torch-mlir from knowing if a value of None was from an optional tensor or not, which was crucial in the original design since each tensor argument must be turned into two separate arguments for the dtype function. This commit changes the interface to dtype functions such that each tensor turns into a tuple of two ints, the first representing the rank of the tensor and the second the dtype of the tensor. Since now there is a one-to-one correspondence between the operands of an op and the operands of its dtype function, there is no ambiguity about which operand of the op corresponds with which operand of the dtype function. To test the implementation, this commit defines dtype functions for the convolution ops, all of which take one optional tensor as an argument. commit 981ac88758df33d0238b7a781630125774b165b0 Author: Ramiro Leal-Cavazos Date: Wed Feb 1 22:30:27 2023 +0000 Add dtype functions for two tensor promotion ops (#1831) This commit adds dtype functions for ops in RefineTypes under the category of "Promote the two dtypes". The only ops not added here are convolution ops, since they take an optional tensor argument, and the dtype pipeline currently does not correctly handle that case. I will add a follow up patch fixing this. This commit also adds two helper functions that perform a very thorough testing of dtype functions. The helper function `_check_two_tensor_op` is able to independently test invalid input dtypes and invalid output dtypes. Lastly, this commit also XFAILs "MobilenetV3Module_basic". commit 83d4e89d25f73778323d6080f301ed6834aa18b2 Author: Jiahao Li Date: Sat Jan 21 02:39:41 2023 +0800 Add dtype functions for floating point ops (#1813) commit 8cae5ba50710dd2952dd9fc06de0b770bfc36dc1 Author: Ramiro Leal-Cavazos Date: Mon Jan 16 14:32:23 2023 -0800 Add dtype functions for comparison ops (#1806) This commit adds dtype functions for comparison ops that always return a tensor of dtype `i1`. commit 5b77c151285585fac380548a518ecfa9d16003fc Author: Ramiro Leal-Cavazos Date: Mon Jan 16 20:27:49 2023 +0000 Add CI to `dtype-functions-staging` branch commit ac94ba22e2c7839680fd59beda65484546f8f4f9 Author: Ramiro Leal-Cavazos Date: Thu Jan 12 22:41:04 2023 +0000 Move dtype functions into their own section in lib gen file In order to easily keep track of the dtype functions that have been moved to `abstract_interp_lib_gen.py` and make it easier to add new ones, this commit groups all the dtype functions together, rather than having them interspersed between the shape functions. --- .github/workflows/buildAndTest.yml | 2 +- .../Transforms/AbstractInterpLibrary.cpp | 2713 ++++++++++++++++- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 701 +---- .../ReifyAbstractInterpCalculationsUtils.cpp | 12 +- ...implifyAbstractInterpCalculationsUtils.cpp | 252 ++ .../SimplifyAbstractInterpCalculationsUtils.h | 7 + .../Transforms/SimplifyDtypeCalculations.cpp | 7 + .../Transforms/SimplifyShapeCalculations.cpp | 242 +- .../build_tools/abstract_interp_lib_gen.py | 2194 ++++++++++++- .../jit_ir/build_tools/library_generator.py | 49 + .../jit_ir/build_tools/testing_framework.py | 20 +- test/Dialect/Torch/refine-types-branch.mlir | 34 - test/Dialect/Torch/refine-types-ops.mlir | 364 --- test/Dialect/Torch/refine-types.mlir | 175 -- 14 files changed, 4991 insertions(+), 1781 deletions(-) delete mode 100644 test/Dialect/Torch/refine-types-ops.mlir diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 1739fde31560d..03c23e0335d20 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -2,7 +2,7 @@ name: Build and Test on: pull_request: - branches: [ main ] + branches: [ main, dtype-functions-staging ] push: branches: [ main ] workflow_dispatch: diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b68d62c8b2ae8..35b82e614dd9b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6111,6 +6111,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.testing_framework._convert_dtype_to_int(%arg0: !torch.int) -> !torch.int {\n" +" return %arg0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.triu\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6159,33 +6162,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" -" %int5 = torch.constant.int 5\n" -" %int15 = torch.constant.int 15\n" -" %true = torch.constant.bool true\n" -" %int7 = torch.constant.int 7\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %5 = torch.aten.eq.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" -" }\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %5 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" -" }\n" -" %4 = torch.prim.If %3 -> (!torch.int) {\n" -" torch.prim.If.yield %0#1 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" }\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6417,24 +6393,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" -" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" -" }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" -" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" -" return %0 : !torch.int\n" -" }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" -" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7067,14 +7025,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7448,7 +7398,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple, list, list> {\n" -" %0 = call @__torch__.torch.jit._shape_functions.native_batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.bool) -> !torch.tuple, list, list>\n" +" %none = torch.constant.none\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.If %arg5 -> (!torch.tuple, list, list>) {\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.ge.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.tuple, list, list>) {\n" +" %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" +" %6 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list\n" +" %8 = torch.prim.TupleConstruct %arg0, %5, %7 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" torch.prim.If.yield %8 : !torch.tuple, list, list>\n" +" } else {\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.prim.ListConstruct : () -> !torch.list\n" +" %6 = torch.prim.TupleConstruct %arg0, %4, %5 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" torch.prim.If.yield %6 : !torch.tuple, list, list>\n" +" }\n" +" torch.prim.If.yield %3 : !torch.tuple, list, list>\n" +" } else {\n" +" %1 = torch.aten.__is__ %arg3, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = torch.prim.unchecked_cast %arg3 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" %3 = torch.aten.__is__ %arg4, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = torch.prim.unchecked_cast %arg4 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" %5 = torch.prim.TupleConstruct %arg0, %2, %4 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" torch.prim.If.yield %5 : !torch.tuple, list, list>\n" +" }\n" " return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" @@ -7644,85 +7634,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %int11 = torch.constant.int 11\n" -" %int7 = torch.constant.int 7\n" -" %int6 = torch.constant.int 6\n" -" %int10 = torch.constant.int 10\n" -" %true = torch.constant.bool true\n" -" %int9 = torch.constant.int 9\n" -" %0 = torch.prim.Uninitialized : !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %5 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" -" }\n" -" %4 = torch.prim.If %3 -> (!torch.int) {\n" -" torch.prim.If.yield %1#1 : !torch.int\n" -" } else {\n" -" %5 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" -" } else {\n" -" %7 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.int) {\n" -" torch.prim.If.yield %int10 : !torch.int\n" -" } else {\n" -" %9 = torch.aten.eq.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" %10 = torch.prim.If %9 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %12 = torch.prim.If %11 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %14 = torch.prim.If %13 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %15 = torch.prim.If %14 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.int\n" -" }\n" -" torch.prim.If.yield %15 : !torch.int\n" -" }\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %6 : !torch.int\n" -" }\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" @@ -7767,15 +7678,2539 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union, %arg1: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" %4 = torch.aten.ne.int %arg0, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %arg0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" +" %int7 = torch.constant.int 7\n" +" %int6 = torch.constant.int 6\n" +" %int15 = torch.constant.int 15\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" +" %int10 = torch.constant.int 10\n" +" %int9 = torch.constant.int 9\n" +" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = func.call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" " %none = torch.constant.none\n" -" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" -" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" %7 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.int\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__._get_dtype_of_floating_point_op(%1#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = func.call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %3:2 = torch.prim.If %2 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" %5 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %6:2 = torch.prim.If %5 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %6#0, %6#1 : !torch.bool, !torch.int\n" +" }\n" +" %4 = torch.prim.If %3#0 -> (!torch.int) {\n" +" torch.prim.If.yield %3#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli\"(%arg0: !torch.tuple, %arg1: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_not\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_to\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ceil\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clone\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.copy\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cpu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumsum\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flatten.using_ints\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gather\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" +" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu\"(%arg0: !torch.tuple, %arg1: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardsigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardswish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %int0, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put.hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor_hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.list>>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.layer_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_select\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.narrow\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.numpy_T\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pad\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.permute\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.prelu\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.repeat\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reshape\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.resize_\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.roll\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.round\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.square\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze.dim\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.t\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.triu\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unsqueeze\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero_\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" +" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" +" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %1#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %11 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%arg0: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Cannot determine priority of dtype\"\n" +" %int15 = torch.constant.int 15\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %int6 = torch.constant.int 6\n" +" %int7 = torch.constant.int 7\n" +" %int8 = torch.constant.int 8\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1 = torch.aten.eq.int %arg0, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" %5 = torch.aten.eq.int %arg0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int2 : !torch.int\n" +" } else {\n" +" %7 = torch.aten.eq.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" } else {\n" +" %9 = torch.aten.eq.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %11 = torch.aten.eq.int %arg0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %13 = torch.aten.eq.int %arg0, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %15 = torch.aten.eq.int %arg0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" %17 = torch.aten.eq.int %arg0, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %19 = torch.aten.eq.int %arg0, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %21 = torch.aten.eq.int %arg0, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %23 = torch.aten.eq.int %arg0, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %24 = torch.prim.If %23 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" torch.prim.If.yield %22 : !torch.int\n" +" }\n" +" torch.prim.If.yield %20 : !torch.int\n" +" }\n" +" torch.prim.If.yield %18 : !torch.int\n" +" }\n" +" torch.prim.If.yield %16 : !torch.int\n" +" }\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %12 : !torch.int\n" +" }\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" torch.prim.If.yield %8 : !torch.int\n" +" }\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %7 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %7 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Result dtype for aten.floor_divide bool\"\n" +" %int11 = torch.constant.int 11\n" +" %str_0 = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.aten.ne.int %8, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int15 = torch.constant.int 15\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int15, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %7 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %7 = torch.aten.ne.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %7 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %8 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %9 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.bool, %arg8: !torch.list, %arg9: !torch.int, %arg10: !torch.list) -> !torch.tuple {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.list) -> !torch.tuple {\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.eq.int %2#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %9 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.prim.TupleConstruct %7#1, %7#1, %7#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %8 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bincount\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %5 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.ne.int %2#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %7 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int4 = torch.constant.int 4\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1, %6 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %7 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum.dim_IntList\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.argmax\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.amax\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = call @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0, %0, %false, %arg1) : (!torch.tuple, !torch.optional>, !torch.bool, !torch.optional) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple\n" +" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.int\"(%arg0: !torch.int, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.bool\"(%arg0: !torch.bool, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty.memory_format\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_ones\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty_strided\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._to_copy\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.nvprims.convert_element_type\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype_layout\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.device\"(%arg0: !torch.tuple, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.int {\n" +" return %arg2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.other\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.type_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.tuple) {\n" +" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %5 : !torch.tuple\n" +" } else {\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.tuple) {\n" +" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" } else {\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" }\n" +" torch.prim.If.yield %6 : !torch.tuple\n" +" }\n" +" return %4 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.tuple) {\n" +" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %5 : !torch.tuple\n" +" } else {\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.tuple) {\n" +" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" } else {\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" }\n" +" torch.prim.If.yield %6 : !torch.tuple\n" +" }\n" +" return %4 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ScalarImplicit\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.If %arg2 -> (!torch.int) {\n" +" %2 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.If %arg2 -> (!torch.int) {\n" +" %2 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log_softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.embedding\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._embedding_bag\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4, %int4, %int4 : !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.embedding_bag.padding_idx\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4, %int4, %int4 : !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bucketize.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.prim.If %arg2 -> (!torch.int) {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" return %0 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index f95ea6e7fadee..8df95def9ee41 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -85,16 +85,6 @@ static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) { return failed(result) ? Type() : *result; } -static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, - Type defaultDtype) { - int64_t dtypeInt; - if (matchPattern(optionalDtype, m_TorchConstantInt(&dtypeInt))) - return getTypeForDTypeInteger(context, dtypeInt); - else if (optionalDtype.getType().isa()) - return defaultDtype; - return Type(); -} - // Get the kind enum for `ValueKnowledge.kind`. static torch_upstream::TypeKind getTypeKind(Type type) { if (type.isa()) @@ -429,12 +419,6 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< void visitAtenLinearOp(AtenLinearOp op, ArrayRef operands); - void visitAtenArangeStartStepOp(AtenArangeStartStepOp op); - void visitAtenArangeStartOp(AtenArangeStartOp op); - void visitAtenArangeOp(AtenArangeOp op); - void visitAtenArangeLikeOpHelper(Operation *op, std::optional start, - Value end, std::optional step, - Value dtype); void visitReductionAlongAllDimsOp(Operation *op, Type dtype, ArrayRef operands); void visitReductionAlongDimIntListOp(Operation *op, Value dim, Value keepdim, @@ -536,37 +520,6 @@ static Type getPromotedResultScalarType(ArrayRef scalarTypes) { return *result; } -// Returns most generic type Type() if the tensor dtype is unknown. -static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) { - if (!tensor->dtype) - return Type(); - torch_upstream::ResultTypeState state = {}; - // No need to check if rank is zero for tensor because scalar uses - // wrappedResult which is a lower priority than both dimResult and zeroResult. - state = updateResultTypeState(tensor, /*rankIsNonZero=*/std::nullopt, state, - /*skipRankCheck=*/true); - state = - updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); - FailureOr result = - getTypeForScalarType(scalarType.getContext(), result_type(state)); - return failed(result) ? Type() : *result; -} - -static SmallVector> -getRankIsNonZeroArray(ValueRange values) { - SmallVector> rankIsNonZero; - for (Value v : values) { - if (auto tensorType = v.getType().dyn_cast()) { - if (tensorType.hasSizes()) { - rankIsNonZero.push_back(tensorType.getSizes().size() != 0); - } else { - rankIsNonZero.push_back(std::nullopt); - } - } - } - return rankIsNonZero; -} - // Normally, tensor dimensions need to be known at compile time to do type // promotion. `skipRankCheck`, when equal to true, can be used to indicate // special cases that tensor operands are guaranteed to be not zero dimension @@ -622,614 +575,6 @@ void TypeAnalysis::fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, void TypeAnalysis::visitOperation(Operation *op, ArrayRef operands, ArrayRef results) { - - // These ops have results that are dynamically the same as their operands. - if (isa(op)) { - incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - return; - } - - // Take dtype from first operand. - if (isa(op)) { - return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - } - - // Dtype is always float32, except for bfloat16, float16, float64 and nullptr. - if (isa(op)) { - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type dtype = operands[0]->getValue().dtype; - if (dtype) { - knowledge.dtype = Float32Type::get(op->getContext()); - if (dtype.isa()) - knowledge.dtype = dtype; - } - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Take dtype from second operand. - if (isa(op)) { - auto self = operands[1]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Dtype is always i1. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = IntegerType::get(op->getContext(), 1); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Dtype is always si64. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote the two dtypes assuming non-zero rank. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote the two dtypes assuming possibly-zero rank. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, - getRankIsNonZeroArray(op->getOperands())); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Dtype is always float32, except for bfloat16, float64 and nullptr after - // promotion and assuming possible-zero rank. - if (isa(op)) { - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type promotedDtype = getPromotedResultType( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, - getRankIsNonZeroArray(op->getOperands())); - if (promotedDtype) { - knowledge.dtype = Float32Type::get(op->getContext()); - if (promotedDtype.isa()) - knowledge.dtype = promotedDtype; - } - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote three dtypes. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(), - &operands[2]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - if (auto linear = llvm::dyn_cast(op)) { - visitAtenLinearOp(linear, operands); - return; - } - - // Promote LHS with scalar RHS. - if (isa(op)) { - auto lhs = operands[0]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( - op->getContext(), {&operands[1]->getValue(), &operands[2]->getValue()}, - getRankIsNonZeroArray(op->getOperands().slice(1, 2))); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - Value lhsScalar = op->getOperand(1); - Value rhsScalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getDefaultDtypeForTorchScalar(getPromotedResultScalarType( - {lhsScalar.getType(), rhsScalar.getType()})); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto lhs = operands[1]->getValue(); - Value scalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto rhs = operands[2]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&rhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // 2 results take dtype from first operand. - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - return; - } - - // 3 results take dtype from first operand. - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = self.dtype; - auto result2Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result2Knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - incorporateKnowledge(op->getResult(2), result2Knowledge); - return; - } - - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - return; - } - - if (auto arange = dyn_cast(op)) { - visitAtenArangeOp(arange); - return; - } - if (auto arangeStart = dyn_cast(op)) { - visitAtenArangeStartOp(arangeStart); - return; - } - if (auto arangeStartStep = dyn_cast(op)) { - visitAtenArangeStartStepOp(arangeStartStep); - return; - } - - if (auto sum = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - if (!defaultDtype) { - incorporateKnowledge( - sum.getResult(), - ValueKnowledge::getTensorPessimisticValueState(op->getContext())); - return; - } - - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sum.getContext(), sum.getDtype(), defaultDtype); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = dtype; - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - if (auto sumDimIntList = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - if (!defaultDtype) { - incorporateKnowledge( - sumDimIntList.getResult(), - ValueKnowledge::getTensorPessimisticValueState(op->getContext())); - return; - } - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sumDimIntList.getContext(), - sumDimIntList.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.getDim(), - sumDimIntList.getKeepdim(), dtype, operands); - return; - } - if (auto meanDim = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(meanDim.getContext(), meanDim.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(meanDim, meanDim.getDim(), meanDim.getKeepdim(), - dtype, operands); - return; - } - if (auto argmax = dyn_cast(op)) { - Value dim = argmax.getDim(); - Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - if (dim.getType().isa()) { - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } - if (dim.getType().isa()) { - visitReductionAlongDimIntOp(argmax, argmax.getDim(), argmax.getKeepdim(), dtype, - operands); - return; - } - } - if (auto anyDim = dyn_cast(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongDimIntOp(anyDim, anyDim.getDim(), anyDim.getKeepdim(), dtype, - operands); - return; - } - if (auto maxDim = dyn_cast(op)) { - Type firstResDtype = operands[0]->getValue().dtype; - Type secondResDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - firstResDtype, operands); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - secondResDtype, operands, /*resNum=*/1); - return; - } - if (auto mean = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(mean.getContext(), mean.getDtype(), defaultDtype); - visitReductionAlongAllDimsOp(mean, dtype, operands); - return; - } else if (isa(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } else if (isa(op)) { - auto input = operands[0]->getValue(); - visitReductionAlongAllDimsOp(op, input.dtype, operands); - return; - } - - if (auto tensorFloat = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorFloat); - return; - } else if (auto tensorInt = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorInt); - return; - } else if (auto tensorBool = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorBool); - return; - } - - if (auto tensor = dyn_cast(op)) { - visitAtenTensorOp(tensor); - return; - } - - if (auto zeros = dyn_cast(op)) { - visitConstantTensorAllocOp(zeros, /*dataType=*/{}); - return; - } else if (auto ones = dyn_cast(op)) { - visitConstantTensorAllocOp(ones, /*dataType=*/{}); - return; - } else if (auto emptyMemoryFormat = dyn_cast(op)) { - visitConstantTensorAllocOp(emptyMemoryFormat, - /*dataType=*/{}); - return; - } else if (auto full = dyn_cast(op)) { - visitConstantTensorAllocOp( - full, /*dataType=*/full.getFillValue().getType()); - return; - } else if (auto zerosLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(zerosLike, operands); - return; - } else if (auto onesLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(onesLike, operands); - return; - } else if (auto emptyLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(emptyLike, operands); - return; - } else if (auto fullLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(fullLike, operands); - return; - } else if (auto newZeros = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newZeros, operands); - return; - } else if (auto newOnes = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newOnes, operands); - return; - } else if (auto newEmpty = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmpty, operands); - return; - } else if (auto newEmptyStrided = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmptyStrided, - operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto toCopy = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(toCopy, operands); - return; - } - - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto primsConvertElementType = dyn_cast(op)) { - visitAtenToDtypeLikeOp(primsConvertElementType, - operands); - return; - } - - if (auto toDtypeLayout = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtypeLayout, operands); - return; - } - - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto toOther = dyn_cast(op)) { - visitTypeConversionOp(toOther, operands); - return; - } else if (auto typeAs = dyn_cast(op)) { - visitTypeConversionOp(typeAs, operands); - return; - } - - if (auto cat = dyn_cast(op)) { - visitAtenCatLikeOp(cat, operands); - return; - } else if (auto stack = dyn_cast(op)) { - visitAtenCatLikeOp(stack, operands); - return; - } - - if (auto shapeAsTensor = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(shapeAsTensor.getResult(), knowledge); - return; - } - - if (auto embedding = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = operands[0]->getValue().dtype; - incorporateKnowledge(embedding.getResult(), knowledge); - return; - } - - if (isa(op)) { - visitAtenEmbeddingBagOp(op); - return; - } - - if (auto softmaxIntOp = dyn_cast(op)) { - visitAtenSoftmaxLikeOp(softmaxIntOp, operands); - return; - } - if (auto _softmaxOp = dyn_cast(op)) { - visitAten_SoftmaxLikeOp(_softmaxOp, operands); - return; - } else if (auto _logSoftmaxOp = dyn_cast(op)) { - visitAten_SoftmaxLikeOp(_logSoftmaxOp, operands); - return; - } else if (auto logSoftmaxIntOp = dyn_cast(op)) { - visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands); - return; - } - - if (auto numToTensorOp = dyn_cast(op)) { - visitNumToTensorOp(numToTensorOp); - return; - } - - if (isa(op)) { - visitBinaryScalarOp(op, operands); - return; - } - - if (auto scalarImplicit = dyn_cast(op)) { - visitAtenScalarImplicitOp(scalarImplicit, operands); - return; - } - - if (auto vectorNorm = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = getDtypeOrDefault(vectorNorm.getContext(), vectorNorm.getDtype(), - defaultDtype); - visitReductionAlongDimIntListOp(vectorNorm, vectorNorm.getDim(), - vectorNorm.getKeepdim(), dtype, operands); - return; - } - - if (auto randIntLow = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randIntLow.getDtype(), defaultDtype); - incorporateKnowledge(randIntLow.getResult(), knowledge); - return; - } - - if (auto randInt = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randInt.getDtype(), defaultDtype); - incorporateKnowledge(randInt.getResult(), knowledge); - return; - } - - if (isa(op)) { - auto input = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = input.dtype; - incorporateKnowledge(op->getResult(0), knowledge); - incorporateKnowledge(op->getResult(1), knowledge); - return; - } - - if (auto randn = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = Float32Type::get(op->getContext()); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randn.getDtype(), defaultDtype); - incorporateKnowledge(randn.getResult(), knowledge); - return; - } - - // aten.sort produces two Tensor outputs. The first one is the sorted Tensor - // which will have the dtype same as that of the input Tensor, while the last - // Tensor comprises of sorted item's indices corresponding to the input - // Tensor. - if (auto sortOp = dyn_cast(op)) { - auto input = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = input.dtype; - incorporateKnowledge(op->getResult(0), knowledge); - Type i64Type = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - knowledge.dtype = i64Type; - incorporateKnowledge(op->getResult(1), knowledge); - return; - } - - if (auto randnGenerator = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = Float32Type::get(op->getContext()); - knowledge.dtype = getDtypeOrDefault(op->getContext(), - randnGenerator.getDtype(), defaultDtype); - incorporateKnowledge(randnGenerator.getResult(), knowledge); - return; - } - - if (auto bucketize = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - bool outInt32; - if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) && - outInt32) { - knowledge.dtype = - IntegerType::get(op->getContext(), 32, IntegerType::Signed); - } else { - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(bucketize.getResult(), knowledge); - return; - } - // Otherwise, this is an unknown operation, so reset the state. setAllToEntryStates(results); return; @@ -1283,49 +628,6 @@ void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) { return; } -// Arange like ops returns a 1-D tensor of size ceil(end - start). -void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op, - std::optional start, - Value end, - std::optional step, - Value dtype) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - int64_t dtypeInt; - if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) { - knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt); - } else if (dtype.getType().isa()) { - // From torch/_torch_docs.py: - // If `dtype` is not given, infer the data type from the other input - // arguments. If any of `start`, `end`, or `step` are floating-point, the - // `dtype` is inferred to be the default dtype, see - // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to - // be `torch.int64` - if ((start.has_value() && (*start).getType().isa()) || - end.getType().isa() || - (step.has_value() && (*step).getType().isa())) { - // TODO: Should get the dtype from torch.get_default_dtype(). - // For now, use float32 which is the initial default dtype. - knowledge.dtype = Float32Type::get(op->getContext()); - } else - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitAtenArangeStartStepOp(AtenArangeStartStepOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), op.getStep(), op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeStartOp(AtenArangeStartOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), {}, op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeOp(AtenArangeOp op) { - visitAtenArangeLikeOpHelper(op, {}, op.getEnd(), {}, op.getDtype()); -} - void TypeAnalysis::visitReductionAlongAllDimsOp( Operation *op, Type dtype, ArrayRef operands) { auto knowledge = @@ -1594,8 +896,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { // the right thing forthose ops. // static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) { - return op->hasTrait() || - isa(op); + return op->hasTrait(); } // Some operations have extra verification logic regarding the relationship diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 6e3d054eb3437..e888d0892710b 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -176,12 +176,14 @@ FailureOr Torch::adjustFunctionArg( return b.create(loc, desiredType, operand).getResult(); } - // !torch.union is the type used for `Scalar` inputs. At - // compile time, such inputs will usually be resolved to an `int` or a `float` - // so we need to derefine to match the library function signature. + // !torch.union or !torch.union is the type used + // for (optional) `Scalar` inputs. At compile time, such inputs will usually + // be resolved to an `int` or a `float` so we need to derefine to match the + // library function signature. if (auto unionType = desiredType.dyn_cast()) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { - return containedType.isa(); + return containedType + .isa(); })) return b.create(loc, desiredType, operand).getResult(); } @@ -189,7 +191,7 @@ FailureOr Torch::adjustFunctionArg( // If the operand is NoneType, then we just need to derefine it to the // optional type in the function signature. if (operandType.isa()) { - assert(desiredType.isa() && + assert(!desiredType.isa() && "Don't expect library functions to have NoneType parameters"); return b.create(loc, desiredType, operand).getResult(); } diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index fd58ead003672..1a2d3d545cbe0 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -8,11 +8,248 @@ //===----------------------------------------------------------------------===// #include "SimplifyAbstractInterpCalculationsUtils.h" +#include "mlir/IR/IRMapping.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace { +class FoldPrimUncheckedCastOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimUncheckedCastOp op, + PatternRewriter &rewriter) const override { + if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) { + return rewriter.notifyMatchFailure( + op, "input tensor type is not a valid subtype of result type"); + } + rewriter.replaceOp(op, op.getX()); + return success(); + } +}; +} // namespace + +namespace { +// TODO: Only unroll inside the shape calculation region. +// Maybe do this by only applying patterns and folding greedily on the ops +// inside the region + the shape.calculate op itself? +class FullyUnrollPrimLoopOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimLoopOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + if (!op.isForLike()) + return rewriter.notifyMatchFailure(op, "Loop is not for-like"); + int64_t maxTripCount; + if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount))) + return rewriter.notifyMatchFailure( + op, "Expected `maxTripCount` to be a constant int"); + ; + SmallVector indices; + for (int64_t i = 0; i < maxTripCount; i++) { + // TODO: Add convenience builder. + indices.push_back(rewriter.create( + loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); + } + Block *beforeBlock = op->getBlock(); + Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); + + SmallVector blocksToMerge; + IRMapping bvm; + // TODO: Helper for region().front() + auto condition = + cast(op.getRegion().front().getTerminator()); + for (int64_t i = 0; i < maxTripCount; i++) { + SmallVector iterArgs; + if (i == 0) { + llvm::append_range(iterArgs, op.getIterArgsInit()); + } else { + llvm::append_range( + iterArgs, llvm::map_range(condition.getIterArgs(), + [&](Value v) { return bvm.lookup(v); })); + } + bvm.clear(); + bvm.map(op.getRegion().front().getArgument(0), indices[i]); + bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs); + + op.getRegion().cloneInto(afterBlock->getParent(), + afterBlock->getIterator(), bvm); + Block *clonedBlock = bvm.lookup(&op.getRegion().front()); + rewriter.eraseOp(clonedBlock->getTerminator()); + blocksToMerge.push_back(clonedBlock); + } + + blocksToMerge.push_back(afterBlock); + for (Block *block : blocksToMerge) + rewriter.mergeBlocks(block, beforeBlock); + if (maxTripCount == 0) { + rewriter.replaceOp(op, op.getIterArgsInit()); + } else { + rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range( + condition.getIterArgs(), + [&](Value v) { return bvm.lookup(v); }))); + } + return success(); + } +}; +} // namespace + +namespace { +class AbstractlyInterpretListOpsWithinABlock + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListConstructOp op, + PatternRewriter &rewriter) const override { + Block *block = op->getBlock(); + auto allUsers = llvm::to_vector<6>(op->getUsers()); + + // Sort the users into program order. + auto getParentInBlock = [&](Operation *op) { + while (op->getBlock() != block) + op = op->getParentOp(); + return op; + }; + // Use a stable sort for deterministic results when users are nested in two + // regions of the same parent op. + llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) { + return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs)); + }); + + // We cannot interpret all ops. So first do a check to see up until which + // point we can interpret. + int numUsersToInterpret = 0; + for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) { + Operation *user = allUsers[i]; + // If a user potentially mutates the list, then we require it to be in the + // same block for our simple abstract interpretation to work (we can't, + // for example, handle an "append" operation in a loop or other region). + // However, if the op is read-only, then from the purpose of our abstract + // interpretation, we can handle it effectively as though it was at the + // same position as the corresponding parent op in the block under + // consideration. + if (potentiallyMutatesListOperands(user)) { + if (user->getBlock() != block) + break; + } + } + + // Truncate the list of users to the number of users we're going to + // interpret. + allUsers.resize(numUsersToInterpret); + auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); + + // For each mutating op (which must be in the same block), we save the + // current state of the list as a vector of Value's. These will then + // be converted to PrimListConstructOp's at the correct program points. + SmallVector> listLiterals; + SmallVector runningList; + llvm::append_range(runningList, op->getOperands()); + bool generatedNewLiteral = false; + for (Operation *user : usersToInterpret) { + if (auto append = dyn_cast(user)) { + if (!append.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `AtenAppendTOp` to not have users"); + if (append.getSelf() == op) { + runningList.push_back(append.getEl()); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + if (auto insert = dyn_cast(user)) { + if (!insert.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `AtenInsertTOp` to not have users"); + int64_t index; + if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `AtenInsertTOp` to be a constant int"); + // The index might be statically out of bounds. + if (index < 0 || index > static_cast(runningList.size())) + return rewriter.notifyMatchFailure( + op, "Index in `AtenInsertTOp` is out of bounds"); + if (insert.getSelf() == op) { + runningList.insert(runningList.begin() + index, insert.getEl()); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + if (auto setItem = dyn_cast(user)) { + if (!setItem.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `Aten_SetItemTOp` to not have users"); + std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( + setItem.getIdx(), runningList.size()); + // The index might be statically out of bounds. + if (!indexOpt) + return rewriter.notifyMatchFailure( + op, "Index in `Aten_SetItemTOp` is out of bounds"); + if (setItem.getL() == op) { + runningList[*indexOpt] = setItem.getEl(); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + // If this user potentially mutates the list and isn't handled above, then + // we can't abstractly interpret any further. + if (potentiallyMutatesListOperands(user)) + break; + } + + if (!generatedNewLiteral) + return rewriter.notifyMatchFailure(op, "No new literal created"); + + // Rewrite all users to use the appropriate list literals. + Value latestLiteral = rewriter.create( + op->getLoc(), op.getType(), op->getOperands()); + int nextLiteral = 0; + for (Operation *user : usersToInterpret) { + if (auto append = dyn_cast(user)) { + rewriter.setInsertionPoint(append); + latestLiteral = rewriter.create( + append->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (append.getSelf() == op) + rewriter.eraseOp(append); + continue; + } + if (auto insert = dyn_cast(user)) { + rewriter.setInsertionPoint(insert); + latestLiteral = rewriter.create( + insert->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (insert.getSelf() == op) + rewriter.eraseOp(insert); + continue; + } + if (auto setItem = dyn_cast(user)) { + rewriter.setInsertionPoint(setItem); + latestLiteral = rewriter.create( + setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (setItem.getL() == op) + rewriter.eraseOp(setItem); + continue; + } + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get() == op.getResult()) { + opOperand.set(latestLiteral); + } + } + } + + // Any remaining uses should use the updated value of the latest literal. + rewriter.replaceOp(op, latestLiteral); + return success(); + } +}; +} // namespace + LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, int resultNum, Type newResultType, @@ -97,3 +334,18 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, return success(); } + +void mlir::torch::Torch::populateFoldPrimUncheckedCastOpPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} + +void mlir::torch::Torch::populateFullyUnrollPrimLoopOpPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} + +void mlir::torch::Torch::populateAbstractlyInterpretListOpsWithinABlockPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h index 9c618d4a27a8f..172d27c00df8b 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h @@ -23,6 +23,13 @@ LogicalResult updateCalculateOpResultTypes(Operation *calculateOp, int resultNum, Type newResultType, PatternRewriter &rewriter); +void populateFoldPrimUncheckedCastOpPattern(RewritePatternSet &patterns, + MLIRContext *context); +void populateFullyUnrollPrimLoopOpPattern(RewritePatternSet &patterns, + MLIRContext *context); +void populateAbstractlyInterpretListOpsWithinABlockPattern( + RewritePatternSet &patterns, MLIRContext *context); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 3c4a334b5708e..43f2b22a3d66c 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -191,10 +191,17 @@ class SimplifyDtypeCalculationsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + populateFullyUnrollPrimLoopOpPattern(patterns, context); + populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context); + populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); patterns.insert(context); + PrimIfOp::getCanonicalizationPatterns(patterns, context); + Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); + PrimTupleUnpackOp::getCanonicalizationPatterns(patterns, context); + // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. GreedyRewriteConfig config; diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index f8d3651d9a5c5..1669be7c4e62c 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -10,7 +10,6 @@ #include "PassDetail.h" #include "SimplifyAbstractInterpCalculationsUtils.h" -#include "mlir/IR/IRMapping.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -19,225 +18,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -namespace { -// TODO: Only unroll inside the shape calculation region. -// Maybe do this by only applying patterns and folding greedily on the ops -// inside the region + the shape.calculate op itself? -class FullyUnrollPrimLoopOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimLoopOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - if (!op.isForLike()) - return rewriter.notifyMatchFailure(op, "Loop is not for-like"); - int64_t maxTripCount; - if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount))) - return rewriter.notifyMatchFailure( - op, "Expected `maxTripCount` to be a constant int"); - ; - SmallVector indices; - for (int64_t i = 0; i < maxTripCount; i++) { - // TODO: Add convenience builder. - indices.push_back(rewriter.create( - loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); - } - Block *beforeBlock = op->getBlock(); - Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); - - SmallVector blocksToMerge; - IRMapping bvm; - // TODO: Helper for region().front() - auto condition = - cast(op.getRegion().front().getTerminator()); - for (int64_t i = 0; i < maxTripCount; i++) { - SmallVector iterArgs; - if (i == 0) { - llvm::append_range(iterArgs, op.getIterArgsInit()); - } else { - llvm::append_range( - iterArgs, llvm::map_range(condition.getIterArgs(), - [&](Value v) { return bvm.lookup(v); })); - } - bvm.clear(); - bvm.map(op.getRegion().front().getArgument(0), indices[i]); - bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs); - - op.getRegion().cloneInto(afterBlock->getParent(), afterBlock->getIterator(), - bvm); - Block *clonedBlock = bvm.lookup(&op.getRegion().front()); - rewriter.eraseOp(clonedBlock->getTerminator()); - blocksToMerge.push_back(clonedBlock); - } - - blocksToMerge.push_back(afterBlock); - for (Block *block : blocksToMerge) - rewriter.mergeBlocks(block, beforeBlock); - if (maxTripCount == 0) { - rewriter.replaceOp(op, op.getIterArgsInit()); - } else { - rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range( - condition.getIterArgs(), - [&](Value v) { return bvm.lookup(v); }))); - } - return success(); - } -}; -} // namespace - -namespace { -class AbstractlyInterpretListOpsWithinABlock - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimListConstructOp op, - PatternRewriter &rewriter) const override { - Block *block = op->getBlock(); - auto allUsers = llvm::to_vector<6>(op->getUsers()); - - // Sort the users into program order. - auto getParentInBlock = [&](Operation *op) { - while (op->getBlock() != block) - op = op->getParentOp(); - return op; - }; - // Use a stable sort for deterministic results when users are nested in two - // regions of the same parent op. - llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) { - return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs)); - }); - - // We cannot interpret all ops. So first do a check to see up until which - // point we can interpret. - int numUsersToInterpret = 0; - for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) { - Operation *user = allUsers[i]; - // If a user potentially mutates the list, then we require it to be in the - // same block for our simple abstract interpretation to work (we can't, - // for example, handle an "append" operation in a loop or other region). - // However, if the op is read-only, then from the purpose of our abstract - // interpretation, we can handle it effectively as though it was at the - // same position as the corresponding parent op in the block under - // consideration. - if (potentiallyMutatesListOperands(user)) { - if (user->getBlock() != block) - break; - } - } - - // Truncate the list of users to the number of users we're going to - // interpret. - allUsers.resize(numUsersToInterpret); - auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); - - // For each mutating op (which must be in the same block), we save the - // current state of the list as a vector of Value's. These will then - // be converted to PrimListConstructOp's at the correct program points. - SmallVector> listLiterals; - SmallVector runningList; - llvm::append_range(runningList, op->getOperands()); - bool generatedNewLiteral = false; - for (Operation *user : usersToInterpret) { - if (auto append = dyn_cast(user)) { - if (!append.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `AtenAppendTOp` to not have users"); - if (append.getSelf() == op) { - runningList.push_back(append.getEl()); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - if (auto insert = dyn_cast(user)) { - if (!insert.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `AtenInsertTOp` to not have users"); - int64_t index; - if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index))) - return rewriter.notifyMatchFailure( - op, "Expected `idx` of `AtenInsertTOp` to be a constant int"); - // The index might be statically out of bounds. - if (index < 0 || index > static_cast(runningList.size())) - return rewriter.notifyMatchFailure( - op, "Index in `AtenInsertTOp` is out of bounds"); - if (insert.getSelf() == op) { - runningList.insert(runningList.begin() + index, insert.getEl()); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - if (auto setItem = dyn_cast(user)) { - if (!setItem.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `Aten_SetItemTOp` to not have users"); - std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( - setItem.getIdx(), runningList.size()); - // The index might be statically out of bounds. - if (!indexOpt) - return rewriter.notifyMatchFailure( - op, "Index in `Aten_SetItemTOp` is out of bounds"); - if (setItem.getL() == op) { - runningList[*indexOpt] = setItem.getEl(); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - // If this user potentially mutates the list and isn't handled above, then - // we can't abstractly interpret any further. - if (potentiallyMutatesListOperands(user)) - break; - } - - if (!generatedNewLiteral) - return rewriter.notifyMatchFailure(op, "No new literal created"); - - // Rewrite all users to use the appropriate list literals. - Value latestLiteral = rewriter.create( - op->getLoc(), op.getType(), op->getOperands()); - int nextLiteral = 0; - for (Operation *user : usersToInterpret) { - if (auto append = dyn_cast(user)) { - rewriter.setInsertionPoint(append); - latestLiteral = rewriter.create( - append->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (append.getSelf() == op) - rewriter.eraseOp(append); - continue; - } - if (auto insert = dyn_cast(user)) { - rewriter.setInsertionPoint(insert); - latestLiteral = rewriter.create( - insert->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (insert.getSelf() == op) - rewriter.eraseOp(insert); - continue; - } - if (auto setItem = dyn_cast(user)) { - rewriter.setInsertionPoint(setItem); - latestLiteral = rewriter.create( - setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (setItem.getL() == op) - rewriter.eraseOp(setItem); - continue; - } - for (OpOperand &opOperand : user->getOpOperands()) { - if (opOperand.get() == op.getResult()) { - opOperand.set(latestLiteral); - } - } - } - - // Any remaining uses should use the updated value of the latest literal. - rewriter.replaceOp(op, latestLiteral); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -266,22 +46,6 @@ class DecomposeAtenSizeOp : public OpRewritePattern { }; } // namespace -namespace { -class FoldPrimUncheckedCastOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimUncheckedCastOp op, - PatternRewriter &rewriter) const override { - if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) { - return rewriter.notifyMatchFailure( - op, "input tensor type is not a valid subtype of result type"); - } - rewriter.replaceOp(op, op.getX()); - return success(); - } -}; -} // namespace - static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, PatternRewriter &rewriter) { @@ -367,11 +131,11 @@ class SimplifyShapeCalculationsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.insert(context); - patterns.insert(context); + populateFullyUnrollPrimLoopOpPattern(patterns, context); + populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context); + populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); - patterns.insert(context); PrimIfOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 90c6f30d66f9f..7623503ece7b5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -12,7 +12,11 @@ import torch.jit._shape_functions as upstream_shape_functions from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function -from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar +from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype, get_priority_of_dtype, all_integer_dtypes, all_float_dtypes, all_complex_dtypes + +# ============================================================================== +# Shape Functions +# ============================================================================== # TODO: upstream this def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int): @@ -79,27 +83,6 @@ def aten〇exp〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32)), - Invocation(NonZeroDTensorWithDtype(torch.float64)), - Invocation(NonZeroDTensorWithDtype(torch.bfloat16)), - Invocation(NonZeroDTensorWithDtype(torch.int64)), - Invocation(NonZeroDTensorWithDtype(torch.int32)), - Invocation(NonZeroDTensorWithDtype(torch.bool)), - Invocation(ZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float64)), - Invocation(ZeroDTensorWithDtype(torch.bfloat16)), - Invocation(ZeroDTensorWithDtype(torch.int64)), - Invocation(ZeroDTensorWithDtype(torch.int32)), - Invocation(ZeroDTensorWithDtype(torch.bool)), -]) -def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: - self_rank, self_dtype = self_rank_dtype - if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: - return self_dtype - else: - return torch.float32 - def aten〇sin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -277,18 +260,6 @@ def aten〇pow〇Tensor_Tensor〡shape(self: List[int], exponent: List[int]) -> def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32), other=0), - Invocation(NonZeroDTensorWithDtype(torch.int64), other=0.0), - Invocation(NonZeroDTensorWithDtype(torch.float16), other=0.0), - Invocation(ZeroDTensorWithDtype(torch.float32), other=0), - Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0), - Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0) -]) -def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: - self_rank, self_dtype = self_rank_dtype - return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) - def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]: return upstream_shape_functions.unary(self) @@ -684,19 +655,6 @@ def aten〇div〇Tensor_mode〡shape(self: List[int], other: List[int], rounding def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float64), NonZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)), - Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)), -]) -def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - self_rank, self_dtype = self_rank_dtype - other_rank, other_dtype = other_rank_dtype - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - return promote_dtypes(ranks, dtypes) - def aten〇atan2〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -841,40 +799,6 @@ def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optio def aten〇_convolution〇deprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) -_convolution_deprecated_kwargs = { - "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], - "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} -@check_dtype_function( - [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Same type - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different type - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different width - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), # Different type and width - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.complex64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.complex128), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) -]) -def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: - input_rank, input_dtype = input_rank_dtype - weight_rank, weight_dtype = weight_rank_dtype - assert input_dtype == weight_dtype - assert input_dtype not in [torch.bool, torch.float16, torch.complex64, torch.complex128] - ranks: List[Optional[int]] = [input_rank, weight_rank] - dtypes = [input_dtype, weight_dtype] - return promote_dtypes(ranks, dtypes) - def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]: return self @@ -959,7 +883,7 @@ def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[in Invocation(TensorOfShape(2, 3), None, None, TensorOfShape(3), TensorOfShape(3), False, 1e-4, 1e-6), # Inference basic case. Invocation(TensorOfShape(2, 3, 4, 5, 6), None, None, None, None, True, 1e-4, 1e-6), # Training high-D case. Invocation(TensorOfShape(2, 3, 4, 5, 6), None, None, TensorOfShape(3), TensorOfShape(3), False, 1e-4, 1e-6), # Inference high-D case. - ErrorInvocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # Dimensionality too low. + Invocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # 1D input. ]) def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.native_batch_norm(input, weight, bias, running_mean, running_var, training) @@ -1058,35 +982,6 @@ def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.complex64)), - Invocation(NonZeroDTensorWithDtype(torch.complex128)), - Invocation(NonZeroDTensorWithDtype(torch.float)), - Invocation(NonZeroDTensorWithDtype(torch.double)), - Invocation(NonZeroDTensorWithDtype(torch.bool)), - Invocation(NonZeroDTensorWithDtype(torch.uint8)), - Invocation(NonZeroDTensorWithDtype(torch.int8)), - Invocation(NonZeroDTensorWithDtype(torch.int16)), - Invocation(NonZeroDTensorWithDtype(torch.int32)), - Invocation(NonZeroDTensorWithDtype(torch.int64)), - ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)), - ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)), -]) -def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: - self_rank, self_dtype = self_rank_dtype - if self_dtype == torch.complex64 or self_dtype == torch.complex128: - return self_dtype - elif self_dtype == torch.float: - return torch.complex64 - elif self_dtype == torch.double: - return torch.complex128 - elif self_dtype == torch.bool or self_dtype == torch.uint8 or \ - self_dtype == torch.int8 or self_dtype == torch.int16 or \ - self_dtype == torch.int32 or self_dtype == torch.int64: - return torch.complex64 - else: - assert False, "Unsupported dtype" - class DummyClassType: def __init__(self): pass @@ -1141,16 +1036,2073 @@ def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return [self[0], self[1], output_size[0], output_size[1]] -@check_dtype_function([ - Invocation(0.0, 0.0), # float, float - Invocation(0.0, 0), # float, int - Invocation(0, 0.0), # int, float - Invocation(0, 0), # int, int -]) -def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: - ranks: List[Optional[int]] = [None, None] - dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] - return promote_dtypes(ranks, dtypes) +# ============================================================================== +# Dtype Functions +# ============================================================================== + +# All the torch types sorted in decreasing order of priority during type promotion. +_SORTED_TORCH_TYPES = [ + torch.complex128, torch.complex64, + torch.float64, torch.float32, torch.float16, torch.bfloat16, + torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool +] + +def _check_tensors_with_the_same_dtype( + num_of_tensors: Optional[int] = None, + tensor_shapes: Optional[list[tuple[int]]] = None, + error_types: Optional[set[int]] = None, *args, **kwargs): + """Create invocations where all tensors have the same dtype. + + This function generates invocations with `num_of_tensors` tensors + that all have the same dtype. It creates an invocation for every + possible dtype. For dtypes in `error_types`, the invocations are + error invocations. + + One can also specify the shapes of the tensors. Either `num_of_tensors` + or `tensor_shapes` must be specified whenever this function is called. + + The extra *args and **kwargs arguments are passed to the invocations. + """ + invocations = [] + for type_ in _SORTED_TORCH_TYPES: + tensors = [] + if tensor_shapes is None and num_of_tensors is not None: + tensors = [NonZeroDTensorWithDtype(type_)] * num_of_tensors + elif tensor_shapes is not None and num_of_tensors is None: + for tensor_shape in tensor_shapes: + tensors.append(TensorOfShape(*tensor_shape, dtype=type_)) + else: + assert False, \ + "Either `num_of_tensors` or `tensor_shapes` must be specified" + + if error_types is not None and type_ in error_types: + invocations.append(ErrorInvocation(*tensors, *args, **kwargs)) + else: + invocations.append(Invocation(*tensors, *args, **kwargs)) + return invocations + +def _check_two_tensor_op( + tensor_shapes: Optional[list[tuple[int]]] = None, + input_error_types: Optional[set[int]] = None, + output_error_types: Optional[set[int]] = None, **kwargs): + """Generate invocations for basic two-tensor dtype functions. + + This helper function is meant to be used to check dtype functions that + take two tensor operands and either return the promoted result or + return a constant dtype based on the tensor dtypes. + + The testing performed is thorough enough to be able to detect if dtypes + are invalid as inputs or as outputs to the PyTorch op. Invalid dtypes + must be specified in `input_error_types` and `output_error_types` to + ensure the invocations are error invocations. + """ + if tensor_shapes is None: + tensor_shapes = [(1,), (1,)] + shape_1, shape_2 = tensor_shapes + + if input_error_types is not None and output_error_types is not None: + assert len(input_error_types.intersection(output_error_types)) == 0, \ + "An invalid input type implies an invalid output type, " \ + "so there is no need to repeat the type in the `output_error_types` set" + all_error_types = set() + all_error_types |= set() if input_error_types is None else input_error_types + all_error_types |= set() if output_error_types is None else output_error_types + + def check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs): + """Create invocations where one tensor varies its dtype. + + This helper function creates invocations with two tensors where one + tensor varies its dtype while the other one stays constant. The varying + is done for both tensors and the varying is performed over every possible + dtype. + + This function helps identify when a dtype is an invalid input dtype + for dtype functions that do promotion. + """ + # We will only create invocations for dtypes with priorities less than + # or equal to the highest priority valid type. By setting the non-varying + # tensor dtype to be the highest priority valid type, we ensure that + # every promotion results in a valid dtype. This allows the invocations + # to test in isolation assertions on input types. + constant_type = None + constant_type_index = None + for e, type_ in enumerate(_SORTED_TORCH_TYPES): + if type_ not in all_error_types: + constant_type = type_ + constant_type_index = e + break + assert constant_type is not None, \ + "Unable to find a constant type. Make sure the union of " \ + "`input_error_types` and `output_error_types` is not all possible types." + + invocations = [] + for type_ in _SORTED_TORCH_TYPES[constant_type_index:]: + if input_error_types is not None and type_ in input_error_types: + invocation_type = ErrorInvocation + else: + invocation_type = Invocation + invocations += [invocation_type(TensorOfShape(*shape_1, dtype=type_), TensorOfShape(*shape_2, dtype=constant_type), **kwargs), + invocation_type(TensorOfShape(*shape_1, dtype=constant_type), TensorOfShape(*shape_2, dtype=type_), **kwargs)] + return invocations + + same_dtype_invocations = _check_tensors_with_the_same_dtype( + tensor_shapes=tensor_shapes, error_types=all_error_types, **kwargs) + + varying_dtype_invocations = \ + check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs) + return same_dtype_invocations + varying_dtype_invocations + +def _get_dtype_of_floating_point_op(input_dtype: int) -> int: + if (is_float_dtype(input_dtype) and input_dtype != torch.float32) \ + or is_complex_dtype(input_dtype): + return input_dtype + return torch.float32 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇reciprocal〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return self_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0])) +def aten〇frobenius_norm〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int], keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex128: + return torch.float64 + elif self_dtype == torch.complex64: + return torch.float32 + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return self_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex128: + return torch.float64 + elif self_dtype == torch.complex64: + return torch.float32 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) +def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) +def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bernoulli〡dtype(self_rank_dtype: Tuple[int, int], generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇bernoulli〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], p_rank_dtype: Tuple[int, int], generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bitwise_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) +def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0)) +def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1)) +def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇copy〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], non_blocking: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +# TODO: Cannot call .cpu() on meta tensor due to: Cannot copy out of meta tensor; no data! +def aten〇cpu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) +def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇expand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇expand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], implicit: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0)) +def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1,), ()])) +def aten〇fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_dim: int = 0, end_dim: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇floor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇gather〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], sparse_grad: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇gelu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇gelu〡dtype(self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardsigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5)) +def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + if is_integer_dtype(grad_output_dtype): + return torch.float32 + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool})) +def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype not in [torch.uint8, torch.bool] + return self_dtype + +_index_put_invocations = [ + # same dtype + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES +] + [ + # different dtypes + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] + [ + # index dtype + Invocation(TensorOfShape(3, dtype=torch.float32), [TensorOfShape(3, dtype=dtype)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] +@check_dtype_function(_index_put_invocations) +def aten〇index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor_hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1])) +def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> int: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype + +@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False)) +def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_two_tensor_op(dim=0, input_dtype=torch.float32) + + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) +def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, TensorOfShape(1, dtype=torch.bool), TensorOfShape(dtype=torch.float32))) +def aten〇masked_fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +# TODO: This op cannot be run on the Meta backend. We should have the test suite default on CPU when the Meta backend does not work. +def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, start=0, length=1)) +def aten〇narrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int, length: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇neg〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇numpy_T〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇pad〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇permute〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇pow〇Tensor_Tensor〡dtype(self_rank_dtype: Tuple[int, int], exponent_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + exponent_rank, exponent_dtype = exponent_rank_dtype + ranks: List[Optional[int]] = [self_rank, exponent_rank] + dtypes = [self_dtype, exponent_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if promoted_dtype == torch.bool: + return torch.int64 + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2) + + [ErrorInvocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64)), + ErrorInvocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32))]) +def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert self_dtype == weight_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇relu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=[1])) +def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1])) +def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shape=[1])) +def aten〇reshape〡dtype(self_rank_dtype: Tuple[int, int], shape: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇resize_〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shifts=[0], dims=[0])) +def aten〇roll〡dtype(self_rank_dtype: Tuple[int, int], shifts: List[int], dims: List[int] = ()) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_reduce〇two〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, index=0)) +def aten〇select〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(tensor_shapes=[(1, 1), (1,)], dim=0, index=0)) +def aten〇select_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(dim=0)) +def aten〇slice_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) + + [Invocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64), dim=0, input_dtype=torch.float32), + Invocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32), dim=0, input_dtype=torch.float32)]) +def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇square〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇squeeze〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇squeeze〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇tanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + output_rank, output_dtype = output_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, output_rank] + dtypes = [grad_output_dtype, output_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0)) +def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇t〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, device=torch.device("meta"))) +def aten〇to〇prim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1)) +def aten〇transpose〇int〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)])) +def aten〇triu〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇_unsafe_view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇unsqueeze〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 4, 8)], output_size=[4, 8], input_size=[1, 1, 2, 3])) +def aten〇upsample_nearest2d_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13])) +def aten〇upsample_nearest2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function([Invocation(-1), Invocation(-1.0)]) +def prim〇abs〇Scalar〡dtype(a: Union[int, float]) -> int: + return get_dtype_of_scalar(a) + +@check_dtype_function(_check_tensors_with_the_same_dtype( + None, [(3,), (3, 4)], None, + TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)) + + [Invocation(TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, 4, dtype=torch.float64), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)), + Invocation(TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32))]) +def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int, total_weight_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, grad_output_rank] + dtypes = [self_dtype, grad_output_dtype] + result = promote_dtypes(ranks, dtypes) + if result == torch.bool: + return torch.int64 + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype( + None, [(2, 4, 7, 6), (2, 4, 6, 5)], None, + [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) + + [ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float32), TensorOfShape(2, 4, 6, 5, dtype=torch.float64), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)), + ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float64), TensorOfShape(2, 4, 6, 5, dtype=torch.float32), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64))]) +def aten〇max_pool2d_with_indices_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert grad_output_dtype == self_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.uint8 if self_dtype == torch.uint8 else torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.uint8 if self_dtype == torch.uint8 else torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇gt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇ge〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇lt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇le〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function([ + Invocation(0.0, 0.0), # float, float + Invocation(0.0, 0), # float, int + Invocation(0, 0.0), # int, float + Invocation(0, 0), # int, int +]) +def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: + ranks: List[Optional[int]] = [None, None] + dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16})) +def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype): + return self_dtype + elif self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + + # Different width + [Invocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float16), + TensorOfShape(2, 4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇bmm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype + mat2_priority = get_priority_of_dtype(mat2_dtype) + self_priority = get_priority_of_dtype(self_dtype) + return mat2_dtype if mat2_priority < self_priority else self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(rounding_mode=None)) +def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rounding_mode: Optional[str]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool})) +def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(other_dtype), "`other` cannot be complex" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool" + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + + # Different width + [Invocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float16), + TensorOfShape(2, 4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇matmul〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + other_priority = get_priority_of_dtype(other_dtype) + self_priority = get_priority_of_dtype(self_dtype) + return other_dtype if other_priority < self_priority else self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇maximum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4, 3)]) + + # Different width + [Invocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(3, 4, dtype=torch.float16), + TensorOfShape(4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32))]) +def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype + + float16_types = [torch.bfloat16, torch.float16] + if self_dtype in float16_types and mat2_dtype in float16_types and self_dtype != mat2_dtype: + return torch.float16 + + ranks: List[Optional[int]] = [self_rank, mat2_rank] + dtypes = [self_dtype, mat2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op( + output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) +def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert not is_integer_dtype(promoted_dtype) + return promoted_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4,)]) + + # Different width + [Invocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(3, 4, dtype=torch.float16), + TensorOfShape(4, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, dtype=torch.int32))]) +def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + vec_rank, vec_dtype = vec_rank_dtype + ranks: List[Optional[int]] = [self_rank, vec_rank] + dtypes = [self_dtype, vec_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +# TODO: This op has incosistent behavior when run on CPU vs META +# If the following check is made to pass, e2e tests fail +# @check_dtype_function(_check_two_tensor_op(threshold=0)) +def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + ranks: List[Optional[int]] = [self_rank, grad_output_rank] + dtypes = [self_dtype, grad_output_dtype] + return promote_dtypes(ranks, dtypes) + +# TODO: This op fails when using meta backend with error: +# Op raised error 'convolution_overrideable not implemented. +# You are likely triggering this with tensor backend other than +# CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL +# to override this function ' but dtype function did not raise any error. +# +# This is similar to https://github.com/pytorch/pytorch/issues/97481 +def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +# TODO: This op fails when using meta backend with error: +# Op raised error 'convolution_overrideable not implemented. +# You are likely triggering this with tensor backend other than +# CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL +# to override this function ' but dtype function did not raise any error. +# +# This is similar to https://github.com/pytorch/pytorch/issues/97481 +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], **convolution_kwargs) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs) +]) +def aten〇convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +convolution_backward_kwargs = { + "bias_sizes" : [1], "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1, "output_mask" : [True, True, True]} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1)], + **convolution_backward_kwargs) + + # dtype of first three tensors must be float + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.int32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float64), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float64), **convolution_backward_kwargs), +]) +def aten〇convolution_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[int, int, int]: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + return grad_output_dtype, grad_output_dtype, grad_output_dtype + +# TODO: Currently unable to test because the op `torch.ops.aten.convolution_backward_overrideable` +# fails to run on the CPU backend. A bug has been filed upstream: https://github.com/pytorch/pytorch/issues/97481 +# The dtype function for this op is unlikely to be different from `torch.ops.aten.convolution_backward` +# which is tested. +def aten〇convolution_backward_overrideable〡dtype(grad_output_rank_dtype: Tuple[int, int], input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[int, int, int]: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert grad_output_dtype == input_dtype + assert weight_dtype == input_dtype + assert is_float_dtype(input_dtype) and input_dtype != torch.float16 + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + return grad_output_dtype, grad_output_dtype, grad_output_dtype + +# TODO: This op cannot be run on the Meta backend. We should have the test suite default on CPU when the Meta backend does not work. +def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype: Optional[Tuple[int, int]] = None, minlength: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_integer_dtype(self_dtype) and self_dtype != torch.bool + if weights_rank_dtype is None: + return torch.int64 + return torch.float64 + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + mat1_rank, mat1_dtype = mat1_rank_dtype + mat2_rank, mat2_dtype = mat2_rank_dtype + + ranks: List[Optional[int]] = [self_rank, mat1_rank, mat2_rank] + dtypes = [self_dtype, mat1_dtype, mat2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + Invocation(TensorOfShape(4, 3, dtype=torch.int32), + TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + end_rank, end_dtype = end_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + + ranks: List[Optional[int]] = [self_rank, end_rank, weight_rank] + dtypes = [self_dtype, end_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float64), + TensorOfShape(3, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32))]) +def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype + + assert self_dtype != torch.bool + assert tensor1_dtype != torch.bool + assert tensor2_dtype != torch.bool + + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float64), + TensorOfShape(3, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32))]) +def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype + + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] + result = promote_dtypes(ranks, dtypes) + if is_integer_dtype(result): + return torch.float32 + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_integer_dtype(promoted_dtype): + return torch.float32 + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) +def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype) + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) +def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(exponent)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, negative_slope=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, negative_slope=1.0)) +def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + ranks: List[Optional[int]] = [self_rank, None] + negative_slope_dtype = get_dtype_of_scalar(negative_slope) + if is_float_dtype(negative_slope_dtype): + assert not is_integer_dtype(self_dtype) + dtypes = [self_dtype, negative_slope_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)]) + + [Invocation(TensorOfShape( + 1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.int16), TensorOfShape(1, 1, 1, dtype=torch.int32)), + Invocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.int64), TensorOfShape(1, 1, 1, dtype=torch.float16)), + Invocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, dtype=torch.int64)), + Invocation( + TensorOfShape(1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, dtype=torch.float16))]) +def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + batch2_rank, batch2_dtype = batch2_rank_dtype + return batch2_dtype + +@check_dtype_function([ + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), NonZeroDTensorWithDtype(torch.int32)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.bfloat16), NonZeroDTensorWithDtype(torch.float16))]) +def aten〇where〇self〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)]) +def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int: + if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)): + return torch.int64 + return torch.float32 + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)]) +def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.int16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))]) +def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [None, other_rank] + dtypes = [get_dtype_of_scalar(self), other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + [Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64), + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int32), # target must be int64 + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.float64), # target must be int64 + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.float64), TensorOfShape(2, dtype=torch.int64), # self and weight must have same dtype + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.int32), TensorOfShape(2, dtype=torch.int64), # self and weight must be float + TensorOfShape(3, dtype=torch.int32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex64), TensorOfShape(2, dtype=torch.int64), # self and weight must be float + TensorOfShape(3, dtype=torch.complex64), reduction=0, ignore_index=0)]) +def aten〇nll_loss_forward〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + assert target_dtype == torch.int64 + return self_dtype, self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float64), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), eps=0.0), + # Input must be float or complex + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.int32), [3], TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex64), [3], TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex128), [3], TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), eps=0.0), + ]) +def aten〇native_layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], eps: float) -> Tuple[int, int, int]: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + result_dtype = input_dtype + if input_dtype == torch.complex64: + result_dtype = torch.float32 + if input_dtype == torch.complex128: + result_dtype = torch.float64 + return input_dtype, input_dtype, result_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + # Tensors with different dtype + Invocation(TensorOfShape(3, 3, dtype=torch.float64), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), training=False, momentum=0.0, eps=0.0), + # Non-float tensors + Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.complex64), TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), training=False, momentum=0.0, eps=0.0), + ]) +def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float) -> Tuple[int, int, int]: + input_rank, input_dtype = input_rank_dtype + result_dtype = input_dtype + if is_integer_dtype(input_dtype): + result_dtype = torch.float32 + return input_dtype, input_dtype, result_dtype + +@check_dtype_function([Invocation(end=0, dtype=None), # No floats + Invocation(end=0.0, dtype=None), # One float + ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified + Invocation(end=0, dtype=torch.float16), # Dtype specified + Invocation(end=0, dtype=torch.int16)]) # Dtype specified +def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, dtype=None), # No floats + Invocation(start=0.0, end=10, dtype=None), # One float + Invocation(start=0, end=10.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, step=1, dtype=None), # No floats + Invocation(start=0.0, end=10, step=1, dtype=None), # One float + Invocation(start=0, end=10.0, step=1, dtype=None), # One float + Invocation(start=0, end=10, step=1.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)) or \ + is_float_dtype(get_dtype_of_scalar(step)): + return torch.float32 + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇sum〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64)) +def aten〇sum〇dim_IntList〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + return aten〇sum〡dtype(self_rank_dtype, dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dim=None, dtype=torch.int32)]) +def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + result = aten〇sum〡dtype(self_rank_dtype, dtype) + assert not is_integer_dtype(result) + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.uint8: + return self_dtype + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int: + return aten〇max〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: + return aten〇max〡dtype(self_rank_dtype), torch.int64 + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇mean〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + return aten〇mean〇dim〡dtype(self_rank_dtype, dim=None, dtype=dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex64: + return torch.float32 + if self_dtype == torch.complex128: + return torch.float64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇std〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇var〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) +def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> int: + return aten〇std〡dtype(inp_rank_dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if dtype is not None: + assert not is_integer_dtype(dtype) + if is_complex_dtype(self_dtype): + assert is_complex_dtype(dtype) + return aten〇std〡dtype((self_rank, dtype)) + assert not is_complex_dtype(dtype) + return dtype + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function([Invocation(0.0), + Invocation(0.0, dtype=torch.int32), + Invocation(0.0, dtype=torch.float16), + Invocation(0.0, dtype=torch.complex64)]) +def aten〇tensor〇float〡dtype(t: float, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.float32 + return dtype + +@check_dtype_function([Invocation(0), + Invocation(0, dtype=torch.int32), + Invocation(0, dtype=torch.float16), + Invocation(0, dtype=torch.complex64)]) +def aten〇tensor〇int〡dtype(t: int, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.int64 + return dtype + +@check_dtype_function([Invocation(True), + Invocation(True, dtype=torch.int32), + Invocation(True, dtype=torch.float16), + Invocation(True, dtype=torch.complex64)]) +def aten〇tensor〇bool〡dtype(t: bool, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.bool + return dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇zeros〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇ones〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇empty〇memory_format〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1], 0.0), + Invocation([1], 0), + Invocation([1], 0.0, dtype=torch.int32), + Invocation([1], 0.0, dtype=torch.float16), + Invocation([1], 0.0, dtype=torch.complex64)]) +def aten〇full〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + return dtype + fill_value_dtype = get_dtype_of_scalar(fill_value) + if is_float_dtype(fill_value_dtype): + return torch.float32 + return fill_value_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇zeros_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇ones_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64)) +def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_zeros〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_ones〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_empty〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.complex64)) +def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇rand_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇randn_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇_to_copy〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype〡dtype(self_rank_dtype: Tuple[int, int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def nvprims〇convert_element_type〡dtype(a_rank_dtype: Tuple[int, int], dtype: int) -> int: + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype_layout〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.complex64)) +def aten〇to〇device〡dtype(self_rank_dtype: Tuple[int, int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇to〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + +@check_dtype_function([Invocation(low=0, high=10, size=[1]), + Invocation(low=0, high=10, size=[1], dtype=torch.float32), + Invocation(low=0, high=10, size=[1], dtype=torch.int32), + ErrorInvocation(low=0, high=10, size=[1], dtype=torch.complex64)]) +def aten〇randint〇low〡dtype(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.int64 + assert not is_complex_dtype(dtype) + return dtype + +@check_dtype_function([Invocation(size=[1]), + Invocation(size=[1], dtype=torch.float32), + ErrorInvocation(size=[1], dtype=torch.int32), + Invocation(size=[1], dtype=torch.complex64)]) +def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + assert not is_integer_dtype(dtype) + return dtype + +@check_dtype_function([Invocation(size=[1], generator=None), + Invocation(size=[1], generator=None, dtype=torch.float32), + ErrorInvocation(size=[1], generator=None, dtype=torch.int32), + Invocation(size=[1], generator=None, dtype=torch.complex64)]) +def aten〇randn〇generator〡dtype(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + assert not is_integer_dtype(dtype) + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) +def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex64: + return torch.float32, self_dtype + if self_dtype == torch.complex128: + return torch.float64, self_dtype + return self_dtype, self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) +def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex64: + return torch.float32, self_dtype + if self_dtype == torch.complex128: + return torch.float64, self_dtype + return self_dtype, self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_integer_dtype(promoted_dtype): + return torch.float32 + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.float32 + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + return input_dtype + +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.int64 + +# Does not work on meta backend +#@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[()])) +def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int: + a_rank, a_dtype = a_rank_dtype + return a_dtype + +@check_dtype_function([Invocation(0), Invocation(0.0)]) +def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float]) -> int: + return get_dtype_of_scalar(a) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) +def aten〇softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if dtype is None: + return self_dtype + return dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), + dim=0, half_to_float=True)) +def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int: + self_rank, self_dtype = self_rank_dtype + if half_to_float: + assert self_dtype == torch.float16 + return torch.float32 + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), + dim=0, half_to_float=True)) +def aten〇_log_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int: + self_rank, self_dtype = self_rank_dtype + if half_to_float: + assert self_dtype == torch.float16 + return torch.float32 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) +def aten〇log_softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if dtype is None: + return self_dtype + return dtype + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇embedding〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> int: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇_embedding_bag〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], offsets_rank_dtype: Tuple[int, int], scale_grad_by_freq: bool = False, mode: int = 0, sparse: bool = False, per_sample_weights_rank_dtype: Optional[Tuple[int, int]] = None, include_last_offset: bool = False, padding_idx: int = -1) -> Tuple[int, int, int, int]: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype, torch.int64, torch.int64, torch.int64 + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇embedding_bag〇padding_idx〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], offsets_rank_dtype: Tuple[int, int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights_rank_dtype: Optional[Tuple[int, int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[int, int, int, int]: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype, torch.int64, torch.int64, torch.int64 + +@check_dtype_function(_check_two_tensor_op(out_int32=True) + _check_two_tensor_op(out_int32=False)) +def aten〇bucketize〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], boundaries_rank_dtype: Tuple[int, int], out_int32: bool = False, right: bool = False) -> int: + if out_int32: + return torch.int32 + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dimensions=[])) +def prims〇squeeze〡dtype(a_rank_dtype: Tuple[int, int], dimensions: List[int]) -> int: + a_rank, a_dtype = a_rank_dtype + return a_dtype + +# ============================================================================== +# Main +# ============================================================================== def _maybe_import_op_extensions(args: argparse.Namespace): extension_string = str.strip(args.pytorch_op_extensions) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index f87a7019d6b9b..f25508897fa55 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -14,6 +14,55 @@ from .registry import Registry +def all_integer_dtypes() -> List[int]: + return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + +def is_integer_dtype(dtype: int) -> bool: + return dtype in all_integer_dtypes() + +def all_complex_dtypes() -> List[int]: + return [torch.complex64, torch.complex128] + +def is_complex_dtype(dtype: int) -> bool: + return dtype in all_complex_dtypes() + +def all_float_dtypes() -> List[int]: + return [torch.float16, torch.bfloat16, torch.float32, torch.float64] + +def is_float_dtype(dtype: int) -> bool: + return dtype in all_float_dtypes() + +def get_priority_of_dtype(dtype: int) -> int: + # If a loop is used to iterate over a list of sorted dtypes, TorchScript + # produces a loop with INT64_MAX max trip count, which causes problems + # during the loop unrolling that takes place when simplifying the dtype + # functions. Therefore, here we result to `if`s. + if dtype == torch.bool: + return 0 + if dtype == torch.uint8: + return 1 + if dtype == torch.int8: + return 2 + if dtype == torch.int16: + return 3 + if dtype == torch.int32: + return 4 + if dtype == torch.int64: + return 5 + if dtype == torch.bfloat16: + return 6 + if dtype == torch.float16: + return 7 + if dtype == torch.float32: + return 8 + if dtype == torch.float64: + return 9 + if dtype == torch.complex64: + return 10 + if dtype == torch.complex128: + return 11 + assert False, "Cannot determine priority of dtype" + def get_dtype_of_scalar(scalar: Union[int, float]) -> int: # This is hacky. `NumToTensor` is the only PyTorch op for scalars # that when `jit.script`ed converts a float scalar to a tensor diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index efd270b78a7f9..c9387cc4c8f79 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -86,7 +86,7 @@ def _recursively_transform_tensor_args( o: Any, tensor_transformer: Callable[[TensorOfShape], Any]) -> Any: """Replace `TensorOfShape` with the result of `tensor_transformer`""" - if o is None or isinstance(o, (float, int)): + if o is None or isinstance(o, (float, int, str)): return o if isinstance(o, TensorOfShape): return tensor_transformer(o) @@ -146,7 +146,7 @@ def to_dtype_function_args(self): def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" - tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype) + tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype).to("meta") return _recursively_transform_tensor_args(self.args, tensor_transformer) def __repr__(self) -> str: @@ -258,6 +258,15 @@ def decorator(f): return f return decorator +@torch.jit.script +def _convert_dtype_to_int(dtype: torch.dtype) -> int: + """Convert a PyTorch `dtype` into its underlying `int` representation. + + This works because in TorchScript there is no special type for `dtypes`; + they are simply `int`s. + """ + return dtype + def check_dtype_function(invocations: List[Invocation]): """Decorator that automatically tests a dtype function. @@ -281,7 +290,12 @@ def decorator(f): golden_dtype = torch.tensor([]).to(type(golden_result)).dtype else: raise ValueError(f"Unhandled return type {type(golden_result)}") - if result_dtype != golden_dtype: + # Some dtype funtions have default `dtype` parameters, which are + # represented as `int` values in the registry. In order to + # support returning the default `int` value, the comparisons of + # the result and golden dtypes are done using their underlying + # `int` representation. + if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype): _report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}") return f return decorator diff --git a/test/Dialect/Torch/refine-types-branch.mlir b/test/Dialect/Torch/refine-types-branch.mlir index 87ff9657695a2..3c76ac95f9e93 100644 --- a/test/Dialect/Torch/refine-types-branch.mlir +++ b/test/Dialect/Torch/refine-types-branch.mlir @@ -117,37 +117,3 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option } : (!torch.int, !torch.bool, !torch.optional) -> (!torch.optional) return %ret: !torch.optional } - -// ----- - -// CHECK-LABEL: func.func @f -// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor - cf.br ^bb1(%cast: !torch.vtensor) -^bb1(%arg1: !torch.vtensor): - %1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- - -// CHECK-LABEL: func.func @f -// CHECK: func.func private @callee -// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> -func.func @f() { - builtin.module { - func.func private @callee(%arg0: !torch.vtensor) { - %1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor - return - } - func.func @caller(%arg0: !torch.vtensor<*,f32>) { - %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor - call @callee(%cast) : (!torch.vtensor) -> () - return - } - } - return -} diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir deleted file mode 100644 index 3c90de22893c7..0000000000000 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ /dev/null @@ -1,364 +0,0 @@ -// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s - -// This file is for tests for individual ops that require a new transfer -// function (i.e. new code called from visitOperation). - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$int64_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.int, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$float32_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.float, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$specified_dtype( -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange -// CHECK-SAME: %[[END]], %[[CST6]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor { - %int6 = torch.constant.int 6 - %none = torch.constant.none - %ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.linear( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>, -// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor { -// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32> -// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor { - %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.sum.dim_IntList( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 -// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]] -// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]] -// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor { - %false = torch.constant.bool false - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int-1 = torch.constant.int -1 - %dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.any.dim( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 -// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { - %false = torch.constant.bool false - %int-1 = torch.constant.int -1 - %ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.any( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { -// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { - %ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.zeros( -// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor { - %none = torch.constant.none - %int2 = torch.constant.int 2 - %sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.type_as( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>, -// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor { - %ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor - return %ret: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.cat( -// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor { - %int1 = torch.constant.int 1 - %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list - %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.cat$promote_type( -// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor { - %int1 = torch.constant.int 1 - %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list - %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten._shape_as_tensor( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor { - %ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor { - %ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.embedding( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>, -// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[PADDING_IDX:.*]] = torch.constant.int 1 -// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor { - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor - return %ret: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %int11 = torch.constant.int 11 - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %int11, %none, %false : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.softmax.int( -// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, -// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[DTYPE:.*]] = torch.constant.none -// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[RET]] : !torch.tensor -func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { - %none = torch.constant.none - %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.softmax.int$specified_dtype( -// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, -// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[DTYPE:.*]] = torch.constant.int 4 -// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[RET]] : !torch.tensor -func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { - %int4 = torch.constant.int 4 - %ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype -// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : -// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.tensor<*,si64> -// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK-NEXT: return %[[RES]] : !torch.tensor -func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ - %none = torch.constant.none - %false = torch.constant.bool false - %int4 = torch.constant.int 4 - %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar( -// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor { - %0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor - return %0: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor( -// CHECK-SAME: %[[DATA:.*]]: !torch.list>) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]] -// CHECK-SAME: : !torch.list>, !torch.none, !torch.none, !torch.bool -// CHECK-SAME: -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor(%t: !torch.list>) -> !torch.tensor { - %none = torch.constant.none - %false = torch.constant.bool false - %ret = torch.aten.tensor %t, %none, %none, %false : !torch.list>, !torch.none, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype( -// CHECK-SAME: %[[DATA:.*]]: !torch.list>) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT4:.*]] = torch.constant.int 4 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor$specified_dtype(%t: !torch.list>) -> !torch.tensor { - %none = torch.constant.none - %int4 = torch.constant.int 4 - %false = torch.constant.bool false - %ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list>, !torch.int, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 50d96b08e3881..df27220374960 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -6,160 +6,6 @@ // Code for testing transfer functions for new ops (which is most changes) // should go in refine-types-ops.mlir. -// ----- -// CHECK-LABEL: func.func @basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { -// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @keep_existing_shape_information( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { -// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32> -// CHECK: return %[[COS]] : !torch.vtensor<[2],f32> -func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32> - return %1 : !torch.vtensor<[2],f32> -} - -// ----- -// CHECK-LABEL: func.func @propagate_through_multiple_ops( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { -// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[COS3]] : !torch.vtensor -func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - %2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor - %3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor - return %3 : !torch.vtensor -} - -// ----- -// Check rewriting logic in case of mixes of users that do/don't allow type -// refinement. -// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { -// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor -func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - %3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor - return %1, %1 : !torch.vtensor, !torch.vtensor -} - -// ----- - -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static( -// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor - %static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32> - return %result : !torch.vtensor<[2],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { -// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32> -// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - %dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor - torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor - %dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32> - return %result : !torch.vtensor<[?],f32> -} - -// ----- -// CHECK-LABEL: func.func @bf16_result_type( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { -// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16> -// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16> -func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { - %1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16> - return %1 : !torch.vtensor<[2],bf16> -} - -// ----- -// CHECK-LABEL: func.func @propagate_scalar_type( -// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { -// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number -// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int -// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number -// CHECK: return %[[RET]] : !torch.number -func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number { - %num = torch.derefine %arg0 : !torch.int to !torch.number - %1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number - return %1 : !torch.number -} - -// ----- -// CHECK-LABEL: func.func @prim.dtype( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor { - -// CHECK: %[[zero:.*]] = torch.constant.int 0 -// CHECK: %[[false:.*]] = torch.constant.bool false - -// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16> -// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor -// CHECK: return %[[cast]] : !torch.vtensor -// CHECK: } - -func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> { - %zero = torch.constant.int 0 - %false = torch.constant.bool false - - // Op that requires type refinement - %neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk> - - // Op whose processing requires type refinement on its source argument. - %dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int - %device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device - - // Another op that requires type refinement - %result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - // Repeat the above three ops a second time to ensure that the type refinement - // code works regardless of the number of alternating refinement+prim.dtype - // sequences. - %dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int - %device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device - %result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - return %result2 : !torch.vtensor<*,unk> -} - // ----- // Check that we don't crash on this input. @@ -194,27 +40,6 @@ func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) { // ----- -// CHECK-LABEL: func.func @torch.aten.zeros_like( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[CPU:.*]] = torch.constant.device "cpu" -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32> -// CHECK: return -func.func @torch.aten.zeros_like(%arg: !torch.vtensor) { - %int6 = torch.constant.int 6 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %cpu = torch.constant.device "cpu" - %2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor - return -} - -// ----- - // The data-flow analysis does not always propagate information to the entire graph. // This results in some lattice elements being uninitialized, which must be properly // handled when using the lattice elements to rewrite the graph.