Skip to content

Commit

Permalink
[MLIR][TORCH] Add e2e support for aten.as_stride
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Jan 3, 2023
1 parent a88e376 commit 9e031e6
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 4 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic",
"UniformModule_basic",
"AsStridedStaticModule_basic",
# error: failed to materialize conversion for result #0 of operation 'torch.aten.t' that remained live after conversion
"TModuleRank1_basic",
# error: unsupported by backend contract: tensor with unknown rank
Expand Down Expand Up @@ -590,6 +591,7 @@
"TypePromotionSameCategoryDifferentWidthModule_basic",
"TypePromotionZeroRankHigherCategoryModule_basic",
"GatherStaticModule_basic",
"AsStridedStaticModule_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic",
Expand Down
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7742,6 +7742,31 @@ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
}];
}

def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$storage_offset
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
85 changes: 85 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -3093,6 +3094,89 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
}

template <>
LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
AtenAsStridedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// The algorithm is commented by example: torch.as_strided(x, (2, 2), (1, 2),
// 1) Not a tensor type.
auto input = adaptor.getSelf();
auto inputTensorType =
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");

SmallVector<int64_t> outputSize;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outputSize)))
return rewriter.notifyMatchFailure(
op, "Non-const size for as_strided unsupported.");

SmallVector<int64_t> strides;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strides)))
return rewriter.notifyMatchFailure(
op, "Non-const strides for as_strided unsupported.");

int64_t storageOffset;
if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&storageOffset)))
return rewriter.notifyMatchFailure(
op, "storageOffset must be a Scalar constant");

auto outType = getTypeConverter()->convertType(op.getType());
auto resultType = outType.dyn_cast<ShapedType>();

int outputNums = 1, inputNums = 1;
auto inputTensorRank = inputTensorType.getShape().size(); // 3
// inputNums: total number of inputs
// flattened inputNums = 1 * 3*3 = 9
for (size_t i = 0; i < inputTensorRank; i++) {
inputNums *= inputTensorType.getShape()[i];
}
// outputNums: number of values in output tensor
for (size_t i = 0; i < outputSize.size(); i++) {
outputNums *= outputSize[i];
}
SmallVector<int64_t, 3> gatherValuesShape({1, inputNums, 1}); // {1,9,1}
SmallVector<int64_t, 2> gatherIndicesShape({1, outputNums}); // {1,4}
SmallVector<int64_t, 3> gatherResultShape({1, outputNums, 1}); // {1,4,1}

// create gather indices:
int32_t flattenedStrideLength{1};
for (size_t i = 0; i < strides.size(); i++) {
flattenedStrideLength *= strides[i];
}
SmallVector<int32_t> flattenedIndices;
for (int i = 0; i < outputNums / 2; i++) {
flattenedIndices.push_back(storageOffset + i);
flattenedIndices.push_back(storageOffset + i + flattenedStrideLength);
}
auto flattenedIndicesValue = tosa::getConstTensor<int32_t>(
rewriter, op, flattenedIndices, llvm::makeArrayRef(gatherIndicesShape));

if (!flattenedIndicesValue)
return rewriter.notifyMatchFailure(op, "Fail to create flatten indices");

// {3,3} -> {1,9,1}
auto gatherValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(gatherValuesShape,
inputTensorType.getElementType()),
input, rewriter.getI64ArrayAttr(gatherValuesShape));

// {1,9,1}, {1,4} -> {1,4}
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(gatherResultShape, resultType.getElementType()),
gatherValuesReshapeOp.getResult(), flattenedIndicesValue.value());

// {1,4} -> {2,2}
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, outType, tosaGatherOp.getResult(),
rewriter.getI64ArrayAttr(resultType.getShape()));

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
AtenGatherOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4130,6 +4214,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6115,6 +6115,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.expand\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,10 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenPreluOp,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenAsStridedOp,
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) {
AtenUpsampleNearest2dBackwardOp, AtenTanhOp, AtenLeakyReluBackwardOp>(
op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

Expand Down Expand Up @@ -782,7 +783,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote LHS with scalar RHS.
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
AtenLeakyReluOp, AtenRemainderScalarOp, AtenRsubScalarOp>(op)) {
auto lhs = operands[0]->getValue();
Value scalar = op->getOperand(1);
auto knowledge =
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp>(op);
AtenNarrowOp, AtenToDeviceOp, AtenAsStridedOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
def aten〇roll〡shape(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
return size

def aten〇expand〡shape(self: List[int], size: List[int], implicit: bool = False) -> List[int]:
return upstream_shape_functions.expand(self, size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def emit_with_mutating_variants(key, **kwargs):

# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# to the backend contract.
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"AsStridedStaticModule_basic",
}

def register_all_tests():
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,28 @@ def ViewNoChangeStaticModule_basic(module, tu: TestUtils):

# ==============================================================================


class AsStridedStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 3], torch.float32, True),
])

def forward(self, x):
return torch.ops.aten.as_strided(x, (2, 2), (1, 2), 1)

@register_test_case(module_factory=lambda: AsStridedStaticModule())
def AsStridedStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3))


# ==============================================================================


class ReshapeAliasExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
26 changes: 26 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,32 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten
return %0 : !torch.vtensor<[3,5],i1>
}

// -----
// CHECK-LABEL: func.func @torch.aten.as_strided(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,2],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() {value = dense<{{\[\[}}0, 2, 1, 3]]> : tensor<1x4xi32>} : () -> tensor<1x4xi32>
// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [1, 9, 1]} : (tensor<3x3xf32>) -> tensor<1x9x1xf32>
// CHECK: %[[VAL_9:.*]] = "tosa.gather"(%[[VAL_8]], %[[VAL_7]]) : (tensor<1x9x1xf32>, tensor<1x4xi32>) -> tensor<1x4x1xf32>
// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_9]]) {new_shape = [2, 2]} : (tensor<1x4x1xf32>) -> tensor<2x2xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,2],f32>
// CHECK: }
func.func @torch.aten.as_strided(%arg0: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,2],f32> {
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[2,2],f32>
return %2 : !torch.vtensor<[2,2],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.gather(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
Expand Down

0 comments on commit 9e031e6

Please sign in to comment.