Skip to content

Commit

Permalink
Fix fp8 benchmark on single GPU (#3032)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Aug 22, 2024
1 parent ad3f574 commit c0cf860
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion benchmarks/fp8/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def evaluate_model(model, dataloader, metric, accelerator=None):
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
references = batch["labels"]
if accelerator is not None and accelerator.num_processes > 1:
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
predictions, references = accelerator.gather_for_metrics((predictions, references))
metric.add_batch(predictions=predictions, references=references)
return metric.compute()

0 comments on commit c0cf860

Please sign in to comment.