From c86ebc86a07c12f73b0f2be8873a526eeb275770 Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Wed, 10 Jul 2024 21:09:49 +0000 Subject: [PATCH] Fix checkpoint events --- composer/callbacks/checkpoint_saver.py | 10 +++++++++- composer/loggers/remote_uploader_downloader.py | 8 +++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index 29468e66c36..c04cfda9e17 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -11,11 +11,12 @@ import shutil import tempfile import textwrap +import time from pathlib import Path from typing import Any, Callable, Optional, Union from composer.core import Callback, Event, State, Time, Timestamp -from composer.loggers import Logger, MLFlowLogger +from composer.loggers import Logger, MLFlowLogger, MosaicMLLogger from composer.utils import ( FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, @@ -619,6 +620,11 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False): if dist.get_global_rank() == 0: shutil.rmtree(prefix_dir) + def _log_checkpoint_upload(self, logger: Logger): + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True) + def batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused if self.remote_uploader is None: @@ -643,6 +649,7 @@ def batch_end(self, state: State, logger: Logger) -> None: file_path=local_symlink_file, overwrite=True, ) + self._log_checkpoint_upload(logger) break else: raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}') @@ -666,6 +673,7 @@ def fit_end(self, state: State, logger: Logger) -> None: overwrite=True, ) symlink_upload_future.result() + self._log_checkpoint_upload(logger) else: raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}') log.info('Checkpoint uploading finished!') diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 9378d5a8d43..a143ac14213 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -22,7 +22,7 @@ import torch -from composer.loggers import Logger, MosaicMLLogger +from composer.loggers import Logger from composer.loggers.logger_destination import LoggerDestination from composer.utils import ( MLFlowObjectStore, @@ -308,12 +308,13 @@ def remote_backend(self) -> ObjectStore: return self._remote_backend def init(self, state: State, logger: Logger) -> None: + del logger # unused + if self._worker_flag is not None: raise RuntimeError('The RemoteUploaderDownloader is already initialized.') self._worker_flag = self._finished_cls() self._run_name = state.run_name file_name_to_test = self._remote_file_name('.credentials_validated_successfully') - self._logger = logger # Create the enqueue thread self._enqueue_thread_flag = self._finished_cls() @@ -426,9 +427,6 @@ def _enqueue_uploads(self): break self._enqueued_objects.remove(object_name) self._completed_queue.task_done() - for destination in self._logger.destinations: - if isinstance(destination, MosaicMLLogger): - destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True) # Enqueue all objects that are in self._logged_objects but not in self._file_upload_queue objects_to_delete = []