Skip to content

Commit

Permalink
Fix checkpoint events
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Jul 15, 2024
1 parent 91519e8 commit c86ebc8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
10 changes: 9 additions & 1 deletion composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import shutil
import tempfile
import textwrap
import time
from pathlib import Path
from typing import Any, Callable, Optional, Union

from composer.core import Callback, Event, State, Time, Timestamp
from composer.loggers import Logger, MLFlowLogger
from composer.loggers import Logger, MLFlowLogger, MosaicMLLogger
from composer.utils import (
FORMAT_NAME_WITH_DIST_AND_TIME_TABLE,
FORMAT_NAME_WITH_DIST_TABLE,
Expand Down Expand Up @@ -619,6 +620,11 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False):
if dist.get_global_rank() == 0:
shutil.rmtree(prefix_dir)

def _log_checkpoint_upload(self, logger: Logger):
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True)

def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.remote_uploader is None:
Expand All @@ -643,6 +649,7 @@ def batch_end(self, state: State, logger: Logger) -> None:
file_path=local_symlink_file,
overwrite=True,
)
self._log_checkpoint_upload(logger)
break
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
Expand All @@ -666,6 +673,7 @@ def fit_end(self, state: State, logger: Logger) -> None:
overwrite=True,
)
symlink_upload_future.result()
self._log_checkpoint_upload(logger)
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
log.info('Checkpoint uploading finished!')
Expand Down
8 changes: 3 additions & 5 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch

from composer.loggers import Logger, MosaicMLLogger
from composer.loggers import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (
MLFlowObjectStore,
Expand Down Expand Up @@ -308,12 +308,13 @@ def remote_backend(self) -> ObjectStore:
return self._remote_backend

def init(self, state: State, logger: Logger) -> None:
del logger # unused

if self._worker_flag is not None:
raise RuntimeError('The RemoteUploaderDownloader is already initialized.')
self._worker_flag = self._finished_cls()
self._run_name = state.run_name
file_name_to_test = self._remote_file_name('.credentials_validated_successfully')
self._logger = logger

# Create the enqueue thread
self._enqueue_thread_flag = self._finished_cls()
Expand Down Expand Up @@ -426,9 +427,6 @@ def _enqueue_uploads(self):
break
self._enqueued_objects.remove(object_name)
self._completed_queue.task_done()
for destination in self._logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True)

# Enqueue all objects that are in self._logged_objects but not in self._file_upload_queue
objects_to_delete = []
Expand Down

0 comments on commit c86ebc8

Please sign in to comment.