Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Aug 28, 2024
1 parent 8e303b7 commit 14bad9d
Showing 3 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -24,11 +24,10 @@ struct DropToLayoutUnitDims final
LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
PatternRewriter &rewriter) const override {
if (!toLayoutOp.hasTensorSemantics()) {
return failure();
return rewriter.notifyMatchFailure(toLayoutOp,
"requires tensor semanticS");
}

rewriter.setInsertionPoint(toLayoutOp);

Location loc = toLayoutOp.getLoc();
ShapedType inputTy = toLayoutOp.getType();
ArrayRef<int64_t> shape = inputTy.getShape();
@@ -67,6 +66,8 @@ struct DropToLayoutUnitDims final
toLayoutOp->getDiscardableAttrDictionary());

// Expand to preserve output shape using insert_slice.
// Here, since the shape comes from the result of a to_layout op, it will
// always be static.
Value dest =
rewriter.create<tensor::EmptyOp>(loc, shape, inputTy.getElementType());

Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
FailureOr<VectorContractOpInfo> opInfo =
VectorContractOpInfo::inferFromIndexingMaps(
contract.getIndexingMapsArray());
assert(succeeded(opInfo) && "contraction should have been infered");
assert(succeeded(opInfo) && "contraction should have been inferred");

auto layouts = schedule.getContractionLayout(opInfo.value(), contract);
if (failed(layouts)) {
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/VectorOpUtils.cpp
Original file line number Diff line number Diff line change
@@ -28,6 +28,13 @@ std::pair<int, int> VectorContractOpInfo::getResultMNIndex() const {

FailureOr<VectorContractOpInfo>
VectorContractOpInfo::inferFromIndexingMaps(ArrayRef<AffineMap> maps) {
// Ensure all maps are projected permutations.
if (!llvm::all_of(maps, [](AffineMap map) {
return map.isProjectedPermutation(/*allowZeroInResults=*/true);
})) {
return failure();
}

auto maybeContractionDims = linalg::inferContractionDims(maps);
if (failed(maybeContractionDims)) {
return failure();

0 comments on commit 14bad9d

Please sign in to comment.