diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index c7cd67f521fd..f28083a633d3 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -145,6 +145,8 @@ "ViewExpandOnesModule_basic", "ViewExpandOnesBeforeAndAfterModule_basic", "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", "ViewCollapseInferredDimModule_basic", "ViewExpandInferredDimModule_basic", "ViewNoChangeStaticModule_basic", diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index d54b600c9ee4..57383095a115 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -126,6 +126,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, strides[dim] = rewriter.create(loc, strides[dim], stepIndex); return success(); } + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -184,12 +185,84 @@ class ConvertAtenFlattenUsingIntsOp namespace { /// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to -/// `linalg.TensorExpandShape` op only when one or multiple static dimensions -/// are expanded. All the other cases of `aten.View` op need to be handled. +/// one `linalg.TensorExpandShape` op for all expanded dimensions and one +/// `linalg.TensorCollapseShape` op for all collapsed dimensions. Cases where +/// there is neither an expand or collapse of dimensions (e.g. [2, 3] -> [3, 2]) +/// is not handled. Additionally, certain dynamic dimension cases rely on naive +/// assumptions or aren't supported. /// TODO: Handle all the other cases of `aten.View` op. class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + + // Helper for filling in remaining un-collapsed dims when the + // input/output dim is next to the next boundary dim. Additionally + // computes the size of a collapsed dynamic dim if necessary. + static LogicalResult + collapseToSingleDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, + int64_t collapseDim, int64_t maxCollapseDim, + int64_t startExpandDim, int64_t maxExpandDim, + SmallVector &collapseShape, + const SmallVector &expandShape, + ReassociationIndices &expandIndices) { + int64_t collapseDimSize = 1; + for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { + expandIndices.push_back(i); + if (collapseDimSize == kUnknownSize) + continue; + + int64_t expandedDimSize = expandShape[i]; + if (expandedDimSize == kUnknownSize) { + collapseDimSize = kUnknownSize; + continue; + } + collapseDimSize *= expandedDimSize; + } + int64_t rawCollapseDimSize = collapseShape[collapseDim]; + if (rawCollapseDimSize != kUnknownSize && collapseDimSize != kUnknownSize && + collapseDimSize != rawCollapseDimSize) { + return rewriter.notifyMatchFailure( + op, "desired size is not compatible with the input tensor size"); + } + collapseShape[collapseDim] = collapseDimSize; + return success(); + } + + // Helper to find the minimum set of dims to collapse with the + // same number of elements as that of collapseDim. This function assumes + // the size of the collapsed dim is never dynamic. + static LogicalResult + minimallyCollapseDimHelper(AtenViewOp op, ConversionPatternRewriter &rewriter, + int64_t collapseDim, int64_t maxCollapseDim, + int64_t startExpandDim, int64_t maxExpandDim, + const SmallVector &collapseShape, + const SmallVector &expandShape, + ReassociationIndices &expandIndices) { + int64_t collapseDimSize = collapseShape[collapseDim]; + int64_t expandedSize = 1; + + for (auto i : llvm::seq(startExpandDim, maxExpandDim)) { + int64_t expandDimSize = expandShape[i]; + if (expandDimSize == kUnknownSize || + collapseDimSize % (expandedSize *= expandDimSize)) { + return rewriter.notifyMatchFailure( + op, "desired size is not compatible with the input tensor size"); + } + expandIndices.push_back(i); + if (expandedSize == collapseDimSize) + return success(); + + if (expandedSize > collapseDimSize) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only supports expanding and collapsing " + "in view"); + } + } + + return rewriter.notifyMatchFailure( + op, "total number of elements mismatch in the expansion"); + } + LogicalResult matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -213,10 +286,6 @@ class ConvertAtenViewOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unimplemented: input rank 0 is not supported"); - bool isCollapse = inputRank > resultRank ? true : false; - int64_t collapsedRank = isCollapse ? resultRank : inputRank; - int64_t expandedRank = isCollapse ? inputRank : resultRank; - // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. SmallVector outputSizeTorchInt; @@ -232,43 +301,26 @@ class ConvertAtenViewOp : public OpConversionPattern { op, "desired size list length mismatches with the result type rank"); } - SmallVector inputSize = getTensorSizes(rewriter, loc, input); - ArrayRef expandedShapeInt = - llvm::makeArrayRef(isCollapse ? inputSize : outputSizeInt); - ArrayRef collapsedShapeInt = - llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSize); - - // Currently, we only handle the expanding or collapsing cases or the - // identity cases where the rank and shape of the input and result are - // equal, and the input itself is the result. We do not handle expanding And - // collapsing happening at the same time or cases where it's neither + // Currently, we only handle the cases where each dimension is either + // being expanded or collapsed. We do not handle cases where it's neither // collapsing nor expanding like view of [2,3] for 3x2 tensor. - // TODO: For the expanding And collapsing case, we will need to identify - // which dimensions are collapsing and which are expanding and do it in two - // steps. // TODO: For neither collapsing nor expanding, we could find a intermediate // shape to collapse and then expanded to the target shape. Like [2,3] => // [6] => [3, 2]. - if (inputRank == resultRank) { - for (unsigned i = 0; i < inputRank; i++) - checkDimEqualHelper(rewriter, loc, inputSize[i], outputSizeInt[i]); - rewriter.replaceOpWithNewOp(op, resultType, input); - return success(); - } // Iterate through the view op size list to do the following: // // 1. Combine output size list and input tensor type info to get the most // static outputShape. // - // 2. Fill in the reassociation for size list item where the output dim size - // is got from `torch.aten.size.int(inputTensor, inputDim)`. We naively - // assume this means the corresponding dimension is not expanded or + // 2. Mark dims in unchangedDims for size list items where the output dim + // size comes from a `torch.aten.size.int(inputTensor, inputDim)`. We + // naively assume this means the corresponding dimension is not expanded or // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption - // is violated. + // is violated for the cases of dynamic dimensions. SmallVector outputShape(resultRank, kUnknownSize); - SmallVector reassociation(collapsedRank); + SmallVector unchangedDims; llvm::Optional inferredDimension; for (auto en : llvm::enumerate(outputSizeTorchInt)) { int64_t inputDim; @@ -277,9 +329,9 @@ class ConvertAtenViewOp : public OpConversionPattern { // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim if (matchPattern(en.value(), m_TorchTensorSizeInt(op.self(), &inputDim))) { - auto collapsedDim = isCollapse ? outputDim : inputDim; - auto expandedDim = isCollapse ? inputDim : outputDim; - reassociation[collapsedDim].push_back(expandedDim); + unchangedDims.emplace_back(); + unchangedDims.back().push_back(inputDim); + unchangedDims.back().push_back(outputDim); if (!inputType.isDynamicDim(inputDim)) { outputShape[outputDim] = inputShape[inputDim]; continue; @@ -298,6 +350,11 @@ class ConvertAtenViewOp : public OpConversionPattern { } } + // Mark the end of the input/output shapes + unchangedDims.emplace_back(); + unchangedDims.back().push_back(inputRank); + unchangedDims.back().push_back(resultRank); + // Use static information of input tensor to determine size of inferred // dimension in output shape. // @@ -334,139 +391,208 @@ class ConvertAtenViewOp : public OpConversionPattern { numOfElements / outputKnownNumOfElements; } - SmallVector collapsedShape = - isCollapse ? outputShape : llvm::to_vector(inputShape); - SmallVector expandedShape = - isCollapse ? llvm::to_vector(inputShape) : outputShape; - - // The while loop does the following: - // 1. Fill in the reassociation indices for dimensions that are expanded. - // Check the interval dimensions between two unchanged dims in the - // collapsedShape. If the interval is size 1, associate all the dims - // in the expandedShape shape until the next unchanged dim. If the interval - // is larger than size 1, figure out the associations with assumptions that - // dynamic dimensions are not splitted. - // 2. Set collapsedShape and expandedShape following the requirements by + SmallVector inputSize = getTensorSizes(rewriter, loc, input); + ArrayRef outputShapeInt = llvm::makeArrayRef(outputSizeInt); + ArrayRef inputShapeInt = llvm::makeArrayRef(inputSize); + + // Association indices for expand/collapse ops. These two vectors + // are populated such that two entries at the same index corresponds + // to an expand or collapse. For example, + // + // inputAssociations: [[0, 1], [2]] + // outputAssociations: [[0], [1, 2, 3]] + // + // indicates that the first two dims of the input tensor + // are collapsed into the first dim of the output, and the + // third dim of the input is expanded into the last three dims + // of the output. + SmallVector inputAssociations; + SmallVector outputAssociations; + + SmallVector inputShapeVec = llvm::to_vector(inputShape); + + // The for loop does the following: + // 1. Attempt to match the indices from inputDim and outputDim to the next + // boundary found from `torch.aten.size.int(inputTensor, inputDim)`, or + // until (inputRank, resultRank) if there is no such op. Look at the first + // dimension of the input and output and collapse the larger one by finding + // a minimal set of opposing indices with the same number of elements. If + // the number of dims to the next boundary is 1, then we assume all + // remaining opposing dims must collapse into it. + // 2. For handling of dynamic dimensions, we first assume they are only + // split if we can easily compute the correct size. + // e.g. [2, -1] -> [2, 3, 4] + // This mainly happens at the edges of boundaries. Otherwise we try to match + // the dynamic dimension with the one across from it and give up if we can't + // reason about how the dimensions are associated. + // e.g. [-1, -1] -> [2, 3, 4] + // 3. Set inputShapeVec and outputShape following the requirements by // tensor.expand_shape verification code: // a. As long as one or more of the related dimensions in the expanded // shape is dynamic the collapsed dimension is dynamic. // b. If all of the related dimensions are static, the collapsed // dimension must be static. In other words, if a collapsed dimension is // dynamic, at least one of the related dimensions need to be dynamic. - int64_t collapsedDim = 0, expandedDim = 0; - while (collapsedDim < collapsedRank && expandedDim < expandedRank) { - // Not empty means the associations has been filled in and the dimension - // is unchanged. - if (!reassociation[collapsedDim].empty()) { - if (expandedDim != reassociation[collapsedDim][0]) - return op.emitOpError("Unsupported: expanded dims are off from the " - "expected dim got from reassociation"); - collapsedDim++; - expandedDim++; - continue; - } - - // Collect the dims that are collapsed until hitting the next dim that's - // unchanged. - SmallVector collapsedDims; - while (collapsedDim < collapsedRank && - reassociation[collapsedDim].empty()) { - collapsedDims.push_back(collapsedDim); - collapsedDim++; - } - // the next reassociation is for a dim that's unchanged. - int64_t expandedDimNext = collapsedDim != collapsedRank - ? reassociation[collapsedDim][0] - : expandedRank; - if (collapsedDims.size() == 1) { - int64_t collapsedDimSize = 1; - int64_t collapsedDim = collapsedDims[0]; - for (auto i : llvm::seq(expandedDim, expandedDimNext)) { - reassociation[collapsedDim].push_back(i); - if (collapsedDimSize == kUnknownSize) - continue; - - int64_t expandedDimSize = expandedShape[i]; - if (expandedDimSize == kUnknownSize) { - collapsedDimSize = kUnknownSize; - continue; + int64_t inputDim = 0, outputDim = 0; + for (auto boundary : unchangedDims) { + // We assume dims specified by AtenSizeInt ops are unchanged + int64_t nextUnchangedInput = boundary[0]; + int64_t nextUnchangedOutput = boundary[1]; + + bool hasDynamic = false; + while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + + // outputDim is next to the boundary + if (outputDim == nextUnchangedOutput - 1) { + if (hasDynamic && inputDim != nextUnchangedInput - 1) { + return rewriter.notifyMatchFailure( + op, "found ambiguous collapse of dynamic input sizes (e.g. " + "[-1, -1, -1] -> [-1, -1])"); } - collapsedDimSize *= expandedShape[i]; + outputAssociations.back().push_back(outputDim); + if (failed(collapseToSingleDimHelper( + op, rewriter, outputDim, nextUnchangedOutput, inputDim, + nextUnchangedInput, outputShape, inputShapeVec, + inputAssociations.back()))) + return failure(); + outputDim = nextUnchangedOutput; + inputDim = nextUnchangedInput; + continue; } - // To meet both requirements from tensor.expand_shape verification code. - collapsedShape[collapsedDim] = collapsedDimSize; - expandedDim = expandedDimNext; - continue; - } - // collpasedDims are expanded to [expandedDim, expandedDimNext) - if (expandedDimNext - expandedDim < (int64_t)collapsedDims.size()) - op.emitError("unimplemented: mixed of expanding and collapsing " - "operations for view"); - for (auto collapsedDim : collapsedDims) { - if (collapsedShape[collapsedDim] == kUnknownSize) { - if (expandedDim >= expandedDimNext) { + // inputDim is next to the boundary + if (inputDim == nextUnchangedInput - 1) { + if (hasDynamic && inputShape[inputDim] == kUnknownSize) { return rewriter.notifyMatchFailure( - op, - "desired size is not compatible with the input tensor size"); - } - checkDimEqualHelper(rewriter, loc, collapsedShapeInt[collapsedDim], - expandedShapeInt[expandedDim]); - // To meet the second requirement from tensor.expand_shape - // verification code. - expandedShape[expandedDim] = kUnknownSize; - reassociation[collapsedDim].push_back(expandedDim++); - } else { - int64_t remainingSizeToExpand = collapsedShape[collapsedDim]; - // A do-while loop is used here to handle the cases where the - // collapsed shape tensor has a dimension of size 1. - do { - int64_t expandedDimSize = expandedShape[expandedDim]; - if (expandedDim >= expandedDimNext || - expandedShape[expandedDim] == kUnknownSize || - remainingSizeToExpand % expandedDimSize != 0) { - return rewriter.notifyMatchFailure( - op, "total number of elements mismatch in the expansion"); - } - reassociation[collapsedDim].push_back(expandedDim++); - remainingSizeToExpand /= expandedDimSize; - } while (remainingSizeToExpand != 1); - - // If all dims until `expandedDimNext` are of size 1, then group those - // with the reassociation for the current `collapsedDim`. - auto expandedShapeSlice = - llvm::makeArrayRef(expandedShape) - .slice(expandedDim, expandedDimNext - expandedDim); - if (llvm::all_of(expandedShapeSlice, - [](int64_t val) { return val == 1; })) { - reassociation[collapsedDim].append( - llvm::to_vector(llvm::seq(expandedDim, expandedDimNext))); - expandedDim = expandedDimNext; + op, "found ambiguous expand of dynamic sizes (e.g. [-1, -1] -> " + "[-1, -1, -1])"); } + inputAssociations.back().push_back(inputDim); + if (failed(collapseToSingleDimHelper( + op, rewriter, inputDim, nextUnchangedInput, outputDim, + nextUnchangedOutput, inputShapeVec, outputShape, + outputAssociations.back()))) + return failure(); + outputDim = nextUnchangedOutput; + inputDim = nextUnchangedInput; + continue; } + + int64_t inputMatchingDimSize = inputShapeVec[inputDim]; + int64_t outputMatchingDimSize = outputShape[outputDim]; + + // If the input is dynamic, first assume it is not split + if (inputMatchingDimSize == kUnknownSize) { + checkDimEqualHelper(rewriter, loc, inputShapeInt[inputDim], + outputShapeInt[outputDim]); + outputShape[outputDim] = kUnknownSize; + inputAssociations.back().push_back(inputDim++); + outputAssociations.back().push_back(outputDim++); + hasDynamic = true; + continue; + } + + // inputDim size is larger; try to collapse onto it + if (inputMatchingDimSize >= outputMatchingDimSize) { + inputAssociations.back().push_back(inputDim); + if (failed(minimallyCollapseDimHelper( + op, rewriter, inputDim, nextUnchangedInput, outputDim, + nextUnchangedOutput, inputShapeVec, outputShape, + outputAssociations.back()))) + return failure(); + hasDynamic = false; + outputDim = outputAssociations.back().back() + 1; + inputDim++; + continue; + } + + // outputDim is larger; try to collapse onto it + outputAssociations.back().push_back(outputDim); + if (failed(minimallyCollapseDimHelper( + op, rewriter, outputDim, nextUnchangedOutput, inputDim, + nextUnchangedInput, outputShape, inputShapeVec, + inputAssociations.back()))) + return failure(); + hasDynamic = false; + inputDim = inputAssociations.back().back() + 1; + outputDim++; + continue; } + + if (inputDim != nextUnchangedInput || outputDim != nextUnchangedOutput) { + return rewriter.notifyMatchFailure( + op, "could not match input tensor shape to output shape; " + "potentially unsupported view shape"); + } + + // Append the associations for the dims matching `aten.size.int` + if (nextUnchangedInput != inputRank && + nextUnchangedOutput != resultRank) { + inputAssociations.emplace_back(); + outputAssociations.emplace_back(); + inputAssociations.back().push_back(inputDim++); + outputAssociations.back().push_back(outputDim++); + } + } + + // Check if the shapes already match up to dynamic sizes. If so, we can just + // cast as the result type because the previous loop sets up the necessary + // dim checks in case of dynamic sizes. + if (llvm::all_of( + inputAssociations, + [](ReassociationIndices indices) { return indices.size() == 1; }) && + llvm::all_of(outputAssociations, [](ReassociationIndices indices) { + return indices.size() == 1; + })) { + rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); } - if (collapsedDim != collapsedRank || expandedDim != expandedRank) - return rewriter.notifyMatchFailure(op, "view shape is not supported"); Type adjustedResultType = - RankedTensorType::get(isCollapse ? collapsedShape : expandedShape, - resultType.getElementType()); + RankedTensorType::get(outputShape, resultType.getElementType()); Type adjustedInputType = - RankedTensorType::get(isCollapse ? expandedShape : collapsedShape, - resultType.getElementType()); + RankedTensorType::get(inputShapeVec, resultType.getElementType()); Value castedInput = rewriter.create(loc, adjustedInputType, input); - Value result = - isCollapse - ? rewriter - .create(loc, adjustedResultType, - castedInput, reassociation) - .result() - : rewriter - .create(loc, adjustedResultType, - castedInput, reassociation) - .result(); + llvm::Optional expandedInput; + llvm::Optional collapsedInput; + + if (llvm::any_of(inputAssociations, [](ReassociationIndices indices) { + return indices.size() > 1; + })) { + SmallVector intermediateShape; + for (auto i : llvm::seq(0, (int)inputAssociations.size())) { + if (inputAssociations[i].size() > 1) { + intermediateShape.push_back(outputShape[outputAssociations[i][0]]); + } else { + intermediateShape.push_back(inputShapeVec[inputAssociations[i][0]]); + } + } + Type intermediateResultType = + RankedTensorType::get(intermediateShape, resultType.getElementType()); + expandedInput = + rewriter + .create(loc, intermediateResultType, + castedInput, inputAssociations) + .result(); + } + + if (llvm::any_of(outputAssociations, [](ReassociationIndices indices) { + return indices.size() > 1; + })) { + collapsedInput = rewriter + .create( + loc, adjustedResultType, + expandedInput.hasValue() ? expandedInput.value() + : castedInput, + outputAssociations) + .result(); + } + + Value result = collapsedInput.hasValue() ? collapsedInput.value() + : expandedInput.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index c99f32a937ec..0396ca9431cf 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -55,15 +55,15 @@ def __init__(self): @export @annotate_args([ None, - ([1, 3], torch.float32, True), + ([2, 1, 16, 1, 1], torch.float32, True), ]) def forward(self, a): - return a.view(1, 1, 3, 1, 1) + return a.view(1, 2, 1, 16, 1, 1, 1, 1) @register_test_case(module_factory=lambda: ViewExpandOnesBeforeAndAfterModule()) def ViewExpandOnesBeforeAndAfterModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 3)) + module.forward(tu.rand(2, 1, 16, 1, 1)) # ============================================================================== @@ -164,6 +164,82 @@ def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ViewExpandCollapseWithOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 8, 8], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 1, 1, 4, 64) + +@register_test_case(module_factory=lambda: ViewExpandCollapseWithOnesModule()) +def ViewExpandCollapseWithOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 8)) + +# ============================================================================== + +class ViewExpandCollapseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 8, 16, 4], torch.float32, True), + ]) + + def forward(self, a): + return a.view(8, 2, 4, 16, 2, 2) + +@register_test_case(module_factory=lambda: ViewExpandCollapseModule()) +def ViewExpandCollapseModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 16, 4)) + +# ============================================================================== + +class ViewDynamicExpandCollapseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 4, -1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 1, 4, 64) + +@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseModule()) +def ViewDynamicExpandCollapseModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 8)) + +# ============================================================================== + +class ViewDynamicExpandCollapseWithAtenIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(2, 1, a.size(1), 64) + +@register_test_case(module_factory=lambda: ViewDynamicExpandCollapseWithAtenIntModule()) +def ViewDynamicExpandCollapseWithAtenIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 8, 8)) + +# ============================================================================== + class View1DFoldModule(torch.nn.Module): def __init__(self): super().__init__()