Skip to content

Commit

Permalink
manager_integ_tests: added recovery test (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Dec 8, 2024
1 parent 7b93da7 commit 9878980
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 72 deletions.
28 changes: 15 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ impl Manager {
bind: String,
store_addr: String,
world_size: u64,
) -> Self {
) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = Runtime::new().unwrap();
let runtime = Runtime::new()?;
let manager = runtime
.block_on(manager::Manager::new(
replica_id,
Expand All @@ -56,13 +56,13 @@ impl Manager {
store_addr,
world_size,
))
.unwrap();
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
let handle = runtime.spawn(manager.clone().run());
Self {
Ok(Self {
handle: handle,
manager: manager,
_runtime: runtime,
}
})
})
}

Expand All @@ -89,7 +89,7 @@ impl ManagerClient {
#[new]
fn new(py: Python<'_>, addr: String, timeout: Duration) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = Runtime::new().unwrap();
let runtime = Runtime::new()?;
let client = runtime
.block_on(manager::manager_client_new(addr, timeout))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Expand Down Expand Up @@ -193,14 +193,16 @@ fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
}

#[pyfunction]
fn lighthouse_main(py: Python<'_>) {
reset_python_signals(py).unwrap();
fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
reset_python_signals(py)?;

let mut args = env::args();
args.next(); // discard binary arg
let opt = lighthouse::LighthouseOpt::from_iter(args);
let rt = Runtime::new().unwrap();
rt.block_on(lighthouse_main_async(opt)).unwrap();
let rt = Runtime::new()?;
rt.block_on(lighthouse_main_async(opt))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}

async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> {
Expand All @@ -223,7 +225,7 @@ impl Lighthouse {
#[new]
fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult<Self> {
py.allow_threads(move || {
let rt = Runtime::new().unwrap();
let rt = Runtime::new()?;

let lighthouse = rt
.block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt {
Expand All @@ -232,7 +234,7 @@ impl Lighthouse {
join_timeout_ms: 100,
quorum_tick_ms: 100,
}))
.unwrap();
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

Ok(Self {
handle: rt.spawn(lighthouse.clone().run()),
Expand Down Expand Up @@ -261,7 +263,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
.show_module_names(true)
.timestamp(stderrlog::Timestamp::Millisecond)
.init()
.unwrap();
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<Manager>()?;
m.add_class::<ManagerClient>()?;
Expand Down
10 changes: 7 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
wait_for_workers=False,
)
self._pg = pg
self._manager = None

if rank == 0:
hostname = socket.gethostname()
Expand All @@ -148,7 +149,8 @@ def __init__(
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]

if replica_id is None:
replica_id = str(uuid.uuid4())
replica_id = ""
replica_id = replica_id + str(uuid.uuid4())
self._manager = _Manager(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
Expand Down Expand Up @@ -180,6 +182,8 @@ def shutdown(self) -> None:
Shutdown the manager and checkpoint server.
"""
self._ckpt_server.shutdown()
if self._manager is not None:
self._manager.shutdown()

def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
"""
Expand Down Expand Up @@ -364,7 +368,7 @@ def _async_quorum(self) -> None:
self._participating_rank = None

if quorum_id != self._quorum_id:
logger.info(f"reconfiguring for quorum_id {quorum_id}")
logger.info(f"{replica_rank=} reconfiguring for quorum_id {quorum_id}")
store_prefixed_addr = f"{store_address}/torchft/{quorum_id}/{self._rank}"
# We use the replica rank and world as we want all replicas in the PG.
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
Expand All @@ -373,7 +377,7 @@ def _async_quorum(self) -> None:
# See manager.rs for healing conditions
if heal:
self._healing = True
logger.info("healing required")
logger.info(f"{replica_rank}= healing required")

logger.info(f"fetching checkpoint server address from {address}")
primary_client = ManagerClient(address, timeout=self._timeout)
Expand Down
208 changes: 152 additions & 56 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from typing import Set, Tuple
from unittest import TestCase

import torch
Expand All @@ -24,63 +26,108 @@ def forward(self, x):
return self.model(x)


def train_loop(replica_id: int, lighthouse_address: str) -> None:
store = dist.TCPStore(
host_name="localhost",
port=0,
is_master=True,
wait_for_workers=False,
)

def load_state_dict(state_dict):
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])

def state_dict():
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
}

pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(replica_id),
store_addr="localhost",
store_port=store.port,
rank=0,
world_size=1,
lighthouse_addr=lighthouse_address,
port=19530 + replica_id,
)
m = DistributedDataParallel(manager, MyModel())
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
criterion = nn.CrossEntropyLoss()

while True:
inputs = torch.rand(2, 3)
labels = torch.randint(4, (2,))

optimizer.zero_grad()
out = m(inputs)
loss = criterion(out, labels)

loss.backward()
optimizer.step()

# TODO: assert weights are equal across replicas

if manager.current_step() >= 5:
break

manager.shutdown()
class InjectedFailure(Exception):
pass


class FailureInjector:
def __init__(self) -> None:
self._failures: Set[int] = set()
self.count = 0

def fail_at(self, step: int) -> "FailureInjector":
self._failures.add(step)
return self

def check(self, step: int) -> None:
if step in self._failures:
self.count += 1
self._failures.remove(step)
print(f"injecting failure {step=}")
raise InjectedFailure(f"injected failure {step=}")


def worker_manager(
replica_id: int,
lighthouse_address: str,
failure_injector: FailureInjector,
attempts: int = 3,
) -> None:
for i in range(attempts):
try:
print(f"starting worker {replica_id} attempt {i}")
return train_loop(
replica_id, lighthouse_address, failure_injector=failure_injector
)
except InjectedFailure as e:
print("got injected failure", i, e)
if i == attempts - 1:
raise
continue


def train_loop(
replica_id: int, lighthouse_address: str, failure_injector: FailureInjector
) -> None:
with ExitStack() as stack:
store = dist.TCPStore(
host_name="localhost",
port=0,
is_master=True,
wait_for_workers=False,
)

def load_state_dict(state_dict):
m.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])

def state_dict():
return {
"model": m.state_dict(),
"optim": optimizer.state_dict(),
}

pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(replica_id),
store_addr="localhost",
store_port=store.port,
rank=0,
world_size=1,
lighthouse_addr=lighthouse_address,
port=19530 + replica_id,
)
stack.callback(manager.shutdown)

