Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VectorDistribution] Add support for multi-subgroup attention #18188

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -598,7 +598,8 @@ static void enforceLayoutToLayoutOp(
}

// Enforce the result layout on init.
ChangeResult changed = input->resolve(toLayout.getLayout());
ChangeResult changed = input->resolveWithPossibleConflict(
toLayout.getLayout(), getOpOperand(toLayout, 0));
update(input, changed);
}

12 changes: 10 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
@@ -706,8 +706,8 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,

// TODO: Currently, we are forcing number of subgroups to be 1. This can be
// fixed by teaching vector distribution chained matmul.
GPUMMAHeuristicSeeds pvMatmulSeeds = {/*bestSubgroupCountPerWorkgroup=*/1,
/*bestMNTileCountPerSubgroup=*/8,
GPUMMAHeuristicSeeds pvMatmulSeeds = {/*bestSubgroupCountPerWorkgroup=*/4,
/*bestMNTileCountPerSubgroup=*/4,
/*bestKTileCountPerSubgroup=*/4};

LDBG("Attention Vector Distribution Config");
@@ -744,6 +744,14 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
return failure();
}

// TODO: Due to a bug in layout configuration, we cannot set warp count on
// the N dimension. This is however ok, because we generally do not want to
// distribute subgroups on N dimension anyway.
if (schedule->nWarpCount != 1) {
schedule->nTileCount *= schedule->nWarpCount;
schedule->nWarpCount = 1;
}

LDBG("Target Subgroup size: " << targetSubgroupSize);
LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
<< schedule->kSize << "]");
Original file line number Diff line number Diff line change
@@ -160,6 +160,37 @@ LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
return success();
}

LogicalResult setAttentionMatmulAnchor(IREE::GPU::MMAScheduleAttr schedule,
RewriterBase &rewriter,
linalg::LinalgOp contract) {
// TODO: Add SIMT fallback.
if (!schedule) {
return contract->emitError("missing mma schedule for contraction");
}

if (contract->hasAttr("attention_qk_matmul")) {
// subgroup_n count for attention matmul is always 1, because it is the
// reduction dimension. The subgroup_n count is in reality, for the second
// matmul.
IREE::GPU::MMAScheduleAttr qkSchedule =
rewriter.getAttr<IREE::GPU::MMAScheduleAttr>(
schedule.getIntrinsic(),
/*subgroup_m_count=*/schedule.getSubgroupMCount(),
/*subgroup_n_count=*/1);
return setContractionAnchor(qkSchedule, rewriter, contract);
}

if (contract->hasAttr("attention_pv_matmul")) {
// subgroup_n count for attention matmul is always 1, because it is the
// reduction dimension. The subgroup_n count is in reality, for the second
// matmul.
return setContractionAnchor(schedule, rewriter, contract);
}

return contract->emitError("attention matmul should have either "
"attention_qk_matmul or attention_pv_matmul set");
}

struct LLVMGPUConfigureTensorLayoutsPass final
: impl::LLVMGPUConfigureTensorLayoutsPassBase<
LLVMGPUConfigureTensorLayoutsPass> {
@@ -181,10 +212,16 @@ struct LLVMGPUConfigureTensorLayoutsPass final
// now, layout setting for other problems like reductions is TODO.
SmallVector<linalg::LinalgOp> contracts;
SmallVector<linalg::LinalgOp> convs;
SmallVector<linalg::LinalgOp> attentionMatmuls;

func->walk([&](linalg::LinalgOp linalgOp) {
if (linalg::isaContractionOpInterface(linalgOp)) {
contracts.push_back(linalgOp);
if (linalgOp->hasAttr("attention_qk_matmul") ||
linalgOp->hasAttr("attention_pv_matmul")) {
attentionMatmuls.push_back(linalgOp);
} else {
contracts.push_back(linalgOp);
}
} else if (succeeded(linalg::inferConvolutionDims(linalgOp))) {
convs.push_back(linalgOp);
}
@@ -203,6 +240,13 @@ struct LLVMGPUConfigureTensorLayoutsPass final
return signalPassFailure();
}
}

for (linalg::LinalgOp attentionMatmul : attentionMatmuls) {
if (failed(setAttentionMatmulAnchor(scheduleAttr, rewriter,
attentionMatmul))) {
return signalPassFailure();
}
}
}
};
} // namespace
Original file line number Diff line number Diff line change
@@ -342,9 +342,9 @@ func.func @matmul_dynamic_dim() {

// -----

// CHECK: #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 32, 0, 64, 64]{{\]}}
// CHECK: #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 64, 0, 64, 64]{{\]}}
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 1
// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1
// CHECK-NOT: prefetch_shared_memory = true

// CHECK-LABEL: func.func @attention_20x4096x64x4096x64()
@@ -377,9 +377,9 @@ func.func @attention_20x4096x64x4096x64() {

// -----

// CHECK: #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[16, 0, 32, 16]{{\]}}
// CHECK: #[[$TILE_SIZES:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[32, 0, 16, 32]{{\]}}
// CHECK: #iree_codegen.translation_info<LLVMGPUVectorDistribute
// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 1
// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1
// CHECK-NOT: prefetch_shared_memory = true

// CHECK-LABEL: func.func @attention_large_head_dim_shared_mem()
Original file line number Diff line number Diff line change
@@ -705,9 +705,9 @@ hal.executable private @attention_20x4096x64x4096x64 {

// Basic test to make sure we can handle attention

// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 64
// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 1>
// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1>
// Prefetching is disabled for attention for now
// CHECK-NOT: gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true

@@ -716,7 +716,7 @@ hal.executable private @attention_20x4096x64x4096x64 {

// CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64
// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x4x1x1x4x1xf16>)
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK-COUNT-48: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield

// -----
@@ -759,14 +759,14 @@ hal.executable private @attention_multiple_m_transpose {
}
}

// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 2, 1] subgroup_size = 64
// CHECK-SAME: mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
// CHECK-SAME: subgroup_m_count = 1, subgroup_n_count = 1>
// CHECK-SAME: subgroup_m_count = 2, subgroup_n_count = 1>

// CHECK-LABEL: func.func @attention_multiple_m_transpose()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]

// CHECK: scf.for %{{.*}} = %c0 to %c72 step %c1
// CHECK-SAME: -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf16>)
// CHECK-COUNT-128: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK-COUNT-96: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: scf.yield
Original file line number Diff line number Diff line change
@@ -297,6 +297,9 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);
// TODO: We shouldn't be relying on such attributes. We need a better
// mechanism to identify attention matmuls.
s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Any thoughts on making this a standardized/registered attribute in LinalgExt dialect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an anti-pattern. We shouldnt rely on such attributes, so i dont want to "codify" them. lets land this for now but plan to unwind this in the medium term. Could you add a note here that we shouldnt be relying on such attributes.


if (qETy.getIntOrFloatBitWidth() <= 8) {
// For low bit-depth types we perform post Q @ K scaling. This is to avoid
@@ -365,6 +368,10 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {

// newAcc = P @ V + newAcc
newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);
// TODO: We shouldn't be relying on such attributes. We need a better
// mechanism to identify attention matmuls.
newAcc.getDefiningOp()->setAttr("attention_pv_matmul", b.getUnitAttr());

return SmallVector<Value>{newAcc, newMax, newSum};
}