diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ffc9a6dbb74f..630c6e9abc3b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1538,6 +1538,51 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ }]; } +def Torch_AtenFracOp : Torch_Op<"aten.frac", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::frac : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFracOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFracOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenFrac_Op : Torch_Op<"aten.frac_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::frac_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFrac_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFrac_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ AllowsTypeRefinement, HasValueSemantics, @@ -3455,6 +3500,53 @@ def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [ }]; } +def Torch_AtenCopysignTensorOp : Torch_Op<"aten.copysign.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::copysign.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCopysignTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCopysignTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCopysign_TensorOp : Torch_Op<"aten.copysign_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::copysign_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCopysign_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCopysign_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement, ReadOnly @@ -3905,6 +3997,53 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ }]; } +def Torch_AtenLdexpTensorOp : Torch_Op<"aten.ldexp.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ldexp.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLdexpTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLdexpTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenSignbitOp : Torch_Op<"aten.signbit", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::signbit : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSignbitOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSignbitOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index ead29d59a59e..fd1a1a11e552 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9349,6 +9349,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.frac\"(%arg0: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.signbit\"(%arg0: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.ldexp.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.copysign.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.__and__.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -13215,6 +13229,93 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frac\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.signbit\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ldexp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %false = torch.constant.bool false\n" +" %int7 = torch.constant.int 7\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %10 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %13 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %13 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %13 : !torch.int\n" +" }\n" +" torch.prim.If.yield %12 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.copysign.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %false = torch.constant.bool false\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.__and__.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 004aaa5a77e5..46feee41b3c2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8012,6 +8012,122 @@ class DecomposeAtenTruncOp : public OpRewritePattern { }; } // namespace +namespace { +// decompose `signbit(x)` to `view.dtype(x, si32/si64) < 0 ` +class DecomposeAtenSignbitOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSignbitOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + auto operandTy = dyn_cast(self.getType()); + auto resultTy = dyn_cast(op.getType()); + if (!operandTy || !operandTy.hasDtype() || !resultTy || + !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, + "operand and result must have dtype"); + } + + if (isa(operandTy.getDtype())) { + mlir::IntegerType intType = rewriter.getIntegerType( + operandTy.getDtype().getIntOrFloatBitWidth(), /*isSigned*/ true); + Value dtype = getDtypeIntValueForType(rewriter, loc, intType); + Value view = rewriter.create( + loc, + operandTy.getWithSizesAndDtype(operandTy.getOptionalSizes(), intType), + self, dtype); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value shift = rewriter.create(loc, resultTy, view, zero); + rewriter.replaceOp(op, shift); + return success(); + } else if (isa(operandTy.getDtype())) { + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value shift = rewriter.create(loc, resultTy, self, zero); + rewriter.replaceOp(op, shift); + } + return failure(); + } +}; +} // namespace + +namespace { +// decompose `frac(x)` to `x - trunc(x)` +class DecomposeAtenFracOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFracOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto resultTy = op.getType(); + + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value trunc = rewriter.create(loc, resultTy, self); + rewriter.replaceOpWithNewOp(op, resultTy, self, trunc, + /*alpha=*/one); + return success(); + } +}; +} // namespace + +namespace { +// decompose `copysign(x, y)` to `signbit(y) ? -abs(x) : abs(x)` +class DecomposeAtenCopysignTensorOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenCopysignTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + auto selfTy = self.getType(); + auto otherTy = cast(other.getType()); + auto resultTy = op.getType(); + + Value signbit = rewriter.create( + loc, + otherTy.getWithSizesAndDtype(otherTy.getOptionalSizes(), + rewriter.getI1Type()), + other); + Value abs = rewriter.create(loc, selfTy, self); + Value neg = rewriter.create(loc, selfTy, abs); + rewriter.replaceOpWithNewOp(op, resultTy, signbit, neg, + abs); + return success(); + } +}; +} // namespace + +namespace { +// decompose `ldexp(x, y)` to `x * 2^y` +class DecomposeAtenLdexpTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLdexpTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + + auto otherTy = dyn_cast(other.getType()); + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + Value exp2 = rewriter.create( + loc, + resultTy.getWithSizesAndDtype(otherTy.getOptionalSizes(), + resultTy.getDtype()), + other); + rewriter.replaceOpWithNewOp(op, resultTy, self, exp2); + return success(); + } +}; +} // namespace + namespace { // decompose `fmod(x, y)` to `x - trunc(x/y) * y` class DecomposeAtenFmodTensorOp : public OpRewritePattern { @@ -9159,10 +9275,16 @@ class DecomposeAtenExp2Op : public OpRewritePattern { Location loc = op.getLoc(); Value self = op.getSelf(); + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + auto two = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp(op, op.getType(), two, self); - + Value to = convertTensorToDtype(rewriter, loc, self, resultTy.getDtype()); + Value pow = rewriter.create(loc, resultTy, two, to); + rewriter.replaceOp(op, pow); return success(); } }; @@ -10263,6 +10385,10 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 86ea382fe8b6..4bca74470772 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -541,6 +541,10 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 90479cf7f0a0..fefd7f279b26 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -530,6 +530,8 @@ "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "ElementwiseSignbitModule_basic", + "ElementwiseCopysignModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -3272,6 +3274,12 @@ "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic", + "ElementwiseSignbitModule_basic", + "ElementwiseSignbitIntModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseCopysignModule_basic", + "ElementwiseLdexpModule_basic", + "Exp2StaticIntModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 06437574d8f0..e53b60cbc7f8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1550,6 +1550,18 @@ def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇atan2〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇frac〡shape(self: List[int]) -> List[int]: + return self + +def aten〇signbit〡shape(self: List[int]) -> List[int]: + return self + +def aten〇ldexp〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + +def aten〇copysign〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇__and__〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -3910,6 +3922,42 @@ def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[ self_rank, self_dtype = self_rank_dtype return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇frac〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇signbit〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇ldexp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + if self_dtype == torch.double and is_complex_dtype(other_dtype): + return other_dtype + elif is_complex_dtype(self_dtype) and other_dtype == torch.double: + return self_dtype + elif is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype): + return torch.float + else: + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇copysign〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + if is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype): + return torch.float + else: + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index b02b3a776e3a..4b8c2d0609dd 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -329,6 +329,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::atanh : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", + "aten::frac : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", @@ -370,6 +371,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::zero : (Tensor) -> (Tensor)", "aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::copysign.Tensor : (Tensor, Tensor) -> (Tensor)", ]: emit_with_mutating_variants(key) # Shape manipulations: @@ -418,6 +420,12 @@ def emit_with_mutating_variants(key, **kwargs): has_canonicalizer=True, has_folder=True, ) + emit( + "aten::ldexp.Tensor : (Tensor, Tensor) -> (Tensor)", + ) + emit( + "aten::signbit : (Tensor) -> (Tensor)", + ) emit_with_mutating_variants( "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 88a269a09f38..6a591c483e82 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2839,6 +2839,130 @@ def ElementwiseTruncIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSignbitModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 8], torch.float32, True), + ] + ) + def forward(self, a): + return torch.signbit(a) + + +@register_test_case(module_factory=lambda: ElementwiseSignbitModule()) +def ElementwiseSignbitModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor( + [[-torch.inf, torch.inf, torch.nan, -torch.nan, 2.3, -2.3, 0.0, -0.0]] + ) + ) + + +class ElementwiseSignbitIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ] + ) + def forward(self, a): + return torch.signbit(a) + + +@register_test_case(module_factory=lambda: ElementwiseSignbitIntModule()) +def ElementwiseSignbitIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseFracModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 6], torch.float32, True), + ] + ) + def forward(self, a): + return torch.frac(a) + + +@register_test_case(module_factory=lambda: ElementwiseFracModule()) +def ElementwiseFracModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[2.3, -2.3, 0.0, -0.0, 2.0, -2.0]])) + + +# ============================================================================== + + +class ElementwiseCopysignModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 1], torch.float32, True), + ([1, 6], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.copysign(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseCopysignModule()) +def ElementwiseCopysignModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([[1.0]]), + torch.tensor([[2.3, -2.3, 0.0, -0.0, torch.inf, -torch.inf]]), + ) + + +# ============================================================================== + + +class ElementwiseLdexpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 6], torch.float32, True), + ([1, 1], torch.int64, True), + ] + ) + def forward(self, a, b): + return torch.ldexp(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseLdexpModule()) +def ElementwiseLdexpModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([[2.3, -2.3, 0.0, -0.0, 4.5, -4.5]]), + torch.tensor([[2]]), + ) + + +# ============================================================================== + + class ElementwiseSignModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2930,6 +3054,26 @@ def Exp2StaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2)) +class Exp2StaticIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.int64, True), + ] + ) + def forward(self, x): + return torch.ops.aten.exp2(x) + + +@register_test_case(module_factory=lambda: Exp2StaticIntModule()) +def Exp2StaticIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-20, high=20)) + + # ==============================================================================