diff --git a/src/lib.rs b/src/lib.rs index f1b5a6b..16ae317 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,9 +44,9 @@ impl Manager { bind: String, store_addr: String, world_size: u64, - ) -> Self { + ) -> PyResult { py.allow_threads(move || { - let runtime = Runtime::new().unwrap(); + let runtime = Runtime::new()?; let manager = runtime .block_on(manager::Manager::new( replica_id, @@ -56,13 +56,13 @@ impl Manager { store_addr, world_size, )) - .unwrap(); + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let handle = runtime.spawn(manager.clone().run()); - Self { + Ok(Self { handle: handle, manager: manager, _runtime: runtime, - } + }) }) } @@ -89,7 +89,7 @@ impl ManagerClient { #[new] fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult { py.allow_threads(move || { - let runtime = Runtime::new().unwrap(); + let runtime = Runtime::new()?; let client = runtime .block_on(manager::manager_client_new(addr, timeout)) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -193,14 +193,16 @@ fn reset_python_signals(py: Python<'_>) -> PyResult<()> { } #[pyfunction] -fn lighthouse_main(py: Python<'_>) { - reset_python_signals(py).unwrap(); +fn lighthouse_main(py: Python<'_>) -> PyResult<()> { + reset_python_signals(py)?; let mut args = env::args(); args.next(); // discard binary arg let opt = lighthouse::LighthouseOpt::from_iter(args); - let rt = Runtime::new().unwrap(); - rt.block_on(lighthouse_main_async(opt)).unwrap(); + let rt = Runtime::new()?; + rt.block_on(lighthouse_main_async(opt)) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(()) } async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> { @@ -223,7 +225,7 @@ impl Lighthouse { #[new] fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult { py.allow_threads(move || { - let rt = Runtime::new().unwrap(); + let rt = Runtime::new()?; let lighthouse = rt .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { @@ -232,7 +234,7 @@ impl Lighthouse { join_timeout_ms: 100, quorum_tick_ms: 100, })) - .unwrap(); + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; Ok(Self { handle: rt.spawn(lighthouse.clone().run()), @@ -261,7 +263,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { .show_module_names(true) .timestamp(stderrlog::Timestamp::Millisecond) .init() - .unwrap(); + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; m.add_class::()?; m.add_class::()?; diff --git a/torchft/manager.py b/torchft/manager.py index e8b8d3b..42d229f 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -140,6 +140,7 @@ def __init__( wait_for_workers=False, ) self._pg = pg + self._manager = None if rank == 0: hostname = socket.gethostname() @@ -148,7 +149,8 @@ def __init__( lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] if replica_id is None: - replica_id = str(uuid.uuid4()) + replica_id = "" + replica_id = replica_id + str(uuid.uuid4()) self._manager = _Manager( replica_id=replica_id, lighthouse_addr=lighthouse_addr, @@ -180,6 +182,8 @@ def shutdown(self) -> None: Shutdown the manager and checkpoint server. """ self._ckpt_server.shutdown() + if self._manager is not None: + self._manager.shutdown() def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]: """ @@ -364,7 +368,7 @@ def _async_quorum(self) -> None: self._participating_rank = None if quorum_id != self._quorum_id: - logger.info(f"reconfiguring for quorum_id {quorum_id}") + logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}") store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}" # We use the replica rank and world as we want all replicas in the PG. self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size) @@ -373,7 +377,7 @@ def _async_quorum(self) -> None: # See manager.rs for healing conditions if heal: self._healing = True - logger.info("healing required") + logger.info(f"{replica_rank}= healing required") logger.info(f"fetching checkpoint server address from {address}") primary_client = ManagerClient(address, timeout=self._timeout) diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index e9c9261..b324ed0 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -1,4 +1,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import ExitStack +from typing import Set, Tuple from unittest import TestCase import torch @@ -24,63 +26,108 @@ def forward(self, x): return self.model(x) -def train_loop(replica_id: int, lighthouse_address: str) -> None: - store = dist.TCPStore( - host_name="localhost", - port=0, - is_master=True, - wait_for_workers=False, - ) - - def load_state_dict(state_dict): - m.load_state_dict(state_dict["model"]) - optimizer.load_state_dict(state_dict["optim"]) - - def state_dict(): - return { - "model": m.state_dict(), - "optim": optimizer.state_dict(), - } - - pg = ProcessGroupGloo() - manager = Manager( - pg=pg, - min_replica_size=2, - load_state_dict=load_state_dict, - state_dict=state_dict, - replica_id=str(replica_id), - store_addr="localhost", - store_port=store.port, - rank=0, - world_size=1, - lighthouse_addr=lighthouse_address, - port=19530 + replica_id, - ) - m = DistributedDataParallel(manager, MyModel()) - optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters())) - criterion = nn.CrossEntropyLoss() - - while True: - inputs = torch.rand(2, 3) - labels = torch.randint(4, (2,)) - - optimizer.zero_grad() - out = m(inputs) - loss = criterion(out, labels) - - loss.backward() - optimizer.step() - - # TODO: assert weights are equal across replicas - - if manager.current_step() >= 5: - break - - manager.shutdown() +class InjectedFailure(Exception): + pass + + +class FailureInjector: + def __init__(self) -> None: + self._failures: Set[int] = set() + self.count = 0 + + def fail_at(self, step: int) -> "FailureInjector": + self._failures.add(step) + return self + + def check(self, step: int) -> None: + if step in self._failures: + self.count += 1 + self._failures.remove(step) + print(f"injecting failure {step=}") + raise InjectedFailure(f"injected failure {step=}") + + +def worker_manager( + replica_id: int, + lighthouse_address: str, + failure_injector: FailureInjector, + attempts: int = 3, +) -> None: + for i in range(attempts): + try: + print(f"starting worker {replica_id} attempt {i}") + return train_loop( + replica_id, lighthouse_address, failure_injector=failure_injector + ) + except InjectedFailure as e: + print("got injected failure", i, e) + if i == attempts - 1: + raise + continue + + +def train_loop( + replica_id: int, lighthouse_address: str, failure_injector: FailureInjector +) -> None: + with ExitStack() as stack: + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=2, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=str(replica_id), + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + lighthouse_addr=lighthouse_address, + port=19530 + replica_id, + ) + stack.callback(manager.shutdown) + + m = DistributedDataParallel(manager, MyModel()) + optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters())) + criterion = nn.CrossEntropyLoss() + + while True: + print(f"worker {replica_id} starting step {manager.current_step()}") + inputs = torch.rand(2, 3) + labels = torch.randint(4, (2,)) + + optimizer.zero_grad() + out = m(inputs) + loss = criterion(out, labels) + + loss.backward() + optimizer.step() + + if manager.current_step() >= 5: + # return state_dict so we can check consistency + return state_dict() + + failure_injector.check(manager.current_step()) class ManagerIntegTest(TestCase): - def test_ddp(self): + def test_ddp_healthy(self): lighthouse = Lighthouse( bind="[::]:0", min_replicas=2, @@ -90,11 +137,60 @@ def test_ddp(self): with ThreadPoolExecutor(max_workers=num_replicas) as executor: for replica_id in range(num_replicas): + failure_injector = FailureInjector() + futures.append( + executor.submit( + worker_manager, + replica_id, + lighthouse.address(), + failure_injector=failure_injector, + ) + ) + + state_dicts = [] + + for fut in as_completed(futures): + state_dicts.append(fut.result()) + + lighthouse.shutdown() + + for state_dict in state_dicts: + torch.testing.assert_close(state_dict, state_dicts[0]) + + def test_ddp_recovery(self): + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + failure_injectors = [ + FailureInjector(), + FailureInjector().fail_at(2), + ] + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id, failure_injector in zip( + range(num_replicas), failure_injectors + ): futures.append( - executor.submit(train_loop, replica_id, lighthouse.address()) + executor.submit( + worker_manager, + replica_id, + lighthouse.address(), + failure_injector=failure_injector, + ) ) + state_dicts = [] + for fut in as_completed(futures): - fut.result() + state_dicts.append(fut.result()) lighthouse.shutdown() + + for state_dict in state_dicts: + torch.testing.assert_close(state_dict, state_dicts[0]) + + self.assertEqual(failure_injectors[1].count, 1)