Skip to content

Commit

Permalink
[DT][NFC] Remove FailureOr<> from getEncodingInfo methods. (#19435)
Browse files Browse the repository at this point in the history
We are able to use identity MaterializationEncodingInfo to represent the
"failure". Thus, we no longer need the `FailureOr` wrapper. The revision
removes the wrapper and updates the `lowerContractionOpWithEncoding`
function type signature. It does not need to pass a callback function.
Instead, we can pass the `IREE::Codegen::LayoutAttrInterface` which has
the method to query the materialization information.

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Dec 11, 2024
1 parent 7177c29 commit a6da532
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 128 deletions.
11 changes: 4 additions & 7 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,13 @@ MaterializeEncodingTypeConverter::MaterializeEncodingTypeConverter(
// itself.
RankedTensorType tensorType =
transposeNarrowN ? transposeIfNarrowNResult(type) : type;
FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
getEncodingInfo(tensorType);
if (failed(maybeEncodingInfo) ||
IREE::Codegen::isIdentityLayout(maybeEncodingInfo.value())) {
MaterializeEncodingInfo encodingInfo = getEncodingInfo(tensorType);
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return dropEncoding(type);
}
auto encodingInfo = *maybeEncodingInfo;
auto packedType = cast<RankedTensorType>(tensor::PackOp::inferPackedType(
tensorType, maybeEncodingInfo->innerTileSizes,
maybeEncodingInfo->innerDimsPos, maybeEncodingInfo->outerDimsPerm));
tensorType, encodingInfo.innerTileSizes, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm));

// There is no swizzle, we are already done. Typically the case on CPU.
if (!encodingInfo.swizzle) {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MaterializeEncodingTypeConverter : public TypeConverter {
return layoutAttr;
}

FailureOr<IREE::Codegen::MaterializeEncodingInfo>
IREE::Codegen::MaterializeEncodingInfo
getEncodingInfo(RankedTensorType type) const {
return layoutAttr.getEncodingInfo(type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,9 @@ struct GPUSetEncodingOpLoweringConversion
return success();
}

FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(encodingOp.getResultType());
if (failed(maybeEncodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp,
"unhandled result encoding");
}
if (!maybeEncodingInfo->swizzle) {
if (!encodingInfo.swizzle) {
rewriter.replaceOp(encodingOp, packedValue.value());
return success();
}
Expand All @@ -128,18 +124,18 @@ struct GPUSetEncodingOpLoweringConversion
.getShape()
.take_front(origRank));
expandShapeShape.append(
getExpandedTileShape(maybeEncodingInfo->swizzle->expandShape));
getExpandedTileShape(encodingInfo.swizzle->expandShape));
RankedTensorType expandShapeType =
encodingOp.getSourceType().clone(expandShapeShape);

SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
origRank, maybeEncodingInfo->swizzle->expandShape);
SmallVector<ReassociationIndices> reassociation =
getReassociationIndices(origRank, encodingInfo.swizzle->expandShape);
auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, packedValue.value(), reassociation);

SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, origRank));
for (auto perm : maybeEncodingInfo->swizzle->permutation) {
for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(origRank + perm);
}
SmallVector<OpFoldResult> transposeResultDims =
Expand Down Expand Up @@ -168,9 +164,9 @@ struct GPUUnsetEncodingOpLoweringConversion
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
getTypeConverter());

