From be8b0c7dfb2cab2492ebb4b404cd6476bdad651a Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 6 Jan 2022 15:53:28 +0100 Subject: [PATCH 1/9] wip. --- rllib/agents/dqn/apex.py | 28 ++++- rllib/agents/impala/impala.py | 2 +- rllib/agents/ppo/appo.py | 6 - rllib/agents/trainer.py | 32 ++++- rllib/env/base_env.py | 2 +- rllib/env/multi_agent_env.py | 2 +- rllib/evaluation/rollout_worker.py | 33 +++-- rllib/evaluation/worker_set.py | 6 +- rllib/execution/rollout_ops.py | 195 +++++++++++++++++++++++++++-- rllib/execution/train_ops.py | 7 -- rllib/execution/tree_agg.py | 19 ++- rllib/policy/policy.py | 38 +++++- rllib/policy/rnn_sequencing.py | 2 +- rllib/policy/sample_batch.py | 6 +- rllib/policy/torch_policy.py | 2 +- rllib/utils/actors.py | 132 ++++++++++++++----- rllib/utils/debug.py | 2 +- rllib/utils/test_utils.py | 15 ++- rllib/utils/tf_utils.py | 3 +- rllib/utils/torch_utils.py | 4 +- rllib/utils/typing.py | 6 +- 21 files changed, 448 insertions(+), 94 deletions(-) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 781f8184ab43..f7ef668bed4b 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -14,6 +14,7 @@ import collections import copy +import platform from typing import Tuple import ray @@ -32,7 +33,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.train_ops import UpdateTargetNetwork from ray.rllib.utils import merge_dicts -from ray.rllib.utils.actors import create_colocated +from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.annotations import override from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict @@ -59,6 +60,11 @@ # TODO(jungong) : add proper replay_buffer_config after # DistributedReplayBuffer type is supported. "replay_buffer_config": None, + # Whether all shards of the replay buffer must be co-located + # with the learner process (running the execution plan). + # If False, replay shards may be created on different node(s). + "replay_buffer_shards_colocated_with_driver": True, + "learning_starts": 50000, "train_batch_size": 512, "rollout_fragment_length": 50, @@ -129,7 +135,8 @@ def execution_plan(workers: WorkerSet, config: dict, # Create a number of replay buffer actors. num_replay_buffer_shards = config["optimizer"][ "num_replay_buffer_shards"] - replay_actors = create_colocated(ReplayActor, [ + + args = [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], @@ -139,7 +146,22 @@ def execution_plan(workers: WorkerSet, config: dict, config["prioritized_replay_eps"], config["multiagent"]["replay_mode"], config.get("replay_sequence_length", 1), - ], num_replay_buffer_shards) + ] + # Place all replay buffer shards on the same node as the learner + # (driver process that runs this execution plan). + if config["replay_buffer_shards_colocated_with_driver"]: + replay_actors = create_colocated_actors( + actor_specs=[ + # (class, args, kwargs={}, count) + (ReplayActor, args, {}, num_replay_buffer_shards) # [0] + ], + node=platform.node(), # localhost + )[0] + # Place replay buffer shards on any node(s). + else: + replay_actors = [ + ReplayActor(*args) for _ in range(num_replay_buffer_shards) + ] # Start the learner thread. learner_thread = LearnerThread(workers.local_worker()) diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index a0d31e6096e1..12f3b7959114 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -366,7 +366,7 @@ def default_resource_request(cls, config): { # Evaluation (remote) workers. # Note: The local eval worker is located on the driver - # CPU. + # CPU or not even created iff >0 eval workers. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", diff --git a/rllib/agents/ppo/appo.py b/rllib/agents/ppo/appo.py index 1c14db3c20ae..e030de383916 100644 --- a/rllib/agents/ppo/appo.py +++ b/rllib/agents/ppo/appo.py @@ -120,12 +120,6 @@ def __init__(self, config, *args, **kwargs): self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) - # TODO: Remove this once ImpalaTrainer directly inherits from Trainer - # (instead of being created by `build_trainer()` utility). - @override(impala.ImpalaTrainer) - def _init(self, *args, **kwargs): - raise NotImplementedError - @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index e55ab358e51e..08912aa4cea8 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -29,6 +29,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.buffers.multi_agent_replay_buffer import \ MultiAgentReplayBuffer +from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \ synchronous_parallel_sample from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \ @@ -775,8 +776,8 @@ def env_creator_from_classpath(env_context): # Set Trainer's seed after we have - if necessary - enabled # tf eager-execution. - update_global_seed_if_necessary( - config.get("framework"), config.get("seed")) + update_global_seed_if_necessary(self.config["framework"], + self.config["seed"]) self.validate_config(self.config) if not callable(self.config["callbacks"]): @@ -844,6 +845,11 @@ def env_creator_from_classpath(env_context): self.workers, self.config, **self._kwargs_for_execution_plan()) + # Now that workers have been created, update our policy specs + # in the config[multiagent] dict with the correct spaces. + self.config["multiagent"]["policies"] = \ + self.workers.local_worker().policy_map.policy_specs + # Evaluation WorkerSet setup. # User would like to setup a separate evaluation worker set. @@ -1295,6 +1301,12 @@ def training_iteration(self) -> ResultDict: else: train_results = multi_gpu_train_one_step(self, train_batch) + # Update weights - after learning on the local worker - on all remote + # workers. + if self.workers.remote_workers(): + with self._timers[WORKER_UPDATE_TIMER]: + self.workers.sync_weights() + return train_results @DeveloperAPI @@ -1976,6 +1988,22 @@ def merge_trainer_configs(cls, config2: PartialTrainerConfigDict, _allow_unknown_configs: Optional[bool] = None ) -> TrainerConfigDict: + """Merges a complete Trainer config with a partial override dict. + + Respects nested structures within the config dicts. The values in the + partial override dict take priority. + + Args: + config1: The complete Trainer's dict to be merged (overridden) + with `config2`. + config2: The partial override config dict to merge on top of + `config1`. + _allow_unknown_configs: If True, keys in `config2` that don't exist + in `config1` are allowed and will be added to the final config. + + Returns: + The merged full trainer config dict. + """ config1 = copy.deepcopy(config1) if "callbacks" in config2 and type(config2["callbacks"]) is dict: legacy_callbacks_dict = config2["callbacks"] diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 760152c79369..881917a4c042 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -722,7 +722,7 @@ def convert_to_base_env( The resulting BaseEnv object. """ - from ray.rllib.env.remote_vector_env import RemoteBaseEnv + from ray.rllib.env.remote_base_env import RemoteBaseEnv from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index c5025bc949a0..508528264207 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -168,7 +168,7 @@ def to_base_env( Returns: The resulting BaseEnv object. """ - from ray.rllib.env.remote_vector_env import RemoteBaseEnv + from ray.rllib.env.remote_base_env import RemoteBaseEnv if remote_envs: env = RemoteBaseEnv( make_env, diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index ef3f4d27e785..353078a61804 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -6,10 +6,11 @@ import platform import os import tree # pip install dm_tree -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \ +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \ TYPE_CHECKING, Union import ray +from ray import ObjectRef from ray import cloudpickle as pickle from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.env.env_context import EnvContext @@ -46,7 +47,7 @@ from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \ ModelConfigDict, ModelGradients, ModelWeights, \ MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \ - SampleBatchType + SampleBatchType, T from ray.util.debug import log_once, disable_log_once_globally, \ enable_periodic_logging from ray.util.iter import ParallelIteratorWorker @@ -56,9 +57,6 @@ from ray.rllib.evaluation.observation_function import ObservationFunction from ray.rllib.agents.callbacks import DefaultCallbacks # noqa -# Generic type var for foreach_* methods. -T = TypeVar("T") - tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -1389,6 +1387,14 @@ def set_weights(self, >>> # Set `global_vars` (timestep) as well. >>> worker.set_weights(weights, {"timestep": 42}) """ + # If per-policy weights are object refs, `ray.get()` them first. + if weights and isinstance(next(iter(weights.values())), ObjectRef): + actual_weights = ray.get(list(weights.values())) + weights = { + pid: actual_weights[i] + for i, pid in enumerate(weights.keys()) + } + for pid, w in weights.items(): self.policy_map[pid].set_weights(w) if global_vars: @@ -1436,19 +1442,26 @@ def stop(self) -> None: sess.close() @DeveloperAPI - def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T], - *args) -> T: + def apply( + self, + func: Callable[["RolloutWorker", Optional[Any], Optional[Any]], T], + *args, **kwargs) -> T: """Calls the given function with this rollout worker instance. + Useful for when the RolloutWorker class has been converted into a + ActorHandle and the user needs to execute some functionality (e.g. + add a property) on the underlying policy object. + Args: - func: The function to call with this RolloutWorker as first - argument. + func: The function to call, with this RolloutWorker as first + argument, followed by *args, and **kwargs. args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. Returns: The return value of the function call. """ - return func(self, *args) + return func(self, *args, **kwargs) def setup_torch_data_parallel(self, url: str, world_rank: int, world_size: int, backend: str) -> None: diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 7e022f374523..993aa847ea63 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -159,13 +159,15 @@ def sync_weights(self, if self.remote_workers() or from_worker is not None: weights = (from_worker or self.local_worker()).get_weights(policies) + # Put weights only once into object store and use same object + # ref to synch to all workers. weights_ref = ray.put(weights) # Sync to all remote workers in this WorkerSet. for to_worker in self.remote_workers(): to_worker.set_weights.remote(weights_ref) - # If from_worker is provided, also sync to this WorkerSet's local - # worker. + # If `from_worker` is provided, also sync to this WorkerSet's + # local worker. if from_worker is not None and self.local_worker() is not None: self.local_worker().set_weights(weights) diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 338652e08ea0..9aa39a5994a5 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -1,10 +1,11 @@ +from collections import defaultdict import logging -from typing import List, Tuple import time +from typing import Any, Callable, Dict, List, Optional, Tuple, \ + TYPE_CHECKING import ray -from ray.util.iter import from_actors, LocalIterator -from ray.util.iter_metrics import SharedMetrics +from ray.actor import ActorHandle from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ @@ -12,27 +13,207 @@ _check_sample_batch_type, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils.annotations import ExperimentalAPI from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ LEARNER_STATS_KEY from ray.rllib.utils.sgd import standardized from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients +from ray.util.iter import from_actors, LocalIterator +from ray.util.iter_metrics import SharedMetrics + +if TYPE_CHECKING: + from ray.rllib.agents.trainer import Trainer + from ray.rllib.evaluation.rollout_worker import RolloutWorker logger = logging.getLogger(__name__) -def synchronous_parallel_sample(workers: WorkerSet) -> List[SampleBatch]: +@ExperimentalAPI +def synchronous_parallel_sample( + worker_set: WorkerSet, + remote_fn: Optional[Callable[["RolloutWorker"], None]] = None, +) -> List[SampleBatch]: + """Runs parallel and synchronous rollouts on all remote workers. + + Waits for all workers to return from the remote calls. + + If no remote workers exist (num_workers == 0), use the local worker + for sampling. + + Alternatively to calling `worker.sample.remote()`, the user can provide a + `remote_fn()`, which will be applied to the worker(s) instead. + + Args: + worker_set: The WorkerSet to use for sampling. + remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead + of `worker.sample.remote()` to generate the requests. + + Returns: + The list of collected sample batch types (one for each parallel + rollout worker in the given `worker_set`). + + Examples: + >>> # 2 remote workers (num_workers=2): + >>> batches = synchronous_parallel_sample(trainer.workers) + >>> print(len(batches)) + ... 2 + >>> print(batches[0]) + ... SampleBatch(16: ['obs', 'actions', 'rewards', 'dones']) + + >>> # 0 remote workers (num_workers=0): Using the local worker. + >>> batches = synchronous_parallel_sample(trainer.workers) + >>> print(len(batches)) + ... 1 + """ # No remote workers in the set -> Use local worker for collecting # samples. - if not workers.remote_workers(): - return [workers.local_worker().sample()] + if not worker_set.remote_workers(): + return [worker_set.local_worker().sample()] # Loop over remote workers' `sample()` method in parallel. sample_batches = ray.get( - [r.sample.remote() for r in workers.remote_workers()]) + [r.sample.remote() for r in worker_set.remote_workers()]) + # Return all collected batches. return sample_batches +# TODO: Move to generic parallel ops module and rename to +# `asynchronous_parallel_requests`: +@ExperimentalAPI +def asynchronous_parallel_sample( + trainer: "Trainer", + actors: List[ActorHandle], + ray_wait_timeout_s: Optional[float] = None, + max_remote_requests_in_flight_per_actor: int = 2, + remote_fn: Optional[Callable[["RolloutWorker"], None]] = None, + remote_args: Optional[List[List[Any]]] = None, + remote_kwargs: Optional[List[Dict[str, Any]]] = None, +) -> Optional[List[SampleBatch]]: + """Runs parallel and asynchronous rollouts on all remote workers. + + May use a timeout (if provided) on `ray.wait()` and returns only those + samples that could be gathered in the timeout window. Allows a maximum + of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight + per remote actor. + + Alternatively to calling `actor.sample.remote()`, the user can provide a + `remote_fn()`, which will be applied to the actor(s) instead. + + Args: + trainer: The Trainer object that we run the sampling for. + actors: The List of ActorHandles to perform the remote requests on. + ray_wait_timeout_s: Timeout (in sec) to be used for the underlying + `ray.wait()` calls. If None (default), never time out (block + until at least one actor returns something). + max_remote_requests_in_flight_per_actor: Maximum number of remote + requests sent to each actor. 2 (default) is probably + sufficient to avoid idle times between two requests. + remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of + `actor.sample.remote()` to generate the requests. + remote_args: If provided, use this list (per-actor) of lists (call + args) as *args to be passed to the `remote_fn`. + E.g.: actors=[A, B], + remote_args=[[...] <- *args for A, [...] <- *args for B]. + remote_kwargs: If provided, use this list (per-actor) of dicts + (kwargs) as **kwargs to be passed to the `remote_fn`. + E.g.: actors=[A, B], + remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. + + Returns: + The list of asynchronously collected sample batch types. None, if no + samples are ready. + + Examples: + >>> # 2 remote rollout workers (num_workers=2): + >>> batches = asynchronous_parallel_sample( + ... trainer, + ... actors=trainer.workers.remote_workers(), + ... ray_wait_timeout_s=0.1, + ... remote_fn=lambda w: time.sleep(1) # sleep 1sec + ... ) + >>> print(len(batches)) + ... 2 + >>> # Expect a timeout to have happened. + >>> batches[0] is None and batches[1] is None + ... True + """ + + if remote_args is not None: + assert len(remote_args) == len(actors) + if remote_kwargs is not None: + assert len(remote_kwargs) == len(actors) + + # Create a map inside Trainer instance that maps actorss to sets of open + # requests (object refs). This way, we keep track, of which actorss have + # already been sent how many requests + # (`max_remote_requests_in_flight_per_actor` arg). + if not hasattr(trainer, "_remote_requests_in_flight"): + trainer._remote_requests_in_flight = defaultdict(set) + + # Collect all currently pending remote requests into a single set of + # object refs. + pending_remotes = set() + # Also build a map to get the associated actor for each remote request. + remote_to_actor = {} + for actor, set_ in trainer._remote_requests_in_flight.items(): + pending_remotes |= set_ + for r in set_: + remote_to_actor[r] = actor + + # Add new requests, if possible (if + # `max_remote_requests_in_flight_per_actor` setting allows it). + for actor_idx, actor in enumerate(actors): + # Still room for another request to this actor. + if len(trainer._remote_requests_in_flight[actor]) < \ + max_remote_requests_in_flight_per_actor: + if remote_fn is None: + req = actor.sample.remote() + else: + args = remote_args[actor_idx] if remote_args else [] + kwargs = remote_kwargs[actor_idx] if remote_kwargs else {} + req = actor.apply.remote(remote_fn, *args, **kwargs) + # Add to our set to send to ray.wait(). + pending_remotes.add(req) + # Keep our mappings properly updated. + trainer._remote_requests_in_flight[actor].add(req) + remote_to_actor[req] = actor + + # There must always be pending remote requests. + assert len(pending_remotes) > 0 + pending_remote_list = list(pending_remotes) + + # No timeout: Block until at least one result is returned. + if ray_wait_timeout_s is None: + # First try to do a `ray.wait` w/o timeout for efficiency. + ready, _ = ray.wait( + pending_remote_list, num_returns=len(pending_remotes), timeout=0) + # Nothing returned and `timeout` is None -> Fall back to a + # blocking wait to make sure we can return something. + if not ready: + ready, _ = ray.wait(pending_remote_list, num_returns=1) + # Timeout: Do a `ray.wait() call` w/ timeout. + else: + ready, _ = ray.wait( + pending_remote_list, + num_returns=len(pending_remotes), + timeout=ray_wait_timeout_s) + + # Return None if nothing ready after the timeout. + if not ready: + return + + for obj_ref in ready: + # Remove in-flight record for this ref. + trainer._remote_requests_in_flight[remote_to_actor[obj_ref]].remove( + obj_ref) + remote_to_actor.pop(obj_ref) + + results = ray.get(ready) + + return results + + def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index 3f448d8d1432..8fd57e438074 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -55,13 +55,6 @@ def train_one_step(trainer, train_batch) -> Dict: trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() - # Update weights - after learning on the local worker - on all remote - # workers. - if workers.remote_workers(): - with trainer._timers[WORKER_UPDATE_TIMER]: - weights = ray.put(workers.local_worker().get_weights(policies)) - for e in workers.remote_workers(): - e.set_weights.remote(weights) return info diff --git a/rllib/execution/tree_agg.py b/rllib/execution/tree_agg.py index 6880fb1cbbad..3766769d3acf 100644 --- a/rllib/execution/tree_agg.py +++ b/rllib/execution/tree_agg.py @@ -9,7 +9,7 @@ from ray.rllib.execution.replay_ops import MixInReplay from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.policy.sample_batch import MultiAgentBatch -from ray.rllib.utils.actors import create_colocated +from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.typing import SampleBatchType, ModelWeights from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \ from_actors, LocalIterator @@ -91,11 +91,18 @@ def gather_experiences_tree_aggregation(workers: WorkerSet, ] # This spawns |num_aggregation_workers| intermediate actors that aggregate - # experiences in parallel. We force colocation on the same node to maximize - # data bandwidth between them and the driver. - train_batches = from_actors([ - create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups - ]) + # experiences in parallel. We force colocation on the same node (localhost) + # to maximize data bandwidth between them and the driver. + all_co_located = create_colocated_actors( + actor_specs=[ + # (class, args, kwargs={}, count=1) + (Aggregator, [config, g], {}, 1) for g in rollout_groups + ], + node=platform.node()) + + # Use the first ([0]) of each created group (each group only has one + # actor: count=1). + train_batches = from_actors([group[0] for group in all_co_located]) # TODO(ekl) properly account for replay. def record_steps_sampled(batch): diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index c01f3f051bbc..a5c2115ed661 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -4,9 +4,12 @@ from gym.spaces import Box import logging import numpy as np +import platform import tree # pip install dm_tree -from typing import Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING +import ray +from ray.actor import ActorHandle from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 @@ -21,7 +24,8 @@ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ get_dummy_batch_for_space, unbatch from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ - TensorType, TensorStructType, TrainerConfigDict, Tuple, Union + PolicyID, T, TensorType, TensorStructType, TrainerConfigDict, Tuple, \ + Union tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -638,6 +642,27 @@ def set_state( self.set_weights(state["weights"]) self.global_timestep = state["global_timestep"] + @ExperimentalAPI + def apply(self, + func: Callable[["Policy", Optional[Any], Optional[Any]], T], + *args, **kwargs) -> T: + """Calls the given function with this Policy instance. + + Useful for when the Policy class has been converted into a ActorHandle + and the user needs to execute some functionality (e.g. add a property) + on the underlying policy object. + + Args: + func: The function to call, with this Policy as first + argument, followed by *args, and **kwargs. + args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. + + Returns: + The return value of the function call. + """ + return func(self, *args, **kwargs) + @DeveloperAPI def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None: """Called on an update to global vars. @@ -697,6 +722,15 @@ def get_session(self) -> Optional["tf1.Session"]: """ return None + def get_host(self) -> str: + """Returns the computer's network name. + + Returns: + The computer's networks name or an empty string, if the network + name could not be determined. + """ + return platform.node() + def _create_exploration(self) -> Exploration: """Creates the Policy's Exploration object. diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 0e4c36570612..41884e01768b 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -294,7 +294,7 @@ def chop_into_sequences( f = np.array(f) length = len(seq_lens) * max_seq_len - if f.dtype == np.object or f.dtype.type is np.str_: + if f.dtype == object or f.dtype.type is np.str_: f_pad = [None] * length else: # Make sure type doesn't change. diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index ab18d93c7d3a..b80615967720 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -620,7 +620,7 @@ def _zero_pad_in_place(path, value): or path[0] == SampleBatch.SEQ_LENS: return # Generate zero-filled primer of len=max_seq_len. - if value.dtype == np.object or value.dtype.type is np.str_: + if value.dtype == object or value.dtype.type is np.str_: f_pad = [None] * length else: # Make sure type doesn't change. @@ -651,13 +651,13 @@ def _zero_pad_in_place(path, value): return self - # Experimental method. + @ExperimentalAPI def to_device(self, device, framework="torch"): """TODO: transfer batch to given device as framework tensor.""" if framework == "torch": assert torch is not None for k, v in self.items(): - if isinstance(v, np.ndarray) and v.dtype != np.object: + if isinstance(v, np.ndarray) and v.dtype != object: self[k] = torch.from_numpy(v).to(device) else: raise NotImplementedError diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index a022147e4fa1..7631ee7995a4 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -1170,7 +1170,7 @@ def on_global_var_update(self, global_vars): @DeveloperAPI class DirectStepOptimizer: - """Typesafe method for indicating apply gradients can directly step the + """Typesafe method for indicating `apply_gradients` can directly step the optimizers with in-place gradients. """ _instance = None diff --git a/rllib/utils/actors.py b/rllib/utils/actors.py index 06eec3c16376..422df2beaa2b 100644 --- a/rllib/utils/actors.py +++ b/rllib/utils/actors.py @@ -1,7 +1,11 @@ +from collections import deque import logging import platform +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type + import ray -from collections import deque +from ray.actor import ActorHandle +from ray.rllib.utils.deprecation import Deprecated logger = logging.getLogger(__name__) @@ -65,6 +69,77 @@ def count(self): return len(self._tasks) +def create_colocated_actors( + actor_specs: Sequence[Tuple[Type, Any, Any, int]], + node: Optional[str] = "localhost", + max_attempts: int = 10, +) -> Dict[Type, List[ActorHandle]]: + """Create co-located actors of any type(s) on any node. + + Args: + actor_specs: Tuple/list with tuples consisting of: 1) The + (already @ray.remote) class(es) to construct, 2) c'tor args, + 3) c'tor kwargs, and 4) the number of actors of that class with + given args/kwargs to construct. + node: The node to co-locate the actors on. By default ("localhost"), + place the actors on the node the caller of this function is + located on. Use None for indicating that any (resource fulfilling) + node in the clusted may be used. + max_attempts: The maximum number of co-location attempts to + perform before throwing an error. + + Returns: + A dict mapping the created types to the list of n ActorHandles + created (and co-located) for that type. + """ + if node == "localhost": + node = platform.host() + + # Maps types to lists of already co-located actors. + ok = [[] for _ in range(len(actor_specs))] + attempt = 1 + while attempt < max_attempts: + all_good = True + for i, (typ, args, kwargs, count) in enumerate(actor_specs): + args = args or [] # Allow None. + kwargs = kwargs or {} # Allow None. + if len(ok[i]) < count: + all_good = False + co_located = try_create_colocated( + cls=typ, + args=args, + kwargs=kwargs, + count=count * attempt, + node=node) + # If node did not matter, from here on, use the host that the + # first actor(s) are already located on. + if node is None: + node = ray.get(co_located[0].get_host.remote()) + ok[i].extend(co_located) + elif len(ok[i]) > count: + for a in ok[i][count:]: + a.__ray_terminate__.remote() + ok[i] = ok[i][:count] + if all_good: + break + elif attempt == max_attempts - 1: + raise Exception( + "Unable to create enough colocated actors -> aborting.") + attempt += 1 + + return ok + + +def try_create_colocated(cls, args, count, kwargs=None, node=None): + kwargs = kwargs or {} + actors = [cls.remote(*args, **kwargs) for _ in range(count)] + co_located, non_co_located = split_colocated(actors, node=node) + logger.info("Got {} colocated actors of {}".format(len(co_located), count)) + for a in non_co_located: + a.__ray_terminate__.remote() + return co_located + + def drop_colocated(actors): colocated, non_colocated = split_colocated(actors) for a in colocated: @@ -72,38 +147,33 @@ def drop_colocated(actors): return non_colocated -def split_colocated(actors): - localhost = platform.node() +def split_colocated(actors, node=None): + # Get nodes of all created actors. hosts = ray.get([a.get_host.remote() for a in actors]) - local = [] - non_local = [] + # Split into co-located (on `node) and non-co-located (not on `node`). + co_located = [] + non_co_located = [] + + # If node not provided, use 1st actor's node as the desired one. + if node is None: + node = hosts[0] + for host, a in zip(hosts, actors): - if host == localhost: - local.append(a) + # This actor has been placed on the correct node. + if host == node: + co_located.append(a) + # This actor has been placed on a different node. else: - non_local.append(a) - return local, non_local + non_co_located.append(a) + return co_located, non_co_located -def try_create_colocated(cls, args, count): - actors = [cls.remote(*args) for _ in range(count)] - local, rest = split_colocated(actors) - logger.info("Got {} colocated actors of {}".format(len(local), count)) - for a in rest: - a.__ray_terminate__.remote() - return local - - -def create_colocated(cls, args, count): - logger.info("Trying to create {} colocated actors".format(count)) - ok = [] - i = 1 - while len(ok) < count and i < 10: - attempt = try_create_colocated(cls, args, count * i) - ok.extend(attempt) - i += 1 - if len(ok) < count: - raise Exception("Unable to create enough colocated actors, abort.") - for a in ok[count:]: - a.__ray_terminate__.remote() - return ok[:count] +@Deprecated(new="create_colocated_actors", error=False) +def create_colocated(cls, arg, count): + kwargs = {} + args = arg + + return create_colocated_actors( + actor_specs=[(cls, args, kwargs, count)], + node=platform.node(), # force on localhost + )[cls] diff --git a/rllib/utils/debug.py b/rllib/utils/debug.py index 90d475cdf9e1..02080e6ba572 100644 --- a/rllib/utils/debug.py +++ b/rllib/utils/debug.py @@ -42,7 +42,7 @@ def _summarize(obj): if obj.size == 0: return _StringValue("np.ndarray({}, dtype={})".format( obj.shape, obj.dtype)) - elif obj.dtype == np.object or obj.dtype.type is np.str_: + elif obj.dtype == object or obj.dtype.type is np.str_: return _StringValue("np.ndarray({}, dtype={}, head={})".format( obj.shape, obj.dtype, _summarize(obj[0]))) else: diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index df1f608b7856..dd1f6e12ce25 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -214,7 +214,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): "ERROR: x ({}) is not the same as y ({})!".format(x, y) # String/byte comparisons. elif hasattr(x, "dtype") and \ - (x.dtype == np.object or str(x.dtype).startswith(" Date: Thu, 6 Jan 2022 15:58:41 +0100 Subject: [PATCH 2/9] wip. --- rllib/agents/impala/impala.py | 2 +- rllib/agents/ppo/appo.py | 6 ---- rllib/agents/trainer.py | 32 +++++++++++++++++-- rllib/env/base_env.py | 2 +- rllib/env/multi_agent_env.py | 2 +- rllib/evaluation/rollout_worker.py | 25 +++++++++------ rllib/evaluation/worker_set.py | 6 ++-- .../buffers/multi_agent_replay_buffer.py | 13 +++++--- rllib/execution/buffers/replay_buffer.py | 15 ++++++--- rllib/execution/train_ops.py | 7 ---- rllib/policy/policy.py | 32 ++++++++++++++++++- rllib/policy/rnn_sequencing.py | 2 +- rllib/policy/sample_batch.py | 6 ++-- rllib/policy/torch_policy.py | 2 +- rllib/utils/debug.py | 2 +- rllib/utils/test_utils.py | 15 ++++++--- rllib/utils/tf_utils.py | 3 +- rllib/utils/torch_utils.py | 4 +-- rllib/utils/typing.py | 6 +++- 19 files changed, 126 insertions(+), 56 deletions(-) diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index a0d31e6096e1..12f3b7959114 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -366,7 +366,7 @@ def default_resource_request(cls, config): { # Evaluation (remote) workers. # Note: The local eval worker is located on the driver - # CPU. + # CPU or not even created iff >0 eval workers. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", diff --git a/rllib/agents/ppo/appo.py b/rllib/agents/ppo/appo.py index 1c14db3c20ae..e030de383916 100644 --- a/rllib/agents/ppo/appo.py +++ b/rllib/agents/ppo/appo.py @@ -120,12 +120,6 @@ def __init__(self, config, *args, **kwargs): self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) - # TODO: Remove this once ImpalaTrainer directly inherits from Trainer - # (instead of being created by `build_trainer()` utility). - @override(impala.ImpalaTrainer) - def _init(self, *args, **kwargs): - raise NotImplementedError - @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index e55ab358e51e..08912aa4cea8 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -29,6 +29,7 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.buffers.multi_agent_replay_buffer import \ MultiAgentReplayBuffer +from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \ synchronous_parallel_sample from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \ @@ -775,8 +776,8 @@ def env_creator_from_classpath(env_context): # Set Trainer's seed after we have - if necessary - enabled # tf eager-execution. - update_global_seed_if_necessary( - config.get("framework"), config.get("seed")) + update_global_seed_if_necessary(self.config["framework"], + self.config["seed"]) self.validate_config(self.config) if not callable(self.config["callbacks"]): @@ -844,6 +845,11 @@ def env_creator_from_classpath(env_context): self.workers, self.config, **self._kwargs_for_execution_plan()) + # Now that workers have been created, update our policy specs + # in the config[multiagent] dict with the correct spaces. + self.config["multiagent"]["policies"] = \ + self.workers.local_worker().policy_map.policy_specs + # Evaluation WorkerSet setup. # User would like to setup a separate evaluation worker set. @@ -1295,6 +1301,12 @@ def training_iteration(self) -> ResultDict: else: train_results = multi_gpu_train_one_step(self, train_batch) + # Update weights - after learning on the local worker - on all remote + # workers. + if self.workers.remote_workers(): + with self._timers[WORKER_UPDATE_TIMER]: + self.workers.sync_weights() + return train_results @DeveloperAPI @@ -1976,6 +1988,22 @@ def merge_trainer_configs(cls, config2: PartialTrainerConfigDict, _allow_unknown_configs: Optional[bool] = None ) -> TrainerConfigDict: + """Merges a complete Trainer config with a partial override dict. + + Respects nested structures within the config dicts. The values in the + partial override dict take priority. + + Args: + config1: The complete Trainer's dict to be merged (overridden) + with `config2`. + config2: The partial override config dict to merge on top of + `config1`. + _allow_unknown_configs: If True, keys in `config2` that don't exist + in `config1` are allowed and will be added to the final config. + + Returns: + The merged full trainer config dict. + """ config1 = copy.deepcopy(config1) if "callbacks" in config2 and type(config2["callbacks"]) is dict: legacy_callbacks_dict = config2["callbacks"] diff --git a/rllib/env/base_env.py b/rllib/env/base_env.py index 760152c79369..881917a4c042 100644 --- a/rllib/env/base_env.py +++ b/rllib/env/base_env.py @@ -722,7 +722,7 @@ def convert_to_base_env( The resulting BaseEnv object. """ - from ray.rllib.env.remote_vector_env import RemoteBaseEnv + from ray.rllib.env.remote_base_env import RemoteBaseEnv from ray.rllib.env.external_env import ExternalEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper diff --git a/rllib/env/multi_agent_env.py b/rllib/env/multi_agent_env.py index c5025bc949a0..508528264207 100644 --- a/rllib/env/multi_agent_env.py +++ b/rllib/env/multi_agent_env.py @@ -168,7 +168,7 @@ def to_base_env( Returns: The resulting BaseEnv object. """ - from ray.rllib.env.remote_vector_env import RemoteBaseEnv + from ray.rllib.env.remote_base_env import RemoteBaseEnv if remote_envs: env = RemoteBaseEnv( make_env, diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index ef3f4d27e785..5ea8bbb5a96d 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -6,10 +6,11 @@ import platform import os import tree # pip install dm_tree -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \ +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \ TYPE_CHECKING, Union import ray +from ray import ObjectRef from ray import cloudpickle as pickle from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.env.env_context import EnvContext @@ -46,7 +47,7 @@ from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \ ModelConfigDict, ModelGradients, ModelWeights, \ MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \ - SampleBatchType + SampleBatchType, T from ray.util.debug import log_once, disable_log_once_globally, \ enable_periodic_logging from ray.util.iter import ParallelIteratorWorker @@ -56,9 +57,6 @@ from ray.rllib.evaluation.observation_function import ObservationFunction from ray.rllib.agents.callbacks import DefaultCallbacks # noqa -# Generic type var for foreach_* methods. -T = TypeVar("T") - tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -1436,19 +1434,26 @@ def stop(self) -> None: sess.close() @DeveloperAPI - def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T], - *args) -> T: + def apply( + self, + func: Callable[["RolloutWorker", Optional[Any], Optional[Any]], T], + *args, **kwargs) -> T: """Calls the given function with this rollout worker instance. + Useful for when the RolloutWorker class has been converted into a + ActorHandle and the user needs to execute some functionality (e.g. + add a property) on the underlying policy object. + Args: - func: The function to call with this RolloutWorker as first - argument. + func: The function to call, with this RolloutWorker as first + argument, followed by *args, and **kwargs. args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. Returns: The return value of the function call. """ - return func(self, *args) + return func(self, *args, **kwargs) def setup_torch_data_parallel(self, url: str, world_rank: int, world_size: int, backend: str) -> None: diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 7e022f374523..993aa847ea63 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -159,13 +159,15 @@ def sync_weights(self, if self.remote_workers() or from_worker is not None: weights = (from_worker or self.local_worker()).get_weights(policies) + # Put weights only once into object store and use same object + # ref to synch to all workers. weights_ref = ray.put(weights) # Sync to all remote workers in this WorkerSet. for to_worker in self.remote_workers(): to_worker.set_weights.remote(weights_ref) - # If from_worker is provided, also sync to this WorkerSet's local - # worker. + # If `from_worker` is provided, also sync to this WorkerSet's + # local worker. if from_worker is not None and self.local_worker() is not None: self.local_worker().set_weights(weights) diff --git a/rllib/execution/buffers/multi_agent_replay_buffer.py b/rllib/execution/buffers/multi_agent_replay_buffer.py index 70997409c439..4e7d7d8f822e 100644 --- a/rllib/execution/buffers/multi_agent_replay_buffer.py +++ b/rllib/execution/buffers/multi_agent_replay_buffer.py @@ -1,11 +1,11 @@ import collections import platform -from typing import Dict, Any +from typing import Any, Dict import numpy as np import ray from ray.rllib import SampleBatch -from ray.rllib.execution import PrioritizedReplayBuffer +from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES from ray.rllib.policy.rnn_sequencing import \ timeslice_along_seq_lens_with_overlap @@ -54,7 +54,7 @@ def __init__( `self.replay_batch_size` will be set to the number of sequences sampled (B). prioritized_replay_alpha (float): Alpha parameter for a prioritized - replay buffer. + replay buffer. Use 0.0 for no prioritization. prioritized_replay_beta (float): Beta parameter for a prioritized replay buffer. prioritized_replay_eps (float): Epsilon parameter for a prioritized @@ -108,8 +108,11 @@ def gen_replay(): ParallelIteratorWorker.__init__(self, gen_replay, False) def new_buffer(): - return PrioritizedReplayBuffer( - self.capacity, alpha=prioritized_replay_alpha) + if prioritized_replay_alpha == 0.0: + return ReplayBuffer(self.capacity) + else: + return PrioritizedReplayBuffer( + self.capacity, alpha=prioritized_replay_alpha) self.replay_buffers = collections.defaultdict(new_buffer) diff --git a/rllib/execution/buffers/replay_buffer.py b/rllib/execution/buffers/replay_buffer.py index f4dc9ee58806..fc77d8369f1e 100644 --- a/rllib/execution/buffers/replay_buffer.py +++ b/rllib/execution/buffers/replay_buffer.py @@ -52,7 +52,7 @@ class ReplayBuffer: def __init__(self, capacity: int = 10000, size: Optional[int] = DEPRECATED_VALUE): - """Initializes a Replaybuffer instance. + """Initializes a ReplayBuffer instance. Args: capacity: Max number of timesteps to store in the FIFO @@ -84,6 +84,7 @@ def __init__(self, self._est_size_bytes = 0 def __len__(self) -> int: + """Returns the number of items currently stored in this buffer.""" return len(self._storage) @DeveloperAPI @@ -147,7 +148,7 @@ def stats(self, debug: bool = False) -> dict: """Returns the stats of this buffer. Args: - debug: If true, adds sample eviction statistics to the returned + debug: If True, adds sample eviction statistics to the returned stats dict. Returns: @@ -253,7 +254,11 @@ def _sample_proportional(self, num_items: int) -> List[int]: @DeveloperAPI @override(ReplayBuffer) def sample(self, num_items: int, beta: float) -> SampleBatchType: - """Sample a batch of experiences and return priority weights, indices. + """Sample `num_items` items from this buffer, including prio. weights. + + If less than `num_items` records are in this buffer, some samples in + the results may be repeated to fulfil the batch size (`num_items`) + request. Args: num_items: Number of items to sample from this buffer. @@ -272,11 +277,11 @@ def sample(self, num_items: int, beta: float) -> SampleBatchType: weights = [] batch_indexes = [] p_min = self._it_min.min() / self._it_sum.sum() - max_weight = (p_min * len(self._storage))**(-beta) + max_weight = (p_min * len(self))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() - weight = (p_sample * len(self._storage))**(-beta) + weight = (p_sample * len(self))**(-beta) count = self._storage[idx].count # If zero-padded, count will not be the actual batch size of the # data. diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index 3f448d8d1432..8fd57e438074 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -55,13 +55,6 @@ def train_one_step(trainer, train_batch) -> Dict: trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() - # Update weights - after learning on the local worker - on all remote - # workers. - if workers.remote_workers(): - with trainer._timers[WORKER_UPDATE_TIMER]: - weights = ray.put(workers.local_worker().get_weights(policies)) - for e in workers.remote_workers(): - e.set_weights.remote(weights) return info diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index c01f3f051bbc..2488f6a813a2 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -21,7 +21,7 @@ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ get_dummy_batch_for_space, unbatch from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ - TensorType, TensorStructType, TrainerConfigDict, Tuple, Union + T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -638,6 +638,27 @@ def set_state( self.set_weights(state["weights"]) self.global_timestep = state["global_timestep"] + @ExperimentalAPI + def apply(self, + func: Callable[["Policy", Optional[Any], Optional[Any]], T], + *args, **kwargs) -> T: + """Calls the given function with this Policy instance. + + Useful for when the Policy class has been converted into a ActorHandle + and the user needs to execute some functionality (e.g. add a property) + on the underlying policy object. + + Args: + func: The function to call, with this Policy as first + argument, followed by *args, and **kwargs. + args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. + + Returns: + The return value of the function call. + """ + return func(self, *args, **kwargs) + @DeveloperAPI def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None: """Called on an update to global vars. @@ -697,6 +718,15 @@ def get_session(self) -> Optional["tf1.Session"]: """ return None + def get_host(self) -> str: + """Returns the computer's network name. + + Returns: + The computer's networks name or an empty string, if the network + name could not be determined. + """ + return platform.node() + def _create_exploration(self) -> Exploration: """Creates the Policy's Exploration object. diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 0e4c36570612..41884e01768b 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -294,7 +294,7 @@ def chop_into_sequences( f = np.array(f) length = len(seq_lens) * max_seq_len - if f.dtype == np.object or f.dtype.type is np.str_: + if f.dtype == object or f.dtype.type is np.str_: f_pad = [None] * length else: # Make sure type doesn't change. diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index ab18d93c7d3a..b80615967720 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -620,7 +620,7 @@ def _zero_pad_in_place(path, value): or path[0] == SampleBatch.SEQ_LENS: return # Generate zero-filled primer of len=max_seq_len. - if value.dtype == np.object or value.dtype.type is np.str_: + if value.dtype == object or value.dtype.type is np.str_: f_pad = [None] * length else: # Make sure type doesn't change. @@ -651,13 +651,13 @@ def _zero_pad_in_place(path, value): return self - # Experimental method. + @ExperimentalAPI def to_device(self, device, framework="torch"): """TODO: transfer batch to given device as framework tensor.""" if framework == "torch": assert torch is not None for k, v in self.items(): - if isinstance(v, np.ndarray) and v.dtype != np.object: + if isinstance(v, np.ndarray) and v.dtype != object: self[k] = torch.from_numpy(v).to(device) else: raise NotImplementedError diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index a022147e4fa1..7631ee7995a4 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -1170,7 +1170,7 @@ def on_global_var_update(self, global_vars): @DeveloperAPI class DirectStepOptimizer: - """Typesafe method for indicating apply gradients can directly step the + """Typesafe method for indicating `apply_gradients` can directly step the optimizers with in-place gradients. """ _instance = None diff --git a/rllib/utils/debug.py b/rllib/utils/debug.py index 90d475cdf9e1..02080e6ba572 100644 --- a/rllib/utils/debug.py +++ b/rllib/utils/debug.py @@ -42,7 +42,7 @@ def _summarize(obj): if obj.size == 0: return _StringValue("np.ndarray({}, dtype={})".format( obj.shape, obj.dtype)) - elif obj.dtype == np.object or obj.dtype.type is np.str_: + elif obj.dtype == object or obj.dtype.type is np.str_: return _StringValue("np.ndarray({}, dtype={}, head={})".format( obj.shape, obj.dtype, _summarize(obj[0]))) else: diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index df1f608b7856..dd1f6e12ce25 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -214,7 +214,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): "ERROR: x ({}) is not the same as y ({})!".format(x, y) # String/byte comparisons. elif hasattr(x, "dtype") and \ - (x.dtype == np.object or str(x.dtype).startswith(" Date: Thu, 6 Jan 2022 16:00:37 +0100 Subject: [PATCH 3/9] LINT. --- rllib/evaluation/rollout_worker.py | 1 - rllib/policy/policy.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 5ea8bbb5a96d..ff77b494bbd9 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -10,7 +10,6 @@ TYPE_CHECKING, Union import ray -from ray import ObjectRef from ray import cloudpickle as pickle from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.env.env_context import EnvContext diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 2488f6a813a2..e65dfdae5570 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -4,8 +4,9 @@ from gym.spaces import Box import logging import numpy as np +import platform import tree # pip install dm_tree -from typing import Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog From 488443fe5a1c7b9695d7e5cb6d815137211ccb56 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 6 Jan 2022 16:13:37 +0100 Subject: [PATCH 4/9] wip. --- rllib/policy/policy.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index a5c2115ed661..ca7a585495e6 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -8,8 +8,6 @@ import tree # pip install dm_tree from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING -import ray -from ray.actor import ActorHandle from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 @@ -24,7 +22,7 @@ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ get_dummy_batch_for_space, unbatch from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ - PolicyID, T, TensorType, TensorStructType, TrainerConfigDict, Tuple, \ + T, TensorType, TensorStructType, TrainerConfigDict, Tuple, \ Union tf1, tf, tfv = try_import_tf() From cf07481a949a096ed7c5addb47df4faadb855ec8 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 10 Jan 2022 12:03:26 +0100 Subject: [PATCH 5/9] wip. --- rllib/agents/trainer.py | 2 +- rllib/evaluation/rollout_worker.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index b1a0d80a5e3a..3982b6d71d25 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -848,7 +848,7 @@ def env_creator_from_classpath(env_context): # Now that workers have been created, update our policy # specs in the config[multiagent] dict with the correct spaces. self.config["multiagent"]["policies"] = \ - self.workers.local_worker().policy_map.policy_specs + self.workers.local_worker().policy_dict # Evaluation WorkerSet setup. # User would like to setup a separate evaluation worker set. diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 9dec4df0fbc5..9592a8525da2 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -538,16 +538,16 @@ def make_sub_env(vector_index): self.make_sub_env_fn = make_sub_env self.spaces = spaces - policy_dict = _determine_spaces_for_multi_agent_dict( + self.policy_dict = _determine_spaces_for_multi_agent_dict( policy_spec, self.env, spaces=self.spaces, policy_config=policy_config) # List of IDs of those policies, which should be trained. - # By default, these are all policies found in the policy_dict. + # By default, these are all policies found in `self.policy_dict`. self.policies_to_train: List[PolicyID] = policies_to_train or list( - policy_dict.keys()) + self.policy_dict.keys()) self.set_policies_to_train(self.policies_to_train) self.policy_map: PolicyMap = None @@ -584,7 +584,7 @@ def make_sub_env(vector_index): f"is ignored.") self._build_policy_map( - policy_dict, + self.policy_dict, policy_config, session_creator=tf_session_creator, seed=seed) @@ -1112,7 +1112,7 @@ def add_policy( """ if policy_id in self.policy_map: raise ValueError(f"Policy ID '{policy_id}' already in policy map!") - policy_dict = _determine_spaces_for_multi_agent_dict( + policy_dict_to_add = _determine_spaces_for_multi_agent_dict( { policy_id: PolicySpec(policy_cls, observation_space, action_space, config or {}) @@ -1121,8 +1121,9 @@ def add_policy( spaces=self.spaces, policy_config=self.policy_config, ) + self.policy_dict.update(policy_dict_to_add) self._build_policy_map( - policy_dict, + policy_dict_to_add, self.policy_config, seed=self.policy_config.get("seed")) new_policy = self.policy_map[policy_id] From 51fae1071c3b434271dc3954beee9e2fcce9585c Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 11 Jan 2022 16:26:23 +0100 Subject: [PATCH 6/9] wip. --- rllib/agents/ddpg/apex.py | 5 +++++ rllib/agents/dqn/apex.py | 1 + 2 files changed, 6 insertions(+) diff --git a/rllib/agents/ddpg/apex.py b/rllib/agents/ddpg/apex.py index 99f82d2f0717..73e4ef041dc7 100644 --- a/rllib/agents/ddpg/apex.py +++ b/rllib/agents/ddpg/apex.py @@ -23,6 +23,10 @@ "buffer_size": 2000000, # TODO(jungong) : update once Apex supports replay_buffer_config. "replay_buffer_config": None, + # Whether all shards of the replay buffer must be co-located + # with the learner process (running the execution plan). + # If False, replay shards may be created on different node(s). + "replay_buffer_shards_colocated_with_driver": True, "learning_starts": 50000, "train_batch_size": 512, "rollout_fragment_length": 50, @@ -31,6 +35,7 @@ "worker_side_prioritization": True, "min_iter_time_s": 30, }, + _allow_unknown_configs=True, ) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index f7ef668bed4b..0c056f8275d1 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -56,6 +56,7 @@ "n_step": 3, "num_gpus": 1, "num_workers": 32, + "buffer_size": 2000000, # TODO(jungong) : add proper replay_buffer_config after # DistributedReplayBuffer type is supported. From d4168caa1a40557ba263ebacc98fc04a3767fc78 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 12 Jan 2022 14:41:40 +0100 Subject: [PATCH 7/9] wip. --- rllib/agents/dqn/apex.py | 10 +-- rllib/execution/rollout_ops.py | 18 ++--- rllib/utils/actors.py | 127 ++++++++++++++++++++++++++------- 3 files changed, 113 insertions(+), 42 deletions(-) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 0c056f8275d1..137effcda965 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -137,7 +137,7 @@ def execution_plan(workers: WorkerSet, config: dict, num_replay_buffer_shards = config["optimizer"][ "num_replay_buffer_shards"] - args = [ + replay_actor_args = [ num_replay_buffer_shards, config["learning_starts"], config["buffer_size"], @@ -154,14 +154,16 @@ def execution_plan(workers: WorkerSet, config: dict, replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) - (ReplayActor, args, {}, num_replay_buffer_shards) # [0] + (ReplayActor, replay_actor_args, {}, + num_replay_buffer_shards) ], node=platform.node(), # localhost - )[0] + )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: replay_actors = [ - ReplayActor(*args) for _ in range(num_replay_buffer_shards) + ReplayActor(*replay_actor_args) + for _ in range(num_replay_buffer_shards) ] # Start the learner thread. diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 9aa39a5994a5..97b05e6e5afb 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -1,4 +1,3 @@ -from collections import defaultdict import logging import time from typing import Any, Callable, Dict, List, Optional, Tuple, \ @@ -144,19 +143,12 @@ def asynchronous_parallel_sample( if remote_kwargs is not None: assert len(remote_kwargs) == len(actors) - # Create a map inside Trainer instance that maps actorss to sets of open - # requests (object refs). This way, we keep track, of which actorss have - # already been sent how many requests - # (`max_remote_requests_in_flight_per_actor` arg). - if not hasattr(trainer, "_remote_requests_in_flight"): - trainer._remote_requests_in_flight = defaultdict(set) - # Collect all currently pending remote requests into a single set of # object refs. pending_remotes = set() # Also build a map to get the associated actor for each remote request. remote_to_actor = {} - for actor, set_ in trainer._remote_requests_in_flight.items(): + for actor, set_ in trainer.remote_requests_in_flight.items(): pending_remotes |= set_ for r in set_: remote_to_actor[r] = actor @@ -165,7 +157,7 @@ def asynchronous_parallel_sample( # `max_remote_requests_in_flight_per_actor` setting allows it). for actor_idx, actor in enumerate(actors): # Still room for another request to this actor. - if len(trainer._remote_requests_in_flight[actor]) < \ + if len(trainer.remote_requests_in_flight[actor]) < \ max_remote_requests_in_flight_per_actor: if remote_fn is None: req = actor.sample.remote() @@ -176,7 +168,7 @@ def asynchronous_parallel_sample( # Add to our set to send to ray.wait(). pending_remotes.add(req) # Keep our mappings properly updated. - trainer._remote_requests_in_flight[actor].add(req) + trainer.remote_requests_in_flight[actor].add(req) remote_to_actor[req] = actor # There must always be pending remote requests. @@ -201,11 +193,11 @@ def asynchronous_parallel_sample( # Return None if nothing ready after the timeout. if not ready: - return + return None for obj_ref in ready: # Remove in-flight record for this ref. - trainer._remote_requests_in_flight[remote_to_actor[obj_ref]].remove( + trainer.remote_requests_in_flight[remote_to_actor[obj_ref]].remove( obj_ref) remote_to_actor.pop(obj_ref) diff --git a/rllib/utils/actors.py b/rllib/utils/actors.py index 422df2beaa2b..8cb4102acba4 100644 --- a/rllib/utils/actors.py +++ b/rllib/utils/actors.py @@ -1,10 +1,10 @@ -from collections import deque +from collections import defaultdict, deque import logging import platform from typing import Any, Dict, List, Optional, Sequence, Tuple, Type import ray -from ray.actor import ActorHandle +from ray.actor import ActorClass, ActorHandle from ray.rllib.utils.deprecation import Deprecated logger = logging.getLogger(__name__) @@ -84,7 +84,7 @@ def create_colocated_actors( node: The node to co-locate the actors on. By default ("localhost"), place the actors on the node the caller of this function is located on. Use None for indicating that any (resource fulfilling) - node in the clusted may be used. + node in the cluster may be used. max_attempts: The maximum number of co-location attempts to perform before throwing an error. @@ -93,7 +93,7 @@ def create_colocated_actors( created (and co-located) for that type. """ if node == "localhost": - node = platform.host() + node = platform.node() # Maps types to lists of already co-located actors. ok = [[] for _ in range(len(actor_specs))] @@ -130,7 +130,41 @@ def create_colocated_actors( return ok -def try_create_colocated(cls, args, count, kwargs=None, node=None): +def try_create_colocated( + cls: Type[ActorClass], + args: List[Any], + count: int, + kwargs: Optional[List[Any]] = None, + node: Optional[str] = "localhost", +) -> List[ActorHandle]: + """Tries to co-locate (same node) a set of Actors of the same type. + + Returns a list of successfully co-located actors. All actors that could + not be co-located (with the others on the given node) will not be in this + list. + + Creates each actor via it's remote() constructor and then checks, whether + it has been co-located (on the same node) with the other (already created) + ones. If not, terminates the just created actor. + + Args: + cls: The Actor class to use (already @ray.remote "converted"). + args: List of args to pass to the Actor's constructor. One item + per to-be-created actor (`count`). + count: Number of actors of the given `cls` to construct. + kwargs: Optional list of kwargs to pass to the Actor's constructor. + One item per to-be-created actor (`count`). + node: The node to co-locate the actors on. By default ("localhost"), + place the actors on the node the caller of this function is + located on. If None, will try to co-locate all actors on + any available node. + + Returns: + List containing all successfully co-located actor handles. + """ + if node == "localhost": + node = platform.node() + kwargs = kwargs or {} actors = [cls.remote(*args, **kwargs) for _ in range(count)] co_located, non_co_located = split_colocated(actors, node=node) @@ -140,32 +174,67 @@ def try_create_colocated(cls, args, count, kwargs=None, node=None): return co_located -def drop_colocated(actors): - colocated, non_colocated = split_colocated(actors) - for a in colocated: - a.__ray_terminate__.remote() - return non_colocated +def split_colocated( + actors: List[ActorHandle], + node: Optional[str] = "localhost", +) -> Tuple[List[ActorHandle], List[ActorHandle]]: + """Splits up given actors into colocated (on same node) and non colocated. + + The co-location criterion depends on the `node` given: + If given (or default: platform.node()): Consider all actors that are on + that node "colocated". + If None: Consider the largest sub-set of actors that are all located on + the same node (whatever that node is) as "colocated". + + Args: + actors: The list of actor handles to split into "colocated" and + "non colocated". + node: The node defining "colocation" criterion. If provided, consider + thos actors "colocated" that sit on this node. If None, use the + largest subset within `actors` that are sitting on the same + (any) node. + Returns: + Tuple of two lists: 1) Co-located ActorHandles, 2) non co-located + ActorHandles. + """ + if node == "localhost": + node = platform.node() -def split_colocated(actors, node=None): # Get nodes of all created actors. hosts = ray.get([a.get_host.remote() for a in actors]) - # Split into co-located (on `node) and non-co-located (not on `node`). - co_located = [] - non_co_located = [] - # If node not provided, use 1st actor's node as the desired one. + # If `node` not provided, use the largest group of actors that sit on the + # same node, regardless of what that node is. if node is None: - node = hosts[0] - - for host, a in zip(hosts, actors): - # This actor has been placed on the correct node. - if host == node: - co_located.append(a) - # This actor has been placed on a different node. - else: - non_co_located.append(a) - return co_located, non_co_located + node_groups = defaultdict(set) + for host, actor in zip(hosts, actors): + node_groups[host].add(actor) + max_ = -1 + largest_group = None + for host in node_groups: + if max_ < len(node_groups[host]): + max_ = len(node_groups[host]) + largest_group = host + non_co_located = [] + for host in node_groups: + if host != largest_group: + non_co_located.extend(list(node_groups[host])) + return list(node_groups[largest_group]), non_co_located + # Node provided (or default: localhost): Consider those actors "colocated" + # that were placed on `node`. + else: + # Split into co-located (on `node) and non-co-located (not on `node`). + co_located = [] + non_co_located = [] + for host, a in zip(hosts, actors): + # This actor has been placed on the correct node. + if host == node: + co_located.append(a) + # This actor has been placed on a different node. + else: + non_co_located.append(a) + return co_located, non_co_located @Deprecated(new="create_colocated_actors", error=False) @@ -177,3 +246,11 @@ def create_colocated(cls, arg, count): actor_specs=[(cls, args, kwargs, count)], node=platform.node(), # force on localhost )[cls] + + +@Deprecated(error=False) +def drop_colocated(actors: List[ActorHandle]) -> List[ActorHandle]: + colocated, non_colocated = split_colocated(actors) + for a in colocated: + a.__ray_terminate__.remote() + return non_colocated From affa2c625404ac21999b7c1a0d879272722a4f06 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 12 Jan 2022 14:48:38 +0100 Subject: [PATCH 8/9] wip. --- rllib/agents/trainer.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 3982b6d71d25..ed2d3d86b141 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -11,9 +11,11 @@ import pickle import tempfile import time -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, DefaultDict, Dict, List, Optional, Set, Tuple, \ + Type, Union import ray +from ray.actor import ActorHandle from ray.exceptions import RayError from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.env.env_context import EnvContext @@ -722,8 +724,9 @@ def default_logger_creator(config): self._episode_history = [] self._episodes_to_be_collected = [] - # Evaluation WorkerSet and metrics last returned by `self.evaluate()`. - self.evaluation_workers = None + # Evaluation WorkerSet. + self.evaluation_workers: Optional[WorkerSet] = None + # Metrics most recently returned by `self.evaluate()`. self.evaluation_metrics = {} super().__init__(config, logger_creator, remote_checkpoint_dir, @@ -798,12 +801,19 @@ def env_creator_from_classpath(env_context): self.local_replay_buffer = ( self._create_local_replay_buffer_if_necessary(self.config)) + # Create a dict, mapping ActorHandles to sets of open remote + # requests (object refs). This way, we keep track, of which actors + # inside this Trainer (e.g. a remote RolloutWorker) have + # already been sent how many (e.g. `sample()`) requests. + self.remote_requests_in_flight: \ + DefaultDict[ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) + # Deprecated way of implementing Trainer sub-classes (or "templates" # via the soon-to-be deprecated `build_trainer` utility function). # Instead, sub-classes should override the Trainable's `setup()` # method and call super().setup() from within that override at some # point. - self.workers = None + self.workers: Optional[WorkerSet] = None self.train_exec_impl = None # Old design: Override `Trainer._init` (or use `build_trainer()`, which @@ -909,7 +919,7 @@ def env_creator_from_classpath(env_context): # If evaluation_num_workers=0, use the evaluation set's local # worker for evaluation, otherwise, use its remote workers # (parallelized evaluation). - self.evaluation_workers = self._make_workers( + self.evaluation_workers: WorkerSet = self._make_workers( env_creator=self.env_creator, validate_env=None, policy_class=self.get_default_policy_class(self.config), From bd1a0bc0957f0d8bdb13b279b62a5bf9d712d458 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 13 Jan 2022 09:29:05 +0100 Subject: [PATCH 9/9] wip. --- rllib/agents/ddpg/apex.py | 7 ++++++- rllib/agents/dqn/apex.py | 7 ++++++- rllib/execution/tree_agg.py | 6 +++++- rllib/utils/actors.py | 42 ++++++++++++++++++++++++------------- 4 files changed, 45 insertions(+), 17 deletions(-) diff --git a/rllib/agents/ddpg/apex.py b/rllib/agents/ddpg/apex.py index 73e4ef041dc7..8669c4e90d86 100644 --- a/rllib/agents/ddpg/apex.py +++ b/rllib/agents/ddpg/apex.py @@ -25,7 +25,12 @@ "replay_buffer_config": None, # Whether all shards of the replay buffer must be co-located # with the learner process (running the execution plan). - # If False, replay shards may be created on different node(s). + # This is preferred b/c the learner process should have quick + # access to the data from the buffer shards, avoiding network + # traffic each time samples from the buffer(s) are drawn. + # Set this to False for relaxing this constraint and allowing + # replay shards to be created on node(s) other than the one + # on which the learner is located. "replay_buffer_shards_colocated_with_driver": True, "learning_starts": 50000, "train_batch_size": 512, diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 137effcda965..d9d1ab3f536b 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -63,7 +63,12 @@ "replay_buffer_config": None, # Whether all shards of the replay buffer must be co-located # with the learner process (running the execution plan). - # If False, replay shards may be created on different node(s). + # This is preferred b/c the learner process should have quick + # access to the data from the buffer shards, avoiding network + # traffic each time samples from the buffer(s) are drawn. + # Set this to False for relaxing this constraint and allowing + # replay shards to be created on node(s) other than the one + # on which the learner is located. "replay_buffer_shards_colocated_with_driver": True, "learning_starts": 50000, diff --git a/rllib/execution/tree_agg.py b/rllib/execution/tree_agg.py index 3766769d3acf..812a4437b808 100644 --- a/rllib/execution/tree_agg.py +++ b/rllib/execution/tree_agg.py @@ -93,12 +93,16 @@ def gather_experiences_tree_aggregation(workers: WorkerSet, # This spawns |num_aggregation_workers| intermediate actors that aggregate # experiences in parallel. We force colocation on the same node (localhost) # to maximize data bandwidth between them and the driver. + localhost = platform.node() + assert localhost != "", \ + "ERROR: Cannot determine local node name! " \ + "`platform.node()` returned empty string." all_co_located = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count=1) (Aggregator, [config, g], {}, 1) for g in rollout_groups ], - node=platform.node()) + node=localhost) # Use the first ([0]) of each created group (each group only has one # actor: count=1). diff --git a/rllib/utils/actors.py b/rllib/utils/actors.py index 8cb4102acba4..a700462c5b22 100644 --- a/rllib/utils/actors.py +++ b/rllib/utils/actors.py @@ -95,39 +95,53 @@ def create_colocated_actors( if node == "localhost": node = platform.node() - # Maps types to lists of already co-located actors. + # Maps each entry in `actor_specs` to lists of already co-located actors. ok = [[] for _ in range(len(actor_specs))] - attempt = 1 - while attempt < max_attempts: + + # Try n times to co-locate all given actor types (`actor_specs`). + # With each (failed) attempt, increase the number of actors we try to + # create (on the same node), then kill the ones that have been created in + # excess. + for attempt in range(max_attempts): + # If any attempt to co-locate fails, set this to False and we'll do + # another attempt. all_good = True + # Process all `actor_specs` in sequence. for i, (typ, args, kwargs, count) in enumerate(actor_specs): args = args or [] # Allow None. kwargs = kwargs or {} # Allow None. + # We don't have enough actors yet of this spec co-located on + # the desired node. if len(ok[i]) < count: - all_good = False co_located = try_create_colocated( cls=typ, args=args, kwargs=kwargs, - count=count * attempt, + count=count * (attempt + 1), node=node) - # If node did not matter, from here on, use the host that the - # first actor(s) are already located on. + # If node did not matter (None), from here on, use the host + # that the first actor(s) are already co-located on. if node is None: node = ray.get(co_located[0].get_host.remote()) + # Add the newly co-located actors to the `ok` list. ok[i].extend(co_located) - elif len(ok[i]) > count: + # If we still don't have enough -> We'll have to do another + # attempt. + if len(ok[i]) < count: + all_good = False + # We created too many actors for this spec -> Kill/truncate + # the excess ones. + if len(ok[i]) > count: for a in ok[i][count:]: a.__ray_terminate__.remote() ok[i] = ok[i][:count] + + # All `actor_specs` have been fulfilled, return lists of + # co-located actors. if all_good: - break - elif attempt == max_attempts - 1: - raise Exception( - "Unable to create enough colocated actors -> aborting.") - attempt += 1 + return ok - return ok + raise Exception("Unable to create enough colocated actors -> aborting.") def try_create_colocated(