From a5dc1555da1a1e9c7c4b707d2a66e8c244d614c6 Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Mon, 15 Jul 2024 10:27:38 -0400 Subject: [PATCH] Fix checkpoint events (#3468) --- composer/callbacks/checkpoint_saver.py | 14 +++++++++++--- composer/loggers/remote_uploader_downloader.py | 8 +++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index 29468e66c3..661b3046ba 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -11,11 +11,12 @@ import shutil import tempfile import textwrap +import time from pathlib import Path from typing import Any, Callable, Optional, Union from composer.core import Callback, Event, State, Time, Timestamp -from composer.loggers import Logger, MLFlowLogger +from composer.loggers import Logger, MLFlowLogger, MosaicMLLogger from composer.utils import ( FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, @@ -619,8 +620,13 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False): if dist.get_global_rank() == 0: shutil.rmtree(prefix_dir) + def _log_checkpoint_upload(self, logger: Logger): + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True) + def batch_end(self, state: State, logger: Logger) -> None: - del state, logger # unused + del state # unused if self.remote_uploader is None: return self.remote_uploader.check_workers() @@ -643,13 +649,14 @@ def batch_end(self, state: State, logger: Logger) -> None: file_path=local_symlink_file, overwrite=True, ) + self._log_checkpoint_upload(logger) break else: raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}') self.symlink_upload_tasks = undone_symlink_upload_tasks def fit_end(self, state: State, logger: Logger) -> None: - del state, logger # unused + del state # unused if self.remote_uploader is None: return log.info('Waiting for checkpoint uploading to finish') @@ -666,6 +673,7 @@ def fit_end(self, state: State, logger: Logger) -> None: overwrite=True, ) symlink_upload_future.result() + self._log_checkpoint_upload(logger) else: raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}') log.info('Checkpoint uploading finished!') diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 9378d5a8d4..a143ac1421 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -22,7 +22,7 @@ import torch -from composer.loggers import Logger, MosaicMLLogger +from composer.loggers import Logger from composer.loggers.logger_destination import LoggerDestination from composer.utils import ( MLFlowObjectStore, @@ -308,12 +308,13 @@ def remote_backend(self) -> ObjectStore: return self._remote_backend def init(self, state: State, logger: Logger) -> None: + del logger # unused + if self._worker_flag is not None: raise RuntimeError('The RemoteUploaderDownloader is already initialized.') self._worker_flag = self._finished_cls() self._run_name = state.run_name file_name_to_test = self._remote_file_name('.credentials_validated_successfully') - self._logger = logger # Create the enqueue thread self._enqueue_thread_flag = self._finished_cls() @@ -426,9 +427,6 @@ def _enqueue_uploads(self): break self._enqueued_objects.remove(object_name) self._completed_queue.task_done() - for destination in self._logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True) # Enqueue all objects that are in self._logged_objects but not in self._file_upload_queue objects_to_delete = []