Skip to content

Commit

Permalink
Fix attention + enable VMMA (#43)
Browse files Browse the repository at this point in the history
Update attention to work ToM and support VMMA:

1. Update translation info to use "pipeline"
2. Update attention IREE IR to specify QK and KV MMA schedule separately
S.T it works with ToM IREE
3. Refactor to use enum.Enum to represent intrinsics
4. Add VMMA support and helper functions to maximize perf

---------

Signed-off-by: Stanley Winata <[email protected]>
Co-authored-by: saienduri <[email protected]>
  • Loading branch information
raikonenfnu and saienduri authored Jan 23, 2025
1 parent f0bd8a1 commit 87c0c8c
Showing 1 changed file with 87 additions and 18 deletions.
105 changes: 87 additions & 18 deletions attentionbench/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,67 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from enum import Enum


class IntrinsicType(Enum):
"""
Formatting for different target intrinsics:
<kind>_<elem-type-C>_<M>x<N>x<K>_<elem-type-A>[_<elem-type-B>]
Values: 0xABCD where:
* A = vendor:
* 1 = AMD
* 2 = NVIDIA
* B = architecture. When an intrinsic exists in multiple architectures, this
should be the architecture it was introduced in, as long as it still
has the same semantics. If a new architecture breaks an existing
intrinsic's semantics, we can use that field for versioning.
* For AMD:
* 0 = CDNA1
* 1 = CDNA2
* 2 = CDNA3
* 8 = RDNA3
* C = element type of A-matrix:
* 0 = 64-bit float (e.g. IEEE754 double precision)
* 1 = 32-bit float (e.g. IEEE754 single precision, and "xf32" fast variants)
* 2 = 16-bit float (incl. IREE754 half and bf16)
* 3 = 8-bit float (incl. f8E5M2, f8E4M3, and "FNUZ" variants)
* C = 8-bit integer (any signedness)
* D enumerates intrinsics that share the same 0xABC* bits.
"""
# Intrinsics introduced in CDNA1
MFMA_F32_16x16x16_F16 = 0x1020
MFMA_F32_32x32x8_F16 = 0x1021
VMFMA_F32_32x32x16_F16 = 0x1022
MFMA_I32_16x16x16_I8 = 0x10C0
MFMA_I32_32x32x8_I8 = 0x10C1

# Intrinsics introduced in CDNA3
MFMA_F32_16x16x32_F8 = 0x1230
MFMA_F32_32x32x16_F8 = 0x1231
MFMA_I32_16x16x32_I8 = 0x12C0
MFMA_I32_32x32x16_I8 = 0x12C1


def get_intrinsic_string(intrinsic: IntrinsicType):
match intrinsic:
case IntrinsicType.VMFMA_F32_32x32x16_F16:
return f"#iree_gpu.virtual_mma_layout<intrinsic = {intrinsic.name}>"
case _:
return f"#iree_gpu.mma_layout<{intrinsic.name}>"

def get_pv_intrinsic(intrinsic: IntrinsicType):
"""
QK intrinsics and PV intrinsics can differ. Mostly used for
selecting VMFMA for QK to maximize contiguous read from shared memory.
"""
match intrinsic:
case IntrinsicType.VMFMA_F32_32x32x16_F16:
return IntrinsicType.MFMA_F32_32x32x8_F16
case _:
return intrinsic

@dataclass
class AttentionConfig:
B: int
Expand Down Expand Up @@ -71,20 +130,11 @@ def get_lowering_config(self) -> str:
+ "{ "
+ f"workgroup = [{', '.join(map(str, self.wg_tiles))}], "
+ f"reduction = [{', '.join(map(str, self.reduction_tiles))}],"
+ f"promote_operands = [0, 1, 2]"
+ f"promote_operands = [1, 2]"
+ " }"
+ f">"
)

def get_mma_schedule(self) -> str:
return (
f"#iree_gpu.mma_schedule<"
+ f"intrinsic = #iree_gpu.mma_layout<{self.intrinsic}>"
+ f", subgroup_m_count = {self.M_warp}"
+ f", subgroup_n_count = {self.N_warp}"
+ f">"
)

def get_translation_info(self) -> str:
llvm_func_attrs = []
if self.waves_per_eu:
Expand All @@ -93,11 +143,10 @@ def get_translation_info(self) -> str:
llvm_func_attrs += [f'"denormal-fp-math-f32" = "preserve-sign"']
return (
f"#iree_codegen.translation_info<"
+ f"LLVMGPUVectorDistribute"
+ f"pipeline = LLVMGPUVectorDistribute"
+ f" workgroup_size = [{self.N_warp * self.M_warp * 64}]"
+ f" subgroup_size = 64"
+ f" ,{{mma_schedule = {self.get_mma_schedule()}"
+ f" , llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
+ f" , {{llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}"
+ f"}}"
+ f">"
)
Expand All @@ -110,6 +159,26 @@ def get_compilation_info(self) -> str:
+ f">"
)

def get_qk_config_info(self) -> str:
return (
f"#iree_gpu.lowering_config<{{"
+ f"mma_kind = {get_intrinsic_string(self.intrinsic)}"
+ f", subgroup_m_count = {self.M_warp}"
+ f", subgroup_n_count = {self.N_warp}"
+ f", promote_operands = [1]"
+ f"}}>"
)

def get_pv_config_info(self) -> str:
return (
f"#iree_gpu.lowering_config<{{"
+ f"mma_kind = {get_intrinsic_string(get_pv_intrinsic(self.intrinsic))}"
+ f", subgroup_m_count = {self.M_warp}"
+ f", subgroup_n_count = {self.N_warp}"
+ f", promote_operands = [1]"
+ f"}}>"
)


def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
shapes = f"""\
Expand All @@ -136,11 +205,11 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
func.func @main(%Q : !Q, %K : !K, %V : !V) -> !O {{
%scale = arith.constant 1.0 : !dtype
%empty = tensor.empty() : !O
%O = iree_linalg_ext.attention
%O = iree_linalg_ext.attention
{{ indexing_maps = [#Q, #K, #V, #S, #O]
,decomposition_config = {{
qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}>}},
pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}>}}
qk_attrs = {{attention_qk_matmul, lowering_config = {tuning.get_qk_config_info()}}},
pv_attrs = {{attention_pv_matmul, lowering_config = {tuning.get_pv_config_info()}}}
}}
{",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""}
}}
Expand Down Expand Up @@ -168,7 +237,7 @@ def compile_attention_config(

# TODO: Use different tuning specs for different configs. This is just a
# general tuning config that worked well for sdxl shapes.
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, "MFMA_F32_32x32x8_F16", 2, True)
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, IntrinsicType.VMFMA_F32_32x32x16_F16, 2, True)
# Generate mlir content
mlir_content = generate_mlir(config, spec)

Expand Down Expand Up @@ -211,5 +280,5 @@ def compile_attention_config(
# Dummy test generation
if __name__ == "__main__":
config = AttentionConfig(20, 4096, 64, 64, 4096, "f16")
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, "MFMA_F32_32x32x8_F16", 2, True)
spec = TuningSpec([1, 128, 0, 0, 0], [0, 0, 0, 0, 32], 4, 1, IntrinsicType.VMFMA_F32_32x32x16_F16, 2, True)
print(generate_mlir(config, spec))

0 comments on commit 87c0c8c

Please sign in to comment.