Skip to content

Commit

Permalink
[HOTFIX] Temporarily disable index diff map when gemmKExtra > 0. (#261)
Browse files Browse the repository at this point in the history
Use the legacy approach which expands affine maps from top level to bottom
level in case gemmKExtra > 0.

This avoids the C=3 layer in resnet50 from failing after applying index diff
maps.

This commit is considered a HACK. A better fix shall be devised later on.
  • Loading branch information
whchung authored Jun 9, 2021
1 parent 70690b9 commit 3fd1610
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -7037,6 +7037,12 @@ struct ThreadwiseCopyRewritePattern
// false : use the faster index diff map.
auto legacyLoadAttr = op->getAttr("legacy_load");
auto legacyStoreAttr = op->getAttr("legacy_store");
bool legacyLoad =
(legacyLoadAttr &&
legacyLoadAttr.template cast<BoolAttr>().getValue() == true);
bool legacyStore =
(legacyStoreAttr &&
legacyStoreAttr.template cast<BoolAttr>().getValue() == true);

Optional<AffineMap> composedSourceTransform;
Optional<AffineMap> composedDestTransform;
Expand Down Expand Up @@ -7067,6 +7073,36 @@ struct ThreadwiseCopyRewritePattern
return failure();
}

// FIXME. XXX.
// Workaround to obtain gemmKExtra attribute.
// And use it to override legacy load/store debug switch.
auto overrideLoadStoreHack =
[](const DictionaryAttr &transformSpec) -> bool {
if (transformSpec) {
Attribute metadataAttr = transformSpec.get("metadata");
if (metadataAttr) {
ArrayAttr layeredTransformMetadata =
metadataAttr.template cast<ArrayAttr>();
for (unsigned iter = 0; iter < layeredTransformMetadata.size();
++iter) {
DictionaryAttr dictAttr =
layeredTransformMetadata[iter].template cast<DictionaryAttr>();
auto gemmKExtraAttr = dictAttr.get("gemmKExtra");
if (gemmKExtraAttr) {
auto gemmKExtra =
gemmKExtraAttr.template cast<IntegerAttr>().getInt();
if (gemmKExtra > 0) {
return true;
}
}
}
}
}
return false;
};
legacyLoad = overrideLoadStoreHack(srcTransformSpec);
legacyStore = overrideLoadStoreHack(destTransformSpec);

// Populate the vector to hold source and dest coordinate.
SmallVector<Value, 8> sourceCoord;
SmallVector<Value, 8> destCoord;
Expand Down Expand Up @@ -7200,8 +7236,7 @@ struct ThreadwiseCopyRewritePattern
// wthe the metadata.
// Only do such computation in the new approach where index diff maps
// would be used.
if (!legacyLoadAttr ||
(legacyLoadAttr.template cast<BoolAttr>().getValue() == false)) {
if (legacyLoad == false) {
// Populate coorindates across the layers of transformations.
if (srcTransformSpec) {
Attribute metadataAttr = srcTransformSpec.get("metadata");
Expand Down Expand Up @@ -7232,8 +7267,7 @@ struct ThreadwiseCopyRewritePattern
// wthe the metadata.
// Only do such computation in the new approach where index diff maps
// would be used.
if (!legacyStoreAttr ||
(legacyStoreAttr.template cast<BoolAttr>().getValue() == false)) {
if (legacyStore == false) {
// Populate coorindates across the layers of transformations.
if (destTransformSpec) {
Attribute metadataAttr = destTransformSpec.get("metadata");
Expand Down Expand Up @@ -7274,8 +7308,7 @@ struct ThreadwiseCopyRewritePattern
bool toExit = false;
do {
// Use the old logic in case "legacy_load" attribute is specified.
if (legacyLoadAttr &&
(legacyLoadAttr.template cast<BoolAttr>().getValue() == true)) {
if (legacyLoad == true) {
computeTopAndBottomIndicesWithAffineMap(
b, loc, srcUpperIndices, srcLowerIndices, sourceCoord,
loopIVsPerAccessOrder, dimAccessOrder, layeredSourceTransform);
Expand All @@ -7298,8 +7331,7 @@ struct ThreadwiseCopyRewritePattern
b, loc, scalarValue, sourceElementType, destElementType);

// Use the old logic in case "legacy_store" attribute is specified.
if (legacyStoreAttr &&
(legacyStoreAttr.template cast<BoolAttr>().getValue() == true)) {
if (legacyStore == true) {
computeTopAndBottomIndicesWithAffineMap(
b, loc, destUpperIndices, destLowerIndices, destCoord,
loopIVsPerAccessOrder, dimAccessOrder, layeredDestTransform);
Expand Down

0 comments on commit 3fd1610

Please sign in to comment.