Skip to content

Commit

Permalink
Support aten.stack op and decompose it into unsqueeze & cat (llvm#1747)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Mar 11, 2023
1 parent d310bb1 commit 4912c39
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 32 deletions.
4 changes: 4 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@
"FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackModule_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic",
"NumelModule_basic",
Expand Down Expand Up @@ -805,6 +808,7 @@
"SubIntModule_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"TensorToFloatZeroRank_basic",
Expand Down
49 changes: 25 additions & 24 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7107,30 +7107,6 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
let hasFolder = 1;
}

def Torch_AtenStackOp : Torch_Op<"aten.stack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenStackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenSumOp : Torch_Op<"aten.sum", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -8882,6 +8858,31 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [
let hasFolder = 1;
}

def Torch_AtenStackOp : Torch_Op<"aten.stack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenStackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
AllowsTypeRefinement
]> {
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,17 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
return list.getElements()[0];
}

//===----------------------------------------------------------------------===//
// AtenStackOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
return list.getElements()[0];
}

//===----------------------------------------------------------------------===//
// AtenSliceTensorOp
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7632,6 +7632,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
Expand Down
49 changes: 47 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ static FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter,
} else {
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
}
Type unsqueezedType =
inputType.getWithSizesAndDtype(unsqueezedShape, inputType.getDtype());
Type unsqueezedType = inputType.getWithSizesAndDtype(
unsqueezedShape, inputType.getOptionalDtype());
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
op->getLoc(), unsqueezedType, input, dim);
return unsqueezed;
Expand Down Expand Up @@ -1055,6 +1055,50 @@ class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
};
} // namespace

// Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`.
namespace {
class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenStackOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> tensors;
if (!getListConstructElements(op.getTensors(), tensors)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the tensor list is not from list construct");
}
// Ensure all tensors have known sizes
for (Value tensor : tensors) {
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (!tensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: one tensor does not have known sizes");
}
}

SmallVector<Value> unsqueezedTensors;
for (Value tensor : tensors) {
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim());
if (failed(unsqueezedInfo)) {
return rewriter.notifyMatchFailure(
op, "cannot generate unsqueeze tensor op");
}
unsqueezedTensors.push_back(*unsqueezedInfo);
}

Type listElemType =
op.getType().cast<BaseTensorType>().getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
op.getLoc(), listType, unsqueezedTensors);
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(),
unsqueezedTensorList, op.getDim());
return success();
}
};
} // namespace

// Decompose aten.roll into aten.slice and aten.cat ops.
// https://pytorch.org/docs/stable/generated/torch.roll.html
namespace {
Expand Down Expand Up @@ -3873,6 +3917,7 @@ class DecomposeComplexOpsPass
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenEmptyLikeOp>();
target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenExpandOp>();
Expand Down
15 changes: 10 additions & 5 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis<
void visitAtenToDtypeLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
template <typename OpTy>
void visitTypeConversionOp(OpTy op, ArrayRef<const ValueState *> operands);
void visitAtenCatOp(AtenCatOp op, ArrayRef<const ValueState *> operands);
template <typename OpTy>
void visitAtenCatLikeOp(OpTy op, ArrayRef<const ValueState *> operands);

template <typename OpTy>
void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
Expand Down Expand Up @@ -1071,7 +1072,10 @@ void TypeAnalysis::visitOperation(Operation *op,
}

if (auto cat = dyn_cast<AtenCatOp>(op)) {
visitAtenCatOp(cat, operands);
visitAtenCatLikeOp<AtenCatOp>(cat, operands);
return;
} else if (auto stack = dyn_cast<AtenStackOp>(op)) {
visitAtenCatLikeOp<AtenStackOp>(stack, operands);
return;
}

Expand Down Expand Up @@ -1417,12 +1421,13 @@ void TypeAnalysis::visitTypeConversionOp(
// `torch.aten.cat` concatenates the given sequence of seq tensors in the given
// dimension. The output has the same sizes as the input for all dimensions
// except the given dimension.
void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
ArrayRef<const ValueState *> operands) {
template <typename OpTy>
void TypeAnalysis::visitAtenCatLikeOp(OpTy op,
ArrayRef<const ValueState *> operands) {
auto tensorList = op.getTensors();
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
auto listConstruct = tensorList.template getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
incorporateKnowledge(op.getResult(), knowledge);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,9 @@ def aten〇index〇Tensor_hacked_twin〡shape(self: List[int], indices: List[Lis
def aten〇cat〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.cat(tensors, dim)

def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.stack(tensors, dim)

def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]:
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)")
emit("aten::max : (Tensor) -> (Tensor)")
Expand Down Expand Up @@ -563,6 +562,7 @@ def emit_with_mutating_variants(key, **kwargs):

# List ops.
emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True)
emit("aten::append.t : (t[], t) -> (t[])")
emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True)
emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True)
Expand Down
74 changes: 74 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,80 @@ def TensorsConcatPromoteDTypeModule_basic(module, tu: TestUtils):
# ==============================================================================


class TensorsStackModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y, z):
return torch.stack([x, y, z], 1)


@register_test_case(module_factory=lambda: TensorsStackModule())
def TensorsStackModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), tu.rand(2, 3, 4), tu.rand(2, 3, 4))


# ==============================================================================


class TensorsStackNegativeDimModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y, z):
return torch.stack([x, y, z], dim=-2)


@register_test_case(module_factory=lambda: TensorsStackNegativeDimModule())
def TensorsStackNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4), tu.rand(2, 3, 4), tu.rand(2, 3, 4))


# ==============================================================================


class TensorsStackPromoteDTypeModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.bool, True),
([-1, -1, -1], torch.int32, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, x, y, z):
return torch.cat([x, y, z], dim=-2)


@register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule())
def TensorsStackPromoteDTypeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 3, 4, low=0, high=2).bool(),
tu.randint(2, 3, 4, low=0, high=100).int(),
tu.randint(2, 3, 4, low=0, high=100).long())


# ==============================================================================


class GatherModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit 4912c39

Please sign in to comment.