Skip to content

Commit

Permalink
Add checkpoint events to mosaicml logger
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed May 23, 2024
1 parent 57c7b72 commit 9f33e1f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
20 changes: 10 additions & 10 deletions composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,30 +92,30 @@ def __init__(
self._enabled = False

def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
self._log_metadata(hyperparameters)
self.log_metadata(hyperparameters)

def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
self._log_metadata(metrics)
self.log_metadata(metrics)

def log_exception(self, exception: Exception):
self._log_metadata({'exception': exception_to_json_serializable_dict(exception)})
self.log_metadata({'exception': exception_to_json_serializable_dict(exception)})
self._flush_metadata(force_flush=True)

def after_load(self, state: State, logger: Logger) -> None:
# Log model data downloaded and initialized for run events
log.debug(f'Logging model initialized time to metadata')
self._log_metadata({'model_initialized_time': time.time()})
self.log_metadata({'model_initialized_time': time.time()})
# Log WandB run URL if it exists. Must run on after_load as WandB is setup on event init
for callback in state.callbacks:
if isinstance(callback, WandBLogger):
run_url = callback.run_url
if run_url is not None:
self._log_metadata({'wandb/run_url': run_url})
self.log_metadata({'wandb/run_url': run_url})
log.debug(f'Logging WandB run URL to metadata: {run_url}')
else:
log.debug('WandB run URL not found, not logging to metadata')
if isinstance(callback, MLFlowLogger) and callback._enabled:
self._log_metadata({'mlflow/run_url': callback.run_url})
self.log_metadata({'mlflow/run_url': callback.run_url})
log.debug(f'Logging MLFlow run URL to metadata: {callback.run_url}')
self._flush_metadata(force_flush=True)

Expand All @@ -125,18 +125,18 @@ def batch_start(self, state: State, logger: Logger) -> None:

def batch_end(self, state: State, logger: Logger) -> None:
training_progress_data = self._get_training_progress_metrics(state)
self._log_metadata(training_progress_data)
self.log_metadata(training_progress_data)
self._flush_metadata()

def epoch_end(self, state: State, logger: Logger) -> None:
self._flush_metadata()

def fit_end(self, state: State, logger: Logger) -> None:
# Log model training finished time for run events
self._log_metadata({'train_finished_time': time.time()})
self.log_metadata({'train_finished_time': time.time()})
training_progress_data = self._get_training_progress_metrics(state)
log.debug(f'\nLogging FINAL training progress data to metadata:\n{dict_to_str(training_progress_data)}')
self._log_metadata(training_progress_data)
self.log_metadata(training_progress_data)
self._flush_metadata(force_flush=True)

def eval_end(self, state: State, logger: Logger) -> None:
Expand All @@ -150,7 +150,7 @@ def close(self, state: State, logger: Logger) -> None:
if self._enabled:
wait(self._futures) # Ignore raised errors on close

def _log_metadata(self, metadata: Dict[str, Any]) -> None:
def log_metadata(self, metadata: Dict[str, Any]) -> None:
"""Buffer metadata and prefix keys with mosaicml."""
if self._enabled:
for key, val in metadata.items():
Expand Down
9 changes: 7 additions & 2 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch

from composer.loggers.logger import Logger
from composer.loggers import Logger, MosaicMLLogger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (
GCSObjectStore,
Expand Down Expand Up @@ -343,7 +343,6 @@ def remote_backend(self) -> ObjectStore:
return self._remote_backend

def init(self, state: State, logger: Logger) -> None:
del logger # unused
if self._worker_flag is not None:
raise RuntimeError('The RemoteUploaderDownloader is already initialized.')
self._worker_flag = self._finished_cls()
Expand Down Expand Up @@ -386,6 +385,7 @@ def init(self, state: State, logger: Logger) -> None:
'num_attempts': self.num_attempts,
'completed_queue': self._completed_queue,
'exception_queue': self._exception_queue,
'logger': logger,
},
# The worker threads are joined in the shutdown procedure, so it is OK to set the daemon status
# Setting daemon status prevents the process from hanging if close was never called (e.g. in doctests)
Expand Down Expand Up @@ -654,6 +654,7 @@ def _upload_worker(
remote_backend_name: str,
backend_kwargs: Dict[str, Any],
num_attempts: int,
logger: Logger,
):
"""A long-running function to handle uploading files to the object store.
Expand Down Expand Up @@ -707,3 +708,7 @@ def upload_file(retry_index: int = 0):
time.sleep(local_rank * local_rank_stagger)

upload_file()

for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'checkpoint_uploaded_time': time.time()})

0 comments on commit 9f33e1f

Please sign in to comment.