Skip to content

Commit

Permalink
Fix socket client using(before socket server creation). (#1361)
Browse files Browse the repository at this point in the history
* optimize ckpt engine's queue using

* lint

* lint

* lint

* lint
  • Loading branch information
BalaBalaYi authored Nov 27, 2024
1 parent 397f304 commit 0674b71
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
4 changes: 4 additions & 0 deletions dlrover/python/common/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def __init__(self, name="", create=False):
self._server = None
self._init_socket()

@property
def name(self):
return self._name

def unlink(self):
try:
os.unlink(self._socket_file)
Expand Down
31 changes: 25 additions & 6 deletions dlrover/trainer/torch/flash_checkpoint/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from dlrover.python.common import env_utils
from dlrover.python.common.constants import CheckpointConstant
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.multi_process import SharedLock, SharedQueue
from dlrover.python.common.multi_process import (
LocalSocketComm,
SharedLock,
SharedQueue,
)
from dlrover.python.common.singleton import Singleton
from dlrover.python.common.storage import CheckpointStorage
from dlrover.python.elastic_agent.torch.ckpt_saver import (
Expand Down Expand Up @@ -133,6 +137,21 @@ def start_saver_process():
return None


def wait_socket_server(socket_server: LocalSocketComm, timeout=60):
"""
Socket client should not be used before socket server is created.
"""

start_time = time.time()
while not socket_server.is_available():
time.sleep(0.1)
if time.time() - start_time > timeout:
raise TimeoutError(
"Timed out waiting for socket server: "
f"{socket_server.name}."
)


class CheckpointEngine(metaclass=ABCMeta):
"""
The checkpoint engine synchronously writes the state dict into
Expand All @@ -144,7 +163,7 @@ class CheckpointEngine(metaclass=ABCMeta):
the training loop and call `save_to_storage`.
If the training process fail, the agent in main process can continuously
saves the state dict from the shared memory into the storage.
save the state dict from the shared memory into the storage.
Args:
checkpoint_dir (str): the directory to save checkpoint.
Expand Down Expand Up @@ -199,8 +218,7 @@ def __init__(
self._shm_lock = SharedLock(name=lock_name, create=False)

# need to wait until the socket server is created(by the saver)
while not self._shm_lock.is_available():
time.sleep(0.1)
wait_socket_server(self._shm_lock)

self._shm_handler = SharedMemoryHandler(
self.local_shard_id, host=False
Expand Down Expand Up @@ -296,6 +314,8 @@ def _notify_agent_to_create_saver(self):
},
)

wait_socket_server(queue)

logger.info(
"Notify agent to create a checkpoint saver using: "
f"{class_meta.__dict__}."
Expand All @@ -315,8 +335,7 @@ def _update_saver_config(self):
"The event queue cannot be None on local rank 0."
)

while not self._event_queue.is_available():
time.sleep(0.1)
wait_socket_server(self._event_queue)

logger.info(f"Update saver config: {event.__dict__}")
self._event_queue.put(event)
Expand Down

0 comments on commit 0674b71

Please sign in to comment.