Skip to content

Commit

Permalink
Add llama3 8b prefill gemm shapes (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar authored Jan 14, 2025
1 parent a70b3df commit ac30afc
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions gemmbench/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,10 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
(5120, 16384, 640, "13b"),
(3456, 16384, 5120, "13b"),
(5120, 16384, 1728, "13b"),
(512, 4096, 14336, "8b_prefill"),
(512, 14336, 4096, "8b_prefill"),
(512, 4096, 4096, "8b_prefill"),
(512, 1024, 4096, "8b_prefill"),
]

GPT4 = [
Expand Down Expand Up @@ -682,6 +686,28 @@ def is_compute_bound(M: int, N: int, K: int, dtype: str, raw_accumulators: bool)
]


def llama8b_prefill(dtype: str) -> list[GemmConfig]:
configs = []
"""LLAMA 8b Prefill, FP16."""
for m, n, k, model in LLAMA:
if model == "8b_prefill":
for raw_accumulators in [False, True]:
configs.append(
GemmConfig(
m,
n,
k,
"T",
"N",
dtype,
get_default_accumulator_element_type(dtype),
get_default_result_element_type(
dtype, raw_accumulators),
)
)
return configs


def llama13bmatvec(dtype: str) -> list[GemmConfig]:
configs = []
"""LLAMA 13b, single batch, FP16."""
Expand Down Expand Up @@ -1010,6 +1036,8 @@ def square(dtype: str) -> list[GemmConfig]:


def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
llama8b_prefill_configs = llama8b_prefill("f16")

llama13bmatvec_configs: list[GemmConfig] = []
llama13bmatvec_configs += llama13bmatvec("f16")
llama13bmatvec_configs += llama13bmatvecbf16("bf16")
Expand Down Expand Up @@ -1041,6 +1069,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]:
square_configs: list[GemmConfig] = square("f16") + square("bf16") + square("i8")

all_configs: list[tuple[str, GemmConfig]] = []
all_configs += [("llama8b_prefill", x) for x in llama8b_prefill_configs]
all_configs += [("llama13bmatvec", x) for x in llama13bmatvec_configs]
all_configs += [("llama70bmatvec", x) for x in llama70bmatvec_configs]
all_configs += [("llama13bskinny", x) for x in llama13bskinny_configs]
Expand Down

0 comments on commit ac30afc

Please sign in to comment.