diff --git a/docs/adding_abstract_interpretation_functions.md b/docs/adding_abstract_interpretation_functions.md index 68e1b98b78ac..b5e427e1adfd 100644 --- a/docs/adding_abstract_interpretation_functions.md +++ b/docs/adding_abstract_interpretation_functions.md @@ -21,7 +21,7 @@ We will use the example of adding support for the `torch.aten.tanh` op. function signatures are: - `def aten〇tanh〡shape(self: List[int]) -> List[int]:` - - `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:` + - `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:` Note the use of `〇` as a separator since `.` or `::` aren't legal in a Python identifier. diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a371e291d66a..5b91bd6532e1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -5979,31 +5979,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int5 = torch.constant.int 5\n" " %int15 = torch.constant.int 15\n" " %true = torch.constant.bool true\n" " %int7 = torch.constant.int 7\n" -" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %5 = torch.aten.eq.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " }\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %5 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " }\n" -" %3 = torch.prim.If %2 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" " } else {\n" " torch.prim.If.yield %int6 : !torch.int\n" " }\n" -" return %3 : !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" @@ -6236,13 +6237,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" -" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" -" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %3 : !torch.int\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" " %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" @@ -6972,11 +6974,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" -" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list>\n" -" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %2 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" @@ -7162,6 +7166,36 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" +" %int10 = torch.constant.int 10\n" +" %int9 = torch.constant.int 9\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int11, %int5, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7639,7 +7673,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.int, %arg4: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int4 = torch.constant.int 4\n" @@ -7654,68 +7688,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %true = torch.constant.bool true\n" " %int9 = torch.constant.int 9\n" " %0 = torch.prim.Uninitialized : !torch.int\n" -" %1 = torch.aten.eq.int %arg1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" +" %5 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" " }\n" -" %3 = torch.prim.If %2 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" " } else {\n" -" %4 = torch.aten.eq.int %arg1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" %5 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" -" %6 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" %7 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" " torch.prim.If.yield %int10 : !torch.int\n" " } else {\n" -" %8 = torch.aten.eq.int %arg1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" %9 = torch.prim.If %8 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" }\n" +" %9 = torch.aten.eq.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" " %10 = torch.prim.If %9 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" " %11 = torch.prim.If %10 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" " %12 = torch.prim.If %11 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" " %13 = torch.prim.If %12 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %15 = torch.aten.eq.int %arg1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" +" %16 = torch.aten.eq.int %1#1, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %16 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " }\n" -" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %15 = torch.prim.If %14 -> (!torch.int) {\n" " torch.prim.If.yield %int9 : !torch.int\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield %0 : !torch.int\n" " }\n" -" torch.prim.If.yield %14 : !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" " }\n" -" torch.prim.If.yield %7 : !torch.int\n" +" torch.prim.If.yield %8 : !torch.int\n" " }\n" -" torch.prim.If.yield %5 : !torch.int\n" +" torch.prim.If.yield %6 : !torch.int\n" " }\n" -" return %3 : !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index ad668ecad585..c3bfe7d8c8d8 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -714,9 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + Aten_ConvolutionOp, AtenMvOp, AtenConvolutionOverrideableOp, + AtenConvTranspose2dInputOp, AtenMseLossOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index e66cc76fe082..9eac538743b4 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -19,55 +19,25 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -static bool isTensorTypeOrWrappedTensorType(Type type) { - // Allowing tuples as arguments to dtype calculation functions can cause - // issues. For example, if an argument is a tuple of tensors and ints, there - // would be no way of differentiating the original ints from the ints created - // to represent the dtype and rank of the tensors. Therefore, to avoid this - // and keep things simple, the tuple type is not allowed. This works well in - // practice, since PyTorch op signatures don't seem to take tuples as inputs. - assert(!type.isa() && - "dtype calculation functions are expected to not have tuples of " - "tensors as arguments"); - - if (type.isa()) - return true; - - if (auto optionalType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(optionalType.getContainedType()); - } else if (auto listType = type.dyn_cast()) { - return isTensorTypeOrWrappedTensorType(listType.getContainedType()); - } else { - return false; - } -} - // Massage the op operands to match the dtype function signature. // The dtype function generally takes the same operands as the op, with a few -// systematic modifications, such as replacing tensors with a rank and dtype -// argument. +// systematic modifications, such as replacing each tensor with a tuple of +// its rank and dtype. static FailureOr> dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, ValueRange originalOperands, func::FuncOp dtypeFunc) { - // Turns a tensor operand into an operand representing the rank of the tensor - auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand, - Type desiredType) -> Value { - if (desiredType.isa() && - operand.getType().isa()) { - auto sizeListType = - Torch::ListType::get(Torch::IntType::get(b.getContext())); - Value size = b.create(loc, sizeListType, operand); - return b.create(loc, desiredType, size); - } - return operand; - }; - - // Turns a tensor operand into an operand representing the dtype of the tensor + // Turn every tensor into a tuple of (tensor_rank, tensor_dtype) auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand, Type desiredType) -> Value { - if (desiredType.isa() && + if (desiredType.isa() && operand.getType().isa()) { - return b.create(loc, desiredType, operand); + Type intType = Torch::IntType::get(b.getContext()); + Type sizeListType = Torch::ListType::get(intType); + Value size = b.create(loc, sizeListType, operand); + Value rank = b.create(loc, intType, size); + Value dtype = b.create(loc, intType, operand); + return b.create(loc, desiredType, + ArrayRef{rank, dtype}); } return operand; }; @@ -79,26 +49,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, "`dtypeFunc` should have at least one argument for each argument in " "`originalOperands`"); Type desiredType = desiredTypes.front(); - if (isTensorTypeOrWrappedTensorType(operand.getType())) { - assert(desiredTypes.size() >= 2 && - "`dtypeFunc` should have two arguments for each tensor argument " - "in `originalOperands`"); - FailureOr rankArg, dtypeArg; - if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType, - rankArgAdjuster))) - return failure(); - desiredTypes = desiredTypes.drop_front(); - desiredType = desiredTypes.front(); - if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType, - dtypeArgAdjuster))) - return failure(); - dtypeFuncArgs.append({*rankArg, *dtypeArg}); - } else { - FailureOr otherArg; - if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType))) - return failure(); - dtypeFuncArgs.push_back(*otherArg); - } + FailureOr otherArg; + if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType, + dtypeArgAdjuster))) + return failure(); + dtypeFuncArgs.push_back(*otherArg); desiredTypes = desiredTypes.drop_front(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 4f0e848d5f8d..cfa21fe152af 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -89,7 +89,8 @@ def aten〇expm1〡shape(self: List[int]) -> List[int]: Invocation(ZeroDTensorWithDtype(torch.int32)), Invocation(ZeroDTensorWithDtype(torch.bool)), ]) -def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: +def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: return self_dtype else: @@ -280,7 +281,8 @@ def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1 Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0), Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0) ]) -def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]: @@ -679,7 +681,9 @@ def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]: Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)), Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)), ]) -def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int: +def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype ranks: List[Optional[int]] = [self_rank, other_rank] dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) @@ -819,6 +823,40 @@ def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optio def aten〇_convolution〇deprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) +_convolution_deprecated_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} +@check_dtype_function( + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Same type + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different type + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different width + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), # Different type and width + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.complex64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.complex128), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) +]) +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert input_dtype == weight_dtype + assert input_dtype not in [torch.bool, torch.float16, torch.complex64, torch.complex128] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]: return self @@ -1035,7 +1073,8 @@ def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)), ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)), ]) -def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: +def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype if self_dtype == torch.complex64 or self_dtype == torch.complex128: return self_dtype elif self_dtype == torch.float: diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 81c669e1abe3..550b47802e76 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -86,7 +86,7 @@ def _pytype_to_dtype_fn_pytype(pytype: str) -> str: """ # Dtype functions only care about the rank and dtype of tensors. if "Tensor" in pytype: - return pytype.replace("Tensor", "int") + return pytype.replace("Tensor", "Tuple[int, int]") return _pytype_to_fn_pytype_common(pytype) def _pytype_to_decomposition_fn_pytype(pytype: str) -> str: @@ -232,8 +232,7 @@ def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: default = _get_default_value(arg) parameter_name = _rename_python_keyword_parameter_name(arg["name"]) if "Tensor" in arg["pytype"]: - return ", ".join([f"{parameter_name}_rank: {pytype}{default}", - f"{parameter_name}_dtype: {pytype}{default}"]) + return f"{parameter_name}_rank_dtype: {pytype}{default}" return f"{parameter_name}: {pytype}{default}" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: @@ -241,7 +240,7 @@ def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: # results of type `number`. Here we handle this case because # `_pytype_to_dtype_fn_pytype` will replace `number` with # `Union[int, float]`. - if arg["pytype"] == "number": + if arg["pytype"] in ["number", "Tensor"]: return "int" return _pytype_to_dtype_fn_pytype(arg["pytype"]) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index 1e82a59706b7..efd270b78a7f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -96,36 +96,6 @@ def _recursively_transform_tensor_args( return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o) raise Exception(f"Unhandled type {type(o)}") -def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]: - """Converts an Invocation argument to a dtype function argument. - - TensorOfShape is replaced with two ints representing the rank - and dtype of the tensor, respectively. - """ - def contains_tensor(o: Any) -> bool: - if o is None or isinstance(o, (float, int)): - return False - if isinstance(o, TensorOfShape): - return True - if isinstance(o, (list, tuple)): - for elem in o: - if contains_tensor(elem): - return True - return False - raise Exception(f"Unhandled type {type(o)}") - - result = [] - for arg in arguments: - if contains_tensor(arg): - rank_arg = _recursively_transform_tensor_args( - arg, lambda x: len(x.shape)) - dtype_arg = _recursively_transform_tensor_args( - arg, lambda x: x.dtype) - result += [rank_arg, dtype_arg] - else: - result.append(arg) - return result - class Invocation: """Representation of a single op invocation (i.e. list of args to the op). @@ -135,8 +105,8 @@ class Invocation: Specifically, this class has special knowledge of `TensorOfShape` and translates it appropriately to either a tensor (for the real op), a - `List[int]` for the shape function, and two `int`s representing - the tensor rank and dtype in the case of a dtype function. + `List[int]` for the shape function, and a tuple with two `int`s + representing the tensor rank and dtype in the case of a dtype function. This class also tracks whether the invocation is expected to raise an exception for greater precision when interpreting errors raised during @@ -170,7 +140,9 @@ def to_shape_function_args(self): def to_dtype_function_args(self): """Gets positional arguments appropriate for a dtype function.""" - return _convert_to_dtype_function_args(self.args) + tensor_transformer = lambda o: (len(o.shape), o.dtype) + return _recursively_transform_tensor_args( + self.args, tensor_transformer) def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 8ed19c68a159..265497ddf324 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -12,7 +12,8 @@ // CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list -> !torch.int // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int -// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE:.*]] = torch.prim.TupleConstruct %[[RANK]], %[[DTYPE]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.expm1(%[[RANK_DTYPE]]) : (!torch.tuple) -> !torch.int // CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int // CHECK: } : !torch.vtensor // CHECK: return %[[RESULT:.*]] : !torch.vtensor @@ -38,6 +39,21 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // ----- +// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten._convolution.deprecated( + +// CHECK-LABEL: func.func @op_with_optional_tensor_arg$none( +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[OPTIONAL_TUPLE:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional> +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten._convolution.deprecated({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional>, {{.*}}) -> !torch.int +func.func @op_with_optional_tensor_arg$none(%input: !torch.vtensor, %weight: !torch.vtensor, %stride: !torch.list, %padding: !torch.list, %dilation: !torch.list, %transposed: !torch.bool, %output_padding: !torch.list, %groups: !torch.int) -> !torch.vtensor { + %bias_none = torch.constant.none + %false = torch.constant.bool false + %0 = torch.aten._convolution.deprecated %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %false, %false, %false : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor + return %0 : !torch.vtensor +} + +// ----- + // CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide( // CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args( @@ -46,10 +62,12 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list -> !torch.int // CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int +// CHECK: %[[RANK_DTYPE0:.*]] = torch.prim.TupleConstruct %[[RANK0]], %[[DTYPE0]] : !torch.int, !torch.int -> !torch.tuple // CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list // CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list -> !torch.int // CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int -// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int +// CHECK: %[[RANK_DTYPE1:.*]] = torch.prim.TupleConstruct %[[RANK1]], %[[DTYPE1]] : !torch.int, !torch.int -> !torch.tuple +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK_DTYPE0]], %[[RANK_DTYPE1]]) : (!torch.tuple, !torch.tuple) -> !torch.int func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { %0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return %0 : !torch.vtensor