Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-dialect] emit aten.as_strided op #2280

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@
}

STABLEHLO_PASS_SET = {
"AsStridedStaticModule_basic",
"AliasModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
Expand Down Expand Up @@ -810,6 +811,7 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AsStridedStaticModule_basic",
"AliasModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"ConstantBoolParameterModule_basic",
Expand Down
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8028,6 +8028,32 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
let hasFolder = 1;
}

def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$storage_offset
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
46 changes: 46 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,52 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
return getOperand(0);
}

//===----------------------------------------------------------------------===//
// AtenAsStrideOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenAsStridedOp::fold(FoldAdaptor adaptor) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
return nullptr;

auto outType = getType().dyn_cast<BaseTensorType>();
if (!outType || !outType.hasSizes() || !outType.areAllSizesKnown())
return nullptr;

int64_t storageOffset;
if (!getStorageOffset().getType().isa<Torch::NoneType>()) {
if (!matchPattern(getStorageOffset(), m_TorchConstantInt(&storageOffset)) ||
storageOffset != 0)
return nullptr;
}

// Check if the shapes of input tensor and output tensor are totally same.
ArrayRef<int64_t> inputSizes = inputType.getSizes();
ArrayRef<int64_t> outSizes = outType.getSizes();
if (inputSizes.size() != outSizes.size())
return nullptr;
for (int i = 0, e = inputSizes.size(); i < e; ++i) {
if (inputSizes[i] != outSizes[i])
return nullptr;
}

// Check if the elements of output tensor are fetched sequentially from input
// tensor's storage.
SmallVector<int64_t> strides;
if (!matchPattern(getStride(), m_TorchListOfConstantInts(strides)))
return nullptr;

if (strides.size() != inputSizes.size() || strides[strides.size() - 1] != 1)
return nullptr;
for (int i = inputSizes.size() - 2; i >= 0; --i) {
if (strides[i] != inputSizes[i + 1] * strides[i + 1])
return nullptr;
}

return getOperand(0);
}

//===----------------------------------------------------------------------===//
// PrimsViewOfOp
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6824,6 +6824,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.resize_\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8520,6 +8523,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.roll\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
118 changes: 118 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4585,6 +4585,123 @@ class DecomposeAtenSignOp : public OpRewritePattern<AtenSignOp> {
};
} // namespace

namespace {
// Decompose `aten.as_strided` into `aten.flatten.using_ints`, `aten.gather` and
// `aten.view`.
// This decomposition is a little hacky. Since aten.as_strided is a
// view-like op, while aten.gather is not. But it might be okay
// in torch-mlir, since we've already assumed view-like ops to be of value
// semantics.
class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAsStridedOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto inputType = op.getSelf().getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown()) {
return rewriter.notifyMatchFailure(
op, "only handle input tensor with static shape information");
}
ArrayRef<int64_t> inputSizes = inputType.getSizes();

SmallVector<int64_t> outSizes;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outSizes))) {
return rewriter.notifyMatchFailure(
op, "out size must be a list of constant integers");
}
SmallVector<int64_t> strides;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strides))) {
return rewriter.notifyMatchFailure(
op, "strides must be a list of constant integers");
}
if (strides.size() != outSizes.size()) {
return rewriter.notifyMatchFailure(op,
"the stride size is expected to be "
"the same as that of output tensor");
}
int64_t storageOffset = 0;
if (!op.getStorageOffset().getType().isa<Torch::NoneType>() &&
!matchPattern(op.getStorageOffset(),
m_TorchConstantInt(&storageOffset))) {
return rewriter.notifyMatchFailure(
op, "storage offset must be a constant integer");
}

int64_t inputTotalSize = 1;
for (auto inputDimSize : inputSizes) {
inputTotalSize *= inputDimSize;
}

auto flattenInputType = inputType.getWithSizesAndDtype(
SmallVector<int64_t>(1, inputTotalSize), inputType.getOptionalDtype());
Value startDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputSizes.size() - 1));
Value flattenInput = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenInputType, op.getSelf(), startDim, endDim);

int64_t outTotalSize = 1;
for (auto outDimSize : outSizes) {
outTotalSize *= outDimSize;
}

