Skip to content

Commit

Permalink
Add logic to verify vector.contract distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Aug 26, 2024
1 parent 4d3471b commit fea5dd4
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,57 @@ namespace {
using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

static LogicalResult isSubgroupLayoutCompatible(
IREE::GPU::MMAAttr::SingleSubgroupLayout subgroupLayout,
NestedLayoutAttr layout, int64_t dim1, int64_t dim2) {
SmallVector<int64_t> element = {layout.getElementsPerThread()[dim1],
layout.getElementsPerThread()[dim2]};
SmallVector<int64_t> thread = {layout.getThreadsPerOuter()[dim1],
layout.getThreadsPerOuter()[dim2]};
SmallVector<int64_t> tstrides = {layout.getThreadStrides()[dim1],
layout.getThreadStrides()[dim2]};
SmallVector<int64_t> outer = {layout.getOutersPerBatch()[dim1],
layout.getOutersPerBatch()[dim2]};

if (subgroupLayout.element != element) {
return failure();
}
if (subgroupLayout.thread != thread) {
return failure();
}
if (subgroupLayout.tstrides != tstrides) {
return failure();
}
if (subgroupLayout.outer != outer) {
return failure();
}

return success();
}

static LogicalResult isIntrinsicLayoutCompatible(VectorContractOpInfo &opInfo,
IREE::GPU::MMAAttr intrinsic,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout,
NestedLayoutAttr accLayout) {
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
auto [accM, accN] = opInfo.getResultMNIndex();
if (failed(isSubgroupLayoutCompatible(intrinsic.getASingleSubgroupLayout(),
lhsLayout, lhsM, lhsK))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getBSingleSubgroupLayout(),
rhsLayout, rhsK, rhsN))) {
return failure();
}
if (failed(isSubgroupLayoutCompatible(intrinsic.getCSingleSubgroupLayout(),
accLayout, accM, accN))) {
return failure();
}
return success();
}

/// Distributes `vector.contract` ops with nested layouts.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;
Expand Down Expand Up @@ -63,6 +114,12 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction rhs");
}
NestedLayoutAttr accLayout =
dyn_cast<NestedLayoutAttr>(signature[resultValue]);
if (!accLayout) {
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction acc");
}

// We assume there is an decision made before regarding which mfma intrinsic
// to use and it is attached as an attribute to this contract op.
Expand All @@ -73,6 +130,14 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
contractOp, "missing iree.amdgpu.mma intrinsic attribute");
}

// Check if the given intrinsic can be distributed with the given
// layouts.
if (failed(isIntrinsicLayoutCompatible(opDetail, mmaAttr, lhsLayout,
rhsLayout, accLayout))) {
return rewriter.notifyMatchFailure(
contractOp, "the intrinsic does not match the expected layouts");
}

SmallVector<int64_t> distShape = resultLayout.getDistributedShape();
LLVM_DEBUG({
llvm::dbgs() << "distributed shape: [";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ builtin.module attributes { transform.with_named_sequence } {
elements_per_thread = [1, 4],

subgroup_strides = [1, 1],
thread_strides = [32, 1]
thread_strides = [1, 32]
>

// C: shape = 32x64, layout = layoutC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 16},
return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*strides=*/{1, 0},
/*element=*/{1, 16}};
}
}
Expand Down Expand Up @@ -598,7 +598,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
}
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{16, 1}};
}
}
Expand All @@ -624,7 +624,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
/*element=*/{1, 1}};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{16, 1},
return {/*outer=*/{16, 1}, /*thread=*/{1, 16}, /*strides=*/{0, 1},
/*element=*/{1, 1}};
}
}
Expand Down

0 comments on commit fea5dd4

Please sign in to comment.