Skip to content

Commit

Permalink
Fix condition when GA loss bug fix is not performed (#35651)
Browse files Browse the repository at this point in the history
* fix condition when GA loss bug fix is not performed

* max loss diff is 2.29

* fix typo

* add an extra validation that loss should not vary too much
  • Loading branch information
techkang authored and ArthurZucker committed Jan 20, 2025
1 parent 612bfd0 commit b00807f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
11 changes: 2 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3672,10 +3672,7 @@ def training_step(
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

del inputs
if (
Expand Down Expand Up @@ -3709,7 +3706,7 @@ def training_step(
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps

self.accelerator.backward(loss, **kwargs)
Expand Down Expand Up @@ -5157,10 +5154,6 @@ def get_batch_samples(self, epoch_iterator, num_batches):
except StopIteration:
break

# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None

if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
Expand Down
9 changes: 8 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,14 @@ def tokenize_function(examples):
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")

# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2")

loss_base = sum(base_loss_callback.losses)
loss_broken = sum(broken_loss_callback.losses)

# mean/sum loss should not vary too much.
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
self.assertLess(relative_diff, 0.1, f"Relative difference {relative_diff} is not within 0.1")

@slow
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
Expand Down

0 comments on commit b00807f

Please sign in to comment.