Skip to content

Commit

Permalink
Adopt index diff map logic in threadwise_copy. Disabled by default.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed May 30, 2021
1 parent a938c75 commit cb930c3
Showing 1 changed file with 93 additions and 31 deletions.
124 changes: 93 additions & 31 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ inline void computeIndexDiffMap(
// coordinate for the next layer.
//===----------------------------------------------------------------------===//
inline void populateLayeredIndicesWithIndexDiffMap(
// FIXME. Study how to get rid of destType.
OpBuilder &b, Location loc, const ShapedType destType,
// FIXME. Study how to get rid of lowerType.
OpBuilder &b, Location loc, const ShapedType lowerType,
const ArrayAttr &layeredTransformMetadata,
const SmallVector<AffineMap> &layeredTransform,
const SmallVector<SmallVector<Value, 8>, 2> &layeredIndices,
Expand All @@ -177,11 +177,11 @@ inline void populateLayeredIndicesWithIndexDiffMap(
layeredTransformMetadata[layer].template cast<DictionaryAttr>();
} else {
// in case there is no metadata, populate the lower level shape.
SmallVector<Attribute, 4> destShapeAttr;
for (auto &v : destType.getShape())
destShapeAttr.push_back(b.getI32IntegerAttr(v));
SmallVector<Attribute, 4> lowerShapeAttr;
for (auto &v : lowerType.getShape())
lowerShapeAttr.push_back(b.getI32IntegerAttr(v));
transformMetadata = b.getDictionaryAttr({b.getNamedAttr(
"lower_layer_bounds", b.getArrayAttr({destShapeAttr}))});
"lower_layer_bounds", b.getArrayAttr({lowerShapeAttr}))});
}
AffineMap transform = layeredTransform[layer];
SmallVector<Value, 8> lowerIndicesOriginal = layeredIndices[layer + 1];
Expand Down Expand Up @@ -6668,10 +6668,9 @@ struct ThreadwiseCopyRewritePattern
else
srcLowerIndices = srcUpperIndices;

Value scalarValue;
// Load from source.
// Issue scalar load.
scalarValue = b.create<LoadOp>(loc, sourceType.getElementType(),
Value scalarValue = b.create<LoadOp>(loc, sourceType.getElementType(),
op.source(), srcLowerIndices);

// Convert from sourceElementType to destElementType if necessary.
Expand Down Expand Up @@ -6797,17 +6796,18 @@ struct ThreadwiseCopyRewritePattern
SmallVector<Value, 8> srcLowerIndices;
SmallVector<Value, 8> destUpperIndices;
SmallVector<Value, 8> destLowerIndices;
// Coordinates across the layers of transformations.
// If the vector is of size n, 0 is the top layer, and
// n-1 is the bottom layer.
SmallVector<SmallVector<Value, 8>, 2> layeredSourceIndices;
SmallVector<SmallVector<Value, 8>, 2> layeredDestIndices;
if (!legacyLoadAttr ||
!legacyLoadAttr.template cast<BoolAttr>().getValue()) {
// Compute high-level coordinate for dest memref.
for (unsigned i = 0; i < sourceCoordLength; ++i) {
srcUpperIndices.push_back(b.create<IndexCastOp>(
loc, sourceAndDestCoord[i], b.getIndexType()));
}
// Coordinates across the layers of transformations.
// If the vector is of size n, 0 is the top layer, and
// n-1 is the bottom layer.
SmallVector<SmallVector<Value, 8>, 2> layeredSourceIndices;

// Populate coorindates across the layers of transformations.
populateLayeredIndicesWithAffineMap(b, loc, layeredSourceIndices,
Expand All @@ -6824,11 +6824,6 @@ struct ThreadwiseCopyRewritePattern
loc, sourceAndDestCoord[i], b.getIndexType()));
}

// Coordinates across the layers of transformations.
// If the vector is of size n, 0 is the top layer, and
// n-1 is the bottom layer.
SmallVector<SmallVector<Value, 8>, 2> layeredDestIndices;

// Populate coorindates across the layers of transformations.
populateLayeredIndicesWithAffineMap(
b, loc, layeredDestIndices, destUpperIndices, layeredDestTransform);
Expand All @@ -6852,7 +6847,7 @@ struct ThreadwiseCopyRewritePattern
// Coordinates across the layers of transformations.
// If the vector is of size n, 0 is the top layer, and
// n-1 is the bottom layer.
SmallVector<SmallVector<Value, 8>, 2> layeredSourceIndices;
layeredSourceIndices.clear();

// Compute high-level coordinate for source memref.
// src_index = (iv_0, iv_1, ...) + sourceCoord
Expand All @@ -6868,6 +6863,7 @@ struct ThreadwiseCopyRewritePattern
}

