diff --git a/rllib/execution/buffers/mixin_replay_buffer.py b/rllib/execution/buffers/mixin_replay_buffer.py new file mode 100644 index 0000000000000..2bbd0afbd49a1 --- /dev/null +++ b/rllib/execution/buffers/mixin_replay_buffer.py @@ -0,0 +1,146 @@ +import collections +import platform +import random +from typing import Optional + +from ray.rllib.execution.replay_ops import SimpleReplayBuffer +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.utils.timer import TimerStat +from ray.rllib.utils.typing import PolicyID, SampleBatchType + + +class MixInMultiAgentReplayBuffer: + """This buffer adds replayed samples to a stream of new experiences. + + - Any newly added batch (`add_batch()`) is immediately returned upon + the next `replay` call (close to on-policy) as well as being moved + into the buffer. + - Additionally, a certain number of old samples is mixed into the + returned sample according to a given "replay ratio". + - If >1 calls to `add_batch()` are made without any `replay()` calls + in between, all newly added batches are returned (plus some older samples + according to the "replay ratio"). + + Examples: + # replay ratio 0.66 (2/3 replayed, 1/3 new samples): + >>> buffer = MixInMultiAgentReplayBuffer(capacity=100, + ... replay_ratio=0.66) + >>> buffer.add_batch() + >>> buffer.add_batch() + >>> buffer.replay() + ... [, , ] + >>> buffer.add_batch() + >>> buffer.replay() + ... [, , ] + >>> # or: [, , ] or [, , ], but always as it + >>> # is the newest sample + + >>> buffer.add_batch() + >>> buffer.replay() + ... [, , ] + + # replay proportion 0.0 -> replay disabled: + >>> buffer = MixInReplay(capacity=100, replay_ratio=0.0) + >>> buffer.add_batch() + >>> buffer.replay() + ... [] + >>> buffer.add_batch() + >>> buffer.replay() + ... [] + """ + + def __init__(self, capacity: int, replay_ratio: float): + """Initializes MixInReplay instance. + + Args: + capacity (int): Number of batches to store in total. + replay_ratio (float): Ratio of replayed samples in the returned + batches. E.g. a ratio of 0.0 means only return new samples + (no replay), a ratio of 0.5 means always return newest sample + plus one old one (1:1), a ratio of 0.66 means always return + the newest sample plus 2 old (replayed) ones (1:2), etc... + """ + self.capacity = capacity + self.replay_ratio = replay_ratio + self.replay_proportion = None + if self.replay_ratio != 1.0: + self.replay_proportion = self.replay_ratio / ( + 1.0 - self.replay_ratio) + + def new_buffer(): + return SimpleReplayBuffer(num_slots=capacity) + + self.replay_buffers = collections.defaultdict(new_buffer) + + # Metrics. + self.add_batch_timer = TimerStat() + self.replay_timer = TimerStat() + self.update_priorities_timer = TimerStat() + + # Added timesteps over lifetime. + self.num_added = 0 + + # Last added batch(es). + self.last_added_batches = collections.defaultdict(list) + + def add_batch(self, batch: SampleBatchType) -> None: + """Adds a batch to the appropriate policy's replay buffer. + + Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if + it is not a MultiAgentBatch. Subsequently adds the individual policy + batches to the storage. + + Args: + batch: The batch to be added. + """ + # Make a copy so the replay buffer doesn't pin plasma memory. + batch = batch.copy() + batch = batch.as_multi_agent() + + with self.add_batch_timer: + for policy_id, sample_batch in batch.policy_batches.items(): + self.replay_buffers[policy_id].add_batch(sample_batch) + self.last_added_batches[policy_id].append(sample_batch) + self.num_added += batch.count + + def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \ + Optional[SampleBatchType]: + buffer = self.replay_buffers[policy_id] + # Return None, if: + # - Buffer empty or + # - `replay_ratio` < 1.0 (new samples required in returned batch) + # and no new samples to mix with replayed ones. + if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0 + and self.replay_ratio < 1.0): + return None + + # Mix buffer's last added batches with older replayed batches. + with self.replay_timer: + output_batches = self.last_added_batches[policy_id] + self.last_added_batches[policy_id] = [] + + # No replay desired -> Return here. + if self.replay_ratio == 0.0: + return SampleBatch.concat_samples(output_batches) + # Only replay desired -> Return a (replayed) sample from the + # buffer. + elif self.replay_ratio == 1.0: + return buffer.replay() + + # Replay ratio = old / [old + new] + # Replay proportion: old / new + num_new = len(output_batches) + replay_proportion = self.replay_proportion + while random.random() < num_new * replay_proportion: + replay_proportion -= 1 + output_batches.append(buffer.replay()) + return SampleBatch.concat_samples(output_batches) + + def get_host(self) -> str: + """Returns the computer's network name. + + Returns: + The computer's networks name or an empty string, if the network + name could not be determined. + """ + return platform.node() diff --git a/rllib/execution/buffers/multi_agent_replay_buffer.py b/rllib/execution/buffers/multi_agent_replay_buffer.py index 87cb75f9a2afa..1c9bdd104043a 100644 --- a/rllib/execution/buffers/multi_agent_replay_buffer.py +++ b/rllib/execution/buffers/multi_agent_replay_buffer.py @@ -1,6 +1,6 @@ import collections import platform -from typing import Any, Dict +from typing import Any, Dict, Optional import numpy as np import ray @@ -13,7 +13,7 @@ from ray.rllib.utils import deprecation_warning from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.timer import TimerStat -from ray.rllib.utils.typing import SampleBatchType +from ray.rllib.utils.typing import PolicyID, SampleBatchType from ray.util.iter import ParallelIteratorWorker @@ -195,7 +195,7 @@ def add_batch(self, batch: SampleBatchType) -> None: time_slice, weight=weight) self.num_added += batch.count - def replay(self) -> SampleBatchType: + def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType: """If this buffer was given a fake batch, return it, otherwise return a MultiAgentBatch with samples. """ @@ -211,8 +211,13 @@ def replay(self) -> SampleBatchType: # Lockstep mode: Sample from all policies at the same time an # equal amount of steps. if self.replay_mode == "lockstep": + assert policy_id is None, \ + "`policy_id` specifier not allowed in `locksetp` mode!" return self.replay_buffers[_ALL_POLICIES].sample( self.replay_batch_size, beta=self.prioritized_replay_beta) + elif policy_id is not None: + return self.replay_buffers[policy_id].sample( + self.replay_batch_size, beta=self.prioritized_replay_beta) else: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): diff --git a/rllib/execution/buffers/replay_buffer.py b/rllib/execution/buffers/replay_buffer.py index 5f5f1d883b4e0..1d19b00fa5d3f 100644 --- a/rllib/execution/buffers/replay_buffer.py +++ b/rllib/execution/buffers/replay_buffer.py @@ -132,19 +132,25 @@ def add(self, item: SampleBatchType, weight: float) -> None: @DeveloperAPI def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType: - """Sample a batch of experiences. + """Sample a batch of size `num_items` from this buffer. + + If less than `num_items` records are in this buffer, some samples in + the results may be repeated to fulfil the batch size (`num_items`) + request. Args: num_items: Number of items to sample from this buffer. - beta: This is ignored (only used by prioritized replay buffers). + beta: The prioritized replay beta value. Only relevant if this + ReplayBuffer is a PrioritizedReplayBuffer. Returns: Concatenated batch of items. """ - idxes = [ - random.randint(0, - len(self._storage) - 1) for _ in range(num_items) - ] + # If we don't have any samples yet in this buffer, return None. + if len(self) == 0: + return None + + idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)] sample = self._encode_sample(idxes) # Update our timesteps counters. self._num_timesteps_sampled += len(sample) @@ -282,6 +288,10 @@ def sample(self, num_items: int, beta: float) -> SampleBatchType: "batch_indexes" fields denoting IS of each sampled transition and original idxes in buffer of sampled experiences. """ + # If we don't have any samples yet in this buffer, return None. + if len(self) == 0: + return None + assert beta >= 0.0 idxes = self._sample_proportional(num_items) diff --git a/rllib/execution/parallel_requests.py b/rllib/execution/parallel_requests.py new file mode 100644 index 0000000000000..67b291cada1f6 --- /dev/null +++ b/rllib/execution/parallel_requests.py @@ -0,0 +1,152 @@ +import logging +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set + +import ray +from ray.actor import ActorHandle +from ray.rllib.utils.annotations import ExperimentalAPI + +logger = logging.getLogger(__name__) + + +@ExperimentalAPI +def asynchronous_parallel_requests( + remote_requests_in_flight: DefaultDict[ActorHandle, Set[ + ray.ObjectRef]], + actors: List[ActorHandle], + ray_wait_timeout_s: Optional[float] = None, + max_remote_requests_in_flight_per_actor: int = 2, + remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None, + remote_args: Optional[List[List[Any]]] = None, + remote_kwargs: Optional[List[Dict[str, Any]]] = None, +) -> Dict[ActorHandle, Any]: + """Runs parallel and asynchronous rollouts on all remote workers. + + May use a timeout (if provided) on `ray.wait()` and returns only those + samples that could be gathered in the timeout window. Allows a maximum + of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight + per remote actor. + + Alternatively to calling `actor.sample.remote()`, the user can provide a + `remote_fn()`, which will be applied to the actor(s) instead. + + Args: + remote_requests_in_flight: Dict mapping actor handles to a set of + their currently-in-flight pending requests (those we expect to + ray.get results for next). If you have an RLlib Trainer that calls + this function, you can use its `self.remote_requests_in_flight` + property here. + actors: The List of ActorHandles to perform the remote requests on. + ray_wait_timeout_s: Timeout (in sec) to be used for the underlying + `ray.wait()` calls. If None (default), never time out (block + until at least one actor returns something). + max_remote_requests_in_flight_per_actor: Maximum number of remote + requests sent to each actor. 2 (default) is probably + sufficient to avoid idle times between two requests. + remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of + `actor.sample.remote()` to generate the requests. + remote_args: If provided, use this list (per-actor) of lists (call + args) as *args to be passed to the `remote_fn`. + E.g.: actors=[A, B], + remote_args=[[...] <- *args for A, [...] <- *args for B]. + remote_kwargs: If provided, use this list (per-actor) of dicts + (kwargs) as **kwargs to be passed to the `remote_fn`. + E.g.: actors=[A, B], + remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. + + Returns: + A dict mapping actor handles to the results received by sending requests + to these actors. + None, if no samples are ready. + + Examples: + >>> # 2 remote rollout workers (num_workers=2): + >>> batches = asynchronous_parallel_sample( + ... trainer.remote_requests_in_flight, + ... actors=trainer.workers.remote_workers(), + ... ray_wait_timeout_s=0.1, + ... remote_fn=lambda w: time.sleep(1) # sleep 1sec + ... ) + >>> print(len(batches)) + ... 2 + >>> # Expect a timeout to have happened. + >>> batches[0] is None and batches[1] is None + ... True + """ + + if remote_args is not None: + assert len(remote_args) == len(actors) + if remote_kwargs is not None: + assert len(remote_kwargs) == len(actors) + + # For faster hash lookup. + actor_set = set(actors) + + # Collect all currently pending remote requests into a single set of + # object refs. + pending_remotes = set() + # Also build a map to get the associated actor for each remote request. + remote_to_actor = {} + for actor, set_ in remote_requests_in_flight.items(): + # Only consider those actors' pending requests that are in + # the given `actors` list. + if actor in actor_set: + pending_remotes |= set_ + for r in set_: + remote_to_actor[r] = actor + + # Add new requests, if possible (if + # `max_remote_requests_in_flight_per_actor` setting allows it). + for actor_idx, actor in enumerate(actors): + # Still room for another request to this actor. + if len(remote_requests_in_flight[actor]) < \ + max_remote_requests_in_flight_per_actor: + if remote_fn is None: + req = actor.sample.remote() + else: + args = remote_args[actor_idx] if remote_args else [] + kwargs = remote_kwargs[actor_idx] if remote_kwargs else {} + req = actor.apply.remote(remote_fn, *args, **kwargs) + # Add to our set to send to ray.wait(). + pending_remotes.add(req) + # Keep our mappings properly updated. + remote_requests_in_flight[actor].add(req) + remote_to_actor[req] = actor + + # There must always be pending remote requests. + assert len(pending_remotes) > 0 + pending_remote_list = list(pending_remotes) + + # No timeout: Block until at least one result is returned. + if ray_wait_timeout_s is None: + # First try to do a `ray.wait` w/o timeout for efficiency. + ready, _ = ray.wait( + pending_remote_list, num_returns=len(pending_remotes), timeout=0) + # Nothing returned and `timeout` is None -> Fall back to a + # blocking wait to make sure we can return something. + if not ready: + ready, _ = ray.wait(pending_remote_list, num_returns=1) + # Timeout: Do a `ray.wait() call` w/ timeout. + else: + ready, _ = ray.wait( + pending_remote_list, + num_returns=len(pending_remotes), + timeout=ray_wait_timeout_s) + + # Return empty results if nothing ready after the timeout. + if not ready: + return {} + + # Remove in-flight records for ready refs. + for obj_ref in ready: + remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref) + + # Do one ray.get(). + results = ray.get(ready) + assert len(ready) == len(results) + + # Return mapping from (ready) actors to their results. + ret = {} + for obj_ref, result in zip(ready, results): + ret[remote_to_actor[obj_ref]] = result + + return ret diff --git a/rllib/execution/replay_ops.py b/rllib/execution/replay_ops.py index 7503b18fec319..6cd1cc5606f61 100644 --- a/rllib/execution/replay_ops.py +++ b/rllib/execution/replay_ops.py @@ -42,12 +42,12 @@ def __init__( actors: An optional list of replay actors to use instead of `local_buffer`. """ - if bool(local_buffer) == bool(actors): + if local_buffer is not None and actors is not None: raise ValueError( "Either `local_buffer` or `replay_actors` must be given, " "not both!") - if local_buffer: + if local_buffer is not None: self.local_actor = local_buffer self.replay_actors = None else: @@ -55,7 +55,7 @@ def __init__( self.replay_actors = actors def __call__(self, batch: SampleBatchType): - if self.local_actor: + if self.local_actor is not None: self.local_actor.add_batch(batch) else: actor = random.choice(self.replay_actors) @@ -64,8 +64,8 @@ def __call__(self, batch: SampleBatchType): def Replay(*, - local_buffer: MultiAgentReplayBuffer = None, - actors: List[ActorHandle] = None, + local_buffer: Optional[MultiAgentReplayBuffer] = None, + actors: Optional[List[ActorHandle]] = None, num_async: int = 4) -> LocalIterator[SampleBatchType]: """Replay experiences from the given buffer or actors. @@ -87,11 +87,11 @@ def Replay(*, SampleBatch(...) """ - if bool(local_buffer) == bool(actors): + if local_buffer is not None and actors is not None: raise ValueError( "Exactly one of local_buffer and replay_actors must be given.") - if actors: + if actors is not None: replay = from_actors(actors) return replay.gather_async( num_async=num_async).filter(lambda x: x is not None) @@ -135,6 +135,8 @@ def __init__(self, self.replay_batches = [] self.replay_index = 0 + self.last_added_batches = [] + def add_batch(self, sample_batch: SampleBatchType) -> None: warn_replay_capacity(item=sample_batch, num_items=self.num_slots) if self.num_slots > 0: @@ -145,9 +147,14 @@ def add_batch(self, sample_batch: SampleBatchType) -> None: self.replay_index += 1 self.replay_index %= self.num_slots + self.last_added_batches.append(sample_batch) + def replay(self) -> SampleBatchType: return random.choice(self.replay_batches) + def __len__(self): + return len(self.replay_batches) + class MixInReplay: """This operator adds replay to a stream of experiences. diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 999420f3a0af3..fc4b384c04887 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -1,10 +1,9 @@ import logging import time -from typing import Any, Callable, Container, Dict, List, Optional, Tuple, \ +from typing import Callable, Container, List, Optional, Tuple, \ TYPE_CHECKING import ray -from ray.actor import ActorHandle from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \ @@ -21,7 +20,6 @@ from ray.util.iter_metrics import SharedMetrics if TYPE_CHECKING: - from ray.rllib.agents.trainer import Trainer from ray.rllib.evaluation.rollout_worker import RolloutWorker logger = logging.getLogger(__name__) @@ -77,135 +75,6 @@ def synchronous_parallel_sample( return sample_batches -# TODO: Move to generic parallel ops module and rename to -# `asynchronous_parallel_requests`: -@ExperimentalAPI -def asynchronous_parallel_sample( - trainer: "Trainer", - actors: List[ActorHandle], - ray_wait_timeout_s: Optional[float] = None, - max_remote_requests_in_flight_per_actor: int = 2, - remote_fn: Optional[Callable[["RolloutWorker"], None]] = None, - remote_args: Optional[List[List[Any]]] = None, - remote_kwargs: Optional[List[Dict[str, Any]]] = None, -) -> Optional[List[SampleBatch]]: - """Runs parallel and asynchronous rollouts on all remote workers. - - May use a timeout (if provided) on `ray.wait()` and returns only those - samples that could be gathered in the timeout window. Allows a maximum - of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight - per remote actor. - - Alternatively to calling `actor.sample.remote()`, the user can provide a - `remote_fn()`, which will be applied to the actor(s) instead. - - Args: - trainer: The Trainer object that we run the sampling for. - actors: The List of ActorHandles to perform the remote requests on. - ray_wait_timeout_s: Timeout (in sec) to be used for the underlying - `ray.wait()` calls. If None (default), never time out (block - until at least one actor returns something). - max_remote_requests_in_flight_per_actor: Maximum number of remote - requests sent to each actor. 2 (default) is probably - sufficient to avoid idle times between two requests. - remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of - `actor.sample.remote()` to generate the requests. - remote_args: If provided, use this list (per-actor) of lists (call - args) as *args to be passed to the `remote_fn`. - E.g.: actors=[A, B], - remote_args=[[...] <- *args for A, [...] <- *args for B]. - remote_kwargs: If provided, use this list (per-actor) of dicts - (kwargs) as **kwargs to be passed to the `remote_fn`. - E.g.: actors=[A, B], - remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. - - Returns: - The list of asynchronously collected sample batch types. None, if no - samples are ready. - - Examples: - >>> # 2 remote rollout workers (num_workers=2): - >>> batches = asynchronous_parallel_sample( - ... trainer, - ... actors=trainer.workers.remote_workers(), - ... ray_wait_timeout_s=0.1, - ... remote_fn=lambda w: time.sleep(1) # sleep 1sec - ... ) - >>> print(len(batches)) - ... 2 - >>> # Expect a timeout to have happened. - >>> batches[0] is None and batches[1] is None - ... True - """ - - if remote_args is not None: - assert len(remote_args) == len(actors) - if remote_kwargs is not None: - assert len(remote_kwargs) == len(actors) - - # Collect all currently pending remote requests into a single set of - # object refs. - pending_remotes = set() - # Also build a map to get the associated actor for each remote request. - remote_to_actor = {} - for actor, set_ in trainer.remote_requests_in_flight.items(): - pending_remotes |= set_ - for r in set_: - remote_to_actor[r] = actor - - # Add new requests, if possible (if - # `max_remote_requests_in_flight_per_actor` setting allows it). - for actor_idx, actor in enumerate(actors): - # Still room for another request to this actor. - if len(trainer.remote_requests_in_flight[actor]) < \ - max_remote_requests_in_flight_per_actor: - if remote_fn is None: - req = actor.sample.remote() - else: - args = remote_args[actor_idx] if remote_args else [] - kwargs = remote_kwargs[actor_idx] if remote_kwargs else {} - req = actor.apply.remote(remote_fn, *args, **kwargs) - # Add to our set to send to ray.wait(). - pending_remotes.add(req) - # Keep our mappings properly updated. - trainer.remote_requests_in_flight[actor].add(req) - remote_to_actor[req] = actor - - # There must always be pending remote requests. - assert len(pending_remotes) > 0 - pending_remote_list = list(pending_remotes) - - # No timeout: Block until at least one result is returned. - if ray_wait_timeout_s is None: - # First try to do a `ray.wait` w/o timeout for efficiency. - ready, _ = ray.wait( - pending_remote_list, num_returns=len(pending_remotes), timeout=0) - # Nothing returned and `timeout` is None -> Fall back to a - # blocking wait to make sure we can return something. - if not ready: - ready, _ = ray.wait(pending_remote_list, num_returns=1) - # Timeout: Do a `ray.wait() call` w/ timeout. - else: - ready, _ = ray.wait( - pending_remote_list, - num_returns=len(pending_remotes), - timeout=ray_wait_timeout_s) - - # Return None if nothing ready after the timeout. - if not ready: - return None - - for obj_ref in ready: - # Remove in-flight record for this ref. - trainer.remote_requests_in_flight[remote_to_actor[obj_ref]].remove( - obj_ref) - remote_to_actor.pop(obj_ref) - - results = ray.get(ready) - - return results - - def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync", num_async=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. diff --git a/rllib/execution/tests/test_mixin_multi_agent_replay_buffer.py b/rllib/execution/tests/test_mixin_multi_agent_replay_buffer.py new file mode 100644 index 0000000000000..bb24a4860ed8b --- /dev/null +++ b/rllib/execution/tests/test_mixin_multi_agent_replay_buffer.py @@ -0,0 +1,151 @@ +import numpy as np +import unittest + +from ray.rllib.execution.buffers.mixin_replay_buffer import \ + MixInMultiAgentReplayBuffer +from ray.rllib.policy.sample_batch import SampleBatch + + +class TestMixInMultiAgentReplayBuffer(unittest.TestCase): + """Tests insertion and mixed sampling of the MixInMultiAgentReplayBuffer. + """ + + capacity = 10 + + def _generate_data(self): + return SampleBatch({ + "obs": [np.random.random((4, ))], + "action": [np.random.choice([0, 1])], + "reward": [np.random.rand()], + "new_obs": [np.random.random((4, ))], + "done": [np.random.choice([False, True])], + }) + + def test_mixin_sampling(self): + # 50% replay ratio. + buffer = MixInMultiAgentReplayBuffer( + capacity=self.capacity, replay_ratio=0.5) + # Add a new batch. + batch = self._generate_data() + buffer.add_batch(batch) + # Expect at least 1 sample to be returned. + sample = buffer.replay() + self.assertTrue(len(sample) >= 1) + # If we insert and replay n times, expect roughly return batches of + # len 2 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 old sample + # on average in each returned value). + results = [] + for _ in range(100): + buffer.add_batch(batch) + sample = buffer.replay() + results.append(len(sample)) + self.assertAlmostEqual(np.mean(results), 2.0) + + # 33% replay ratio. + buffer = MixInMultiAgentReplayBuffer( + capacity=self.capacity, replay_ratio=0.333) + # Expect exactly 0 samples to be returned (buffer empty). + sample = buffer.replay() + self.assertTrue(sample is None) + # Add a new batch. + batch = self._generate_data() + buffer.add_batch(batch) + # Expect at least 1 sample to be returned. + sample = buffer.replay() + self.assertTrue(len(sample) >= 1) + # If we insert-2x and replay n times, expect roughly return batches of + # len 3 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 old sample + # on average in each returned value). + results = [] + for _ in range(100): + buffer.add_batch(batch) + buffer.add_batch(batch) + sample = buffer.replay() + results.append(len(sample)) + self.assertAlmostEqual(np.mean(results), 3.0, delta=0.1) + + # If we insert-1x and replay n times, expect roughly return batches of + # len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 old + # samples on average in each returned value). + results = [] + for _ in range(100): + buffer.add_batch(batch) + sample = buffer.replay() + results.append(len(sample)) + self.assertAlmostEqual(np.mean(results), 1.5, delta=0.1) + + # 90% replay ratio. + buffer = MixInMultiAgentReplayBuffer( + capacity=self.capacity, replay_ratio=0.9) + # Expect exactly 0 samples to be returned (buffer empty). + sample = buffer.replay() + self.assertTrue(sample is None) + # Add a new batch. + batch = self._generate_data() + buffer.add_batch(batch) + # Expect at least 2 samples to be returned (new one plus at least one + # replay sample). + sample = buffer.replay() + self.assertTrue(len(sample) >= 2) + # If we insert and replay n times, expect roughly return batches of + # len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old + # samples on average in each returned value). + results = [] + for _ in range(100): + buffer.add_batch(batch) + sample = buffer.replay() + results.append(len(sample)) + self.assertAlmostEqual(np.mean(results), 10.0, delta=0.1) + + # 0% replay ratio -> Only new samples. + buffer = MixInMultiAgentReplayBuffer( + capacity=self.capacity, replay_ratio=0.0) + # Add a new batch. + batch = self._generate_data() + buffer.add_batch(batch) + # Expect exactly 1 sample to be returned. + sample = buffer.replay() + self.assertTrue(len(sample) == 1) + # Expect exactly 0 sample to be returned (nothing new to be returned; + # no replay allowed (replay_ratio=0.0)). + sample = buffer.replay() + self.assertTrue(sample is None) + # If we insert and replay n times, expect roughly return batches of + # len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples + # on average in each returned value). + results = [] + for _ in range(100): + buffer.add_batch(batch) + sample = buffer.replay() + results.append(len(sample)) + self.assertAlmostEqual(np.mean(results), 1.0) + + # 100% replay ratio -> Only new samples. + buffer = MixInMultiAgentReplayBuffer( + capacity=self.capacity, replay_ratio=1.0) + # Expect exactly 0 samples to be returned (buffer empty). + sample = buffer.replay() + self.assertTrue(sample is None) + # Add a new batch. + batch = self._generate_data() + buffer.add_batch(batch) + # Expect exactly 1 sample to be returned (the new batch). + sample = buffer.replay() + self.assertTrue(len(sample) == 1) + # Another replay -> Expect exactly 1 sample to be returned. + sample = buffer.replay() + self.assertTrue(len(sample) == 1) + # If we replay n times, expect roughly return batches of + # len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples + # on average in each returned value). + results = [] + for _ in range(100): + sample = buffer.replay() + results.append(len(sample)) + self.assertAlmostEqual(np.mean(results), 1.0) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/jax/misc.py b/rllib/models/jax/misc.py index 8a9397a1fdb4b..e7b067c03f0ab 100644 --- a/rllib/models/jax/misc.py +++ b/rllib/models/jax/misc.py @@ -38,7 +38,7 @@ def __init__(self, # By default, use Glorot unform initializer. if initializer is None: - initializer = flax.nn.initializers.xavier_uniform() + initializer = nn.initializers.xavier_uniform() self.prng_key = prng_key or jax.random.PRNGKey(int(time.time())) _, self.prng_key = jax.random.split(self.prng_key) diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index e28e3376acf6e..53ace30afd1a9 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -19,6 +19,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.metrics import NUM_AGENT_STEPS_TRAINED from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.spaces.space_utils import normalize_action @@ -296,7 +297,10 @@ def build_eager_tf_policy( class eager_policy_cls(base): def __init__(self, observation_space, action_space, config): - assert tf.executing_eagerly() + # If this class runs as a @ray.remote actor, eager mode may not + # have been activated yet. + if not tf1.executing_eagerly(): + tf1.enable_eager_execution() self.framework = config.get("framework", "tfe") Policy.__init__(self, observation_space, action_space, config) @@ -600,7 +604,10 @@ def learn_on_batch(self, postprocessed_batch): postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch) postprocessed_batch.set_training(True) stats = self._learn_on_batch_helper(postprocessed_batch) - stats.update({"custom_metrics": learn_stats}) + stats.update({ + "custom_metrics": learn_stats, + NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count, + }) return convert_to_numpy(stats) @override(Policy) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 3419c62da54be..f341595c2766a 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -6,8 +6,11 @@ import numpy as np import platform import tree # pip install dm_tree -from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \ + TYPE_CHECKING, Union +import ray +from ray.actor import ActorHandle from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 @@ -22,7 +25,7 @@ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ get_dummy_batch_for_space, unbatch from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ - T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union + PolicyID, PolicyState, T, TensorType, TensorStructType, TrainerConfigDict tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -451,6 +454,36 @@ def learn_on_batch(self, samples: SampleBatch) -> Dict[str, TensorType]: self.apply_gradients(grads) return grad_info + @ExperimentalAPI + def learn_on_batch_from_replay_buffer( + self, replay_actor: ActorHandle, + policy_id: PolicyID) -> Dict[str, TensorType]: + """Samples a batch from given replay actor and performs an update. + + Args: + replay_actor: The replay buffer actor to sample from. + policy_id: The ID of this policy. + + Returns: + Dictionary of extra metadata from `compute_gradients()`. + """ + # Sample a batch from the given replay actor. + # Note that for better performance (less data sent through the + # network), this policy should be co-located on the same node + # as `replay_actor`. Such a co-location step is usually done during + # the Trainer's `setup()` phase. + batch = ray.get(replay_actor.replay.remote(policy_id=policy_id)) + if batch is None: + return {} + + # Send to own learn_on_batch method for updating. + # TODO: hack w/ `hasattr` + if hasattr(self, "devices") and len(self.devices) > 1: + self.load_batch_into_buffer(batch, buffer_index=0) + return self.learn_on_loaded_batch(offset=0, buffer_index=0) + else: + return self.learn_on_batch(batch) + @DeveloperAPI def load_batch_into_buffer(self, batch: SampleBatch, buffer_index: int = 0) -> int: @@ -606,7 +639,7 @@ def get_initial_state(self) -> List[TensorType]: return [] @DeveloperAPI - def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: + def get_state(self) -> PolicyState: """Returns the entire current state of this Policy. Note: Not to be confused with an RNN model's internal state. @@ -626,10 +659,7 @@ def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]: return state @DeveloperAPI - def set_state( - self, - state: Union[Dict[str, TensorType], List[TensorType]], - ) -> None: + def set_state(self, state: PolicyState) -> None: """Restores the entire current state of this Policy from `state`. Args: diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 4324c39037662..27caf30c33f7c 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -8,11 +8,15 @@ LAST_TARGET_UPDATE_TS = "last_target_update_ts" NUM_TARGET_UPDATES = "num_target_updates" -# Performance timers (keys for metrics.timers). +# Performance timers (keys for Trainer._timers or metrics.timers). APPLY_GRADS_TIMER = "apply_grad" COMPUTE_GRADS_TIMER = "compute_grads" -WORKER_UPDATE_TIMER = "update" +SYNCH_WORKER_WEIGHTS_TIMER = "synch_weights" GRAD_WAIT_TIMER = "grad_wait" SAMPLE_TIMER = "sample" LEARN_ON_BATCH_TIMER = "learn" LOAD_BATCH_TIMER = "load" +TARGET_NET_UPDATE_TIMER = "target_net_update" + +# Deprecated: Use `SYNCH_WORKER_WEIGHTS_TIMER` instead. +WORKER_UPDATE_TIMER = "update" diff --git a/rllib/utils/metrics/window_stat.py b/rllib/utils/metrics/window_stat.py index 9aa0d9f301dff..8c753a2b64980 100644 --- a/rllib/utils/metrics/window_stat.py +++ b/rllib/utils/metrics/window_stat.py @@ -2,27 +2,72 @@ class WindowStat: - def __init__(self, name, n): + """Handles/stores incoming datastream and provides window-based statistics. + + Examples: + >>> win_stats = WindowStat("level", 3) + >>> win_stats.push(5.0) + >>> win_stats.push(7.0) + >>> win_stats.push(7.0) + >>> win_stats.push(10.0) + >>> # Expect 8.0 as the mean of the last 3 values: (7+7+10)/3=8.0 + >>> print(win_stats.mean()) + ... 8.0 + """ + + def __init__(self, name: str, n: int): + """Initializes a WindowStat instance. + + Args: + name: The name of the stats to collect and return stats for. + n: The window size. Statistics will be computed for the last n + items received from the stream. + """ + # The window-size. + self.window_size = n + # The name of the data (used for `self.stats()`). self.name = name - self.items = [None] * n + # List of items to do calculations over (len=self.n). + self.items = [None] * self.window_size + # The current index to insert the next item into `self.items`. self.idx = 0 + # How many items have been added over the lifetime of this object. self.count = 0 - def push(self, obj): + def push(self, obj) -> None: + """Pushes a new value/object into the data buffer.""" + # Insert object at current index. self.items[self.idx] = obj + # Increase insertion index by 1. self.idx += 1 + # Increase lifetime count by 1. self.count += 1 + # Fix index in case of rollover. self.idx %= len(self.items) - def stats(self): + def mean(self) -> float: + """Returns the (NaN-)mean of the last `self.window_size` items. + """ + return float(np.nanmean(self.items[:self.count])) + + def std(self) -> float: + """Returns the (NaN)-stddev of the last `self.window_size` items. + """ + return float(np.nanstd(self.items[:self.count])) + + def quantiles(self) -> np.ndarray: + """Returns ndarray with 0, 10, 50, 90, and 100 percentiles. + """ if not self.count: - _quantiles = [] + return np.ndarray([], dtype=np.float32) else: - _quantiles = np.nanpercentile(self.items[:self.count], - [0, 10, 50, 90, 100]).tolist() + return np.nanpercentile(self.items[:self.count], + [0, 10, 50, 90, 100]).tolist() + + def stats(self): return { self.name + "_count": int(self.count), - self.name + "_mean": float(np.nanmean(self.items[:self.count])), - self.name + "_std": float(np.nanstd(self.items[:self.count])), - self.name + "_quantiles": _quantiles, + self.name + "_mean": self.mean(), + self.name + "_std": self.std(), + self.name + "_quantiles": self.quantiles(), }