diff --git a/CHANGELOG.md b/CHANGELOG.md index 58888aae..176d095c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Allow running on Augusta cluster with existing train scripts. - Added `olmo_core.utils.logging_configured()` function to check if logging has been configured. +### Fixed + +- Fixed a potential distributed deadlock bug when training without a separate CPU-only bookkeeping backend. +- Removed some unnecessary host-device syncs in `olmo_core.distributed.utils`. +- Added `Trainer(Config).async_bookkeeping` field to toggle async bookkeeping. + ## [v1.6.0](https://github.com/allenai/OLMo-core/releases/tag/v1.6.0) - 2024-11-01 ### Added diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 2bd47e5b..1b0569d6 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -279,7 +279,7 @@ def synchronize_value( """ if dist.is_available() and dist.is_initialized(): is_tensor = isinstance(value, torch.Tensor) - value_tensor = value.to(device) if is_tensor else torch.tensor(value, device=device) # type: ignore + value_tensor = move_to_device(value, device) if is_tensor else move_to_device(torch.tensor(value), device) # type: ignore dist.broadcast(value_tensor, src, group=group) return value_tensor if is_tensor else value_tensor.item() # type: ignore else: @@ -303,7 +303,7 @@ def all_reduce_value( """ if dist.is_available() and dist.is_initialized(): is_tensor = isinstance(value, torch.Tensor) - value_tensor = value.to(device) if is_tensor else torch.tensor(value, device=device) # type: ignore + value_tensor = move_to_device(value, device) if is_tensor else move_to_device(torch.tensor(value), device) # type: ignore dist.all_reduce(value_tensor, op=op, group=group) return value_tensor if is_tensor else value_tensor.item() # type: ignore else: diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 64994f9b..792cf5cc 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Optional, cast from beaker import Beaker +from torch.distributed.device_mesh import DeviceMesh from olmo_core.config import Config, StrEnum from olmo_core.data import ( @@ -64,6 +65,22 @@ class CommonComponents(Config): callbacks: Dict[str, Callback] +class DPMeshType(StrEnum): + full = "full" + hybrid = "hybrid" + + +@dataclass +class DPMeshConfig(Config): + name: DPMeshType = DPMeshType.hybrid + + def build(self) -> Optional[DeviceMesh]: + if get_num_nodes() == 1 or self.name == DPMeshType.full: + return None + else: + return init_hybrid_shard_mesh() + + @dataclass class ExperimentConfig(Config): run_name: str @@ -73,6 +90,7 @@ class ExperimentConfig(Config): dataset: NumpyDatasetConfig data_loader: NumpyDataLoaderConfig trainer: TrainerConfig + dp_mesh: DPMeshConfig init_seed: int = 12536 @@ -260,6 +278,7 @@ def build_config( dataset=common.dataset, data_loader=common.data_loader, trainer=trainer, + dp_mesh=DPMeshConfig(), ) if finalize_config is not None: @@ -302,7 +321,7 @@ def train(config: ExperimentConfig): init_device="meta", device=get_default_device(), max_seq_len=config.dataset.sequence_length, - dp_mesh=None if get_num_nodes() == 1 else init_hybrid_shard_mesh(), + dp_mesh=config.dp_mesh.build(), ) optim = config.optim.build(model) dataset = config.dataset.build() diff --git a/src/olmo_core/train/callbacks/checkpointer.py b/src/olmo_core/train/callbacks/checkpointer.py index dc9506c0..cf2b36f8 100644 --- a/src/olmo_core/train/callbacks/checkpointer.py +++ b/src/olmo_core/train/callbacks/checkpointer.py @@ -75,9 +75,10 @@ class CheckpointerCallback(Callback): Save a pretrain checkpoint. Defaults to ``True`` unless the trainer resumes from a checkpoint. """ - save_async: bool = False + save_async: Optional[bool] = None """ - Save checkpoints asynchronously. Requires a backend that supports CPU. + Save checkpoints asynchronously. Requires a separate CPU-only backend. + Defaults to ``True`` if there is one. """ remove: CheckpointRemovalStrategy = CheckpointRemovalStrategy.ephemeral_only @@ -145,6 +146,9 @@ def _remove_checkpoint(self, path: str): self.trainer.thread_pool.submit(clear_directory, path) def pre_train(self): + if self.save_async is None: + self.save_async = backend_supports_cpu() + # Maybe create a new process group for async checkpointing. if is_distributed() and self.save_async and self.checkpointer.process_group is None: if not backend_supports_cpu(): diff --git a/src/olmo_core/train/config.py b/src/olmo_core/train/config.py index 7dfa119b..f1d1f868 100644 --- a/src/olmo_core/train/config.py +++ b/src/olmo_core/train/config.py @@ -47,6 +47,7 @@ class TrainerConfig(Config): compile_loss: bool = False z_loss_multiplier: Optional[float] = None autocast_precision: Optional[DType] = None + async_bookkeeping: Optional[bool] = None def add_callback(self, name: str, callback: Callback): """ diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 95fd310d..b246fe87 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -8,12 +8,14 @@ from pathlib import Path from typing import ( Any, + Callable, Dict, Generator, Literal, Optional, Tuple, TypedDict, + TypeVar, Union, cast, ) @@ -45,7 +47,7 @@ fused_cross_entropy_loss, ) from ..optim import SkipStepOptimizer -from ..utils import move_to_device +from ..utils import cuda_sync_debug_mode, move_to_device from .callbacks import ( Callback, CheckpointerCallback, @@ -65,6 +67,8 @@ OPTIM_STEP_SKIPPED_METRIC = "optim/step skipped" SEQ_LEN_METRIC = "data/sequence length" +T = TypeVar("T") + class TrainerStateDict(TypedDict): global_step: int @@ -254,6 +258,12 @@ class Trainer: not affect the learning rate schedule. """ + async_bookkeeping: Optional[bool] = None + """ + Do collective bookkeeping operations like reducing metrics asynchronously. + This requires a separate CPU-only backend, and will default to ``True`` if one is available. + """ + _metrics: Dict[int, Dict[str, torch.Tensor]] = field(default_factory=OrderedDict) _metrics_reduce_type: Dict[str, Optional[ReduceType]] = field(default_factory=dict) _canceled: bool = False @@ -314,18 +324,19 @@ def __post_init__(self): # Maybe create separate process group for bookkeeping. if self._bookkeeping_pg is None and is_distributed(): - if backend_supports_cpu(): - log.info("Creating new process group for bookkeeping") + if self.async_bookkeeping is None: + self.async_bookkeeping = backend_supports_cpu() + if self.async_bookkeeping: + if not backend_supports_cpu(): + raise OLMoConfigurationError( + "A CPU-only backend is required for async bookkeeping" + ) + log.info("Creating new process group for async bookkeeping") self._bookkeeping_pg = dist.new_group( ranks=None if self.dp_process_group is None else dist.get_process_group_ranks(self.dp_process_group) ) - else: - log.warning( - "No CPU backend configured, bookkeeping collectives will occur on the default " - "backend and will be blocking. This may result in slower training throughput." - ) # Check data loader configuration. if self.data_loader.dp_world_size != get_world_size(self.dp_process_group): @@ -393,7 +404,7 @@ def training_complete(self) -> bool: and self.global_step > 0 and self.global_step % self.cancel_check_interval == 0 ): - self.thread_pool.submit(self._check_if_canceled) + self.check_if_canceled() if self.is_canceled: return True @@ -472,7 +483,7 @@ def bookkeeping_device(self) -> torch.device: The device used for collective bookkeeping (non-training) operations that can potentially. use a different backend. """ - if backend_supports_cpu(): + if self.async_bookkeeping and backend_supports_cpu(): return torch.device("cpu") else: return self.device @@ -519,7 +530,7 @@ def check_if_canceled(self): Asynchronously check if the run is canceled. Use :data:`is_canceled` to see the result. This needs to be called by all ranks at the same point in the training loop. """ - self.thread_pool.submit(self._check_if_canceled) + self._run_bookkeeping_op(self._check_if_canceled) def fit(self): """ @@ -981,21 +992,51 @@ def _handle_os_signal(self, signalnum, stack_frame): log.warning(msg) self.cancel_run(msg) + def _run_bookkeeping_op( + self, op: Callable[..., T], *args, cb: Optional[Callable[[T], None]] = None, **kwargs + ): + if ( + self.async_bookkeeping + and self.bookkeeping_device.type == "cpu" + and self.bookkeeping_pg is not None + ): + # Can safely run in the thread pool. + future = self.thread_pool.submit(op, *args, **kwargs) + if cb is not None: + + def callback(fut: Future[T]): + try: + cb(fut.result()) # type: ignore[misc] + except BaseException as e: + log.exception(e) + self._error = e + + future.add_done_callback(callback) + else: + result = op(*args, **kwargs) + if cb is not None: + cb(result) + def _check_if_canceled(self): if self._canceled: return canceling_rank = self._canceling_rank if self._canceling_rank is not None else -1 - canceling_rank = all_reduce_value( - canceling_rank, self.bookkeeping_device, op=dist.ReduceOp.MAX, group=self.bookkeeping_pg - ) - if canceling_rank >= 0: - cancel_reason = scatter_object(self._cancel_reason, src=canceling_rank) - assert cancel_reason is not None - self._canceled = True - self._canceling_rank = canceling_rank - self._cancel_reason = cancel_reason - log.warning(f"Run canceled from rank {canceling_rank}. Reason: {cancel_reason}") + # NOTE: this is a known host-device sync (potentially) so we don't need the warning + with cuda_sync_debug_mode(0): + canceling_rank = all_reduce_value( + canceling_rank, + self.bookkeeping_device, + op=dist.ReduceOp.MAX, + group=self.bookkeeping_pg, + ) + if canceling_rank >= 0: + cancel_reason = scatter_object(self._cancel_reason, src=canceling_rank) + assert cancel_reason is not None + self._canceled = True + self._canceling_rank = canceling_rank + self._cancel_reason = cancel_reason + log.warning(f"Run canceled from rank {canceling_rank}. Reason: {cancel_reason}") def _log_metrics(self): if not self._metrics: @@ -1008,37 +1049,14 @@ def _log_metrics(self): # so CUDA training can continue. metrics_to_reduce = move_metrics(self._metrics, self.bookkeeping_device) self._metrics.clear() - - if self.bookkeeping_device.type == "cpu" and self.bookkeeping_pg is not None: - # If we have a separate CPU backend and process group we can safely reduce - # metrics on CPU in a thread. - future = self.thread_pool.submit( - reduce_metrics, - metrics_to_reduce, - self._metrics_reduce_type, - self.bookkeeping_device, - process_group=self.bookkeeping_pg, - ) - - def callback(fut): - try: - self._check_and_pass_on_metrics(fut.result()) - except BaseException as e: - log.exception(e) - self._error = e - - future.add_done_callback(callback) - else: - # Otherwise we have to reduce them now in the main thread. - # NOTE: if we're training on GPU and didn't have a host device sync above, this will - # trigger a host-device sync as we transfer the metrics back to CPU post-reducing. - metrics = reduce_metrics( - metrics_to_reduce, - self._metrics_reduce_type, - self.bookkeeping_device, - process_group=self.bookkeeping_pg, - ) - self._check_and_pass_on_metrics(metrics) + self._run_bookkeeping_op( + reduce_metrics, + metrics_to_reduce, + self._metrics_reduce_type, + self.bookkeeping_device, + process_group=self.bookkeeping_pg, + cb=self._check_and_pass_on_metrics, + ) def _check_and_pass_on_metrics(self, metrics: Dict[int, Dict[str, float]]): for step in sorted(metrics.keys()): diff --git a/src/olmo_core/train/utils.py b/src/olmo_core/train/utils.py index d6402e41..f164e095 100644 --- a/src/olmo_core/train/utils.py +++ b/src/olmo_core/train/utils.py @@ -128,11 +128,13 @@ def move_metrics( get_local_tensor(m) for step_metrics in source.values() for m in step_metrics.values() - if m.device != device + # NOTE: compare device type since 'torch.device("cuda") != torch.device("cuda:0")' + # even when both point to the same physical device. + if m.device.type != device.type ] metrics_to_move: Optional[torch.Tensor] = None if metrics_to_move_list: - # NOTE: this is a known host-device sync so we don't need the warning + # NOTE: this is a known host-device sync (potentially) so we don't need the warning with cuda_sync_debug_mode(0): metrics_to_move = torch.stack(metrics_to_move_list).to( device, non_blocking=non_blocking @@ -145,7 +147,7 @@ def move_metrics( for name, m in step_metrics.items(): if step not in target: target[step] = OrderedDict() - if metrics_to_move is not None and m.device != device: + if metrics_to_move is not None and m.device.type != device.type: target[step][name] = metrics_to_move[idx] idx += 1 else: