From d54bf0ef9868c6ee5890d9424f982d7742f91e12 Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 16 Feb 2024 18:18:56 +0100 Subject: [PATCH] Add multi-GPU benchmark of Mixtral (#353) --- README.md | 15 +++++++++++++++ awq/modules/fused/moe.py | 3 ++- docs/index.md | 1 + examples/benchmark.py | 4 +++- 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 96d2cb0d..945e26a0 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,21 @@ These benchmarks showcase the speed and memory usage of processing context (pref | DeepSeek | 33B | 🔵GEMM | 1 | 64 | 64 | 1160.18 | 40.29 | 18.92 GB (80.00%) | | DeepSeek | 33B | 🔵GEMM | 1 | 2048 | 2048 | 1012.1 | 34.0093 | 19.87 GB (84.02%) | +### Multi-GPU + +GPU: 2x NVIDIA GeForce RTX 4090 + +| Model | Size | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) | +|--------:|------:|--------------:|-------------:|-----------------:|----------------:|-------------------:|------------------:|:------------------| +| Mixtral | 46.7B | 🔵GEMM | 1 | 32 | 32 | 149.742 | 93.406 | 25.28 GB (53.44%) | +| Mixtral | 46.7B | 🔵GEMM | 1 | 64 | 64 | 1489.64 | 93.184 | 25.32 GB (53.53%) | +| Mixtral | 46.7B | 🔵GEMM | 1 | 128 | 128 | 2082.95 | 92.9444 | 25.33 GB (53.55%) | +| Mixtral | 46.7B | 🔵GEMM | 1 | 256 | 256 | 2428.59 | 91.5187 | 25.35 GB (53.59%) | +| Mixtral | 46.7B | 🔵GEMM | 1 | 512 | 512 | 2633.11 | 89.1457 | 25.39 GB (53.67%) | +| Mixtral | 46.7B | 🔵GEMM | 1 | 1024 | 1024 | 2598.95 | 84.6753 | 25.75 GB (54.44%) | +| Mixtral | 46.7B | 🔵GEMM | 1 | 2048 | 2048 | 2446.15 | 77.0516 | 27.98 GB (59.15%) | +| Mixtral | 46.7B | 🔵GEMM | 1 | 4096 | 4096 | 1985.78 | 77.5689 | 34.65 GB (73.26%) | + ## Reference If you find AWQ useful or relevant to your research, you can cite their [paper](https://arxiv.org/abs/2306.00978): diff --git a/awq/modules/fused/moe.py b/awq/modules/fused/moe.py index b252ce46..bab3bf55 100644 --- a/awq/modules/fused/moe.py +++ b/awq/modules/fused/moe.py @@ -52,7 +52,8 @@ def apply_moe_weights( topk: int, renormalize: bool, ) -> torch.Tensor: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024 + # NOTE: DISABLED FOR NOW + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1e9 #1024 if FP16_MATMUL_HEURISTIC_CONDITION: dequant_w1 = awq_ext.dequantize_weights_cuda( w1.qweight, w1.scales, w1.qzeros, 0, 0, 0, False diff --git a/docs/index.md b/docs/index.md index 9129030d..37c5abd8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,6 +8,7 @@ Example inference speed (RTX 4090, Ryzen 9 7950X, 64 tokens): - Vicuna 7B (GEMV kernel): 198.848 tokens/s - Mistral 7B (GEMM kernel): 156.317 tokens/s - Mistral 7B (ExLlamaV2 kernel): 188.865 tokens/s +- Mixtral 46.7B (GEMM kernel): 93 tokens/s (2x 4090) ## Installation notes diff --git a/examples/benchmark.py b/examples/benchmark.py index dd47ea7e..ed71af45 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -116,6 +116,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si else: raise RuntimeError(ex) + total_memory_used = 0 if successful_generate: # number of tokens in context / time for processing context * batch size prefill_tokens_per_second = input_ids.shape[1] / context_time * batch_size @@ -127,6 +128,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si for device in range(torch.cuda.device_count()): memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) + total_memory_used += memory_used memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 print(f" ** Max Memory (device: {device}): {memory_used:.2f} GB ({memory_pct:.2f}%)") else: @@ -144,7 +146,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si "Decode Length": n_generate, "Prefill tokens/s": prefill_tokens_per_second, "Decode tokens/s": decode_tokens_per_second, - "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)" + "Memory (VRAM)": f"{total_memory_used:.2f} GB ({memory_pct:.2f}%)" }, version def main(args):