Skip to content

Commit

Permalink
Error with checking args.eval_accumulation_steps to gather tensors (h…
Browse files Browse the repository at this point in the history
…uggingface#25819)

* Update trainer.py (error with checking steps in args.eval_accumulation_steps to gather tensors)

While the deprecated code has the correct check (line 3772): 
"if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:"

The current code does not (line 3196):
"if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients:"

We need to check "(step + 1) % args.eval_accumulation_steps == 0". Hence, the line 3196 should be modified to:
"if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 and self.accelerator.sync_gradients:"

* Fix error with checking args.eval_accumulation_steps to gather tensors
  • Loading branch information
chaumng authored and parambharat committed Sep 26, 2023
1 parent f6292cb commit 7a476d3
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3193,7 +3193,11 @@ def evaluation_loop(
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients:
if (
args.eval_accumulation_steps is not None
and (step + 1) % args.eval_accumulation_steps == 0
and self.accelerator.sync_gradients
):
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
Expand Down

0 comments on commit 7a476d3

Please sign in to comment.