diff --git a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h index 9fcaa167edb0..f3c8ef1c81f8 100644 --- a/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h +++ b/mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h @@ -153,8 +153,8 @@ inline void computeIndexDiffMap( // coordinate for the next layer. //===----------------------------------------------------------------------===// inline void populateLayeredIndicesWithIndexDiffMap( - // FIXME. Study how to get rid of destType. - OpBuilder &b, Location loc, const ShapedType destType, + // FIXME. Study how to get rid of lowerType. + OpBuilder &b, Location loc, const ShapedType lowerType, const ArrayAttr &layeredTransformMetadata, const SmallVector &layeredTransform, const SmallVector, 2> &layeredIndices, @@ -177,11 +177,11 @@ inline void populateLayeredIndicesWithIndexDiffMap( layeredTransformMetadata[layer].template cast(); } else { // in case there is no metadata, populate the lower level shape. - SmallVector destShapeAttr; - for (auto &v : destType.getShape()) - destShapeAttr.push_back(b.getI32IntegerAttr(v)); + SmallVector lowerShapeAttr; + for (auto &v : lowerType.getShape()) + lowerShapeAttr.push_back(b.getI32IntegerAttr(v)); transformMetadata = b.getDictionaryAttr({b.getNamedAttr( - "lower_layer_bounds", b.getArrayAttr({destShapeAttr}))}); + "lower_layer_bounds", b.getArrayAttr({lowerShapeAttr}))}); } AffineMap transform = layeredTransform[layer]; SmallVector lowerIndicesOriginal = layeredIndices[layer + 1]; @@ -6668,10 +6668,9 @@ struct ThreadwiseCopyRewritePattern else srcLowerIndices = srcUpperIndices; - Value scalarValue; // Load from source. // Issue scalar load. - scalarValue = b.create(loc, sourceType.getElementType(), + Value scalarValue = b.create(loc, sourceType.getElementType(), op.source(), srcLowerIndices); // Convert from sourceElementType to destElementType if necessary. @@ -6797,6 +6796,11 @@ struct ThreadwiseCopyRewritePattern SmallVector srcLowerIndices; SmallVector destUpperIndices; SmallVector destLowerIndices; + // Coordinates across the layers of transformations. + // If the vector is of size n, 0 is the top layer, and + // n-1 is the bottom layer. + SmallVector, 2> layeredSourceIndices; + SmallVector, 2> layeredDestIndices; if (!legacyLoadAttr || !legacyLoadAttr.template cast().getValue()) { // Compute high-level coordinate for dest memref. @@ -6804,10 +6808,6 @@ struct ThreadwiseCopyRewritePattern srcUpperIndices.push_back(b.create( loc, sourceAndDestCoord[i], b.getIndexType())); } - // Coordinates across the layers of transformations. - // If the vector is of size n, 0 is the top layer, and - // n-1 is the bottom layer. - SmallVector, 2> layeredSourceIndices; // Populate coorindates across the layers of transformations. populateLayeredIndicesWithAffineMap(b, loc, layeredSourceIndices, @@ -6824,11 +6824,6 @@ struct ThreadwiseCopyRewritePattern loc, sourceAndDestCoord[i], b.getIndexType())); } - // Coordinates across the layers of transformations. - // If the vector is of size n, 0 is the top layer, and - // n-1 is the bottom layer. - SmallVector, 2> layeredDestIndices; - // Populate coorindates across the layers of transformations. populateLayeredIndicesWithAffineMap( b, loc, layeredDestIndices, destUpperIndices, layeredDestTransform); @@ -6852,7 +6847,7 @@ struct ThreadwiseCopyRewritePattern // Coordinates across the layers of transformations. // If the vector is of size n, 0 is the top layer, and // n-1 is the bottom layer. - SmallVector, 2> layeredSourceIndices; + layeredSourceIndices.clear(); // Compute high-level coordinate for source memref. // src_index = (iv_0, iv_1, ...) + sourceCoord @@ -6868,6 +6863,7 @@ struct ThreadwiseCopyRewritePattern } // Populate coorindates across the layers of transformations. + SmallVector, 2> layeredSourceIndices; populateLayeredIndicesWithAffineMap(b, loc, layeredSourceIndices, srcUpperIndices, layeredSourceTransform); @@ -6876,7 +6872,41 @@ struct ThreadwiseCopyRewritePattern srcLowerIndices = layeredSourceIndices[layeredSourceIndices.size() - 1]; } else { - // TBD insert index diff map codes here. + // New approach. Use index diff map. + + // Coordinates across the layers of transformations. + // If the vector is of size n, 0 is the top layer, and + // n-1 is the bottom layer. + SmallVector, 2> layeredSourceDiffs; + SmallVector, 2> layeredSourceIndicesUpdated; + + // Populate coorindates across the layers of transformations. + ArrayAttr layeredSourceTransformMetadata; + if (srcTransformSpec) { + Attribute metadataAttr = srcTransformSpec.get("metadata"); + if (metadataAttr) + layeredSourceTransformMetadata = + metadataAttr.template cast(); + } + SmallVector srcTopDiff = loopIVsPerAccessOrder; + layeredSourceDiffs.push_back(srcTopDiff); + // Progressively apply index diff maps across all coordinate + // transformation layers. + populateLayeredIndicesWithIndexDiffMap( + b, loc, /*lowerType=*/sourceType, layeredSourceTransformMetadata, + layeredSourceTransform, layeredSourceIndices, srcTopDiff, + layeredSourceDiffs, layeredSourceIndicesUpdated); + + // Fetch low-level coordinate. + SmallVector srcLowerIndicesUpdated = + layeredSourceIndicesUpdated[layeredSourceIndicesUpdated.size() - + 1]; + // computeIndexDiffMap by default emit indices of type i32, convert to + // index type. + srcLowerIndices.clear(); + for (auto &v : srcLowerIndicesUpdated) + srcLowerIndices.push_back( + b.create(loc, v, b.getIndexType())); } // Pre-populate srcLowerOOBIndices. It will be modified inside @@ -6989,10 +7019,7 @@ struct ThreadwiseCopyRewritePattern if (legacyStoreAttr && legacyStoreAttr.template cast().getValue()) { - // Coordinates across the layers of transformations. - // If the vector is of size n, 0 is the top layer, and - // n-1 is the bottom layer. - SmallVector, 2> layeredDestIndices; + layeredDestIndices.clear(); // Compute high-level coordinate for dest memref. // dst_index = (iv_0, iv_1, ...) + destCoord @@ -7008,6 +7035,7 @@ struct ThreadwiseCopyRewritePattern } // Populate coorindates across the layers of transformations. + SmallVector, 2> layeredDestIndices; populateLayeredIndicesWithAffineMap(b, loc, layeredDestIndices, destUpperIndices, layeredDestTransform); @@ -7015,13 +7043,47 @@ struct ThreadwiseCopyRewritePattern // Fetch low-level coordinate. destLowerIndices = layeredDestIndices[layeredDestIndices.size() - 1]; - // Store to dest. - // Issue scalar store. - b.create(loc, convertedScalarValue, op.dest(), - destLowerIndices); } else { - // TBD insert index diff map codes here. - } + // New approach. Use index diff map. + + // Coordinates across the layers of transformations. + // If the vector is of size n, 0 is the top layer, and + // n-1 is the bottom layer. + SmallVector, 2> layeredDestDiffs; + SmallVector, 2> layeredDestIndicesUpdated; + + // Populate coorindates across the layers of transformations. + ArrayAttr layeredDestTransformMetadata; + if (destTransformSpec) { + Attribute metadataAttr = destTransformSpec.get("metadata"); + if (metadataAttr) + layeredDestTransformMetadata = + metadataAttr.template cast(); + } + SmallVector destTopDiff = loopIVsPerAccessOrder; + layeredDestDiffs.push_back(destTopDiff); + // Progressively apply index diff maps across all coordinate + // transformation layers. + populateLayeredIndicesWithIndexDiffMap( + b, loc, /*lowerType=*/destType, layeredDestTransformMetadata, + layeredDestTransform, layeredDestIndices, destTopDiff, + layeredDestDiffs, layeredDestIndicesUpdated); + + // Fetch low-level coordinate. + SmallVector destLowerIndicesUpdated = + layeredDestIndicesUpdated[layeredDestIndicesUpdated.size() - 1]; + // computeIndexDiffMap by default emit indices of type i32, convert to + // index type. + destLowerIndices.clear(); + for (auto &v : destLowerIndicesUpdated) + destLowerIndices.push_back( + b.create(loc, v, b.getIndexType())); + } + + // Store to dest. + // Issue scalar store. + b.create(loc, convertedScalarValue, op.dest(), + destLowerIndices); // increase IVs bool toIncreaseNextDigit = true; @@ -7272,7 +7334,7 @@ struct ThreadwiseCopyV2RewritePattern // Progressively apply index diff maps across all coordinate // transformation layers. populateLayeredIndicesWithIndexDiffMap( - b, loc, /*destType=*/destType, layeredSourceTransformMetadata, + b, loc, /*lowerType=*/sourceType, layeredSourceTransformMetadata, layeredSourceTransform, layeredSourceIndices, srcTopDiff, layeredSourceDiffs, layeredSourceIndicesUpdated); @@ -7305,7 +7367,7 @@ struct ThreadwiseCopyV2RewritePattern // Progressively apply index diff maps across all coordinate // transformation layers. populateLayeredIndicesWithIndexDiffMap( - b, loc, /*destType=*/destType, layeredDestTransformMetadata, + b, loc, /*lowerType=*/destType, layeredDestTransformMetadata, layeredDestTransform, layeredDestIndices, destTopDiff, layeredDestDiffs, layeredDestIndicesUpdated);