Skip to content

Commit

Permalink
[MLIR][TORCH] Add e2e support for aten.as_stride
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Dec 22, 2022
1 parent 9dc09ac commit 89ec255
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 2 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7692,6 +7692,31 @@ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
}];
}

def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$storage_offset
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6094,6 +6094,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.expand\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp,
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp,
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenAsStridedOp,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp>(op);
AtenNarrowOp, AtenToDeviceOp, AtenAsStridedOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
return size

def aten〇expand〡shape(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
return upstream_shape_functions.expand(self, size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ def emit_with_mutating_variants(key, **kwargs):

# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# to the backend contract.
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"AsStridedStaticModule_basic",
}

def register_all_tests():
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,28 @@ def ViewNoChangeStaticModule_basic(module, tu: TestUtils):

# ==============================================================================


class AsStridedStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 3], torch.float32, True),
])

def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (1, 2), 1)

@register_test_case(module_factory=lambda: AsStridedStaticModule())
def AsStridedStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3))


# ==============================================================================


class ReshapeAliasExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 89ec255

Please sign in to comment.