diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 19039af0f1..6d91237bc8 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -19,14 +19,17 @@ from llmfoundry.callbacks.eval_output_logging_callback import EvalOutputLogging from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer -from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import \ - MegaBlocksMoE_TokPerExpert -from llmfoundry.callbacks.monolithic_ckpt_callback import \ - MonolithicCheckpointSaver +from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import ( + MegaBlocksMoE_TokPerExpert, +) +from llmfoundry.callbacks.monolithic_ckpt_callback import ( + MonolithicCheckpointSaver, +) from llmfoundry.callbacks.resumption_callbacks import ( GlobalLRScaling, LayerFreezing, ) +from llmfoundry.callbacks.run_timeout_callback import RunTimeoutCallback from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector from llmfoundry.registry import callbacks, callbacks_with_config @@ -47,6 +50,7 @@ callbacks.register('oom_observer', func=OOMObserver) callbacks.register('eval_output_logging', func=EvalOutputLogging) callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert) +callbacks.register('run_timeout', func=RunTimeoutCallback) callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py new file mode 100644 index 0000000000..eb8051240d --- /dev/null +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -0,0 +1,58 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import signal +import threading +from typing import Optional + +from composer import Callback, Logger, State +from composer.loggers import MosaicMLLogger + +from llmfoundry.utils.exceptions import RunTimeoutError + +log = logging.getLogger(__name__) + + +def _timeout(timeout: int, mosaicml_logger: Optional[MosaicMLLogger] = None): + log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',) + if mosaicml_logger is not None: + mosaicml_logger.log_exception(RunTimeoutError(timeout=timeout)) + os.kill(os.getpid(), signal.SIGINT) + + +class RunTimeoutCallback(Callback): + + def __init__( + self, + timeout: int = 1800, + ): + self.timeout = timeout + self.mosaicml_logger: Optional[MosaicMLLogger] = None + self.timer: Optional[threading.Timer] = None + + def init(self, state: State, logger: Logger): + for callback in state.callbacks: + if isinstance(callback, MosaicMLLogger): + self.mosaicml_logger = callback + + def _reset(self): + if self.timer is not None: + self.timer.cancel() + self.timer = None + + def _timeout(self): + self._reset() + self.timer = threading.Timer( + self.timeout, + _timeout, + [self.timeout, self.mosaicml_logger], + ) + self.timer.daemon = True + self.timer.start() + + def fit_end(self, state: State, logger: Logger): + del state + del logger + self._timeout() diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 8e9e46a1cf..6921297866 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -246,3 +246,12 @@ def __init__(self, dataset_name: str, split: str) -> None: self.split = split message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. Please check your dataset format and make sure you can load your dataset locally.' super().__init__(message) + + +class RunTimeoutError(RuntimeError): + """Error thrown when a run times out.""" + + def __init__(self, timeout: int) -> None: + self.timeout = timeout + message = f'Run timed out after {timeout} seconds.' + super().__init__(message)