From 2f45959f0d834c3786bde31d93edd4c323418087 Mon Sep 17 00:00:00 2001 From: Shivam Gupta Date: Wed, 28 Dec 2022 08:51:33 +0530 Subject: [PATCH] Prelu lowering to linalg (#1712) Prelu lowering to linalg --- e2e_testing/xfail_sets.py | 1 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++++++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 26 +++++++++++++++++-- .../Transforms/AbstractInterpLibrary.cpp | 4 +++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- .../build_tools/abstract_interp_lib_gen.py | 3 +++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 22 ++++++++++++++++ 8 files changed, 80 insertions(+), 3 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 06a620e9f7c6..37786e82426a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -763,6 +763,7 @@ "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", "ElementwiseClampModule_basic", + "ElementwisePreluModule_basic", "IouOfModule_basic", "MobilenetV3Module_basic", "NativeBatchNormNoneWeightModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index af603e24923c..6bc69127dea2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3441,6 +3441,30 @@ def Torch_AtenSoftplusOp : Torch_Op<"aten.softplus", [ }]; } +def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::prelu : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$weight + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPreluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPreluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUniformOp : Torch_Op<"aten.uniform", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8c94cdcb39bb..ba6e21276ae4 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -301,6 +301,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, negativePart, scale); return b.create(loc, positivePart, scaledNegativePart); } + if (auto prelu = dyn_cast(op)) { + if (!prelu.getType() + .cast() + .getDtype() + .isa()) { + prelu.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + Type elementType = payloadArgs[0].getType(); + Value constZero = + b.create(loc, b.getZeroAttr(elementType)); + Value pred = b.create(loc, arith::CmpFPredicate::UGT, + payloadArgs[0], constZero); + Value positivePart = + b.create(loc, pred, payloadArgs[0], constZero); + Value negativePart = + b.create(loc, pred, constZero, payloadArgs[0]); + Value scale = convertScalarToDtype(b, loc, payloadArgs[1], elementType); + Value scaledNegativePart = + b.create(loc, negativePart, scale); + return b.create(loc, positivePart, scaledNegativePart); + } if (auto gelu = dyn_cast(op)) { if (!gelu.getType() .cast() @@ -1054,7 +1076,7 @@ class ConvertElementwiseOp : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.prelu\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.gather\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 3f3f37154d0f..e7adf792cf84 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -663,7 +663,7 @@ void TypeAnalysis::visitOperation(Operation *op, AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp, - AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, + AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenPreluOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index e135c2f91fae..60dde6857a87 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -274,6 +274,9 @@ def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[i def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_grad: bool = False) -> List[int]: return upstream_shape_functions.unary(index) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index ab836c779490..e6b042b12b7a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -323,6 +323,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") + emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") # Random number generation emit_with_mutating_variants("aten::uniform : (Tensor, float, float, Generator?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index c1ab3d1404ce..8a72db157353 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -413,6 +413,28 @@ def ElementwiseLeakyReluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwisePreluModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, weight): + return torch.ops.aten.prelu(x, weight) + +@register_test_case(module_factory=lambda: ElementwisePreluModule()) +def ElementwisePreluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2, 1, low=-1, high=1), tu.rand(1) ) + + +# ============================================================================== + + class ElementwiseGeluModule(torch.nn.Module): def __init__(self):