// Generate index for gather op;
DenseSet<int64_t> visitedStoragePlaces;
SmallVector<int64_t> gatherIndicies;
for (int64_t i = 0; i < outTotalSize; ++i) {
int64_t newI = i;
int64_t index = 0;
for (int64_t d = outSizes.size() - 1; d >= 0; --d) {
index += newI % outSizes[d] * strides[d];
newI /= outSizes[d];
}

// We can not handle the situation when the view is "overlapped"
// Ref: https://pytorch.org/docs/stable/generated/torch.as_strided.html
if (visitedStoragePlaces.find(index) != visitedStoragePlaces.end()) {
return rewriter.notifyMatchFailure(
op, "multiple indices of new tensor are mapped to the same storage "
"location");
}

visitedStoragePlaces.insert(index);
gatherIndicies.push_back(index + storageOffset);
}
auto gatherOpType = inputType.getWithSizesAndDtype(
{outTotalSize}, inputType.getOptionalDtype());
Value gatherDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));

Value gatherIndex = rewriter.create<Torch::ValueTensorLiteralOp>(
loc,
inputType.getWithSizesAndDtype({outTotalSize},
rewriter.getIntegerType(64, true)),
DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(outTotalSize)},
rewriter.getIntegerType(64, true)),
gatherIndicies));
Value sparseGrad = rewriter.create<Torch::ConstantBoolOp>(loc, false);
auto gatherOp = rewriter.create<Torch::AtenGatherOp>(
loc, gatherOpType, flattenInput, gatherDim, gatherIndex, sparseGrad);

SmallVector<Value> viewSizes;
for (auto outDimSize : outSizes) {
viewSizes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(outDimSize)));
}
auto viewSizesListConstructOp = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
ValueRange(viewSizes));

rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), gatherOp,
viewSizesListConstructOp);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -4754,6 +4871,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTopkOp>();
target.addIllegalOp<AtenScalarTensorOp>();
target.addIllegalOp<AtenScatterValueOp>();
target.addIllegalOp<AtenAsStridedOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenViewAsComplexOp>(op);
PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenViewAsComplexOp,
AtenAsStridedOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ def aten〇_unsafe_view〡shape(self: List[int], size: List[int]) -> List[int]:
def aten〇resize_〡shape(self: List[int], size: List[int], memory_format: Optional[int] = None) -> List[int]:
return size

def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
return size

def aten〇max_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> List[int]:
return upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)

Expand Down Expand Up @@ -1732,6 +1735,11 @@ def aten〇resize_〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], me
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]))
def aten〇as_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shifts=[0], dims=[0]))
def aten〇roll〡dtype(self_rank_dtype: Tuple[int, int], shifts: List[int], dims: List[int] = ()) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True)
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)", has_folder=True)
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
Expand Down
21 changes: 20 additions & 1 deletion python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,4 +710,23 @@ def forward(self, a):

@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
module.forward(tu.rand(2, 4))

# ==============================================================================

class AsStridedStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 3], torch.float32, True),
])

def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (1, 2))

@register_test_case(module_factory=lambda: AsStridedStaticModule())
def AsStridedStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3))
26 changes: 26 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2065,3 +2065,29 @@ func.func @torch.aten.add$fold() -> !torch.float {
%0 = torch.aten.add %float1, %float2 : !torch.float, !torch.float -> !torch.float
return %0 : !torch.float
}

// CHECK-LABEL: func.func @torch.aten.as_strided$none_storage_offset(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[3,3],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[3,3],f32>
func.func @torch.aten.as_strided$none_storage_offset(%arg0: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[3,3],f32> {
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
return %2 : !torch.vtensor<[3,3],f32>
}

// CHECK-LABEL: func.func @torch.aten.as_strided$zero_storage_offset(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[3,3],f32> {
// CHECK: return %[[ARG]] : !torch.vtensor<[3,3],f32>
func.func @torch.aten.as_strided$zero_storage_offset(%arg0: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[3,3],f32> {
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[3,3],f32>
return %2 : !torch.vtensor<[3,3],f32>
}