Skip to content

Commit

Permalink
chore: raise exception in HF trainer for mismatched train units [MD-4…
Browse files Browse the repository at this point in the history
…56] (#9669)
  • Loading branch information
azhou-determined authored Jul 18, 2024
1 parent e956f28 commit c36705b
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions harness/determined/transformers/_hf_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,28 +264,26 @@ def _wait_for_metrics(self, control: transformers.TrainerControl) -> None:
def _check_searcher_compatibility(self, args: transformers.TrainingArguments) -> None:
if self.searcher_unit == "batches":
if args.max_steps == -1:
self._log_config_mismatch("epochs", args.num_train_epochs)
self._raise_config_mismatch("epochs", args.num_train_epochs)
elif args.max_steps != self.searcher_max_length:
self._log_config_mismatch("batches", args.max_steps)
self._raise_config_mismatch("batches", args.max_steps)
elif self.searcher_unit == "epochs":
if args.max_steps != -1:
self._log_config_mismatch("batches", args.max_steps)
self._raise_config_mismatch("batches", args.max_steps)
elif args.num_train_epochs != self.searcher_max_length:
self._log_config_mismatch("epochs", args.num_train_epochs)
self._raise_config_mismatch("epochs", args.num_train_epochs)

def _log_config_mismatch(
def _raise_config_mismatch(
self,
trainer_units: str,
trainer_len: float,
) -> None:
logger.warning(
f"Searcher configuration does not match HF Trainer configuration. "
f"Searcher uses {self.searcher_unit}={self.searcher_max_length}, "
f"while HF Trainer uses {trainer_units}={trainer_len}. "
f"Continuing this run may cause Searcher not to behave correctly. "
f"Make sure to match the units between HF Trainer and Searcher: "
f"use (--num_train_epochs and searcher.max_length.epochs) OR "
f"(--max_steps and searcher.max_length.batches)."
raise ValueError(
f"HF trainer units {trainer_units}={trainer_len} MUST match searcher config "
f"{self.searcher_unit}={self.searcher_max_length}. "
f"Modify either --num_train_epochs for the training script or "
f"searcher.max_length.epochs in the experiment config so they are the same value "
f"(--max_steps and searcher.max_length.batches if using batches)."
)


Expand Down

0 comments on commit c36705b

Please sign in to comment.