Skip to content

Commit

Permalink
Add Op for torch.aten.unfold (#3772)
Browse files Browse the repository at this point in the history
# Description

Implementation of the op for `torch.aten.unfold`: [TorchToLinalg Op
Support #347](nod-ai/SHARK-ModelDev#849)

Documentation of op can be found here: [PyTorch
Docs](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)

For this op, we apply a sliding window of some `size` along a single
`dimension`, with `step` in between iterations.

`Declaration: aten::unfold(Tensor(a) self, int dimension, int size, int
step) -> Tensor(a)`

The resulting `unfolded` tensor modifies the shape of `dimension` to be
equal to the number of blocks that the sliding windows extracts/inserts,
with an additional dimension of `size` appended (the number of cols of
the output tensor directly translates from the size of the sliding
window).

So if we had a tensor of rank 3 (A x B x C), with dimension = 1, size =
2 and step = 2:

    (A x B x C) |=> (A x (B - size) // step + 1 x C x size)

After extracting the window from the input tensor, we insert the (1 x
size) slice into the output tensor. We can make this simpler by mapping
the output indices from the input indices, like they do in the official
implementation:

[PyTorch
Code](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/lowering.py#L1694)
  • Loading branch information
stbaione authored Oct 8, 2024
1 parent 7830c00 commit d49eabb
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 2 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13692,6 +13692,31 @@ def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [
}];
}

def Torch_AtenUnfoldOp : Torch_Op<"aten.unfold", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::unfold : (Tensor, int, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dimension,
Torch_IntType:$size,
Torch_IntType:$step
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnfoldOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenUnfoldOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
164 changes: 163 additions & 1 deletion lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2611,6 +2611,167 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern<AtenDiagEmbedOp> {
};
} // namespace

namespace {
class ConvertAtenUnfoldOp : public OpConversionPattern<AtenUnfoldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenUnfoldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto self = adaptor.getSelf();
RankedTensorType selfType = cast<RankedTensorType>(self.getType());

int64_t dimension;
if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dimension))) {
return rewriter.notifyMatchFailure(op,
"only support constant int dimension");
}
int64_t size;
if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) {
return rewriter.notifyMatchFailure(op, "only support constant int size");
}
int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
return rewriter.notifyMatchFailure(op, "only support constant int step");
}

if (step <= 0) {
return rewriter.notifyMatchFailure(op, "step must be greater than zero.");
}

int64_t selfRank = selfType.getRank();

// Zero-Rank case
if (selfRank == 0) {
// Empty tensor
if (size == 0) {
RankedTensorType resultType =
RankedTensorType::get({0}, selfType.getElementType());
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType());

rewriter.replaceOp(op, emptyTensor);
return success();
}

Value unsqueezedSelf = rewriter.create<tensor::ExpandShapeOp>(
loc, RankedTensorType::get({1}, selfType.getElementType()), self,
ArrayRef<ReassociationIndices>{});
rewriter.replaceOp(op, unsqueezedSelf);
return success();
}

auto shape = selfType.getShape();

if (dimension < 0) {
dimension = toPositiveDim(dimension, selfRank);
}
if (!isValidDim(dimension, selfRank)) {
return rewriter.notifyMatchFailure(op, "dimension out of range");
}

Value dimSize = rewriter.create<tensor::DimOp>(loc, self, dimension);

Value sizeValue = rewriter.create<arith::ConstantIndexOp>(loc, size);
Value sizeCheck = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ule, sizeValue, dimSize);
rewriter.create<cf::AssertOp>(
loc, sizeCheck,
rewriter.getStringAttr("size must be <= target dimension"));

/* Calculate output shape of unfold op:
* https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
* outputShape[dimension] is set to numBlocks, with size appended as an
* additional dimension
*/
SmallVector<OpFoldResult> outputShape;
for (int64_t i = 0; i < selfRank; i++) {
if (i == dimension) {
outputShape.push_back(getDynamicOrStaticNumBlocks(
rewriter, loc, shape[dimension], dimSize, size, step));
} else if (shape[i] == ShapedType::kDynamic) {
outputShape.push_back(
OpFoldResult(rewriter.create<tensor::DimOp>(loc, self, i)));
} else {
outputShape.push_back(rewriter.getIndexAttr(shape[i]));
}
}
outputShape.push_back(rewriter.getIndexAttr(size));

// Empty tensor to insert values into
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, outputShape, selfType.getElementType());

/**
* Use reindexing to map output indices to input indices
* i.e. In output of rank 3 case:
* (i, j, k) => (i', j') where i' = i * step + k and j' = j
* if dimension == 0
* (i, j, k) => (i', j') where i' = i and j' = j * step + k
* if dimension == 1
*/
MLIRContext *context = rewriter.getContext();
SmallVector<AffineExpr> outputExprs;
for (int dim = 0; dim < selfRank; ++dim) {
if (dim == dimension) {
auto idxLast = getAffineDimExpr(selfRank, context);
auto idxDimension = getAffineDimExpr(dimension, context);

AffineExpr dimIdx =
idxLast + idxDimension * rewriter.getAffineConstantExpr(step);
outputExprs.push_back(dimIdx);
} else {
outputExprs.push_back(getAffineDimExpr(dim, context));
}
}

