diff --git a/.coveragerc b/.coveragerc index 071b61b24d..d9174acf52 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,4 +1,5 @@ [run] +concurrency = multiprocessing,thread omit = ding/utils/slurm_helper.py ding/utils/file_helper.py diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index 4ff09167df..6ff67d8301 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -3,3 +3,4 @@ from .learner import OffPolicyLearner, HERLearner from .ckpt_handler import CkptSaver from .distributer import ContextExchanger, ModelExchanger +from .barrier import Barrier, BarrierRuntime diff --git a/ding/framework/middleware/barrier.py b/ding/framework/middleware/barrier.py new file mode 100644 index 0000000000..f958c079ae --- /dev/null +++ b/ding/framework/middleware/barrier.py @@ -0,0 +1,227 @@ +from time import sleep, time +from ditk import logging +from ding.framework import task +from ding.utils.lock_helper import LockContext, LockContextType +from ding.utils.design_helper import SingletonMetaclass + + +class BarrierRuntime(metaclass=SingletonMetaclass): + + def __init__(self, node_id: int, max_world_size: int = 100): + """ + Overview: + 'BarrierRuntime' is a singleton class. In addition, it must be initialized before the + class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after + the detection is completed. We don't have a message retransmission mechanism, and losing + a message means deadlock. + Arguments: + - node_id (int): Process ID. + - max_world_size (int, optional): The maximum total number of processes that can be + synchronized, the defalut value is 100. + """ + self.node_id = node_id + self._has_detected = False + self._range_len = len(str(max_world_size)) + 1 + + self._barrier_epoch = 0 + self._barrier_recv_peers_buff = dict() + self._barrier_recv_peers = dict() + self._barrier_ack_peers = [] + self._barrier_lock = LockContext(LockContextType.THREAD_LOCK) + + self.mq_type = task.router.mq_type + self._connected_peers = dict() + self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK) + self._keep_alive_daemon = False + + self._event_name_detect = "b_det" + self.event_name_req = "b_req" + self.event_name_ack = "b_ack" + + def _alive_msg_handler(self, peer_id): + with self._connected_peers_lock: + self._connected_peers[peer_id] = time() + + def _add_barrier_req(self, msg): + peer, epoch = self._unpickle_barrier_tag(msg) + logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch)) + with self._barrier_lock: + if peer not in self._barrier_recv_peers: + self._barrier_recv_peers[peer] = [] + self._barrier_recv_peers[peer].append(epoch) + + def _add_barrier_ack(self, peer): + logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer)) + with self._barrier_lock: + self._barrier_ack_peers.append(peer) + + def _unpickle_barrier_tag(self, msg): + return msg % self._range_len, msg // self._range_len + + def pickle_barrier_tag(self): + return int(self._barrier_epoch * self._range_len + self.node_id) + + def reset_all_peers(self): + with self._barrier_lock: + for peer, q in self._barrier_recv_peers.items(): + if len(q) != 0: + assert q.pop(0) == self._barrier_epoch + self._barrier_ack_peers = [] + self._barrier_epoch += 1 + + def get_recv_num(self): + count = 0 + with self._barrier_lock: + if len(self._barrier_recv_peers) > 0: + for _, q in self._barrier_recv_peers.items(): + if len(q) > 0 and q[0] == self._barrier_epoch: + count += 1 + return count + + def get_ack_num(self): + with self._barrier_lock: + return len(self._barrier_ack_peers) + + def detect_alive(self, expected, timeout): + # The barrier can only block other nodes within the visible range of the current node. + # If the 'attch_to' list of a node is empty, it does not know how many nodes will attach to him, + # so we cannot specify the effective range of a barrier in advance. + assert task._running + task.on(self._event_name_detect, self._alive_msg_handler) + task.on(self.event_name_req, self._add_barrier_req) + task.on(self.event_name_ack, self._add_barrier_ack) + start = time() + while True: + sleep(0.1) + task.emit(self._event_name_detect, self.node_id, only_remote=True) + # In case the other node has not had time to receive our detect message, + # we will send an additional round. + if self._has_detected: + break + with self._connected_peers_lock: + if len(self._connected_peers) == expected: + self._has_detected = True + + if time() - start > timeout: + raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) + + task.off(self._event_name_detect) + logging.info( + "Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected) + ) + + +class BarrierContext: + + def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0): + self._runtime = runtime + self._expected_peer_num = expected_peer_num + self._timeout = detect_timeout + + def __enter__(self): + if not self._runtime._has_detected: + self._runtime.detect_alive(self._expected_peer_num, self._timeout) + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is not None: + import traceback + traceback.print_exception(exc_type, exc_value, tb) + self._runtime.reset_all_peers() + + +class Barrier: + + def __init__(self, attch_from_nums: int, timeout: int = 60): + """ + Overview: + Barrier() is a middleware for debug or profiling. It can synchronize the task step of each + process within the scope of all visible processes. When using Barrier(), you need to pay + attention to the following points: + + 1. All processes must call the same number of Barrier(), otherwise a deadlock occurs. + + 2. 'attch_from_nums' is a very important variable, This value indicates the number of times + the current process will be attached to by other processes (the number of connections + established). + For example: + Node0: address: 127.0.0.1:12345, attach_to = [] + Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"] + For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1) + For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1) + Please note that this value must be given correctly, otherwise, for a node whose 'attach_to' + list is empty, it cannot perceive how many processes will establish connections with it, + resulting in any form of synchronization cannot be performed. + + 3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need + to carefully calculate the number of times each thread calls Barrier() to avoid deadlock. + + 4. In normal training tasks, please do not use Barrier(), which will force the step synchronization + between each process, so it will greatly damage the training efficiency. In addition, if your + training task has dynamic processes, do not use Barrier() to prevent deadlock. + + Arguments: + - attch_from_nums (int): [description] + - timeout (int, optional): The timeout for successful detection of 'expected_peer_num' + number of nodes, the default value is 60 seconds. + """ + self.node_id = task.router.node_id + self.timeout = timeout + self._runtime: BarrierRuntime = task.router.barrier_runtime + self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums + + logging.info( + "Node:[{}], attach to num is:{}, attach from num is:{}".format( + self.node_id, task.get_attch_to_len(), attch_from_nums + ) + ) + + def __call__(self, ctx): + self._wait_barrier(ctx) + yield + self._wait_barrier(ctx) + + def _wait_barrier(self, ctx): + self_ready = False + with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums): + logging.debug("Node:[{}] enter barrier".format(self.node_id)) + # Step1: Notifies all the attached nodes that we have reached the barrier. + task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True) + logging.debug("Node:[{}] sended barrier request".format(self.node_id)) + + # Step2: We check the number of flags we have received. + # In the current CI design of DI-engine, there will always be a node whose 'attach_to' list is empty, + # so there will always be a node that will send ACK unconditionally, so deadlock will not occur. + if self._runtime.get_recv_num() == self._barrier_peers_nums: + self_ready = True + + # Step3: Waiting for our own to be ready. + # Even if the current process has reached the barrier, we will not send an ack immediately, + # we need to wait for the slowest directly connected or indirectly connected peer to + # reach the barrier. + start = time() + if not self_ready: + while True: + if time() - start > self.timeout: + raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) + + if self._runtime.get_recv_num() != self._barrier_peers_nums: + sleep(0.1) + else: + break + + # Step4: Notifies all attached nodes that we are ready. + task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True) + logging.debug("Node:[{}] sended barrier ack".format(self.node_id)) + + # Step5: Wait until all directly or indirectly connected nodes are ready. + start = time() + while True: + if time() - start > self.timeout: + raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) + + if self._runtime.get_ack_num() != self._barrier_peers_nums: + sleep(0.1) + else: + break + + logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step)) diff --git a/ding/framework/middleware/tests/test_barrier.py b/ding/framework/middleware/tests/test_barrier.py new file mode 100644 index 0000000000..de176eda2d --- /dev/null +++ b/ding/framework/middleware/tests/test_barrier.py @@ -0,0 +1,144 @@ +import random +import time +import socket +import pytest +import multiprocessing as mp +from ditk import logging +from ding.framework import task +from ding.framework.parallel import Parallel +from ding.framework.context import OnlineRLContext +from ding.framework.middleware.barrier import Barrier + +PORTS_LIST = ["1235", "1236", "1237"] + + +class EnvStepMiddleware: + + def __call__(self, ctx): + yield + ctx.env_step += 1 + + +class SleepMiddleware: + + def __init__(self, node_id): + self.node_id = node_id + + def random_sleep(self, diection, step): + random.seed(self.node_id + step) + sleep_second = random.randint(1, 5) + logging.info("Node:[{}] env_step:[{}]-{} will sleep:{}s".format(self.node_id, step, diection, sleep_second)) + for i in range(sleep_second): + time.sleep(1) + print("Node:[{}] sleepping...".format(self.node_id)) + logging.info("Node:[{}] env_step:[{}]-{} wake up!".format(self.node_id, step, diection)) + + def __call__(self, ctx): + self.random_sleep("forward", ctx.env_step) + yield + self.random_sleep("backward", ctx.env_step) + + +def star_barrier(): + with task.start(ctx=OnlineRLContext()): + node_id = task.router.node_id + if node_id == 0: + attch_from_nums = 3 + else: + attch_from_nums = 0 + barrier = Barrier(attch_from_nums) + task.use(barrier, lock=False) + task.use(SleepMiddleware(node_id), lock=False) + task.use(barrier, lock=False) + task.use(EnvStepMiddleware(), lock=False) + try: + task.run(2) + except Exception as e: + logging.error(e) + assert False + + +def mesh_barrier(): + with task.start(ctx=OnlineRLContext()): + node_id = task.router.node_id + attch_from_nums = 3 - task.router.node_id + barrier = Barrier(attch_from_nums) + task.use(barrier, lock=False) + task.use(SleepMiddleware(node_id), lock=False) + task.use(barrier, lock=False) + task.use(EnvStepMiddleware(), lock=False) + try: + task.run(2) + except Exception as e: + logging.error(e) + assert False + + +def unmatch_barrier(): + with task.start(ctx=OnlineRLContext()): + node_id = task.router.node_id + attch_from_nums = 3 - task.router.node_id + task.use(Barrier(attch_from_nums, 5), lock=False) + if node_id != 2: + task.use(Barrier(attch_from_nums, 5), lock=False) + try: + task.run(2) + except TimeoutError as e: + assert node_id != 2 + logging.info("Node:[{}] timeout with barrier".format(node_id)) + else: + time.sleep(5) + assert node_id == 2 + logging.info("Node:[{}] finish barrier".format(node_id)) + + +def launch_barrier(args): + i, topo, fn, test_id = args + address = socket.gethostbyname(socket.gethostname()) + topology = "alone" + attach_to = [] + port_base = PORTS_LIST[test_id] + port = port_base + str(i) + if topo == 'star': + if i != 0: + attach_to = ['tcp://{}:{}{}'.format(address, port_base, 0)] + elif topo == 'mesh': + for j in range(i): + attach_to.append('tcp://{}:{}{}'.format(address, port_base, j)) + + Parallel.runner( + node_ids=i, + ports=int(port), + attach_to=attach_to, + topology=topology, + protocol="tcp", + n_parallel_workers=1, + startup_interval=0 + )(fn) + + +@pytest.mark.unittest +def test_star_topology_barrier(): + ctx = mp.get_context("spawn") + with ctx.Pool(processes=4) as pool: + pool.map(launch_barrier, [[i, 'star', star_barrier, 0] for i in range(4)]) + pool.close() + pool.join() + + +@pytest.mark.unittest +def test_mesh_topology_barrier(): + ctx = mp.get_context("spawn") + with ctx.Pool(processes=4) as pool: + pool.map(launch_barrier, [[i, 'mesh', mesh_barrier, 1] for i in range(4)]) + pool.close() + pool.join() + + +@pytest.mark.unittest +def test_unmatch_barrier(): + ctx = mp.get_context("spawn") + with ctx.Pool(processes=4) as pool: + pool.map(launch_barrier, [[i, 'mesh', unmatch_barrier, 2] for i in range(4)]) + pool.close() + pool.join() diff --git a/ding/framework/parallel.py b/ding/framework/parallel.py index 38e343e495..df7b430a8f 100644 --- a/ding/framework/parallel.py +++ b/ding/framework/parallel.py @@ -56,6 +56,9 @@ def _run( self._listener = Thread(target=self.listen, name="mq_listener", daemon=True) self._listener.start() + self.mq_type = mq_type + self.barrier_runtime = Parallel.get_barrier_runtime()(self.node_id) + @classmethod def runner( cls, @@ -372,6 +375,18 @@ def get_ip(cls): s.close() return ip + def get_attch_to_len(self) -> int: + """ + Overview: + Get the length of the 'attach_to' list of message queue. + Returns: + int: the length of the self._mq.attach_to. Returns 0 if self._mq is not initialized + """ + if self._mq: + if hasattr(self._mq, 'attach_to'): + return len(self._mq.attach_to) + return 0 + def __enter__(self) -> "Parallel": return self @@ -389,3 +404,9 @@ def stop(self): self._listener.join(timeout=1) self._listener = None self._event_loop.stop() + + @classmethod + def get_barrier_runtime(cls): + # We get the BarrierRuntime object in the closure to avoid circular import. + from ding.framework.middleware.barrier import BarrierRuntime + return BarrierRuntime diff --git a/ding/framework/task.py b/ding/framework/task.py index 8a0b3f749e..081eb8cdae 100644 --- a/ding/framework/task.py +++ b/ding/framework/task.py @@ -536,5 +536,17 @@ def _activate_async(self): if not self._async_loop: self._async_loop = asyncio.new_event_loop() + def get_attch_to_len(self) -> int: + """ + Overview: + Get the length of the 'attach_to' list in Parallel._mq. + Returns: + int: the length of the Parallel._mq. + """ + if self.router.is_active: + return self.router.get_attch_to_len() + else: + raise RuntimeError("The router is inactive, failed to be obtained the length of 'attch_to' list.") + task = Task()