Skip to content

Commit

Permalink
Fix checkpoint events (#3468)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored and mvpatel2000 committed Jul 21, 2024
1 parent 14bc187 commit a5dc155
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
14 changes: 11 additions & 3 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import shutil
import tempfile
import textwrap
import time
from pathlib import Path
from typing import Any, Callable, Optional, Union

from composer.core import Callback, Event, State, Time, Timestamp
from composer.loggers import Logger, MLFlowLogger
from composer.loggers import Logger, MLFlowLogger, MosaicMLLogger
from composer.utils import (
FORMAT_NAME_WITH_DIST_AND_TIME_TABLE,
FORMAT_NAME_WITH_DIST_TABLE,
Expand Down Expand Up @@ -619,8 +620,13 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False):
if dist.get_global_rank() == 0:
shutil.rmtree(prefix_dir)

def _log_checkpoint_upload(self, logger: Logger):
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True)

def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
del state # unused
if self.remote_uploader is None:
return
self.remote_uploader.check_workers()
Expand All @@ -643,13 +649,14 @@ def batch_end(self, state: State, logger: Logger) -> None:
file_path=local_symlink_file,
overwrite=True,
)
self._log_checkpoint_upload(logger)
break
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
self.symlink_upload_tasks = undone_symlink_upload_tasks

def fit_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
del state # unused
if self.remote_uploader is None:
return
log.info('Waiting for checkpoint uploading to finish')
Expand All @@ -666,6 +673,7 @@ def fit_end(self, state: State, logger: Logger) -> None:
overwrite=True,
)
symlink_upload_future.result()
self._log_checkpoint_upload(logger)
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
log.info('Checkpoint uploading finished!')
Expand Down
8 changes: 3 additions & 5 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch

from composer.loggers import Logger, MosaicMLLogger
from composer.loggers import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (
MLFlowObjectStore,
Expand Down Expand Up @@ -308,12 +308,13 @@ def remote_backend(self) -> ObjectStore:
return self._remote_backend

def init(self, state: State, logger: Logger) -> None:
del logger # unused

if self._worker_flag is not None:
raise RuntimeError('The RemoteUploaderDownloader is already initialized.')
self._worker_flag = self._finished_cls()
self._run_name = state.run_name
file_name_to_test = self._remote_file_name('.credentials_validated_successfully')
self._logger = logger

# Create the enqueue thread
self._enqueue_thread_flag = self._finished_cls()
Expand Down Expand Up @@ -426,9 +427,6 @@ def _enqueue_uploads(self):
break
self._enqueued_objects.remove(object_name)
self._completed_queue.task_done()
for destination in self._logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True)

# Enqueue all objects that are in self._logged_objects but not in self._file_upload_queue
objects_to_delete = []
Expand Down

0 comments on commit a5dc155

Please sign in to comment.