diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 004428bd343e25..c76d489e281190 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -774,6 +774,94 @@ struct ConvertIllegalShapeCastOpsToTransposes } }; +/// Returns an iterator over the dims (inc scalability) of a VectorType. +static auto getDims(VectorType vType) { + return llvm::zip_equal(vType.getShape(), vType.getScalableDims()); +} + +/// Helper to drop (fixed-size) unit dims from a VectorType. +static VectorType dropUnitDims(VectorType vType) { + SmallVector scalableFlags; + SmallVector dimSizes; + for (auto dim : getDims(vType)) { + if (dim == std::make_tuple(1, false)) + continue; + auto [size, scalableFlag] = dim; + dimSizes.push_back(size); + scalableFlags.push_back(scalableFlag); + } + return VectorType::get(dimSizes, vType.getElementType(), scalableFlags); +} + +/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the +/// shape_cast only drops unit dimensions. +/// +/// This simplifies the transpose making it possible for other legalization +/// rewrites to handle it. +/// +/// Example: +/// +/// BEFORE: +/// ```mlir +/// %0 = vector.transpose %vector, [3, 0, 1, 2] +/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> +/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> +/// ``` +/// +/// AFTER: +/// ```mlir +/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> +/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> +/// ``` +struct SwapShapeCastOfTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + auto transposeOp = + shapeCastOp.getSource().getDefiningOp(); + if (!transposeOp) + return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp"); + + auto resultType = shapeCastOp.getResultVectorType(); + if (resultType.getRank() <= 1) + return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low"); + + if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType())) + return rewriter.notifyMatchFailure( + shapeCastOp, "ShapeCastOp changes non-unit dimension(s)"); + + auto transposeSourceVectorType = transposeOp.getSourceVectorType(); + auto transposeSourceDims = + llvm::to_vector(getDims(transposeSourceVectorType)); + + // Construct a map from dimIdx -> number of dims dropped before dimIdx. + SmallVector droppedDimsBefore(transposeSourceVectorType.getRank()); + int64_t droppedDims = 0; + for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) { + droppedDimsBefore[i] = droppedDims; + if (dim == std::make_tuple(1, false)) + ++droppedDims; + } + + // Drop unit dims from transpose permutation. + auto perm = transposeOp.getPermutation(); + SmallVector newPerm; + for (int64_t idx : perm) { + if (transposeSourceDims[idx] == std::make_tuple(1, false)) + continue; + newPerm.push_back(idx - droppedDimsBefore[idx]); + } + + auto loc = shapeCastOp.getLoc(); + auto newShapeCastOp = rewriter.create( + loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector()); + rewriter.replaceOpWithNewOp(shapeCastOp, + newShapeCastOp, newPerm); + return success(); + } +}; + /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use /// the ZA state. This workaround rewrite to support these transposes when ZA is /// available. @@ -939,7 +1027,8 @@ struct VectorLegalizationPass patterns.add(context); + SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>( + context); // Note: These two patterns are added with a high benefit to ensure: // - Masked outer products are handled before unmasked ones // - Multi-tile writes are lowered as a store loop (if possible) diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index 458906a1879829..adc02adb6e974c 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -646,3 +646,29 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref return } + +// ----- + +// CHECK-LABEL: @swap_shape_cast_of_transpose( +// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>) +func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> { + // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> + // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> + // CHECK: return %[[TRANSPOSE]] + %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> + %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32> + return %1 : vector<[4]x4xf32> +} + +// ----- + +// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after( +// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>) +func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> { + // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32> + // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> + // CHECK: return %[[TRANSPOSE]] + %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32> + %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32> + return %1 : vector<[4]x4xf32> +}