// Populate coorindates across the layers of transformations.
SmallVector<SmallVector<Value, 8>, 2> layeredSourceIndices;
populateLayeredIndicesWithAffineMap(b, loc, layeredSourceIndices,
srcUpperIndices,
layeredSourceTransform);
Expand All @@ -6876,7 +6872,41 @@ struct ThreadwiseCopyRewritePattern
srcLowerIndices =
layeredSourceIndices[layeredSourceIndices.size() - 1];
} else {
// TBD insert index diff map codes here.
// New approach. Use index diff map.

// Coordinates across the layers of transformations.
// If the vector is of size n, 0 is the top layer, and
// n-1 is the bottom layer.
SmallVector<SmallVector<int64_t, 8>, 2> layeredSourceDiffs;
SmallVector<SmallVector<Value, 8>, 2> layeredSourceIndicesUpdated;

// Populate coorindates across the layers of transformations.
ArrayAttr layeredSourceTransformMetadata;
if (srcTransformSpec) {
Attribute metadataAttr = srcTransformSpec.get("metadata");
if (metadataAttr)
layeredSourceTransformMetadata =
metadataAttr.template cast<ArrayAttr>();
}
SmallVector<int64_t, 8> srcTopDiff = loopIVsPerAccessOrder;
layeredSourceDiffs.push_back(srcTopDiff);
// Progressively apply index diff maps across all coordinate
// transformation layers.
populateLayeredIndicesWithIndexDiffMap(
b, loc, /*lowerType=*/sourceType, layeredSourceTransformMetadata,
layeredSourceTransform, layeredSourceIndices, srcTopDiff,
layeredSourceDiffs, layeredSourceIndicesUpdated);

// Fetch low-level coordinate.
SmallVector<Value, 8> srcLowerIndicesUpdated =
layeredSourceIndicesUpdated[layeredSourceIndicesUpdated.size() -
1];
// computeIndexDiffMap by default emit indices of type i32, convert to
// index type.
srcLowerIndices.clear();
for (auto &v : srcLowerIndicesUpdated)
srcLowerIndices.push_back(
b.create<IndexCastOp>(loc, v, b.getIndexType()));
}

// Pre-populate srcLowerOOBIndices. It will be modified inside
Expand Down Expand Up @@ -6989,10 +7019,7 @@ struct ThreadwiseCopyRewritePattern

if (legacyStoreAttr &&
legacyStoreAttr.template cast<BoolAttr>().getValue()) {
// Coordinates across the layers of transformations.
// If the vector is of size n, 0 is the top layer, and
// n-1 is the bottom layer.
SmallVector<SmallVector<Value, 8>, 2> layeredDestIndices;
layeredDestIndices.clear();

// Compute high-level coordinate for dest memref.
// dst_index = (iv_0, iv_1, ...) + destCoord
Expand All @@ -7008,20 +7035,55 @@ struct ThreadwiseCopyRewritePattern
}

// Populate coorindates across the layers of transformations.
SmallVector<SmallVector<Value, 8>, 2> layeredDestIndices;
populateLayeredIndicesWithAffineMap(b, loc, layeredDestIndices,
destUpperIndices,
layeredDestTransform);

// Fetch low-level coordinate.
destLowerIndices = layeredDestIndices[layeredDestIndices.size() - 1];

// Store to dest.
// Issue scalar store.
b.create<StoreOp>(loc, convertedScalarValue, op.dest(),
destLowerIndices);
} else {
// TBD insert index diff map codes here.
}
// New approach. Use index diff map.

// Coordinates across the layers of transformations.
// If the vector is of size n, 0 is the top layer, and
// n-1 is the bottom layer.
SmallVector<SmallVector<int64_t, 8>, 2> layeredDestDiffs;
SmallVector<SmallVector<Value, 8>, 2> layeredDestIndicesUpdated;

// Populate coorindates across the layers of transformations.
ArrayAttr layeredDestTransformMetadata;
if (destTransformSpec) {
Attribute metadataAttr = destTransformSpec.get("metadata");
if (metadataAttr)
layeredDestTransformMetadata =
metadataAttr.template cast<ArrayAttr>();
}
SmallVector<int64_t, 8> destTopDiff = loopIVsPerAccessOrder;
layeredDestDiffs.push_back(destTopDiff);
// Progressively apply index diff maps across all coordinate
// transformation layers.
populateLayeredIndicesWithIndexDiffMap(
b, loc, /*lowerType=*/destType, layeredDestTransformMetadata,
layeredDestTransform, layeredDestIndices, destTopDiff,
layeredDestDiffs, layeredDestIndicesUpdated);

// Fetch low-level coordinate.
SmallVector<Value, 8> destLowerIndicesUpdated =
layeredDestIndicesUpdated[layeredDestIndicesUpdated.size() - 1];
// computeIndexDiffMap by default emit indices of type i32, convert to
// index type.
destLowerIndices.clear();
for (auto &v : destLowerIndicesUpdated)
destLowerIndices.push_back(
b.create<IndexCastOp>(loc, v, b.getIndexType()));
}

// Store to dest.
// Issue scalar store.
b.create<StoreOp>(loc, convertedScalarValue, op.dest(),
destLowerIndices);

// increase IVs
bool toIncreaseNextDigit = true;
Expand Down Expand Up @@ -7272,7 +7334,7 @@ struct ThreadwiseCopyV2RewritePattern
// Progressively apply index diff maps across all coordinate
// transformation layers.
populateLayeredIndicesWithIndexDiffMap(
b, loc, /*destType=*/destType, layeredSourceTransformMetadata,
b, loc, /*lowerType=*/sourceType, layeredSourceTransformMetadata,
layeredSourceTransform, layeredSourceIndices, srcTopDiff,
layeredSourceDiffs, layeredSourceIndicesUpdated);

Expand Down Expand Up @@ -7305,7 +7367,7 @@ struct ThreadwiseCopyV2RewritePattern
// Progressively apply index diff maps across all coordinate
// transformation layers.
populateLayeredIndicesWithIndexDiffMap(
b, loc, /*destType=*/destType, layeredDestTransformMetadata,
b, loc, /*lowerType=*/destType, layeredDestTransformMetadata,
layeredDestTransform, layeredDestIndices, destTopDiff,
layeredDestDiffs, layeredDestIndicesUpdated);

Expand Down

0 comments on commit cb930c3

Please sign in to comment.