Skip to content

Commit

Permalink
Add on_optimizer_step to callback options (#31095)
Browse files Browse the repository at this point in the history
* Modified test

* Added on_optimizer_step to callbacks

* Move callback after step is called

* Added on optimizer step callback
  • Loading branch information
dhruvbpai authored May 29, 2024
1 parent 4af705c commit 5c88253
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,6 +2306,8 @@ def _inner_training_loop(

self.optimizer.step()

self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)

optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,12 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
"""
pass

def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
"""
pass

def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of an substep during gradient accumulation.
Expand Down Expand Up @@ -470,6 +476,9 @@ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: T
control.should_save = False
return self.call_event("on_step_begin", args, state, control)

def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_optimizer_step", args, state, control)

def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_substep_end", args, state, control)

Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def on_epoch_end(self, args, state, control, **kwargs):
def on_step_begin(self, args, state, control, **kwargs):
self.events.append("on_step_begin")

def on_optimizer_step(self, args, state, control, **kwargs):
self.events.append("on_optimizer_step")

def on_step_end(self, args, state, control, **kwargs):
self.events.append("on_step_end")

Expand Down Expand Up @@ -148,7 +151,7 @@ def get_expected_events(self, trainer):
expected_events.append("on_epoch_begin")
for _ in range(train_dl_len):
step += 1
expected_events += ["on_step_begin", "on_step_end"]
expected_events += ["on_step_begin", "on_optimizer_step", "on_step_end"]
if step % trainer.args.logging_steps == 0:
expected_events.append("on_log")
if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0:
Expand Down

0 comments on commit 5c88253

Please sign in to comment.