int64_t outputRank = selfRank + 1;
auto inputAffineMap = AffineMap::get(outputRank, 0, outputExprs, context);
auto outputAffineMap =
AffineMap::getMultiDimIdentityMap(outputRank, context);

SmallVector<utils::IteratorType> iteratorTypes(
outputRank, utils::IteratorType::parallel);

Value result =
rewriter
.create<linalg::GenericOp>(
loc, outputTensor.getType(), self, outputTensor,
ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes,
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
b.create<linalg::YieldOp>(nestedLoc, args[0]);
})
.getResult(0);

rewriter.replaceOp(op, result);
return success();
}

private:
OpFoldResult getDynamicOrStaticNumBlocks(OpBuilder &rewriter, Location loc,
int64_t shapeDim, Value dimSize,
int64_t size, int64_t step) const {
/**
* numBlocks = (shape[dimension] - size) // step + 1
*/
if (shapeDim == ShapedType::kDynamic) {
Value numBlocksSubOp = rewriter.create<arith::SubIOp>(
loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, size));
Value numBlocksDivOp = rewriter.create<arith::DivUIOp>(
loc, numBlocksSubOp,
rewriter.create<arith::ConstantIndexOp>(loc, step));
Value numBlocks = rewriter.create<arith::AddIOp>(
loc, rewriter.create<arith::ConstantIndexOp>(loc, 1), numBlocksDivOp);
return OpFoldResult(numBlocks);
}

int64_t staticNumBlocks = (shapeDim - size) / step + 1;
return rewriter.getIndexAttr(staticNumBlocks); // Use static value
}
};
} // namespace

namespace {
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
public:
Expand Down Expand Up @@ -2679,7 +2840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
/*benefit=*/200);
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
/*benefit=*/100);

target.addIllegalOp<AtenUnfoldOp>();
patterns.add<ConvertAtenUnfoldOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeDimOp>();
Expand Down
77 changes: 77 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15588,6 +15588,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unfold\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"size must be less than or equal to {}\"\n"
" %false = torch.constant.bool false\n"
" %str_0 = torch.constant.str \"AssertionError: size must be less than or equal to 1\"\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: \"\n"
" %str_2 = torch.constant.str \"dimension out of range of {}\"\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %6 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
" %7 = torch.aten.add.str %str_1, %6 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %7, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.le.int %arg2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %5 : !torch.list<int>\n"
" } else {\n"
" %3 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" %15 = torch.aten.add.int %arg1, %0 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %15 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" %5 = torch.aten.ge.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
" %15 = torch.aten.lt.int %4, %0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %15 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n"
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.__getitem__.t %arg0, %4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %8 = torch.aten.le.int %arg2, %7 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %15 = torch.aten.format(%str, %7) : !torch.str, !torch.int -> !torch.str\n"
" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.sub.int %7, %arg2 : !torch.int, !torch.int -> !torch.int\n"
" %10 = torch.aten.floordiv.int %9, %arg3 : !torch.int, !torch.int -> !torch.int\n"
" %11 = torch.aten.add.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %12 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %13 = torch.aten._set_item.t %12, %4, %11 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" %14 = torch.aten.append.t %12, %arg2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %12 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.unfold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
"}\n"
"";
// clang-format on
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ bool Torch::isViewLikeOp(Operation *op) {
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
AtenPixelShuffleOp, AtenDiagonalOp>(op);
AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,11 @@
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"Unfold_Module_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -3158,6 +3163,10 @@
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"UnfoldModule_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5559,7 +5559,45 @@ def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int,
return torch.qint8
return torch.qint32

@check_shape_function([
Invocation(TensorOfShape(), 0, 1, 1), # Rank Zero.
Invocation(TensorOfShape(), 0, 0, 1), # Rank Zero, size of 0.
Invocation(TensorOfShape(6, 4), 0, 2, 1), # Basic case.
Invocation(TensorOfShape(6, 4, 2), 0, 2, 1), # Basic case.
Invocation(TensorOfShape(6, 4), -1, 2, 1), # Negative Dimension.
Invocation(TensorOfShape(6, 4, 2), -1, 2, 1), # Negative Dimension.
])
def aten〇unfold〡shape(self: List[int], dimension: int, size: int, step: int) -> List[int]:
ndim = len(self)

# Rank zero tensor
if ndim == 0:
assert dimension == 0, f"dimension out of range of {ndim}"
assert size <= 1, "size must be less than or equal to 1"
return [size]

dim = dimension
if dim < 0:
dim += ndim

assert (dim >= 0 and dim < ndim), f"dimension out of range of {ndim}"

size_dim = self[dim]
assert size <= size_dim, f"size must be less than or equal to {size_dim}"

num_blocks = (size_dim - size) // step + 1

out = upstream_shape_functions._copy(self)
out[dim] = num_blocks
out.append(size)
return out

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dimension=0, size=1, step=1)
)
def aten〇unfold〡dtype(self_rank_dtype: Tuple[int, int], dimension: int, size: int, step: int) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)")
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
emit("aten::unfold : (Tensor, int, int, int) -> (Tensor)")
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)")
emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)")
Expand Down
Loading

0 comments on commit d49eabb

Please sign in to comment.