From c0cf860dc67af9d7b8e02104202c9c38e75c35ee Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 22 Aug 2024 16:54:32 -0400 Subject: [PATCH] Fix fp8 benchmark on single GPU (#3032) --- benchmarks/fp8/fp8_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/fp8/fp8_utils.py b/benchmarks/fp8/fp8_utils.py index e8b3b2eb8a2..d28702e05ff 100644 --- a/benchmarks/fp8/fp8_utils.py +++ b/benchmarks/fp8/fp8_utils.py @@ -109,7 +109,8 @@ def evaluate_model(model, dataloader, metric, accelerator=None): with torch.no_grad(): outputs = model(**batch) predictions = outputs.logits.argmax(dim=-1) + references = batch["labels"] if accelerator is not None and accelerator.num_processes > 1: - predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"])) + predictions, references = accelerator.gather_for_metrics((predictions, references)) metric.add_batch(predictions=predictions, references=references) return metric.compute()