FailureOr<MaterializeEncodingInfo> maybeEncodingInfo =
MaterializeEncodingInfo encodingInfo =
converter->getEncodingInfo(unsetEncodingOp.getSource().getType());
if (failed(maybeEncodingInfo)) {
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
Type targetType =
getTypeConverter()->convertType(unsetEncodingOp.getSourceType());
Value result = rewriter.createOrFold<tensor::CastOp>(
Expand All @@ -181,35 +177,34 @@ struct GPUUnsetEncodingOpLoweringConversion

Location loc = unsetEncodingOp.getLoc();
Value unpackSrc = adaptor.getSource();
if (maybeEncodingInfo->swizzle) {
if (encodingInfo.swizzle) {
int targetRank = unsetEncodingOp.getResultType().getRank();
auto srcConvertedType =
cast<RankedTensorType>(adaptor.getSource().getType());
SmallVector<OpFoldResult> emptyShape =
tensor::getMixedSizes(rewriter, loc, adaptor.getSource());
emptyShape.resize(targetRank);
for (auto i :
getExpandedTileShape(maybeEncodingInfo->swizzle->expandShape)) {
for (auto i : getExpandedTileShape(encodingInfo.swizzle->expandShape)) {
emptyShape.push_back(rewriter.getIndexAttr(i));
}
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, emptyShape, unsetEncodingOp.getSourceType().getElementType());

SmallVector<int64_t> transposePerm =
llvm::to_vector(llvm::seq<int64_t>(0, targetRank));
for (auto perm : maybeEncodingInfo->swizzle->permutation) {
for (auto perm : encodingInfo.swizzle->permutation) {
transposePerm.push_back(targetRank + perm);
}
auto invertedTransposePerm = invertPermutationVector(transposePerm);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, adaptor.getSource(), emptyTensor, invertedTransposePerm);

SmallVector<ReassociationIndices> reassociation = getReassociationIndices(
targetRank, maybeEncodingInfo->swizzle->expandShape);
targetRank, encodingInfo.swizzle->expandShape);
SmallVector<int64_t> unpackSrcShape(
srcConvertedType.getShape().take_front(targetRank));
unpackSrcShape.append(maybeEncodingInfo->innerTileSizes.begin(),
maybeEncodingInfo->innerTileSizes.end());
unpackSrcShape.append(encodingInfo.innerTileSizes.begin(),
encodingInfo.innerTileSizes.end());
RankedTensorType unpackSrcType =
unsetEncodingOp.getResultType().clone(unpackSrcShape);
unpackSrc = rewriter.create<tensor::CollapseShapeOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,11 @@ FailureOr<Value> lowerSetEncodingOpToPackOp(
Value source, const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType resultType = encodingOp.getResultType();
FailureOr<MaterializeEncodingInfo> encodingInfo =
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(resultType);
if (failed(encodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp, "unhandled result encoding");
}

// Shortcut to avoid creating new operations.
if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) {
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return source;
}

Expand All @@ -142,13 +139,13 @@ FailureOr<Value> lowerSetEncodingOpToPackOp(
return failure();
}
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
transposeInPlace(*encodingInfo);
transposeInPlace(encodingInfo);
}

// Create `tensor.empty` operation for the result of the pack operation.
Location loc = encodingOp.getLoc();
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, resultType, *encodingInfo, materializeEncodingValueFn);
rewriter, loc, resultType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
Expand All @@ -158,14 +155,14 @@ FailureOr<Value> lowerSetEncodingOpToPackOp(
SmallVector<OpFoldResult> sourceDims =
tensor::getMixedSizes(rewriter, loc, source);
SmallVector<OpFoldResult> resultDims = tensor::PackOp::getResultShape(
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo->innerDimsPos,
encodingInfo->outerDimsPerm);
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
resultType.getElementType());
return rewriter
.create<tensor::PackOp>(loc, source, emptyOp, encodingInfo->innerDimsPos,
.create<tensor::PackOp>(loc, source, emptyOp, encodingInfo.innerDimsPos,
*innerTileSizesOfr, paddingValue,
encodingInfo->outerDimsPerm)
encodingInfo.outerDimsPerm)
.getResult();
}

