Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Address comments
Browse files Browse the repository at this point in the history
Groverkss committed Aug 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 8e303b7 commit 14bad9d
Showing 3 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -24,11 +24,10 @@ struct DropToLayoutUnitDims final
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
PatternRewriter &rewriter) const override {
if (!toLayoutOp.hasTensorSemantics()) {
return failure();
return rewriter.notifyMatchFailure(toLayoutOp,
"requires tensor semanticS");
}

rewriter.setInsertionPoint(toLayoutOp);

Location loc = toLayoutOp.getLoc();
ShapedType inputTy = toLayoutOp.getType();
ArrayRef<int64_t> shape = inputTy.getShape();
@@ -67,6 +66,8 @@ struct DropToLayoutUnitDims final
toLayoutOp->getDiscardableAttrDictionary());

// Expand to preserve output shape using insert_slice.
// Here, since the shape comes from the result of a to_layout op, it will
// always be static.
Value dest =
rewriter.create<tensor::EmptyOp>(loc, shape, inputTy.getElementType());

Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
FailureOr<VectorContractOpInfo> opInfo =
VectorContractOpInfo::inferFromIndexingMaps(
contract.getIndexingMapsArray());
assert(succeeded(opInfo) && "contraction should have been infered");
assert(succeeded(opInfo) && "contraction should have been inferred");

auto layouts = schedule.getContractionLayout(opInfo.value(), contract);
if (failed(layouts)) {
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
Original file line number Diff line number Diff line change
@@ -28,6 +28,13 @@ std::pair<int, int> VectorContractOpInfo::getResultMNIndex() const {

FailureOr<VectorContractOpInfo>
VectorContractOpInfo::inferFromIndexingMaps(ArrayRef<AffineMap> maps) {
// Ensure all maps are projected permutations.
if (!llvm::all_of(maps, [](AffineMap map) {
return map.isProjectedPermutation(/*allowZeroInResults=*/true);
})) {
return failure();
}

auto maybeContractionDims = linalg::inferContractionDims(maps);
if (failed(maybeContractionDims)) {
return failure();

0 comments on commit 14bad9d

Please sign in to comment.