Skip to content

Commit

Permalink
Some fixes/improvements around synchronous bookkeeping operations (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Nov 7, 2024
1 parent c435c94 commit 83db5f7
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 60 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow running on Augusta cluster with existing train scripts.
- Added `olmo_core.utils.logging_configured()` function to check if logging has been configured.

### Fixed

- Fixed a potential distributed deadlock bug when training without a separate CPU-only bookkeeping backend.
- Removed some unnecessary host-device syncs in `olmo_core.distributed.utils`.
- Added `Trainer(Config).async_bookkeeping` field to toggle async bookkeeping.

## [v1.6.0](https://github.com/allenai/OLMo-core/releases/tag/v1.6.0) - 2024-11-01

### Added
Expand Down
4 changes: 2 additions & 2 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def synchronize_value(
"""
if dist.is_available() and dist.is_initialized():
is_tensor = isinstance(value, torch.Tensor)
value_tensor = value.to(device) if is_tensor else torch.tensor(value, device=device) # type: ignore
value_tensor = move_to_device(value, device) if is_tensor else move_to_device(torch.tensor(value), device) # type: ignore
dist.broadcast(value_tensor, src, group=group)
return value_tensor if is_tensor else value_tensor.item() # type: ignore
else:
Expand All @@ -303,7 +303,7 @@ def all_reduce_value(
"""
if dist.is_available() and dist.is_initialized():
is_tensor = isinstance(value, torch.Tensor)
value_tensor = value.to(device) if is_tensor else torch.tensor(value, device=device) # type: ignore
value_tensor = move_to_device(value, device) if is_tensor else move_to_device(torch.tensor(value), device) # type: ignore
dist.all_reduce(value_tensor, op=op, group=group)
return value_tensor if is_tensor else value_tensor.item() # type: ignore
else:
Expand Down
21 changes: 20 additions & 1 deletion src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Dict, List, Optional, cast

from beaker import Beaker
from torch.distributed.device_mesh import DeviceMesh

from olmo_core.config import Config, StrEnum
from olmo_core.data import (
Expand Down Expand Up @@ -64,6 +65,22 @@ class CommonComponents(Config):
callbacks: Dict[str, Callback]


class DPMeshType(StrEnum):
full = "full"
hybrid = "hybrid"


@dataclass
class DPMeshConfig(Config):
name: DPMeshType = DPMeshType.hybrid

def build(self) -> Optional[DeviceMesh]:
if get_num_nodes() == 1 or self.name == DPMeshType.full:
return None
else:
return init_hybrid_shard_mesh()


@dataclass
class ExperimentConfig(Config):
run_name: str
Expand All @@ -73,6 +90,7 @@ class ExperimentConfig(Config):
dataset: NumpyDatasetConfig
data_loader: NumpyDataLoaderConfig
trainer: TrainerConfig
dp_mesh: DPMeshConfig
init_seed: int = 12536


Expand Down Expand Up @@ -260,6 +278,7 @@ def build_config(
dataset=common.dataset,
data_loader=common.data_loader,
trainer=trainer,
dp_mesh=DPMeshConfig(),
)

if finalize_config is not None:
Expand Down Expand Up @@ -302,7 +321,7 @@ def train(config: ExperimentConfig):
init_device="meta",
device=get_default_device(),
max_seq_len=config.dataset.sequence_length,
dp_mesh=None if get_num_nodes() == 1 else init_hybrid_shard_mesh(),
dp_mesh=config.dp_mesh.build(),
)
optim = config.optim.build(model)
dataset = config.dataset.build()
Expand Down
8 changes: 6 additions & 2 deletions src/olmo_core/train/callbacks/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ class CheckpointerCallback(Callback):
Save a pretrain checkpoint. Defaults to ``True`` unless the trainer resumes from a checkpoint.
"""

save_async: bool = False
save_async: Optional[bool] = None
"""
Save checkpoints asynchronously. Requires a backend that supports CPU.
Save checkpoints asynchronously. Requires a separate CPU-only backend.
Defaults to ``True`` if there is one.
"""

remove: CheckpointRemovalStrategy = CheckpointRemovalStrategy.ephemeral_only
Expand Down Expand Up @@ -145,6 +146,9 @@ def _remove_checkpoint(self, path: str):
self.trainer.thread_pool.submit(clear_directory, path)

def pre_train(self):
if self.save_async is None:
self.save_async = backend_supports_cpu()

# Maybe create a new process group for async checkpointing.
if is_distributed() and self.save_async and self.checkpointer.process_group is None:
if not backend_supports_cpu():
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TrainerConfig(Config):
compile_loss: bool = False
z_loss_multiplier: Optional[float] = None
autocast_precision: Optional[DType] = None
async_bookkeeping: Optional[bool] = None

def add_callback(self, name: str, callback: Callback):
"""
Expand Down
122 changes: 70 additions & 52 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Generator,
Literal,
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
cast,
)
Expand Down Expand Up @@ -45,7 +47,7 @@
fused_cross_entropy_loss,
)
from ..optim import SkipStepOptimizer
from ..utils import move_to_device
from ..utils import cuda_sync_debug_mode, move_to_device
from .callbacks import (
Callback,
CheckpointerCallback,
Expand All @@ -65,6 +67,8 @@
OPTIM_STEP_SKIPPED_METRIC = "optim/step skipped"
SEQ_LEN_METRIC = "data/sequence length"

T = TypeVar("T")


class TrainerStateDict(TypedDict):
global_step: int
Expand Down Expand Up @@ -254,6 +258,12 @@ class Trainer:
not affect the learning rate schedule.
"""

async_bookkeeping: Optional[bool] = None
"""
Do collective bookkeeping operations like reducing metrics asynchronously.
This requires a separate CPU-only backend, and will default to ``True`` if one is available.
"""

_metrics: Dict[int, Dict[str, torch.Tensor]] = field(default_factory=OrderedDict)
_metrics_reduce_type: Dict[str, Optional[ReduceType]] = field(default_factory=dict)
_canceled: bool = False
Expand Down Expand Up @@ -314,18 +324,19 @@ def __post_init__(self):

# Maybe create separate process group for bookkeeping.
if self._bookkeeping_pg is None and is_distributed():
if backend_supports_cpu():
log.info("Creating new process group for bookkeeping")
if self.async_bookkeeping is None:
self.async_bookkeeping = backend_supports_cpu()
if self.async_bookkeeping:
if not backend_supports_cpu():
raise OLMoConfigurationError(
"A CPU-only backend is required for async bookkeeping"
)
log.info("Creating new process group for async bookkeeping")
self._bookkeeping_pg = dist.new_group(
ranks=None
if self.dp_process_group is None
else dist.get_process_group_ranks(self.dp_process_group)
)
else:
log.warning(
"No CPU backend configured, bookkeeping collectives will occur on the default "
"backend and will be blocking. This may result in slower training throughput."
)

# Check data loader configuration.
if self.data_loader.dp_world_size != get_world_size(self.dp_process_group):
Expand Down Expand Up @@ -393,7 +404,7 @@ def training_complete(self) -> bool:
and self.global_step > 0
and self.global_step % self.cancel_check_interval == 0
):
self.thread_pool.submit(self._check_if_canceled)
self.check_if_canceled()

if self.is_canceled:
return True
Expand Down Expand Up @@ -472,7 +483,7 @@ def bookkeeping_device(self) -> torch.device:
The device used for collective bookkeeping (non-training) operations that can potentially.
use a different backend.
"""
if backend_supports_cpu():
if self.async_bookkeeping and backend_supports_cpu():
return torch.device("cpu")
else:
return self.device
Expand Down Expand Up @@ -519,7 +530,7 @@ def check_if_canceled(self):
Asynchronously check if the run is canceled. Use :data:`is_canceled` to see the result.
This needs to be called by all ranks at the same point in the training loop.
"""
self.thread_pool.submit(self._check_if_canceled)
self._run_bookkeeping_op(self._check_if_canceled)

def fit(self):
"""
Expand Down Expand Up @@ -981,21 +992,51 @@ def _handle_os_signal(self, signalnum, stack_frame):
log.warning(msg)
self.cancel_run(msg)

def _run_bookkeeping_op(
self, op: Callable[..., T], *args, cb: Optional[Callable[[T], None]] = None, **kwargs
):
if (
self.async_bookkeeping
and self.bookkeeping_device.type == "cpu"
and self.bookkeeping_pg is not None
):
# Can safely run in the thread pool.
future = self.thread_pool.submit(op, *args, **kwargs)
if cb is not None:

def callback(fut: Future[T]):
try:
cb(fut.result()) # type: ignore[misc]
except BaseException as e:
log.exception(e)
self._error = e

future.add_done_callback(callback)
else:
result = op(*args, **kwargs)
if cb is not None:
cb(result)

def _check_if_canceled(self):
if self._canceled:
return

canceling_rank = self._canceling_rank if self._canceling_rank is not None else -1
canceling_rank = all_reduce_value(
canceling_rank, self.bookkeeping_device, op=dist.ReduceOp.MAX, group=self.bookkeeping_pg
)
if canceling_rank >= 0:
cancel_reason = scatter_object(self._cancel_reason, src=canceling_rank)
assert cancel_reason is not None
self._canceled = True
self._canceling_rank = canceling_rank
self._cancel_reason = cancel_reason
log.warning(f"Run canceled from rank {canceling_rank}. Reason: {cancel_reason}")
# NOTE: this is a known host-device sync (potentially) so we don't need the warning
with cuda_sync_debug_mode(0):
canceling_rank = all_reduce_value(
canceling_rank,
self.bookkeeping_device,
op=dist.ReduceOp.MAX,
group=self.bookkeeping_pg,
)
if canceling_rank >= 0:
cancel_reason = scatter_object(self._cancel_reason, src=canceling_rank)
assert cancel_reason is not None
self._canceled = True
self._canceling_rank = canceling_rank
self._cancel_reason = cancel_reason
log.warning(f"Run canceled from rank {canceling_rank}. Reason: {cancel_reason}")

def _log_metrics(self):
if not self._metrics:
Expand All @@ -1008,37 +1049,14 @@ def _log_metrics(self):
# so CUDA training can continue.
metrics_to_reduce = move_metrics(self._metrics, self.bookkeeping_device)
self._metrics.clear()

if self.bookkeeping_device.type == "cpu" and self.bookkeeping_pg is not None:
# If we have a separate CPU backend and process group we can safely reduce
# metrics on CPU in a thread.
future = self.thread_pool.submit(
reduce_metrics,
metrics_to_reduce,
self._metrics_reduce_type,
self.bookkeeping_device,
process_group=self.bookkeeping_pg,
)

def callback(fut):
try:
self._check_and_pass_on_metrics(fut.result())
except BaseException as e:
log.exception(e)
self._error = e

future.add_done_callback(callback)
else:
# Otherwise we have to reduce them now in the main thread.
# NOTE: if we're training on GPU and didn't have a host device sync above, this will
# trigger a host-device sync as we transfer the metrics back to CPU post-reducing.
metrics = reduce_metrics(
metrics_to_reduce,
self._metrics_reduce_type,
self.bookkeeping_device,
process_group=self.bookkeeping_pg,
)
self._check_and_pass_on_metrics(metrics)
self._run_bookkeeping_op(
reduce_metrics,
metrics_to_reduce,
self._metrics_reduce_type,
self.bookkeeping_device,
process_group=self.bookkeeping_pg,
cb=self._check_and_pass_on_metrics,
)

def _check_and_pass_on_metrics(self, metrics: Dict[int, Dict[str, float]]):
for step in sorted(metrics.keys()):
Expand Down
8 changes: 5 additions & 3 deletions src/olmo_core/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ def move_metrics(
get_local_tensor(m)
for step_metrics in source.values()
for m in step_metrics.values()
if m.device != device
# NOTE: compare device type since 'torch.device("cuda") != torch.device("cuda:0")'
# even when both point to the same physical device.
if m.device.type != device.type
]
metrics_to_move: Optional[torch.Tensor] = None
if metrics_to_move_list:
# NOTE: this is a known host-device sync so we don't need the warning
# NOTE: this is a known host-device sync (potentially) so we don't need the warning
with cuda_sync_debug_mode(0):
metrics_to_move = torch.stack(metrics_to_move_list).to(
device, non_blocking=non_blocking
Expand All @@ -145,7 +147,7 @@ def move_metrics(
for name, m in step_metrics.items():
if step not in target:
target[step] = OrderedDict()
if metrics_to_move is not None and m.device != device:
if metrics_to_move is not None and m.device.type != device.type:
target[step][name] = metrics_to_move[idx]
idx += 1
else:
Expand Down

0 comments on commit 83db5f7

Please sign in to comment.