Expand All @@ -174,20 +171,17 @@ FailureOr<Value> lowerUnsetEncodingToUnpackOp(
Value packedValue, const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
RankedTensorType sourceType = encodingOp.getSourceType();
FailureOr<MaterializeEncodingInfo> encodingInfo =
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(sourceType);
if (failed(encodingInfo)) {
return rewriter.notifyMatchFailure(encodingOp, "unhandled source encoding");
}

// Shortcut to avoid creating new operations.
if (IREE::Codegen::isIdentityLayout(encodingInfo.value())) {
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return packedValue;
}

auto encoding = IREE::Encoding::getEncodingAttr(sourceType);
if (typeConverter.getTransposeNarrowN() && isNarrowNResult(encoding)) {
transposeInPlace(*encodingInfo);
transposeInPlace(encodingInfo);
}
// Create an `tensor.empty` for the result of the unpack operation.
Location loc = encodingOp.getLoc();
Expand All @@ -197,15 +191,15 @@ FailureOr<Value> lowerUnsetEncodingToUnpackOp(
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resultDims,
sourceType.getElementType());
FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, sourceType, *encodingInfo, materializeEncodingValueFn);
rewriter, loc, sourceType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
encodingOp, "failed to generate runtime tile size query");
}
return rewriter
.create<tensor::UnPackOp>(loc, packedValue, emptyOp,
encodingInfo->innerDimsPos, *innerTileSizesOfr,
encodingInfo->outerDimsPerm)
encodingInfo.innerDimsPos, *innerTileSizesOfr,
encodingInfo.outerDimsPerm)
.getResult();
}

Expand All @@ -217,22 +211,23 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
auto emptyType = cast<RankedTensorType>(emptyOp->getResultTypes()[0]);
FailureOr<MaterializeEncodingInfo> encodingInfo =
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(emptyType);
Location loc = emptyOp.getLoc();
if (failed(encodingInfo)) {
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
loc, emptyOp.getMixedSizes(), emptyType.getElementType());
return newEmptyOp;
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return rewriter
.create<tensor::EmptyOp>(loc, emptyOp.getMixedSizes(),
emptyType.getElementType())
.getOperation();
}

if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(emptyType))) {
transposeInPlace(*encodingInfo);
transposeInPlace(encodingInfo);
}

FailureOr<SmallVector<OpFoldResult>> innerTileSizesOfr = getInnerTileSizesOfr(
rewriter, loc, emptyType, *encodingInfo, materializeEncodingValueFn);
rewriter, loc, emptyType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizesOfr)) {
return rewriter.notifyMatchFailure(
emptyOp, "failed to generate runtime tile size query");
Expand All @@ -241,9 +236,9 @@ lowerOpWithEncoding(RewriterBase &rewriter, tensor::EmptyOp emptyOp,
SmallVector<OpFoldResult> sourceDims = emptyOp.getMixedSizes();
(void)foldDynamicIndexList(sourceDims);
SmallVector<OpFoldResult> newShape = tensor::PackOp::getResultShape(
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo->innerDimsPos,
encodingInfo->outerDimsPerm);
newShape = getSwizzledShape(newShape, *encodingInfo);
rewriter, loc, sourceDims, *innerTileSizesOfr, encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
newShape = getSwizzledShape(newShape, encodingInfo);
Operation *newEmptyOp = rewriter.create<tensor::EmptyOp>(
loc, newShape, emptyType.getElementType());
return newEmptyOp;
Expand All @@ -262,10 +257,10 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
return rewriter.notifyMatchFailure(genericOp,
"Output indexing map is not identity");
}
FailureOr<MaterializeEncodingInfo> outMaterializeEncodingInfo =
MaterializeEncodingInfo outMaterializeEncodingInfo =
typeConverter.getEncodingInfo(
cast<RankedTensorType>(outputOperand->get().getType()));
if (failed(outMaterializeEncodingInfo)) {
if (IREE::Codegen::isIdentityLayout(outMaterializeEncodingInfo)) {
return rewriter.notifyMatchFailure(
genericOp, "MaterializeEncodingInfo failed for output");
}
Expand All @@ -277,20 +272,20 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
// Compute the new indexing maps for the packed layout. This assumes that
// the output map is identity, and that all iterator types are parallel.
SmallVector<int64_t> outInnerDimsPos =
outMaterializeEncodingInfo->innerDimsPos;
outMaterializeEncodingInfo.innerDimsPos;
SmallVector<int64_t> outInverseOuterDimsPerm =
invertPermutationVector(outMaterializeEncodingInfo->outerDimsPerm);
invertPermutationVector(outMaterializeEncodingInfo.outerDimsPerm);
SmallVector<AffineMap> packedIndexingMaps;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
FailureOr<MaterializeEncodingInfo> materializeEncodingInfo =
MaterializeEncodingInfo materializeEncodingInfo =
typeConverter.getEncodingInfo(
cast<RankedTensorType>(inputOperand->get().getType()));
if (failed(materializeEncodingInfo)) {
if (IREE::Codegen::isIdentityLayout(materializeEncodingInfo)) {
return rewriter.notifyMatchFailure(
genericOp, "MaterializeEncodingInfo failed for input");
}
SmallVector<int64_t> innerDimsPos = materializeEncodingInfo->innerDimsPos;
SmallVector<int64_t> outerDimsPerm = materializeEncodingInfo->outerDimsPerm;
ArrayRef<int64_t> innerDimsPos = materializeEncodingInfo.innerDimsPos;
ArrayRef<int64_t> outerDimsPerm = materializeEncodingInfo.outerDimsPerm;
AffineMap inputMap = genericOp.getMatchingIndexingMap(inputOperand);
// Permute result dims to the input packed domain, and map dims to the
// output packed domain.
Expand Down Expand Up @@ -388,28 +383,28 @@ static FailureOr<SmallVector<OpFoldResult>> getPackedDimsForDispatchTensor(
return failure();
}

FailureOr<MaterializeEncodingInfo> encodingInfo =
MaterializeEncodingInfo encodingInfo =
typeConverter.getEncodingInfo(boundTensorType);
if (failed(encodingInfo)) {
if (IREE::Codegen::isIdentityLayout(encodingInfo)) {
return failure();
}
if (typeConverter.getTransposeNarrowN() &&
isNarrowNResult(IREE::Encoding::getEncodingAttr(boundTensorType))) {
transposeInPlace(*encodingInfo);
transposeInPlace(encodingInfo);
}

SmallVector<OpFoldResult> targetShape =
getMixedValues(boundTensorType.getShape(), dynamicDims, builder);
auto innerTileSizes = getInnerTileSizesOfr(
builder, loc, boundTensorType, *encodingInfo, materializeEncodingValueFn);
builder, loc, boundTensorType, encodingInfo, materializeEncodingValueFn);
if (failed(innerTileSizes)) {
return failure();
}
SmallVector<OpFoldResult> convertedTargetShape =
tensor::PackOp::getResultShape(builder, loc, targetShape, *innerTileSizes,
encodingInfo->innerDimsPos,
encodingInfo->outerDimsPerm);
return getSwizzledShape(convertedTargetShape, *encodingInfo);
encodingInfo.innerDimsPos,
encodingInfo.outerDimsPerm);
return getSwizzledShape(convertedTargetShape, encodingInfo);
}

/// For `dispatchTensorType` that bind a `RankedTensorType` with encoding,
Expand Down Expand Up @@ -756,17 +751,10 @@ class MaterializeContractionOp
return success();
}

// TODO(hanchung): This is a transition state for moving the implementation
// details to backend attributes. We won't need the function type argument
// after all the backends that support encodings implement the attribute.
auto getEncodingInfoWrapper =
[&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
return converter->getEncodingInfo(type);
};
FailureOr<Operation *> convertedOp =
IREE::Codegen::lowerContractionOpWithEncoding(
rewriter, op, operands, converter->getTransposeNarrowN(),
getEncodingInfoWrapper);
converter->getLayoutAttr());
if (failed(convertedOp)) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,5 @@ struct MaterializeEncodingInfo {
std::optional<TileSwizzle> swizzle;
};

using ResolveEncodingInfoFn =
std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType type)>;

} // namespace mlir::iree_compiler::IREE::Codegen
#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IR_IREECODEGENTYPES_H_
Loading

0 comments on commit a6da532

Please sign in to comment.