Skip to content

Commit

Permalink
process_group: wait for futher_thread join before creating new one
Browse files Browse the repository at this point in the history
  • Loading branch information
dwancn committed Jan 17, 2025
1 parent 2f97660 commit 82b1c16
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
6 changes: 5 additions & 1 deletion torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,12 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
if self._rx is not None:
self._rx.close()
if self._future_queue is not None:
# wait for the future thread to exit and then close the queue
self._future_queue.put(_QUEUE_CLOSE)
assert self._future_queue is not None
assert self._future_thread is not None
self._future_thread.join(timeout=10.0)
if self._future_thread.is_alive():
raise RuntimeError("future thread did not exit")
self._future_queue.close()

ctx = mp.get_context("spawn")
Expand Down
30 changes: 30 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,36 @@ def test_baby_gloo_timeout(self) -> None:
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
a.configure(store_addr, 0, 2)

def test_reconfigure_baby_process_group(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
store_addr = f"localhost:{store.port}/prefix"

a = ProcessGroupBabyGloo()
a.configure(store_addr, 0, 1)
futher_thread_1 = a._future_thread
futher_queue_1 = a._future_queue
p_1 = a._p

store_addr = f"localhost:{store.port}/prefix2"
a.configure(store_addr, 0, 1)
futher_thread_2 = a._future_thread
futher_queue_2 = a._future_queue
p_2 = a._p

self.assertNotEqual(futher_thread_1, futher_thread_2)
self.assertNotEqual(futher_queue_1, futher_queue_2)
self.assertNotEqual(p_1, p_2)

self.assertFalse(futher_thread_1.is_alive())
self.assertTrue(futher_queue_1._closed)
self.assertFalse(p_1.is_alive())

self.assertTrue(futher_thread_2.is_alive())
self.assertFalse(futher_queue_2._closed)
self.assertTrue(p_2.is_alive())

def test_dummy(self) -> None:
pg = ProcessGroupDummy(0, 1)
m = nn.Linear(3, 4)
Expand Down

0 comments on commit 82b1c16

Please sign in to comment.