diff --git a/rllib/agents/ddpg/apex.py b/rllib/agents/ddpg/apex.py index 99f82d2f0717b..8669c4e90d864 100644 --- a/rllib/agents/ddpg/apex.py +++ b/rllib/agents/ddpg/apex.py @@ -23,6 +23,15 @@ "buffer_size": 2000000, # TODO(jungong) : update once Apex supports replay_buffer_config. "replay_buffer_config": None, + # Whether all shards of the replay buffer must be co-located + # with the learner process (running the execution plan). + # This is preferred b/c the learner process should have quick + # access to the data from the buffer shards, avoiding network + # traffic each time samples from the buffer(s) are drawn. + # Set this to False for relaxing this constraint and allowing + # replay shards to be created on node(s) other than the one + # on which the learner is located. + "replay_buffer_shards_colocated_with_driver": True, "learning_starts": 50000, "train_batch_size": 512, "rollout_fragment_length": 50, @@ -31,6 +40,7 @@ "worker_side_prioritization": True, "min_iter_time_s": 30, }, + _allow_unknown_configs=True, ) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 781f8184ab436..d9d1ab3f536b4 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -14,6 +14,7 @@ import collections import copy +import platform from typing import Tuple import ray @@ -32,7 +33,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import UpdateTargetNetwork from ray.rllib.utils import merge_dicts -from ray.rllib.utils.actors import create_colocated +from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.annotations import override from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict @@ -55,10 +56,21 @@ "n_step": 3, "num_gpus": 1, "num_workers": 32, + "buffer_size": 2000000, # TODO(jungong) : add proper replay_buffer_config after # DistributedReplayBuffer type is supported. "replay_buffer_config": None, + # Whether all shards of the replay buffer must be co-located + # with the learner process (running the execution plan). + # This is preferred b/c the learner process should have quick + # access to the data from the buffer shards, avoiding network + # traffic each time samples from the buffer(s) are drawn. + # Set this to False for relaxing this constraint and allowing + # replay shards to be created on node(s) other than the one + # on which the learner is located. + "replay_buffer_shards_colocated_with_driver": True, + "learning_starts": 50000, "train_batch_size": 512, "rollout_fragment_length": 50, @@ -129,7 +141,8 @@ def execution_plan(workers: WorkerSet, config: dict, # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"][ "num_replay_buffer_shards"] - replay_actors = create_colocated(ReplayActor, [ + + replay_actor_args = [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], @@ -139,7 +152,24 @@ def execution_plan(workers: WorkerSet, config: dict, config["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config.get("replay_sequence_length", 1), - ], num_replay_buffer_shards) + ] + # Place all replay buffer shards on the same node as the learner + # (driver process that runs this execution plan). + if config["replay_buffer_shards_colocated_with_driver"]: + replay_actors = create_colocated_actors( + actor_specs=[ + # (class, args, kwargs={}, count) + (ReplayActor, replay_actor_args, {}, + num_replay_buffer_shards) + ], + node=platform.node(), # localhost + )[0] # [0]=only one item in `actor_specs`. + # Place replay buffer shards on any node(s). + else: + replay_actors = [ + ReplayActor(*replay_actor_args) + for _ in range(num_replay_buffer_shards) + ] # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 3b03482eb9725..6f935d7839d60 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -11,9 +11,11 @@ import pickle import tempfile import time -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, DefaultDict, Dict, List, Optional, Set, Tuple, \ + Type, Union import ray +from ray.actor import ActorHandle from ray.exceptions import RayError from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env.env_context import EnvContext @@ -722,8 +724,9 @@ def default_logger_creator(config): self._episode_history = [] self._episodes_to_be_collected = [] - # Evaluation WorkerSet and metrics last returned by `self.evaluate()`. - self.evaluation_workers = None + # Evaluation WorkerSet. + self.evaluation_workers: Optional[WorkerSet] = None + # Metrics most recently returned by `self.evaluate()`. self.evaluation_metrics = {} super().__init__(config, logger_creator, remote_checkpoint_dir, @@ -798,12 +801,19 @@ def env_creator_from_classpath(env_context): self.local_replay_buffer = ( self._create_local_replay_buffer_if_necessary(self.config)) + # Create a dict, mapping ActorHandles to sets of open remote + # requests (object refs). This way, we keep track, of which actors + # inside this Trainer (e.g. a remote RolloutWorker) have + # already been sent how many (e.g. `sample()`) requests. + self.remote_requests_in_flight: \ + DefaultDict[ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) + # Deprecated way of implementing Trainer sub-classes (or "templates" # via the soon-to-be deprecated `build_trainer` utility function). # Instead, sub-classes should override the Trainable's `setup()` # method and call super().setup() from within that override at some # point. - self.workers = None + self.workers: Optional[WorkerSet] = None self.train_exec_impl = None # Old design: Override `Trainer._init` (or use `build_trainer()`, which @@ -845,13 +855,10 @@ def env_creator_from_classpath(env_context): self.workers, self.config, **self._kwargs_for_execution_plan()) - # TODO: Now that workers have been created, update our policy - # specs in the config[multiagent] dict with the correct spaces. - # However, this leads to a problem with the evaluation - # workers' observation one-hot preprocessor in - # `examples/documentation/rllib_in_6sec.py` script. - # self.config["multiagent"]["policies"] = \ - # self.workers.local_worker().policy_map.policy_specs + # Now that workers have been created, update our policy + # specs in the config[multiagent] dict with the correct spaces. + self.config["multiagent"]["policies"] = \ + self.workers.local_worker().policy_dict # Evaluation WorkerSet setup. # User would like to setup a separate evaluation worker set. @@ -912,7 +919,7 @@ def env_creator_from_classpath(env_context): # If evaluation_num_workers=0, use the evaluation set's local # worker for evaluation, otherwise, use its remote workers # (parallelized evaluation). - self.evaluation_workers = self._make_workers( + self.evaluation_workers: WorkerSet = self._make_workers( env_creator=self.env_creator, validate_env=None, policy_class=self.get_default_policy_class(self.config), diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 0531c9c8efa04..9592a8525da20 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -10,6 +10,7 @@ TYPE_CHECKING, Union import ray +from ray import ObjectRef from ray import cloudpickle as pickle from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.env.env_context import EnvContext @@ -537,16 +538,16 @@ def make_sub_env(vector_index): self.make_sub_env_fn = make_sub_env self.spaces = spaces - policy_dict = _determine_spaces_for_multi_agent_dict( + self.policy_dict = _determine_spaces_for_multi_agent_dict( policy_spec, self.env, spaces=self.spaces, policy_config=policy_config) # List of IDs of those policies, which should be trained. - # By default, these are all policies found in the policy_dict. + # By default, these are all policies found in `self.policy_dict`. self.policies_to_train: List[PolicyID] = policies_to_train or list( - policy_dict.keys()) + self.policy_dict.keys()) self.set_policies_to_train(self.policies_to_train) self.policy_map: PolicyMap = None @@ -583,7 +584,7 @@ def make_sub_env(vector_index): f"is ignored.") self._build_policy_map( - policy_dict, + self.policy_dict, policy_config, session_creator=tf_session_creator, seed=seed) @@ -1111,7 +1112,7 @@ def add_policy( """ if policy_id in self.policy_map: raise ValueError(f"Policy ID '{policy_id}' already in policy map!") - policy_dict = _determine_spaces_for_multi_agent_dict( + policy_dict_to_add = _determine_spaces_for_multi_agent_dict( { policy_id: PolicySpec(policy_cls, observation_space, action_space, config or {}) @@ -1120,8 +1121,9 @@ def add_policy( spaces=self.spaces, policy_config=self.policy_config, ) + self.policy_dict.update(policy_dict_to_add) self._build_policy_map( - policy_dict, + policy_dict_to_add, self.policy_config, seed=self.policy_config.get("seed")) new_policy = self.policy_map[policy_id] @@ -1386,6 +1388,14 @@ def set_weights(self, >>> # Set `global_vars` (timestep) as well. >>> worker.set_weights(weights, {"timestep": 42}) """ + # If per-policy weights are object refs, `ray.get()` them first. + if weights and isinstance(next(iter(weights.values())), ObjectRef): + actual_weights = ray.get(list(weights.values())) + weights = { + pid: actual_weights[i] + for i, pid in enumerate(weights.keys()) + } + for pid, w in weights.items(): self.policy_map[pid].set_weights(w) if global_vars: diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 8a0b354add715..eae7ac8483cfb 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -1,10 +1,10 @@ import logging -from typing import List, Tuple import time +from typing import Any, Callable, Dict, List, Optional, Tuple, \ + TYPE_CHECKING import ray -from ray.util.iter import from_actors, LocalIterator -from ray.util.iter_metrics import SharedMetrics +from ray.actor import ActorHandle from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ @@ -12,27 +12,200 @@ _check_sample_batch_type, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils.annotations import ExperimentalAPI from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ LEARNER_STATS_KEY from ray.rllib.utils.sgd import standardized from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients +from ray.util.iter import from_actors, LocalIterator +from ray.util.iter_metrics import SharedMetrics + +if TYPE_CHECKING: + from ray.rllib.agents.trainer import Trainer + from ray.rllib.evaluation.rollout_worker import RolloutWorker logger = logging.getLogger(__name__) -def synchronous_parallel_sample(workers: WorkerSet) -> List[SampleBatch]: +@ExperimentalAPI +def synchronous_parallel_sample( + worker_set: WorkerSet, + remote_fn: Optional[Callable[["RolloutWorker"], None]] = None, +) -> List[SampleBatch]: + """Runs parallel and synchronous rollouts on all remote workers. + + Waits for all workers to return from the remote calls. + + If no remote workers exist (num_workers == 0), use the local worker + for sampling. + + Alternatively to calling `worker.sample.remote()`, the user can provide a + `remote_fn()`, which will be applied to the worker(s) instead. + + Args: + worker_set: The WorkerSet to use for sampling. + remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead + of `worker.sample.remote()` to generate the requests. + + Returns: + The list of collected sample batch types (one for each parallel + rollout worker in the given `worker_set`). + + Examples: + >>> # 2 remote workers (num_workers=2): + >>> batches = synchronous_parallel_sample(trainer.workers) + >>> print(len(batches)) + ... 2 + >>> print(batches[0]) + ... SampleBatch(16: ['obs', 'actions', 'rewards', 'dones']) + + >>> # 0 remote workers (num_workers=0): Using the local worker. + >>> batches = synchronous_parallel_sample(trainer.workers) + >>> print(len(batches)) + ... 1 + """ # No remote workers in the set -> Use local worker for collecting # samples. - if not workers.remote_workers(): - return [workers.local_worker().sample()] + if not worker_set.remote_workers(): + return [worker_set.local_worker().sample()] # Loop over remote workers' `sample()` method in parallel. sample_batches = ray.get( - [r.sample.remote() for r in workers.remote_workers()]) + [r.sample.remote() for r in worker_set.remote_workers()]) + # Return all collected batches. return sample_batches +# TODO: Move to generic parallel ops module and rename to +# `asynchronous_parallel_requests`: +@ExperimentalAPI +def asynchronous_parallel_sample( + trainer: "Trainer", + actors: List[ActorHandle], + ray_wait_timeout_s: Optional[float] = None, + max_remote_requests_in_flight_per_actor: int = 2, + remote_fn: Optional[Callable[["RolloutWorker"], None]] = None, + remote_args: Optional[List[List[Any]]] = None, + remote_kwargs: Optional[List[Dict[str, Any]]] = None, +) -> Optional[List[SampleBatch]]: + """Runs parallel and asynchronous rollouts on all remote workers. + + May use a timeout (if provided) on `ray.wait()` and returns only those + samples that could be gathered in the timeout window. Allows a maximum + of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight + per remote actor. + + Alternatively to calling `actor.sample.remote()`, the user can provide a + `remote_fn()`, which will be applied to the actor(s) instead. + + Args: + trainer: The Trainer object that we run the sampling for. + actors: The List of ActorHandles to perform the remote requests on. + ray_wait_timeout_s: Timeout (in sec) to be used for the underlying + `ray.wait()` calls. If None (default), never time out (block + until at least one actor returns something). + max_remote_requests_in_flight_per_actor: Maximum number of remote + requests sent to each actor. 2 (default) is probably + sufficient to avoid idle times between two requests. + remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of + `actor.sample.remote()` to generate the requests. + remote_args: If provided, use this list (per-actor) of lists (call + args) as *args to be passed to the `remote_fn`. + E.g.: actors=[A, B], + remote_args=[[...] <- *args for A, [...] <- *args for B]. + remote_kwargs: If provided, use this list (per-actor) of dicts + (kwargs) as **kwargs to be passed to the `remote_fn`. + E.g.: actors=[A, B], + remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. + + Returns: + The list of asynchronously collected sample batch types. None, if no + samples are ready. + + Examples: + >>> # 2 remote rollout workers (num_workers=2): + >>> batches = asynchronous_parallel_sample( + ... trainer, + ... actors=trainer.workers.remote_workers(), + ... ray_wait_timeout_s=0.1, + ... remote_fn=lambda w: time.sleep(1) # sleep 1sec + ... ) + >>> print(len(batches)) + ... 2 + >>> # Expect a timeout to have happened. + >>> batches[0] is None and batches[1] is None + ... True + """ + + if remote_args is not None: + assert len(remote_args) == len(actors) + if remote_kwargs is not None: + assert len(remote_kwargs) == len(actors) + + # Collect all currently pending remote requests into a single set of + # object refs. + pending_remotes = set() + # Also build a map to get the associated actor for each remote request. + remote_to_actor = {} + for actor, set_ in trainer.remote_requests_in_flight.items(): + pending_remotes |= set_ + for r in set_: + remote_to_actor[r] = actor + + # Add new requests, if possible (if + # `max_remote_requests_in_flight_per_actor` setting allows it). + for actor_idx, actor in enumerate(actors): + # Still room for another request to this actor. + if len(trainer.remote_requests_in_flight[actor]) < \ + max_remote_requests_in_flight_per_actor: + if remote_fn is None: + req = actor.sample.remote() + else: + args = remote_args[actor_idx] if remote_args else [] + kwargs = remote_kwargs[actor_idx] if remote_kwargs else {} + req = actor.apply.remote(remote_fn, *args, **kwargs) + # Add to our set to send to ray.wait(). + pending_remotes.add(req) + # Keep our mappings properly updated. + trainer.remote_requests_in_flight[actor].add(req) + remote_to_actor[req] = actor + + # There must always be pending remote requests. + assert len(pending_remotes) > 0 + pending_remote_list = list(pending_remotes) + + # No timeout: Block until at least one result is returned. + if ray_wait_timeout_s is None: + # First try to do a `ray.wait` w/o timeout for efficiency. + ready, _ = ray.wait( + pending_remote_list, num_returns=len(pending_remotes), timeout=0) + # Nothing returned and `timeout` is None -> Fall back to a + # blocking wait to make sure we can return something. + if not ready: + ready, _ = ray.wait(pending_remote_list, num_returns=1) + # Timeout: Do a `ray.wait() call` w/ timeout. + else: + ready, _ = ray.wait( + pending_remote_list, + num_returns=len(pending_remotes), + timeout=ray_wait_timeout_s) + + # Return None if nothing ready after the timeout. + if not ready: + return None + + for obj_ref in ready: + # Remove in-flight record for this ref. + trainer.remote_requests_in_flight[remote_to_actor[obj_ref]].remove( + obj_ref) + remote_to_actor.pop(obj_ref) + + results = ray.get(ready) + + return results + + def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. diff --git a/rllib/execution/tree_agg.py b/rllib/execution/tree_agg.py index 6880fb1cbbada..812a4437b8081 100644 --- a/rllib/execution/tree_agg.py +++ b/rllib/execution/tree_agg.py @@ -9,7 +9,7 @@ from ray.rllib.execution.replay_ops import MixInReplay from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.policy.sample_batch import MultiAgentBatch -from ray.rllib.utils.actors import create_colocated +from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.typing import SampleBatchType, ModelWeights from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \ from_actors, LocalIterator @@ -91,11 +91,22 @@ def gather_experiences_tree_aggregation(workers: WorkerSet, ] # This spawns |num_aggregation_workers| intermediate actors that aggregate - # experiences in parallel. We force colocation on the same node to maximize - # data bandwidth between them and the driver. - train_batches = from_actors([ - create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups - ]) + # experiences in parallel. We force colocation on the same node (localhost) + # to maximize data bandwidth between them and the driver. + localhost = platform.node() + assert localhost != "", \ + "ERROR: Cannot determine local node name! " \ + "`platform.node()` returned empty string." + all_co_located = create_colocated_actors( + actor_specs=[ + # (class, args, kwargs={}, count=1) + (Aggregator, [config, g], {}, 1) for g in rollout_groups + ], + node=localhost) + + # Use the first ([0]) of each created group (each group only has one + # actor: count=1). + train_batches = from_actors([group[0] for group in all_co_located]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): diff --git a/rllib/utils/actors.py b/rllib/utils/actors.py index 06eec3c163762..a700462c5b223 100644 --- a/rllib/utils/actors.py +++ b/rllib/utils/actors.py @@ -1,7 +1,11 @@ +from collections import defaultdict, deque import logging import platform +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type + import ray -from collections import deque +from ray.actor import ActorClass, ActorHandle +from ray.rllib.utils.deprecation import Deprecated logger = logging.getLogger(__name__) @@ -65,45 +69,202 @@ def count(self): return len(self._tasks) -def drop_colocated(actors): - colocated, non_colocated = split_colocated(actors) - for a in colocated: +def create_colocated_actors( + actor_specs: Sequence[Tuple[Type, Any, Any, int]], + node: Optional[str] = "localhost", + max_attempts: int = 10, +) -> Dict[Type, List[ActorHandle]]: + """Create co-located actors of any type(s) on any node. + + Args: + actor_specs: Tuple/list with tuples consisting of: 1) The + (already @ray.remote) class(es) to construct, 2) c'tor args, + 3) c'tor kwargs, and 4) the number of actors of that class with + given args/kwargs to construct. + node: The node to co-locate the actors on. By default ("localhost"), + place the actors on the node the caller of this function is + located on. Use None for indicating that any (resource fulfilling) + node in the cluster may be used. + max_attempts: The maximum number of co-location attempts to + perform before throwing an error. + + Returns: + A dict mapping the created types to the list of n ActorHandles + created (and co-located) for that type. + """ + if node == "localhost": + node = platform.node() + + # Maps each entry in `actor_specs` to lists of already co-located actors. + ok = [[] for _ in range(len(actor_specs))] + + # Try n times to co-locate all given actor types (`actor_specs`). + # With each (failed) attempt, increase the number of actors we try to + # create (on the same node), then kill the ones that have been created in + # excess. + for attempt in range(max_attempts): + # If any attempt to co-locate fails, set this to False and we'll do + # another attempt. + all_good = True + # Process all `actor_specs` in sequence. + for i, (typ, args, kwargs, count) in enumerate(actor_specs): + args = args or [] # Allow None. + kwargs = kwargs or {} # Allow None. + # We don't have enough actors yet of this spec co-located on + # the desired node. + if len(ok[i]) < count: + co_located = try_create_colocated( + cls=typ, + args=args, + kwargs=kwargs, + count=count * (attempt + 1), + node=node) + # If node did not matter (None), from here on, use the host + # that the first actor(s) are already co-located on. + if node is None: + node = ray.get(co_located[0].get_host.remote()) + # Add the newly co-located actors to the `ok` list. + ok[i].extend(co_located) + # If we still don't have enough -> We'll have to do another + # attempt. + if len(ok[i]) < count: + all_good = False + # We created too many actors for this spec -> Kill/truncate + # the excess ones. + if len(ok[i]) > count: + for a in ok[i][count:]: + a.__ray_terminate__.remote() + ok[i] = ok[i][:count] + + # All `actor_specs` have been fulfilled, return lists of + # co-located actors. + if all_good: + return ok + + raise Exception("Unable to create enough colocated actors -> aborting.") + + +def try_create_colocated( + cls: Type[ActorClass], + args: List[Any], + count: int, + kwargs: Optional[List[Any]] = None, + node: Optional[str] = "localhost", +) -> List[ActorHandle]: + """Tries to co-locate (same node) a set of Actors of the same type. + + Returns a list of successfully co-located actors. All actors that could + not be co-located (with the others on the given node) will not be in this + list. + + Creates each actor via it's remote() constructor and then checks, whether + it has been co-located (on the same node) with the other (already created) + ones. If not, terminates the just created actor. + + Args: + cls: The Actor class to use (already @ray.remote "converted"). + args: List of args to pass to the Actor's constructor. One item + per to-be-created actor (`count`). + count: Number of actors of the given `cls` to construct. + kwargs: Optional list of kwargs to pass to the Actor's constructor. + One item per to-be-created actor (`count`). + node: The node to co-locate the actors on. By default ("localhost"), + place the actors on the node the caller of this function is + located on. If None, will try to co-locate all actors on + any available node. + + Returns: + List containing all successfully co-located actor handles. + """ + if node == "localhost": + node = platform.node() + + kwargs = kwargs or {} + actors = [cls.remote(*args, **kwargs) for _ in range(count)] + co_located, non_co_located = split_colocated(actors, node=node) + logger.info("Got {} colocated actors of {}".format(len(co_located), count)) + for a in non_co_located: a.__ray_terminate__.remote() - return non_colocated + return co_located + + +def split_colocated( + actors: List[ActorHandle], + node: Optional[str] = "localhost", +) -> Tuple[List[ActorHandle], List[ActorHandle]]: + """Splits up given actors into colocated (on same node) and non colocated. + The co-location criterion depends on the `node` given: + If given (or default: platform.node()): Consider all actors that are on + that node "colocated". + If None: Consider the largest sub-set of actors that are all located on + the same node (whatever that node is) as "colocated". -def split_colocated(actors): - localhost = platform.node() + Args: + actors: The list of actor handles to split into "colocated" and + "non colocated". + node: The node defining "colocation" criterion. If provided, consider + thos actors "colocated" that sit on this node. If None, use the + largest subset within `actors` that are sitting on the same + (any) node. + + Returns: + Tuple of two lists: 1) Co-located ActorHandles, 2) non co-located + ActorHandles. + """ + if node == "localhost": + node = platform.node() + + # Get nodes of all created actors. hosts = ray.get([a.get_host.remote() for a in actors]) - local = [] - non_local = [] - for host, a in zip(hosts, actors): - if host == localhost: - local.append(a) - else: - non_local.append(a) - return local, non_local + # If `node` not provided, use the largest group of actors that sit on the + # same node, regardless of what that node is. + if node is None: + node_groups = defaultdict(set) + for host, actor in zip(hosts, actors): + node_groups[host].add(actor) + max_ = -1 + largest_group = None + for host in node_groups: + if max_ < len(node_groups[host]): + max_ = len(node_groups[host]) + largest_group = host + non_co_located = [] + for host in node_groups: + if host != largest_group: + non_co_located.extend(list(node_groups[host])) + return list(node_groups[largest_group]), non_co_located + # Node provided (or default: localhost): Consider those actors "colocated" + # that were placed on `node`. + else: + # Split into co-located (on `node) and non-co-located (not on `node`). + co_located = [] + non_co_located = [] + for host, a in zip(hosts, actors): + # This actor has been placed on the correct node. + if host == node: + co_located.append(a) + # This actor has been placed on a different node. + else: + non_co_located.append(a) + return co_located, non_co_located -def try_create_colocated(cls, args, count): - actors = [cls.remote(*args) for _ in range(count)] - local, rest = split_colocated(actors) - logger.info("Got {} colocated actors of {}".format(len(local), count)) - for a in rest: - a.__ray_terminate__.remote() - return local - - -def create_colocated(cls, args, count): - logger.info("Trying to create {} colocated actors".format(count)) - ok = [] - i = 1 - while len(ok) < count and i < 10: - attempt = try_create_colocated(cls, args, count * i) - ok.extend(attempt) - i += 1 - if len(ok) < count: - raise Exception("Unable to create enough colocated actors, abort.") - for a in ok[count:]: + +@Deprecated(new="create_colocated_actors", error=False) +def create_colocated(cls, arg, count): + kwargs = {} + args = arg + + return create_colocated_actors( + actor_specs=[(cls, args, kwargs, count)], + node=platform.node(), # force on localhost + )[cls] + + +@Deprecated(error=False) +def drop_colocated(actors: List[ActorHandle]) -> List[ActorHandle]: + colocated, non_colocated = split_colocated(actors) + for a in colocated: a.__ray_terminate__.remote() - return ok[:count] + return non_colocated