-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Preparatory PR for multi-agent, multi-GPU learning agent (alpha-star style) #02. #21649
Changes from 26 commits
f6f3581
a9f3098
5fe33ee
2f2c546
21efe03
05802c9
6616496
59c8a33
2312bcd
b9b8e98
3a1b7f0
76d02ef
3ad6097
e3c9222
fb01568
aa990c4
616b467
526e0e6
86dad3a
ae5e118
ada9719
73a589b
baaf2ec
fda66dc
9d0d9ec
f31d1d8
303fff6
ad18fcf
bec657a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import collections | ||
import platform | ||
import random | ||
from typing import Optional | ||
|
||
from ray.rllib.execution.replay_ops import SimpleReplayBuffer | ||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch | ||
from ray.rllib.utils.timer import TimerStat | ||
from ray.rllib.utils.typing import PolicyID, SampleBatchType | ||
|
||
|
||
class MixInMultiAgentReplayBuffer: | ||
"""This buffer adds replayed samples to a stream of new experiences. | ||
|
||
- Any newly added batch (`add_batch()`) is immediately returned upon | ||
the next `replay` call (close to on-policy) as well as being moved | ||
into the buffer. | ||
- Additionally, a certain number of old samples is mixed into the | ||
returned sample according to a given "replay ratio". | ||
- If >1 calls to `add_batch()` are made without any `replay()` calls | ||
in between, all newly added batches are returned (plus some older samples | ||
according to the "replay ratio"). | ||
|
||
Examples: | ||
# replay ratio 0.66 (2/3 replayed, 1/3 new samples): | ||
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100, | ||
... replay_ratio=0.66) | ||
>>> buffer.add_batch(<A>) | ||
>>> buffer.add_batch(<B>) | ||
>>> buffer.replay() | ||
... [<A>, <B>, <B>] | ||
>>> buffer.add_batch(<C>) | ||
>>> buffer.replay() | ||
... [<C>, <A>, <B>] | ||
>>> # or: [<C>, <A>, <A>] or [<C>, <B>, <B>], but always <C> as it | ||
>>> # is the newest sample | ||
|
||
>>> buffer.add_batch(<D>) | ||
>>> buffer.replay() | ||
... [<D>, <A>, <C>] | ||
|
||
# replay proportion 0.0 -> replay disabled: | ||
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0) | ||
>>> buffer.add_batch(<A>) | ||
>>> buffer.replay() | ||
... [<A>] | ||
>>> buffer.add_batch(<B>) | ||
>>> buffer.replay() | ||
... [<B>] | ||
""" | ||
|
||
def __init__(self, capacity: int, replay_ratio: float): | ||
"""Initializes MixInReplay instance. | ||
|
||
Args: | ||
capacity (int): Number of batches to store in total. | ||
replay_ratio (float): Ratio of replayed samples in the returned | ||
batches. E.g. a ratio of 0.0 means only return new samples | ||
(no replay), a ratio of 0.5 means always return newest sample | ||
plus one old one (1:1), a ratio of 0.66 means always return | ||
the newest sample plus 2 old (replayed) ones (1:2), etc... | ||
""" | ||
self.capacity = capacity | ||
self.replay_ratio = replay_ratio | ||
self.replay_proportion = None | ||
if self.replay_ratio != 1.0: | ||
self.replay_proportion = self.replay_ratio / ( | ||
1.0 - self.replay_ratio) | ||
|
||
def new_buffer(): | ||
return SimpleReplayBuffer(num_slots=capacity) | ||
|
||
self.replay_buffers = collections.defaultdict(new_buffer) | ||
|
||
# Metrics. | ||
self.add_batch_timer = TimerStat() | ||
self.replay_timer = TimerStat() | ||
self.update_priorities_timer = TimerStat() | ||
|
||
# Added timesteps over lifetime. | ||
self.num_added = 0 | ||
|
||
# Last added batch(es). | ||
self.last_added_batches = collections.defaultdict(list) | ||
|
||
def add_batch(self, batch: SampleBatchType) -> None: | ||
"""Adds a batch to the appropriate policy's replay buffer. | ||
|
||
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if | ||
it is not a MultiAgentBatch. Subsequently adds the individual policy | ||
batches to the storage. | ||
|
||
Args: | ||
batch: The batch to be added. | ||
""" | ||
# Make a copy so the replay buffer doesn't pin plasma memory. | ||
batch = batch.copy() | ||
batch = batch.as_multi_agent() | ||
|
||
with self.add_batch_timer: | ||
for policy_id, sample_batch in batch.policy_batches.items(): | ||
self.replay_buffers[policy_id].add_batch(sample_batch) | ||
self.last_added_batches[policy_id].append(sample_batch) | ||
self.num_added += batch.count | ||
|
||
def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \ | ||
Optional[SampleBatchType]: | ||
buffer = self.replay_buffers[policy_id] | ||
# Return None, if: | ||
# - Buffer empty or | ||
# - `replay_ratio` < 1.0 (new samples required in returned batch) | ||
# and no new samples to mix with replayed ones. | ||
if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0 | ||
and self.replay_ratio < 1.0): | ||
return None | ||
|
||
# Mix buffer's last added batches with older replayed batches. | ||
with self.replay_timer: | ||
output_batches = self.last_added_batches[policy_id].copy() | ||
self.last_added_batches[policy_id].clear() | ||
|
||
# No replay desired -> Return here. | ||
if self.replay_ratio == 0.0: | ||
return SampleBatch.concat_samples(output_batches) | ||
# Only replay desired -> Return a (replayed) sample from the | ||
# buffer. | ||
elif self.replay_ratio == 1.0: | ||
return buffer.replay() | ||
|
||
# Replay ratio = old / [old + new] | ||
# Replay proportion: old / new | ||
num_new = len(output_batches) | ||
replay_proportion = self.replay_proportion | ||
gjoliver marked this conversation as resolved.
Show resolved
Hide resolved
|
||
while random.random() < num_new * replay_proportion: | ||
replay_proportion -= 1 | ||
output_batches.append(buffer.replay()) | ||
return SampleBatch.concat_samples(output_batches) | ||
|
||
def get_host(self) -> str: | ||
"""Returns the computer's network name. | ||
|
||
Returns: | ||
The computer's networks name or an empty string, if the network | ||
name could not be determined. | ||
""" | ||
return platform.node() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import logging | ||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set | ||
|
||
import ray | ||
from ray.actor import ActorHandle | ||
from ray.rllib.utils.annotations import ExperimentalAPI | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@ExperimentalAPI | ||
def asynchronous_parallel_requests( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function looks really familiar. are we not replacing some existing logics with this util func call somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, sorry, moved it into a new module for better clarity: This function may not only be used to collect SampleBatches from a RolloutWorker, but works generically on any set (and types!) of ray remote actors. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the old code (not used yet anywhere anyways) in |
||
remote_requests_in_flight: DefaultDict[ActorHandle, Set[ | ||
ray.ObjectRef]], | ||
actors: List[ActorHandle], | ||
ray_wait_timeout_s: Optional[float] = None, | ||
max_remote_requests_in_flight_per_actor: int = 2, | ||
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None, | ||
remote_args: Optional[List[List[Any]]] = None, | ||
remote_kwargs: Optional[List[Dict[str, Any]]] = None, | ||
) -> Dict[ActorHandle, Any]: | ||
"""Runs parallel and asynchronous rollouts on all remote workers. | ||
|
||
May use a timeout (if provided) on `ray.wait()` and returns only those | ||
samples that could be gathered in the timeout window. Allows a maximum | ||
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight | ||
per remote actor. | ||
|
||
Alternatively to calling `actor.sample.remote()`, the user can provide a | ||
`remote_fn()`, which will be applied to the actor(s) instead. | ||
|
||
Args: | ||
remote_requests_in_flight: Dict mapping actor handles to a set of | ||
their currently-in-flight pending requests (those we expect to | ||
ray.get results for next). If you have an RLlib Trainer that calls | ||
this function, you can use its `self.remote_requests_in_flight` | ||
property here. | ||
actors: The List of ActorHandles to perform the remote requests on. | ||
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying | ||
`ray.wait()` calls. If None (default), never time out (block | ||
until at least one actor returns something). | ||
max_remote_requests_in_flight_per_actor: Maximum number of remote | ||
requests sent to each actor. 2 (default) is probably | ||
sufficient to avoid idle times between two requests. | ||
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of | ||
`actor.sample.remote()` to generate the requests. | ||
remote_args: If provided, use this list (per-actor) of lists (call | ||
args) as *args to be passed to the `remote_fn`. | ||
E.g.: actors=[A, B], | ||
remote_args=[[...] <- *args for A, [...] <- *args for B]. | ||
remote_kwargs: If provided, use this list (per-actor) of dicts | ||
(kwargs) as **kwargs to be passed to the `remote_fn`. | ||
E.g.: actors=[A, B], | ||
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B]. | ||
|
||
Returns: | ||
A dict mapping actor handles to the results received by sending requests | ||
to these actors. | ||
None, if no samples are ready. | ||
|
||
Examples: | ||
>>> # 2 remote rollout workers (num_workers=2): | ||
>>> batches = asynchronous_parallel_sample( | ||
... trainer.remote_requests_in_flight, | ||
... actors=trainer.workers.remote_workers(), | ||
... ray_wait_timeout_s=0.1, | ||
... remote_fn=lambda w: time.sleep(1) # sleep 1sec | ||
... ) | ||
>>> print(len(batches)) | ||
... 2 | ||
>>> # Expect a timeout to have happened. | ||
>>> batches[0] is None and batches[1] is None | ||
... True | ||
""" | ||
|
||
if remote_args is not None: | ||
assert len(remote_args) == len(actors) | ||
if remote_kwargs is not None: | ||
assert len(remote_kwargs) == len(actors) | ||
|
||
# For faster hash lookup. | ||
actor_set = set(actors) | ||
|
||
# Collect all currently pending remote requests into a single set of | ||
# object refs. | ||
pending_remotes = set() | ||
# Also build a map to get the associated actor for each remote request. | ||
remote_to_actor = {} | ||
for actor, set_ in remote_requests_in_flight.items(): | ||
# Only consider those actors' pending requests that are in | ||
# the given `actors` list. | ||
if actor in actor_set: | ||
pending_remotes |= set_ | ||
for r in set_: | ||
remote_to_actor[r] = actor | ||
|
||
# Add new requests, if possible (if | ||
# `max_remote_requests_in_flight_per_actor` setting allows it). | ||
for actor_idx, actor in enumerate(actors): | ||
# Still room for another request to this actor. | ||
if len(remote_requests_in_flight[actor]) < \ | ||
max_remote_requests_in_flight_per_actor: | ||
if remote_fn is None: | ||
req = actor.sample.remote() | ||
else: | ||
args = remote_args[actor_idx] if remote_args else [] | ||
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {} | ||
req = actor.apply.remote(remote_fn, *args, **kwargs) | ||
# Add to our set to send to ray.wait(). | ||
pending_remotes.add(req) | ||
# Keep our mappings properly updated. | ||
remote_requests_in_flight[actor].add(req) | ||
remote_to_actor[req] = actor | ||
|
||
# There must always be pending remote requests. | ||
assert len(pending_remotes) > 0 | ||
pending_remote_list = list(pending_remotes) | ||
|
||
# No timeout: Block until at least one result is returned. | ||
if ray_wait_timeout_s is None: | ||
# First try to do a `ray.wait` w/o timeout for efficiency. | ||
ready, _ = ray.wait( | ||
pending_remote_list, num_returns=len(pending_remotes), timeout=0) | ||
# Nothing returned and `timeout` is None -> Fall back to a | ||
# blocking wait to make sure we can return something. | ||
if not ready: | ||
ready, _ = ray.wait(pending_remote_list, num_returns=1) | ||
# Timeout: Do a `ray.wait() call` w/ timeout. | ||
else: | ||
ready, _ = ray.wait( | ||
pending_remote_list, | ||
num_returns=len(pending_remotes), | ||
timeout=ray_wait_timeout_s) | ||
|
||
# Return empty results if nothing ready after the timeout. | ||
if not ready: | ||
return {} | ||
|
||
# Remove in-flight records for ready refs. | ||
for obj_ref in ready: | ||
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref) | ||
|
||
# Do one ray.get(). | ||
results = ray.get(ready) | ||
assert len(ready) == len(results) | ||
|
||
# Return mapping from (ready) actors to their results. | ||
ret = {} | ||
for obj_ref, result in zip(ready, results): | ||
ret[remote_to_actor[obj_ref]] = result | ||
|
||
return ret |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you clear right after copy, why not:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to save a .copy() op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think this is a good idea? this is the only comment I have left.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch! Indeed, it saves the copy. Done. :)