diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 7e5e7a137..6aace690f 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -52,6 +52,7 @@ from .exceptions import OLMoCheckpointError from .optim import Optimizer, fix_optim_state_dict from .safetensors_util import safetensors_file_to_state_dict +from .torch_util import SingleAccelerator as SINGLE from .torch_util import ( barrier, gc_cuda, @@ -645,7 +646,7 @@ def save_checkpoint( self._write_optim_dict( optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite ) - elif isinstance(dist_model, DDP): + elif isinstance(dist_model, DDP) or isinstance(dist_model, SINGLE): # _write_model_dict and _write_optim_dict only write checkpoints for rank 0 # First, get the model state dict from DDP wrapped model model_state_dict = dist_model.module.state_dict() @@ -660,7 +661,7 @@ def save_checkpoint( ) else: log.info( - "`FullCheckpointer.save_checkpoint` only supported for FSDP and DDP distributed strategies!" + "`FullCheckpointer.save_checkpoint` only supported for FSDP, DDP, and SINGLE distributed strategies!" ) # Save trainer state. @@ -757,7 +758,7 @@ def restore_checkpoint( torch.cuda.empty_cache() barrier() del optim_state_dict_to_load - elif isinstance(dist_model, DDP): + elif isinstance(dist_model, DDP) or isinstance(dist_model, SINGLE): # Load model state. with torch.no_grad(): state_dict_to_load = load_state_dict( @@ -773,11 +774,12 @@ def restore_checkpoint( optim.load_state_dict(optim_state_dict_to_load) gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() barrier() else: raise NotImplementedError( - "`FullCheckpointer.restore_checkpoint` only supported for FSDP and DDP distributed strategies!" + "`FullCheckpointer.restore_checkpoint` only supported for FSDP, DDP, and SINGLE distributed strategies!" ) # Load other state. diff --git a/olmo/config.py b/olmo/config.py index 4e197a341..91f744745 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -719,6 +719,10 @@ class DistributedStrategy(StrEnum): Wrap OLMo in torch.distributed.fsdp.FullyShardedDataParallel to train across ranks. """ + single = "single" + """ + Train on a single device, i.e., do not distribute trainig. For development and debugging. + """ class DDPGradSyncMode(StrEnum): batch = "batch" diff --git a/olmo/torch_util.py b/olmo/torch_util.py index 0aa52961e..b6f3e5bd3 100644 --- a/olmo/torch_util.py +++ b/olmo/torch_util.py @@ -156,3 +156,11 @@ def get_cumulative_document_lengths(doc_lens: torch.Tensor) -> torch.Tensor: torch.cumsum(doc_lens.masked_select(doc_lens != 0), 0, dtype=torch.int32), ] ) + +class SingleAccelerator(torch.nn.Module): + process_group = None + def __init__(self, module: torch.nn.Module): + super().__init__() + self.module = module + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/olmo/train.py b/olmo/train.py index 105f82e40..a7b5426ae 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -332,7 +332,8 @@ def trainer_state_dict(self) -> Dict[str, Any]: "python": random.getstate(), "numpy": np.random.get_state(), "torch": torch.random.get_rng_state(), - "cuda": torch.cuda.get_rng_state(), + "cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None, + "mps": torch.mps.get_rng_state() if torch.mps.is_available() else None, }, } @@ -430,7 +431,10 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None: random.setstate(rng_state["python"]) np.random.set_state(rng_state["numpy"]) torch.set_rng_state(rng_state["torch"]) - torch.cuda.set_rng_state(rng_state["cuda"]) + if rng_state.get("cuda", None) is not None: + torch.cuda.set_rng_state(rng_state["cuda"]) + if rng_state.get("mps", None) is not None: + torch.mps.set_rng_state(rng_state["mps"]) def _save_checkpoint( self, checkpointer: Checkpointer, checkpoint_type: CheckpointType @@ -1247,7 +1251,7 @@ def on_trace_ready(p): stop_at = min(stop_at, self.global_step + extra_steps) # Maybe save sharded checkpoint. - if self.cfg.distributed_strategy != DistributedStrategy.ddp: + if self.cfg.distributed_strategy == DistributedStrategy.fsdp: if save_checkpoints and ( cancel_initiated or ( diff --git a/scripts/train.py b/scripts/train.py index b4d89be2d..cc2ad4c0b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -27,6 +27,7 @@ from olmo.exceptions import OLMoCliError, OLMoConfigurationError from olmo.model import OLMo from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler +from olmo.torch_util import SingleAccelerator as SINGLE from olmo.torch_util import ( barrier, get_default_device, @@ -65,9 +66,14 @@ def main(cfg: TrainConfig) -> None: barrier() # Set CUDA device. - torch.cuda.set_device(f"cuda:{get_local_rank()}") - torch.cuda.empty_cache() - device = torch.device("cuda") + if torch.cuda.is_available(): + torch.cuda.set_device(f"cuda:{get_local_rank()}") + torch.cuda.empty_cache() + device = torch.device("cuda") + elif torch.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") # Fill some configuration options. cfg.model.precision = cfg.precision @@ -211,8 +217,9 @@ def dummy_init_fn(module: torch.nn.Module) -> None: param_init_fn=param_init_fn, **hybrid_sharding_fsdp_kwargs, ) - elif cfg.distributed_strategy is None: - raise NotImplementedError("Single accelerator training not implemented yet!") + elif cfg.distributed_strategy == DistributedStrategy.single: + param_init_fn = None + dist_model = SINGLE(olmo_model.to(device)) # when param_init_fn is None, FSDP will call reset_parameters() automatically if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp: @@ -287,7 +294,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: cfg.reset_optimizer_state = False if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None: - if cfg.distributed_strategy == DistributedStrategy.ddp: + if cfg.distributed_strategy in [DistributedStrategy.ddp, DistributedStrategy.single]: checkpoint_type = CheckpointType.unsharded if cfg.save_interval_unsharded is None: @@ -363,17 +370,20 @@ def dummy_init_fn(module: torch.nn.Module) -> None: print(f"failed to set multiprocessing start method: {e}") log.info(f"Multiprocessing start method set to '{mp.get_start_method()}'") - # Set CUDA device. - torch.cuda.set_device(f"cuda:{get_local_rank()}") - - # Initialize process group. - device_as_string = f"cuda:{get_local_rank()}" - torch.cuda.set_device( - device_as_string - ) # Set this early to prevent GPU 0 from picking up a bunch of tensors it shouldn't have. - dist.init_process_group( - backend="nccl", timeout=timedelta(minutes=30), device_id=torch.device(device_as_string) - ) + if torch.cuda.is_available(): + # Set CUDA device. + torch.cuda.set_device(f"cuda:{get_local_rank()}") + + # Initialize process group. + device_as_string = f"cuda:{get_local_rank()}" + torch.cuda.set_device( + device_as_string + ) # Set this early to prevent GPU 0 from picking up a bunch of tensors it shouldn't have. + dist.init_process_group( + backend="nccl", timeout=timedelta(minutes=30), device_id=torch.device(device_as_string) + ) + else: + dist.init_process_group(backend="gloo", timeout=timedelta(minutes=30)) log.info("Process group initialized") prepare_cli_environment()