From f593580a5d1a6a11807923b9b8d69d00d6910a93 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 13 Feb 2023 17:56:09 -0800 Subject: [PATCH] Change dtype functions interface to take ints tuple for each tensor The original design for the dtype functions outlined in https://github.com/llvm/torch-mlir/issues/1462 was unable to properly handle ops that take optional tensors as an input when the optional tensor has a value of None. By the time the op gets imported into torch-mlir, if an optional value is None, all information about the original type is lost from the op type signature, preventing torch-mlir from knowing if a value of None was from an optional tensor or not, which was crucial in the original design since each tensor argument must be turned into two separate arguments for the dtype function. This commit changes the interface to dtype functions such that each tensor turns into a tuple of two ints, the first representing the rank of the tensor and the second the dtype of the tensor. Since now there is a one-to-one correspondence between the operands of an op and the operands of its dtype function, there is no ambiguity about which operand of the op corresponds with which operand of the dtype function. To test the implementation, this commit defines dtype function for convolution op, which takes one optional tensor as an argument. --- ...dding_abstract_interpretation_functions.md | 2 +- .../Transforms/AbstractInterpLibrary.cpp | 141 +++++++++++------- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 5 +- .../Transforms/ReifyDtypeCalculations.cpp | 77 ++-------- .../build_tools/abstract_interp_lib_gen.py | 47 +++++- .../importer/jit_ir/build_tools/registry.py | 7 +- .../jit_ir/build_tools/testing_framework.py | 38 +---- .../Torch/reify-dtype-calculations.mlir | 22 ++- 8 files changed, 178 insertions(+), 161 deletions(-) diff --git a/docs/adding_abstract_interpretation_functions.md b/docs/adding_abstract_interpretation_functions.md index 68e1b98b78ac..b5e427e1adfd 100644 --- a/docs/adding_abstract_interpretation_functions.md +++ b/docs/adding_abstract_interpretation_functions.md @@ -21,7 +21,7 @@ We will use the example of adding support for the `torch.aten.tanh` op. function signatures are: - `def aten〇tanh〡shape(self: List[int]) -> List[int]:` - - `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:` + - `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:` Note the use of `〇` as a separator since `.` or `::` aren't legal in a Python identifier. diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a371e291d66a..5b91bd6532e1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -5979,31 +5979,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int5 = torch.constant.int 5\n" " %int15 = torch.constant.int 15\n" " %true = torch.constant.bool true\n" " %int7 = torch.constant.int 7\n" -" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %5 = torch.aten.eq.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " }\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %5 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " }\n" -" %3 = torch.prim.If %2 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" return %3 : !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" @@ -6236,13 +6237,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" -" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" -" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %3 : !torch.int\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" " %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" @@ -6972,11 +6974,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %2 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" @@ -7162,6 +7166,36 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" +" %int10 = torch.constant.int 10\n" +" %int9 = torch.constant.int 9\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int11, %int5, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7639,7 +7673,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int4 = torch.constant.int 4\n" @@ -7654,68 +7688,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %true = torch.constant.bool true\n" " %int9 = torch.constant.int 9\n" " %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %5 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " }\n" -" %3 = torch.prim.If %2 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" %5 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" -" %6 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" %7 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" " torch.prim.If.yield %int10 : !torch.int\n" " } else {\n" -" %8 = torch.aten.eq.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" %9 = torch.prim.If %8 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" }\n" +" %9 = torch.aten.eq.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " %10 = torch.prim.If %9 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" " %11 = torch.prim.If %10 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" " %12 = torch.prim.If %11 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" " %13 = torch.prim.If %12 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %16 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" -" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %15 = torch.prim.If %14 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield %0 : !torch.int\n" " }\n" -" torch.prim.If.yield %14 : !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" " }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" torch.prim.If.yield %8 : !torch.int\n" " }\n" -" torch.prim.If.yield %5 : !torch.int\n" +" torch.prim.If.yield %6 : !torch.int\n" " }\n" -" return %3 : !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index ad668ecad585..c3bfe7d8c8d8 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -714,9 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + Aten_ConvolutionOp, AtenMvOp, AtenConvolutionOverrideableOp, + AtenConvTranspose2dInputOp, AtenMseLossOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index e66cc76fe082..9eac538743b4 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -19,55 +19,25 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -static bool isTensorTypeOrWrappedTensorType(Type type) { - // Allowing tuples as arguments to dtype calculation functions can cause - // issues. For example, if an argument is a tuple of tensors and ints, there - // would be no way of differentiating the original ints from the ints created - // to represent the dtype and rank of the tensors. Therefore, to avoid this - // and keep things simple, the tuple type is not allowed. This works well in - // practice, since PyTorch op signatures don't seem to take tuples as inputs. - assert(!type.isa() && - "dtype calculation functions are expected to not have tuples of " - "tensors as arguments"); - - if (type.isa()) - return true; - - if (auto optionalType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(optionalType.getContainedType()); - } else if (auto listType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(listType.getContainedType()); - } else { - return false; - } -} - // Massage the op operands to match the dtype function signature. // The dtype function generally takes the same operands as the op, with a few -// systematic modifications, such as replacing tensors with a rank and dtype -// argument. +// systematic modifications, such as replacing each tensor with a tuple of +// its rank and dtype. static FailureOr> dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, ValueRange originalOperands, func::FuncOp dtypeFunc) { - // Turns a tensor operand into an operand representing the rank of the tensor - auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand, - Type desiredType) -> Value { - if (desiredType.isa() && - operand.getType().isa()) { - auto sizeListType = - Torch::ListType::get(Torch::IntType::get(b.getContext())); - Value size = b.create(loc, sizeListType, operand); - return b.create(loc, desiredType, size); - } - return operand; - }; - - // Turns a tensor operand into an operand representing the dtype of the tensor + // Turn every tensor into a tuple of (tensor_rank, tensor_dtype) auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, Type desiredType) -> Value { - if (desiredType.isa() && + if (desiredType.isa() && operand.getType().isa()) { - return b.create(loc, desiredType, operand); + Type intType = Torch::IntType::get(b.getContext()); + Type sizeListType = Torch::ListType::get(intType); + Value size = b.create(loc, sizeListType, operand); + Value rank = b.create(loc, intType, size); + Value dtype = b.create(loc, intType, operand); + return b.create(loc, desiredType, + ArrayRef{rank, dtype}); } return operand; }; @@ -79,26 +49,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, "`dtypeFunc` should have at least one argument for each argument in " "`originalOperands`"); Type desiredType = desiredTypes.front(); - if (isTensorTypeOrWrappedTensorType(operand.getType())) { - assert(desiredTypes.size() >= 2 && - "`dtypeFunc` should have two arguments for each tensor argument " - "in `originalOperands`"); - FailureOr rankArg, dtypeArg; - if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType, - rankArgAdjuster))) - return failure(); - desiredTypes = desiredTypes.drop_front(); - desiredType = desiredTypes.front(); - if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType, - dtypeArgAdjuster))) - return failure(); - dtypeFuncArgs.append({*rankArg, *dtypeArg}); - } else { - FailureOr otherArg; - if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType))) - return failure(); - dtypeFuncArgs.push_back(*otherArg); - } + FailureOr otherArg; + if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType, + dtypeArgAdjuster))) + return failure(); + dtypeFuncArgs.push_back(*otherArg); desiredTypes = desiredTypes.drop_front(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 4f0e848d5f8d..cfa21fe152af 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -89,7 +89,8 @@ def aten〇expm1〡shape(self: List[int]) -> List[int]: Invocation(ZeroDTensorWithDtype(torch.int32)), Invocation(ZeroDTensorWithDtype(torch.bool)), ]) -def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: return self_dtype else: @@ -280,7 +281,8 @@ def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1 Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0), Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0) ]) -def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]: @@ -679,7 +681,9 @@ def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]: Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)), Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)), ]) -def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) @@ -819,6 +823,40 @@ def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optio def aten〇_convolution〇deprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) +_convolution_deprecated_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} +@check_dtype_function( + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Same type + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different type + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different width + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), # Different type and width + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.complex64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.complex128), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) +]) +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert input_dtype == weight_dtype + assert input_dtype not in [torch.bool, torch.float16, torch.complex64, torch.complex128] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]: return self @@ -1035,7 +1073,8 @@ def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)), ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)), ]) -def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: +def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype if self_dtype == torch.complex64 or self_dtype == torch.complex128: return self_dtype elif self_dtype == torch.float: diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 81c669e1abe3..550b47802e76 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -86,7 +86,7 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str: """ # Dtype functions only care about the rank and dtype of tensors. if "Tensor" in pytype: - return pytype.replace("Tensor", "int") + return pytype.replace("Tensor", "Tuple[int, int]") return _pytype_to_fn_pytype_common(pytype) def _pytype_to_decomposition_fn_pytype(pytype: str) -> str: @@ -232,8 +232,7 @@ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: default = _get_default_value(arg) parameter_name = _rename_python_keyword_parameter_name(arg["name"]) if "Tensor" in arg["pytype"]: - return ", ".join([f"{parameter_name}_rank: {pytype}{default}", - f"{parameter_name}_dtype: {pytype}{default}"]) + return f"{parameter_name}_rank_dtype: {pytype}{default}" return f"{parameter_name}: {pytype}{default}" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: @@ -241,7 +240,7 @@ def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: # results of type `number`. Here we handle this case because # `_pytype_to_dtype_fn_pytype` will replace `number` with # `Union[int, float]`. - if arg["pytype"] == "number": + if arg["pytype"] in ["number", "Tensor"]: return "int" return _pytype_to_dtype_fn_pytype(arg["pytype"]) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index 1e82a59706b7..efd270b78a7f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -96,36 +96,6 @@ def _recursively_transform_tensor_args( return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o) raise Exception(f"Unhandled type {type(o)}") -def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]: - """Converts an Invocation argument to a dtype function argument. - - TensorOfShape is replaced with two ints representing the rank - and dtype of the tensor, respectively. - """ - def contains_tensor(o: Any) -> bool: - if o is None or isinstance(o, (float, int)): - return False - if isinstance(o, TensorOfShape): - return True - if isinstance(o, (list, tuple)): - for elem in o: - if contains_tensor(elem): - return True - return False - raise Exception(f"Unhandled type {type(o)}") - - result = [] - for arg in arguments: - if contains_tensor(arg): - rank_arg = _recursively_transform_tensor_args( - arg, lambda x: len(x.shape)) - dtype_arg = _recursively_transform_tensor_args( - arg, lambda x: x.dtype) - result += [rank_arg, dtype_arg] - else: - result.append(arg) - return result - class Invocation: """Representation of a single op invocation (i.e. list of args to the op). @@ -135,8 +105,8 @@ class Invocation: Specifically, this class has special knowledge of `TensorOfShape` and translates it appropriately to either a tensor (for the real op), a - `List[int]` for the shape function, and two `int`s representing - the tensor rank and dtype in the case of a dtype function. + `List[int]` for the shape function, and a tuple with two `int`s + representing the tensor rank and dtype in the case of a dtype function. This class also tracks whether the invocation is expected to raise an exception for greater precision when interpreting errors raised during @@ -170,7 +140,9 @@ def to_shape_function_args(self): def to_dtype_function_args(self): """Gets positional arguments appropriate for a dtype function.""" - return _convert_to_dtype_function_args(self.args) + tensor_transformer = lambda o: (len(o.shape), o.dtype) + return _recursively_transform_tensor_args( + self.args, tensor_transformer) def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 8ed19c68a159..265497ddf324 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -12,7 +12,8 @@ // CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list -> !torch.int // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int -// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE:.*]] = torch.prim.TupleConstruct %[[RANK]], %[[DTYPE]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK_DTYPE]]) : (!torch.tuple) -> !torch.int // CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int // CHECK: } : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor @@ -38,6 +39,21 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // ----- +// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten._convolution.deprecated( + +// CHECK-LABEL: func.func @op_with_optional_tensor_arg$none( +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[OPTIONAL_TUPLE:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional> +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten._convolution.deprecated({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional>, {{.*}}) -> !torch.int +func.func @op_with_optional_tensor_arg$none(%input: !torch.vtensor, %weight: !torch.vtensor, %stride: !torch.list, %padding: !torch.list, %dilation: !torch.list, %transposed: !torch.bool, %output_padding: !torch.list, %groups: !torch.int) -> !torch.vtensor { + %bias_none = torch.constant.none + %false = torch.constant.bool false + %0 = torch.aten._convolution.deprecated %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %false, %false, %false : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor + return %0 : !torch.vtensor +} + +// ----- + // CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide( // CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args( @@ -46,10 +62,12 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list -> !torch.int // CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int +// CHECK: %[[RANK_DTYPE0:.*]] = torch.prim.TupleConstruct %[[RANK0]], %[[DTYPE0]] : !torch.int, !torch.int -> !torch.tuple // CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list -> !torch.int // CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int -// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE1:.*]] = torch.prim.TupleConstruct %[[RANK1]], %[[DTYPE1]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK_DTYPE0]], %[[RANK_DTYPE1]]) : (!torch.tuple, !torch.tuple) -> !torch.int func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return %0 : !torch.vtensor