Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align logged attributes for errors and run metadata in kill_loss_spike_callback.py #1494

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def _detect_high_losses(self, current_step: int) -> bool:

return is_high_loss

def _log_metadata(self, logger: Logger, key: str, message: str) -> None:
def _log_metadata(self, logger: Logger, key: str, value: dict) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
key: message,
key: value,
'loss_window': list(self.loss_window),
})

Expand All @@ -122,22 +122,39 @@ def _handle_loss_spike(
logger: Logger,
running_loss_avg: float,
) -> None:
message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'loss_spike', message)
if not self.log_only:
if self.log_only:
self._log_metadata(
logger,
'loss_spike',
{
'outlier_multiplier': self.outlier_multiplier,
'running_loss_avg': running_loss_avg,
'outlier_counter': self.outlier_counter,
},
)
else:
raise LossSpikeError(
outlier_multiplier=self.outlier_multiplier,
running_loss_avg=round(running_loss_avg),
running_loss_avg=running_loss_avg,
outlier_counter=self.outlier_counter,
loss_window=list(self.loss_window),
)

def _handle_high_losses(self, logger: Logger) -> None:
message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'high_loss', message)
if not self.log_only:
if self.log_only:
self._log_metadata(
logger,
'high_loss',
{
'loss_cap': self.loss_cap,
'window_size': self.window_size,
},
)
else:
raise HighLossError(
loss_cap=self.loss_cap,
window_size=self.window_size,
loss_window=list(self.loss_window),
)

def _set_window_size(self, state: State) -> None:
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,20 @@ class LossSpikeError(UserError):
def __init__(
self,
outlier_multiplier: float,
running_loss_avg: int,
running_loss_avg: float,
outlier_counter: int,
loss_window: list[float],
) -> None:
message = f'Training stopped due to a loss spike. The training loss was more than {outlier_multiplier} times greater than \
the running average loss (approx. {running_loss_avg}) over {outlier_counter} consecutive training steps. \
the running average loss (approx. {round(running_loss_avg, 1)}) over {outlier_counter} consecutive training steps. \
Please try submitting the run again with a lower learning rate.'

super().__init__(
message,
outlier_multiplier=outlier_multiplier,
running_loss_avg=running_loss_avg,
outlier_counter=outlier_counter,
loss_window=loss_window,
)


Expand All @@ -417,6 +419,7 @@ def __init__(
self,
loss_cap: float,
window_size: int,
loss_window: list[float],
) -> None:
message = f'Training stopped due to consistently high losses. The training loss exceeded the threshold of {loss_cap} \
for more than half of the {window_size} most recent training steps. Please try submitting the run again with a lower learning rate.'
Expand All @@ -425,4 +428,5 @@ def __init__(
message,
loss_cap=loss_cap,
window_size=window_size,
loss_window=loss_window,
)
Loading