From 981ac88758df33d0238b7a781630125774b165b0 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 1 Feb 2023 22:30:27 +0000 Subject: [PATCH] Add dtype functions for two tensor promotion ops (#1831) This commit adds dtype functions for ops in RefineTypes under the category of "Promote the two dtypes". The only ops not added here are convolution ops, since they take an optional tensor argument, and the dtype pipeline currently does not correctly handle that case. I will add a follow up patch fixing this. This commit also adds two helper functions that perform a very thorough testing of dtype functions. The helper function `_check_two_tensor_op` is able to independently test invalid input dtypes and invalid output dtypes. Lastly, this commit also XFAILs "MobilenetV3Module_basic". --- e2e_testing/xfail_sets.py | 5 + .../Transforms/AbstractInterpLibrary.cpp | 974 ++++++++++++++++-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 21 +- .../Transforms/SimplifyDtypeCalculations.cpp | 2 + .../build_tools/abstract_interp_lib_gen.py | 515 +++++++-- .../jit_ir/build_tools/library_generator.py | 10 + .../test_suite/__init__.py | 1 + test/Dialect/Torch/refine-types-ops.mlir | 24 - 8 files changed, 1285 insertions(+), 267 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index ae1a9b8bebfe..836ed7a08c98 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -85,6 +85,11 @@ "ElementwisePreluModule_basic", # error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792 "StdCorrectionKeepDimModule_basic", + + # Dtype function transition failures + "MobilenetV3Module_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", } MHLO_PASS_SET = { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c0a0df5ee9f1..682048b93654 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7351,80 +7351,225 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" " %int6 = torch.constant.int 6\n" -" %int5 = torch.constant.int 5\n" -" %int15 = torch.constant.int 15\n" -" %int7 = torch.constant.int 7\n" -" %0 = torch.prim.ListConstruct %int7, %int15, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" %4 = torch.aten.ne.int %arg0, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" " torch.prim.If.yield %arg0 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" return %2 : !torch.int\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %int7 = torch.constant.int 7\n" +" %int6 = torch.constant.int 6\n" +" %int15 = torch.constant.int 15\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %int10 = torch.constant.int 10\n" +" %int9 = torch.constant.int 9\n" +" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n" +" %int5 = torch.constant.int 5\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %3 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" " %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %int11 = torch.constant.int 11\n" -" %int3 = torch.constant.int 3\n" -" %int4 = torch.constant.int 4\n" -" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" -" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %4 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -7434,14 +7579,57 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" " return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %int11 = torch.constant.int 11\n" -" %int3 = torch.constant.int 3\n" +" %int5 = torch.constant.int 5\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %6 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %5 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" " %int4 = torch.constant.int 4\n" -" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" " %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" @@ -7449,20 +7637,61 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %3 : !torch.int\n" +" %3 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %4:2 = torch.prim.If %3 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %7:2 = torch.prim.If %6 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %7#0, %7#1 : !torch.bool, !torch.int\n" +" }\n" +" %5 = torch.prim.If %4#0 -> (!torch.int) {\n" +" torch.prim.If.yield %4#1 : !torch.int\n" +" } else {\n" +" %6 = func.call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" -" return %0 : !torch.int\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.ne.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" @@ -7474,18 +7703,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %int11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %int11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %int11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %int11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" @@ -7506,10 +7783,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %int11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %int11 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" @@ -7537,94 +7842,537 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %int11 = torch.constant.int 11\n" -" %int7 = torch.constant.int 7\n" -" %int6 = torch.constant.int 6\n" " %int10 = torch.constant.int 10\n" -" %true = torch.constant.bool true\n" +" %int7 = torch.constant.int 7\n" " %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" " %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %4 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" -" }\n" -" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" " torch.prim.If.yield %arg1 : !torch.int\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" %3 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" -" %6 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" %5 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" " torch.prim.If.yield %int10 : !torch.int\n" " } else {\n" -" %8 = torch.aten.eq.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" %9 = torch.prim.If %8 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" }\n" -" %10 = torch.prim.If %9 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.eq.int %arg1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" }\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.eq.int %arg1, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" }\n" -" %12 = torch.prim.If %11 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.eq.int %arg1, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" }\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.eq.int %arg1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" }\n" -" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield %0 : !torch.int\n" " }\n" -" torch.prim.If.yield %14 : !torch.int\n" +" torch.prim.If.yield %8 : !torch.int\n" " }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" torch.prim.If.yield %6 : !torch.int\n" " }\n" -" torch.prim.If.yield %5 : !torch.int\n" +" torch.prim.If.yield %4 : !torch.int\n" " }\n" -" return %3 : !torch.int\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: `self` cannot have bool dtype\"\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.aten.ne.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %arg1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" +" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Expected `other` to have integer dtype\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: Expected `self` to have integer dtype\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" +" %int11 = torch.constant.int 11\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %8 = torch.aten.ne.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %8 = torch.aten.ne.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Result dtype for aten.floor_divide bool\"\n" +" %int11 = torch.constant.int 11\n" +" %str_0 = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" %7 = torch.aten.ne.int %6, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: `self` and `other` must have the same dtype\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `other` to not be float16 or bool\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" +" %int11 = torch.constant.int 11\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: `self` and `mat2` must have the same dtype\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `mat2` to not be float16 or bool\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" +" %int11 = torch.constant.int 11\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Expected promoted dtype to be float but not `bfloat16`\"\n" +" %false = torch.constant.bool false\n" +" %int15 = torch.constant.int 15\n" +" %str_0 = torch.constant.str \"AssertionError: `target` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %9 = torch.aten.ne.int %6, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" " %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" " %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %2 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: `self` and `vec` must have the same dtype\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected dtype of `vec` to not be float16 or bool\"\n" " %none = torch.constant.none\n" -" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" -" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %3 : !torch.int\n" +" %str_1 = torch.constant.str \"AssertionError: Expected dtype of `self` to not be float16 or bool\"\n" +" %int11 = torch.constant.int 11\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int5, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %arg3 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %arg1, %arg3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %8 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %9 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %9 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: `other` cannot be of bool dtype\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: `self` cannot be of bool dtype\"\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.aten.ne.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1 = torch.aten.ne.int %arg3, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.union) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: `grad_output` cannot be complex\"\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg1) : (!torch.int) -> !torch.bool\n" +" %1 = torch.aten.__not__ %0 : !torch.bool -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg3) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" +" %5 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%4, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" %7 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = torch.aten.__contains__.int_list %7, %6 : !torch.list, !torch.int -> !torch.bool\n" +" %9 = torch.aten.__not__ %8 : !torch.bool -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %6 : !torch.int\n" " }\n" "}\n" ""; diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 94436925e5cf..34acd2c80744 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -693,10 +693,9 @@ void TypeAnalysis::visitOperation(Operation *op, } // Promote the two dtypes assuming non-zero rank. - if (isa(op)) { + if (isa(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( @@ -705,20 +704,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - // Promote the two dtypes assuming possibly-zero rank. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, - getRankIsNonZeroArray(op->getOperands())); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - // Dtype is always float32, except for bfloat16, float64 and nullptr after // promotion and assuming possible-zero rank. if (isa(op)) { diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 9a29b976a6e9..4e0411552cf0 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -190,6 +190,8 @@ class SimplifyDtypeCalculationsPass patterns.insert(context); patterns.insert(context); + PrimIfOp::getCanonicalizationPatterns(patterns, context); + // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. GreedyRewriteConfig config; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index d95aba4b2af4..65282ea94096 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -12,7 +12,7 @@ import torch.jit._shape_functions as upstream_shape_functions from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function -from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar +from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype # ============================================================================== # Shape Functions @@ -1017,180 +1017,295 @@ def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], s # Dtype Functions # ============================================================================== -def _get_invocations_for_op_with_tensor_arg_followed_by(*args): - """Generate invocations that thoroughly test the first tensor arg of the op. +# All the torch types sorted in decreasing order of priority during type promotion. +_SORTED_TORCH_TYPES = [ + torch.complex128, torch.complex64, + torch.float64, torch.float32, torch.float16, torch.bfloat16, + torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool +] - This is meant to be used by ops where the entire dtype computation involves - at most the first tensor argument of the op. If an dtype function uses other - arguments, custom invocations should be created to test the logic of the - dtype function instead of using this helper function. +def _check_tensors_with_the_same_dtype( + num_of_tensors: Optional[int] = None, + tensor_shapes: Optional[list[tuple[int]]] = None, + error_types: Optional[set[int]] = None, *args, **kwargs): + """Create invocations where all tensors have the same dtype. + + This function generates invocations with `num_of_tensors` tensors + that all have the same dtype. It creates an invocation for every + possible dtype. For dtypes in `error_types`, the invocations are + error invocations. + + One can also specify the shapes of the tensors. Either `num_of_tensors` + or `tensor_shapes` must be specified whenever this function is called. + + The extra *args and **kwargs arguments are passed to the invocations. """ - return [ - Invocation(NonZeroDTensorWithDtype(torch.float32), *args), - Invocation(NonZeroDTensorWithDtype(torch.float64), *args), - Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), - Invocation(NonZeroDTensorWithDtype(torch.int64), *args), - Invocation(NonZeroDTensorWithDtype(torch.int32), *args), - Invocation(NonZeroDTensorWithDtype(torch.bool), *args), - Invocation(ZeroDTensorWithDtype(torch.float32), *args), - Invocation(ZeroDTensorWithDtype(torch.float64), *args), - Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), - Invocation(ZeroDTensorWithDtype(torch.int64), *args), - Invocation(ZeroDTensorWithDtype(torch.int32), *args), - Invocation(ZeroDTensorWithDtype(torch.bool), *args), - ] - -def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args): - """Generate invocations for floating point only op.""" - return [ - Invocation(NonZeroDTensorWithDtype(torch.float32), *args), - Invocation(NonZeroDTensorWithDtype(torch.float64), *args), - Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), - ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args), - ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args), - ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args), - Invocation(ZeroDTensorWithDtype(torch.float32), *args), - Invocation(ZeroDTensorWithDtype(torch.float64), *args), - Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), - ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args), - ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args), - ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args), - ] + invocations = [] + for type_ in _SORTED_TORCH_TYPES: + tensors = [] + if tensor_shapes is None and num_of_tensors is not None: + tensors = [NonZeroDTensorWithDtype(type_)] * num_of_tensors + elif tensor_shapes is not None and num_of_tensors is None: + for tensor_shape in tensor_shapes: + tensors.append(TensorOfShape(*tensor_shape, dtype=type_)) + else: + assert False, \ + "Either `num_of_tensors` or `tensor_shapes` must be specified" + + if error_types is not None and type_ in error_types: + invocations.append(ErrorInvocation(*tensors, *args, **kwargs)) + else: + invocations.append(Invocation(*tensors, *args, **kwargs)) + return invocations + +def _check_two_tensor_op( + input_error_types: Optional[set[int]] = None, + output_error_types: Optional[set[int]] = None, **kwargs): + """Generate invocations for basic two-tensor dtype functions. + + This helper function is meant to be used to check dtype functions that + take two tensor operands and either return the promoted result or + return a constant dtype based on the tensor dtypes. + + The testing performed is thorough enough to be able to detect if dtypes + are invalid as inputs or as outputs to the PyTorch op. Invalid dtypes + must be specified in `input_error_types` and `output_error_types` to + ensure the invocations are error invocations. + """ + if input_error_types is not None and output_error_types is not None: + assert len(input_error_types.intersection(output_error_types)) == 0, \ + "An invalid input type implies an invalid output type, " \ + "so there is no need to repeat the type in the `output_error_types` set" + all_error_types = set() + all_error_types |= set() if input_error_types is None else input_error_types + all_error_types |= set() if output_error_types is None else output_error_types + + def check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs): + """Create invocations where one tensor varies its dtype. + + This helper function creates invocations with two tensors where one + tensor varies its dtype while the other one stays constant. The varying + is done for both tensors and the varying is performed over every possible + dtype. + + This function helps identify when a dtype is an invalid input dtype + for dtype functions that do promotion. + """ + # We will only create invocations for dtypes with priorities less than + # or equal to the highest priority valid type. By setting the non-varying + # tensor dtype to be the highest priority valid type, we ensure that + # every promotion results in a valid dtype. This allows the invocations + # to test in isolation assertions on input types. + constant_type = None + constant_type_index = None + for e, type_ in enumerate(_SORTED_TORCH_TYPES): + if type_ not in all_error_types: + constant_type = type_ + constant_type_index = e + break + assert constant_type is not None, \ + "Unable to find a constant type. Make sure the union of " \ + "`input_error_types` and `output_error_types` is not all possible types." + + invocations = [] + for type_ in _SORTED_TORCH_TYPES[constant_type_index:]: + tensor_1 = NonZeroDTensorWithDtype(type_) + tensor_2 = NonZeroDTensorWithDtype(constant_type) + if input_error_types is not None and type_ in input_error_types: + invocations += [ErrorInvocation(tensor_1, tensor_2, **kwargs), + ErrorInvocation(tensor_2, tensor_1, **kwargs)] + else: + invocations += [Invocation(tensor_1, tensor_2, **kwargs), + Invocation(tensor_2, tensor_1, **kwargs)] + return invocations + + same_dtype_invocations = _check_tensors_with_the_same_dtype( + num_of_tensors=2, error_types=all_error_types, **kwargs) + + varying_dtype_invocations = \ + check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs) + return same_dtype_invocations + varying_dtype_invocations def _get_dtype_of_floating_point_op(input_dtype: int) -> int: - if input_dtype in (torch.float64, torch.bfloat16, torch.float16): + if (is_float_dtype(input_dtype) and input_dtype != torch.float32) \ + or is_complex_dtype(input_dtype): return input_dtype return torch.float32 -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert self_dtype != torch.float16, "`self` cannot have float16 dtype" return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int: return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128})) def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int: + assert not is_complex_dtype(self_dtype) and self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, + torch.int64, torch.float16, torch.complex64, torch.complex128})) def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: - assert self_dtype not in (torch.int64, torch.int32, torch.bool) + assert not is_complex_dtype(self_dtype) and not is_integer_dtype(self_dtype) and self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0])) +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0])) def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int: - assert self_dtype not in (torch.int64, torch.int32, torch.bool) + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex128: + return torch.float64 + elif self_dtype == torch.complex64: + return torch.float32 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16})) def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: + assert self_dtype != torch.float16 return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int: - return torch.bool + return torch.uint8 if self_dtype == torch.uint8 else torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇any〡dtype(self_rank: int, self_dtype: int) -> int: - return torch.bool + return torch.uint8 if self_dtype == torch.uint8 else torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) def aten〇eq〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: return torch.bool -@check_dtype_function( - _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +@check_dtype_function(_check_two_tensor_op()) def aten〇eq〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: return torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) def aten〇ge〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) def aten〇gt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function( - _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) + _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) def aten〇gt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) def aten〇le〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool -@check_dtype_function( - _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +@check_dtype_function(_check_two_tensor_op()) def aten〇logical_and〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: return torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇logical_not〡dtype(self_rank: int, self_dtype: int) -> int: return torch.bool -@check_dtype_function( - _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +@check_dtype_function(_check_two_tensor_op()) def aten〇logical_or〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: return torch.bool -@check_dtype_function( - _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +@check_dtype_function(_check_two_tensor_op()) def aten〇logical_xor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: return torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0)) def aten〇lt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" return torch.bool @check_dtype_function( - _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) + _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) def aten〇lt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) def aten〇ne〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: return torch.bool @@ -1205,55 +1320,231 @@ def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] return promote_dtypes(ranks, dtypes) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.complex64)), - Invocation(NonZeroDTensorWithDtype(torch.complex128)), - Invocation(NonZeroDTensorWithDtype(torch.float)), - Invocation(NonZeroDTensorWithDtype(torch.double)), - Invocation(NonZeroDTensorWithDtype(torch.bool)), - Invocation(NonZeroDTensorWithDtype(torch.uint8)), - Invocation(NonZeroDTensorWithDtype(torch.int8)), - Invocation(NonZeroDTensorWithDtype(torch.int16)), - Invocation(NonZeroDTensorWithDtype(torch.int32)), - Invocation(NonZeroDTensorWithDtype(torch.int64)), - ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)), - ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)), -]) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.float16, torch.bfloat16})) def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: - if self_dtype == torch.complex64 or self_dtype == torch.complex128: + if is_complex_dtype(self_dtype): return self_dtype elif self_dtype == torch.float: return torch.complex64 elif self_dtype == torch.double: return torch.complex128 - elif self_dtype == torch.bool or self_dtype == torch.uint8 or \ - self_dtype == torch.int8 or self_dtype == torch.int16 or \ - self_dtype == torch.int32 or self_dtype == torch.int64: + elif is_integer_dtype(self_dtype): return torch.complex64 else: assert False, "Unsupported dtype" -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float64), NonZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)), - Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)), -]) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.bool}, other=0.0) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={torch.bool}, other=0)) +def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: + assert self_dtype != torch.bool, "`self` cannot have bool dtype" + return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) + +@check_dtype_function( + _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇__and__〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" + assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇add〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int: + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_and〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" + assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_or〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" + assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32, + torch.float64, torch.complex64, torch.complex128})) +def aten〇bitwise_xor〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype" + assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(2, 3, 4), (2, 4, 3)], error_types={torch.float16, torch.bool}) + + # Different width + [ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇bmm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int: + assert self_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `self` to not be float16 or bool" + assert mat2_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `mat2` to not be float16 or bool" + assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype" + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇div〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(rounding_mode=None)) +def aten〇div〇Tensor_mode〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, rounding_mode: Optional[str]) -> int: + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool})) def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(other_dtype), "`other` cannot be complex" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool" + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(2, 3, 4), (2, 4, 3)], error_types={torch.float16, torch.bool}) + + # Different width + [ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇matmul〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert self_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `self` to not be float16 or bool" + assert other_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `other` to not be float16 or bool" + assert self_dtype == other_dtype, "`self` and `other` must have the same dtype" + return self_dtype + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) +def aten〇maximum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(other_dtype), "`other` cannot be complex" ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32), other=0), - Invocation(NonZeroDTensorWithDtype(torch.int64), other=0.0), - Invocation(NonZeroDTensorWithDtype(torch.float16), other=0.0), - Invocation(ZeroDTensorWithDtype(torch.float32), other=0), - Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0), - Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0) -]) -def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: - return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) +def aten〇minimum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(other_dtype), "`other` cannot be complex" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(3, 4), (4, 3)], error_types={torch.float16, torch.bool}) + + # Different width + [ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32))]) +def aten〇mm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int: + assert self_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `self` to not be float16 or bool" + assert mat2_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `mat2` to not be float16 or bool" + assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype" + return self_dtype + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, + output_error_types={torch.bool, torch.uint8, torch.int8, torch.int16, + torch.int32, torch.int64, torch.bfloat16})) +def aten〇mse_loss〡dtype(self_rank: int, self_dtype: int, target_rank: int, target_dtype: int, reduction: int = 1) -> int: + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(target_dtype), "`target` cannot be complex" + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert is_float_dtype(promoted_dtype) and promoted_dtype != torch.bfloat16, \ + "Expected promoted dtype to be float but not `bfloat16`" + return promoted_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇mul〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(3, 4), (4,)], error_types={torch.float16, torch.bool}) + + # Different width + [ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, dtype=torch.float32)), + # Different type + ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, dtype=torch.int32))]) +def aten〇mv〡dtype(self_rank: int, self_dtype: int, vec_rank: int, vec_dtype: int) -> int: + assert self_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `self` to not be float16 or bool" + assert vec_dtype not in [torch.float16, torch.bool], \ + "Expected dtype of `vec` to not be float16 or bool" + assert self_dtype == vec_dtype, "`self` and `vec` must have the same dtype" + ranks: List[Optional[int]] = [self_rank, vec_rank] + dtypes = [self_dtype, vec_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.bool})) +def aten〇sub〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int: + assert self_dtype != torch.bool, "`self` cannot be of bool dtype" + assert other_dtype != torch.bool, "`other` cannot be of bool dtype" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool, torch.float16}, threshold=0)) +def aten〇threshold_backward〡dtype(grad_output_rank: int, grad_output_dtype: int, self_rank: int, self_dtype: int, threshold: Union[int, float]) -> int: + assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex" + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert promoted_dtype not in [torch.bool, torch.float16], \ + "Result dtype for aten.threshold_backward cannot be bool or float16" + return promoted_dtype # ============================================================================== # Main diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index de499a2e25b4..fa962f41d64c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -14,6 +14,16 @@ from .registry import Registry +def is_integer_dtype(dtype: int) -> bool: + return dtype in [torch.bool, torch.uint8, torch.int8, + torch.int16, torch.int32, torch.int64] + +def is_complex_dtype(dtype: int) -> bool: + return dtype in [torch.complex64, torch.complex128] + +def is_float_dtype(dtype: int) -> bool: + return dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64] + def get_dtype_of_scalar(scalar: Union[int, float]) -> int: # This is hacky. `NumToTensor` is the only PyTorch op for scalars # that when `jit.script`ed converts a float scalar to a tensor diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index 7d133a87f9eb..d3795307fbf1 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -9,6 +9,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", "NormalizeModule_basic", + "MobilenetV3Module_basic", } def register_all_tests(): diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 81c93b511091..e058e0d6773e 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -235,30 +235,6 @@ func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, return %ret : !torch.tensor } -// ----- -// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor - return %0 : !torch.tensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.to.dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor