Skip to content

Commit

Permalink
[CodeGen][NFC] Remove unused encoding utils. (#16892)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhanW authored Mar 26, 2024
1 parent b96adf6 commit e3ced3a
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 55 deletions.
46 changes: 0 additions & 46 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,52 +103,6 @@ static unsigned mapDimToRoleIndex(int64_t dimPos, EncodingAttr encoding) {
return idx.value();
}

std::optional<SmallVector<int64_t>>
getPermutationToCanonicalMatmulShape(EncodingAttr encoding) {
FailureOr<linalg::ContractionDimensions> cDims =
getEncodingContractionDims(encoding);
if (failed(cDims)) {
return std::nullopt;
}
// Only support at most 1 Batch, M, N, K dimensions for now
if (cDims->m.size() > 1 || cDims->n.size() > 1 || cDims->k.size() > 1 ||
cDims->batch.size() > 1) {
return std::nullopt;
}
SmallVector<int64_t> perm;
EncodingRole role = encoding.getRole().getValue();
// Add batch dim
if (!cDims->batch.empty()) {
perm.push_back(mapDimToRoleIndex(cDims->batch[0], encoding));
}
// Add M dim
if (role != EncodingRole::RHS && cDims->m.size() == 1) {
perm.push_back(mapDimToRoleIndex(cDims->m[0], encoding));
}
// Add K dim
if (role != EncodingRole::RESULT) {
perm.push_back(mapDimToRoleIndex(cDims->k[0], encoding));
}
// Add N dim
if (role != EncodingRole::LHS && cDims->n.size() == 1) {
perm.push_back(mapDimToRoleIndex(cDims->n[0], encoding));
}
return perm;
}

RankedTensorType getCanonicalMatmulTypeWithEncoding(RankedTensorType type) {
auto encoding = getEncodingAttr(type);
if (!encoding) {
return type;
}
auto perm = getPermutationToCanonicalMatmulShape(encoding);
if (!perm) {
return type;
}
return RankedTensorType::get(applyPermutation(type.getShape(), perm.value()),
type.getElementType(), encoding);
}

RankedTensorType getOriginalTypeWithEncoding(RankedTensorType type) {
auto encoding = getEncodingAttr(type);
if (!encoding) {
Expand Down
9 changes: 0 additions & 9 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,6 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
/// Otherwise, returns null.
IREE::LinalgExt::EncodingAttr getEncodingAttr(RankedTensorType type);

/// Get the permutation that permutes the input shape to the canonical
/// matmul input shape based on the IndexingMaps encoding attribute.
std::optional<SmallVector<int64_t>>
getPermutationToCanonicalMatmulShape(IREE::LinalgExt::EncodingAttr encoding);

/// Returns a RankedTensorType that has been transposed into the canonical
/// form for an ordinary matmul/batch_matmul op.
RankedTensorType getCanonicalMatmulTypeWithEncoding(RankedTensorType type);

/// Returns the ContractionDimensions for the encoding user_indexing_maps.
FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(IREE::LinalgExt::EncodingAttr encoding);
Expand Down

0 comments on commit e3ced3a

Please sign in to comment.