m = DistributedDataParallel(manager, MyModel())
optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters()))
criterion = nn.CrossEntropyLoss()

while True:
print(f"worker {replica_id} starting step {manager.current_step()}")
inputs = torch.rand(2, 3)
labels = torch.randint(4, (2,))

optimizer.zero_grad()
out = m(inputs)
loss = criterion(out, labels)

loss.backward()
optimizer.step()

if manager.current_step() >= 5:
# return state_dict so we can check consistency
return state_dict()

failure_injector.check(manager.current_step())


class ManagerIntegTest(TestCase):
def test_ddp(self):
def test_ddp_healthy(self):
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
Expand All @@ -90,11 +137,60 @@ def test_ddp(self):

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id in range(num_replicas):
failure_injector = FailureInjector()
futures.append(
executor.submit(
worker_manager,
replica_id,
lighthouse.address(),
failure_injector=failure_injector,
)
)

state_dicts = []

for fut in as_completed(futures):
state_dicts.append(fut.result())

lighthouse.shutdown()

for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

def test_ddp_recovery(self):
lighthouse = Lighthouse(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []

failure_injectors = [
FailureInjector(),
FailureInjector().fail_at(2),
]

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id, failure_injector in zip(
range(num_replicas), failure_injectors
):
futures.append(
executor.submit(train_loop, replica_id, lighthouse.address())
executor.submit(
worker_manager,
replica_id,
lighthouse.address(),
failure_injector=failure_injector,
)
)

state_dicts = []

for fut in as_completed(futures):
fut.result()
state_dicts.append(fut.result())

lighthouse.shutdown()

for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

self.assertEqual(failure_injectors[1].count, 1)

0 comments on commit 9878980

Please sign in to comment.