From 8cae5ba50710dd2952dd9fc06de0b770bfc36dc1 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 16 Jan 2023 14:32:23 -0800 Subject: [PATCH] Add dtype functions for comparison ops (#1806) This commit adds dtype functions for comparison ops that always return a tensor of dtype `i1`. --- .../Transforms/AbstractInterpLibrary.cpp | 280 ++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 12 - .../build_tools/abstract_interp_lib_gen.py | 104 ++++++- test/Dialect/Torch/refine-types-ops.mlir | 26 -- 4 files changed, 370 insertions(+), 52 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b1a35b112b6e..88b0555ea79f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -4164,6 +4164,226 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %5 : !torch.list\n" " }\n" +" func.func @__torch__.torch.jit._shape_functions.movedim(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.le.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" torch.prim.If.yield %arg0 : !torch.list\n" +" } else {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %26 = torch.aten.neg.int %25 : !torch.int -> !torch.int\n" +" %27 = torch.aten.sub.int %25, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.lt.int %23, %26 : !torch.int, !torch.int -> !torch.bool\n" +" %29 = torch.prim.If %28 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.gt.int %23, %27 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %45 : !torch.bool\n" +" }\n" +" %30 = torch.aten.__not__ %29 : !torch.bool -> !torch.bool\n" +" torch.prim.If %30 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %31 = torch.aten.lt.int %23, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %32 = torch.prim.If %31 -> (!torch.int) {\n" +" %45 = torch.aten.add.int %23, %25 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %45 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %23 : !torch.int\n" +" }\n" +" %33 = torch.aten.append.t %3, %32 : !torch.list, !torch.int -> !torch.list\n" +" %34 = torch.aten.__getitem__.t %arg2, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %36 = torch.prim.If %35 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" %37 = torch.aten.neg.int %36 : !torch.int -> !torch.int\n" +" %38 = torch.aten.sub.int %36, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %39 = torch.aten.lt.int %34, %37 : !torch.int, !torch.int -> !torch.bool\n" +" %40 = torch.prim.If %39 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.gt.int %34, %38 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %45 : !torch.bool\n" +" }\n" +" %41 = torch.aten.__not__ %40 : !torch.bool -> !torch.bool\n" +" torch.prim.If %41 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %42 = torch.aten.lt.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %43 = torch.prim.If %42 -> (!torch.int) {\n" +" %45 = torch.aten.add.int %34, %36 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %45 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %34 : !torch.int\n" +" }\n" +" %44 = torch.aten.append.t %4, %43 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.append.t %6, %int-1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.append.t %7, %arg3 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %8 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %0, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.append.t %8, %arg3 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %9 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %3, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %4, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten._set_item.t %6, %24, %23 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %26 = torch.aten.__getitem__.t %3, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten._set_item.t %7, %26, %int-1 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %28 = torch.aten.__getitem__.t %4, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %29 = torch.aten._set_item.t %8, %28, %int-1 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %10 = torch.prim.ListConstruct : () -> !torch.list\n" +" %11 = torch.prim.ListConstruct : () -> !torch.list\n" +" %12 = torch.aten.len.t %7 : !torch.list -> !torch.int\n" +" torch.prim.Loop %12, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %7, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.ne.int %23, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %24 -> () {\n" +" %25 = torch.aten.append.t %10, %23 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %13 = torch.aten.len.t %8 : !torch.list -> !torch.int\n" +" torch.prim.Loop %13, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %8, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.ne.int %23, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %24 -> () {\n" +" %25 = torch.aten.append.t %11, %23 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %14 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %15 = torch.aten.sub.int %0, %14 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %15, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %10, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %11, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten._set_item.t %6, %24, %23 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %16 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %17 = torch.aten.len.t %6 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %16, %17 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %18 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %19 = torch.aten.len.t %6 : !torch.list -> !torch.int\n" +" %20 = torch.prim.ListConstruct : () -> !torch.list\n" +" %21 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.Loop %19, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %6, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten.le.int %19, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %19 : !torch.int\n" +" }\n" +" %26 = torch.aten.neg.int %25 : !torch.int -> !torch.int\n" +" %27 = torch.aten.sub.int %25, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.lt.int %23, %26 : !torch.int, !torch.int -> !torch.bool\n" +" %29 = torch.prim.If %28 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %36 = torch.aten.gt.int %23, %27 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %36 : !torch.bool\n" +" }\n" +" %30 = torch.aten.__not__ %29 : !torch.bool -> !torch.bool\n" +" torch.prim.If %30 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %31 = torch.aten.lt.int %23, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %32 = torch.prim.If %31 -> (!torch.int) {\n" +" %36 = torch.aten.add.int %23, %25 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %36 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %23 : !torch.int\n" +" }\n" +" %33 = torch.aten.append.t %20, %32 : !torch.list, !torch.int -> !torch.list\n" +" %34 = torch.aten.__getitem__.t %arg0, %32 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.append.t %21, %34 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %22 = torch.aten.__range_length %int1, %19, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %22, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %23 = torch.aten.__derive_index %arg3, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %23, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %24 = torch.aten.__getitem__.t %20, %23 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.__getitem__.t %20, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.ne.int %24, %25 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %21 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @__torch__.torch.jit._shape_functions.view(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: invalid shape\"\n" " %false = torch.constant.bool false\n" @@ -7130,6 +7350,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union, %arg1: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 6d66995cc1b8..414e3766059a 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -699,18 +699,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - // Dtype is always i1. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = IntegerType::get(op->getContext(), 1); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - // Dtype is always si64. if (isa(op)) { auto knowledge = diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index d18e5a0e05e2..5502fe551a00 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1017,6 +1017,95 @@ def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], s # Dtype Functions # ============================================================================== +def _get_invocations_for_op_with_tensor_arg_followed_by(*args): + """Generate invocations that thoroughly test the first tensor arg of the op. + + This is meant to be used by ops where the entire dtype computation involves + at most the first tensor argument of the op. If an dtype function uses other + arguments, custom invocations should be created to test the logic of the + dtype function instead of using this helper function. + """ + return [ + Invocation(NonZeroDTensorWithDtype(torch.float32), *args), + Invocation(NonZeroDTensorWithDtype(torch.float64), *args), + Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), + Invocation(NonZeroDTensorWithDtype(torch.int64), *args), + Invocation(NonZeroDTensorWithDtype(torch.int32), *args), + Invocation(NonZeroDTensorWithDtype(torch.bool), *args), + Invocation(ZeroDTensorWithDtype(torch.float32), *args), + Invocation(ZeroDTensorWithDtype(torch.float64), *args), + Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), + Invocation(ZeroDTensorWithDtype(torch.int64), *args), + Invocation(ZeroDTensorWithDtype(torch.int32), *args), + Invocation(ZeroDTensorWithDtype(torch.bool), *args), +] + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇any〡dtype(self_rank: int, self_dtype: int) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +def aten〇eq〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function( + _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +def aten〇eq〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +def aten〇ge〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +def aten〇gt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function( + _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +def aten〇gt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +def aten〇le〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function( + _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +def aten〇logical_and〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇logical_not〡dtype(self_rank: int, self_dtype: int) -> int: + return torch.bool + +@check_dtype_function( + _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +def aten〇logical_or〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + return torch.bool + +@check_dtype_function( + _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +def aten〇logical_xor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +def aten〇lt〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function( + _get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float))) +def aten〇lt〇Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: + return torch.bool + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0)) +def aten〇ne〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int: + return torch.bool + @check_dtype_function([ Invocation(0.0, 0.0), # float, float Invocation(0.0, 0), # float, int @@ -1078,20 +1167,7 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32)), - Invocation(NonZeroDTensorWithDtype(torch.float64)), - Invocation(NonZeroDTensorWithDtype(torch.bfloat16)), - Invocation(NonZeroDTensorWithDtype(torch.int64)), - Invocation(NonZeroDTensorWithDtype(torch.int32)), - Invocation(NonZeroDTensorWithDtype(torch.bool)), - Invocation(ZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float64)), - Invocation(ZeroDTensorWithDtype(torch.bfloat16)), - Invocation(ZeroDTensorWithDtype(torch.int64)), - Invocation(ZeroDTensorWithDtype(torch.int32)), - Invocation(ZeroDTensorWithDtype(torch.bool)), -]) +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: return self_dtype diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 261b0ecc8efc..81c93b511091 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -92,32 +92,6 @@ func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor { return %ret : !torch.vtensor } -// ----- -// CHECK-LABEL: func.func @aten.any.dim( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 -// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { - %false = torch.constant.bool false - %int-1 = torch.constant.int -1 - %ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.any( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { -// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { - %ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor - return %ret : !torch.vtensor -} - // ----- // CHECK-LABEL: func.func @torch.aten.zeros( // CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {