Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checkpoint events #3468

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import shutil
import tempfile
import textwrap
import time
from pathlib import Path
from typing import Any, Callable, Optional, Union

from composer.core import Callback, Event, State, Time, Timestamp
from composer.loggers import Logger, MLFlowLogger
from composer.loggers import Logger, MLFlowLogger, MosaicMLLogger
from composer.utils import (
FORMAT_NAME_WITH_DIST_AND_TIME_TABLE,
FORMAT_NAME_WITH_DIST_TABLE,
Expand Down Expand Up @@ -619,8 +620,13 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False):
if dist.get_global_rank() == 0:
shutil.rmtree(prefix_dir)

def _log_checkpoint_upload(self, logger: Logger):
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True)

def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
del state # unused
if self.remote_uploader is None:
return
self.remote_uploader.check_workers()
Expand All @@ -643,13 +649,14 @@ def batch_end(self, state: State, logger: Logger) -> None:
file_path=local_symlink_file,
overwrite=True,
)
self._log_checkpoint_upload(logger)
break
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
self.symlink_upload_tasks = undone_symlink_upload_tasks

def fit_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
del state # unused
if self.remote_uploader is None:
return
log.info('Waiting for checkpoint uploading to finish')
Expand All @@ -666,6 +673,7 @@ def fit_end(self, state: State, logger: Logger) -> None:
overwrite=True,
)
symlink_upload_future.result()
self._log_checkpoint_upload(logger)
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
log.info('Checkpoint uploading finished!')
Expand Down
8 changes: 3 additions & 5 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch

from composer.loggers import Logger, MosaicMLLogger
from composer.loggers import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (
MLFlowObjectStore,
Expand Down Expand Up @@ -308,12 +308,13 @@ def remote_backend(self) -> ObjectStore:
return self._remote_backend

def init(self, state: State, logger: Logger) -> None:
del logger # unused

if self._worker_flag is not None:
raise RuntimeError('The RemoteUploaderDownloader is already initialized.')
self._worker_flag = self._finished_cls()
self._run_name = state.run_name
file_name_to_test = self._remote_file_name('.credentials_validated_successfully')
self._logger = logger

# Create the enqueue thread
self._enqueue_thread_flag = self._finished_cls()
Expand Down Expand Up @@ -426,9 +427,6 @@ def _enqueue_uploads(self):
break
self._enqueued_objects.remove(object_name)
self._completed_queue.task_done()
for destination in self._logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({'checkpoint_uploaded_time': time.time()}, force_flush=True)

# Enqueue all objects that are in self._logged_objects but not in self._file_upload_queue
objects_to_delete = []
Expand Down
Loading