diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 253281b62538f..7b5530b1a20cc 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7692,6 +7692,31 @@ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ }]; } +def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$storage_offset + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenAsStridedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 583a6ed5c7094..1e58f8c01b547 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6094,6 +6094,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.expand\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 6ef4c7d0984d6..dcebd8438daf0 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -663,7 +663,7 @@ void TypeAnalysis::visitOperation(Operation *op, AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp, - AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, + AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenAsStridedOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 1d67b24e871c2..f718750f0f68d 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -199,7 +199,7 @@ bool Torch::isViewLikeOp(Operation *op) { AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, - AtenNarrowOp, AtenToDeviceOp>(op); + AtenNarrowOp, AtenToDeviceOp, AtenAsStridedOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 98edcb7902b35..9101285fce239 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -429,6 +429,9 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]: def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: + return size + def aten〇expand〡shape(self: List[int], size: List[int], implicit: bool = False) -> List[int]: return upstream_shape_functions.expand(self, size) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a9b5cb91640a6..74f9c60a27d9d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -515,6 +515,7 @@ def emit_with_mutating_variants(key, **kwargs): # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") + emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index c107ceb375619..16f09bafbd8bd 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -8,6 +8,7 @@ # to the backend contract. COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", + "AsStridedStaticModule_basic", } def register_all_tests(): diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 7ac4be9e4fc50..fbd811ad20508 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -657,6 +657,28 @@ def ViewNoChangeStaticModule_basic(module, tu: TestUtils): # ============================================================================== + +class AsStridedStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 3], torch.float32, True), + ]) + + def forward(self, x): + return torch.ops.aten.as_strided(x, (2, 2), (1, 2), 1) + +@register_test_case(module_factory=lambda: AsStridedStaticModule()) +def AsStridedStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + + +# ============================================================================== + + class ReshapeAliasExpandModule(torch.nn.Module): def __init__(self): super().__init__()