Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Preparatory PR for multi-agent, multi-GPU learning agent (alpha-star style) #02. #21649

Merged
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f6f3581
wip
sven1977 Jan 17, 2022
a9f3098
wip
sven1977 Jan 17, 2022
5fe33ee
wip.
sven1977 Jan 20, 2022
2f2c546
fixes
sven1977 Jan 20, 2022
21efe03
fix
sven1977 Jan 20, 2022
05802c9
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 21, 2022
6616496
wip
sven1977 Jan 21, 2022
59c8a33
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 24, 2022
2312bcd
fixes
sven1977 Jan 24, 2022
b9b8e98
fixes
sven1977 Jan 24, 2022
3a1b7f0
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 25, 2022
76d02ef
merge
sven1977 Jan 25, 2022
3ad6097
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 25, 2022
e3c9222
wip.
sven1977 Jan 25, 2022
fb01568
fixes.
sven1977 Jan 25, 2022
aa990c4
Merge branch 'decentralized_multi_agent_learning_03' into decentraliz…
sven1977 Jan 25, 2022
616b467
wip.
sven1977 Jan 25, 2022
526e0e6
fix
sven1977 Jan 25, 2022
86dad3a
Merge branch 'decentralized_multi_agent_learning_03' into decentraliz…
sven1977 Jan 25, 2022
ae5e118
wip.
sven1977 Jan 25, 2022
ada9719
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 25, 2022
73a589b
wip
sven1977 Jan 26, 2022
baaf2ec
wip
sven1977 Jan 26, 2022
fda66dc
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 26, 2022
9d0d9ec
wip.
sven1977 Jan 27, 2022
f31d1d8
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 27, 2022
303fff6
wip
sven1977 Jan 27, 2022
ad18fcf
Merge branch 'master' of https://github.com/ray-project/ray into dece…
sven1977 Jan 27, 2022
bec657a
wip.
sven1977 Jan 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions rllib/execution/buffers/mixin_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import collections
import platform
import random
from typing import Optional

from ray.rllib.execution.replay_ops import SimpleReplayBuffer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType


class MixInMultiAgentReplayBuffer:
"""This buffer adds replayed samples to a stream of new experiences.

- Any newly added batch (`add_batch()`) is immediately returned upon
the next `replay` call (close to on-policy) as well as being moved
into the buffer.
- Additionally, a certain number of old samples is mixed into the
returned sample according to a given "replay ratio".
- If >1 calls to `add_batch()` are made without any `replay()` calls
in between, all newly added batches are returned (plus some older samples
according to the "replay ratio").

Examples:
# replay ratio 0.66 (2/3 replayed, 1/3 new samples):
>>> buffer = MixInMultiAgentReplayBuffer(capacity=100,
... replay_ratio=0.66)
>>> buffer.add_batch(<A>)
>>> buffer.add_batch(<B>)
>>> buffer.replay()
... [<A>, <B>, <B>]
>>> buffer.add_batch(<C>)
>>> buffer.replay()
... [<C>, <A>, <B>]
>>> # or: [<C>, <A>, <A>] or [<C>, <B>, <B>], but always <C> as it
>>> # is the newest sample

>>> buffer.add_batch(<D>)
>>> buffer.replay()
... [<D>, <A>, <C>]

# replay proportion 0.0 -> replay disabled:
>>> buffer = MixInReplay(capacity=100, replay_ratio=0.0)
>>> buffer.add_batch(<A>)
>>> buffer.replay()
... [<A>]
>>> buffer.add_batch(<B>)
>>> buffer.replay()
... [<B>]
"""

def __init__(self, capacity: int, replay_ratio: float):
"""Initializes MixInReplay instance.

Args:
capacity (int): Number of batches to store in total.
replay_ratio (float): Ratio of replayed samples in the returned
batches. E.g. a ratio of 0.0 means only return new samples
(no replay), a ratio of 0.5 means always return newest sample
plus one old one (1:1), a ratio of 0.66 means always return
the newest sample plus 2 old (replayed) ones (1:2), etc...
"""
self.capacity = capacity
self.replay_ratio = replay_ratio
self.replay_proportion = None
if self.replay_ratio != 1.0:
self.replay_proportion = self.replay_ratio / (
1.0 - self.replay_ratio)

def new_buffer():
return SimpleReplayBuffer(num_slots=capacity)

self.replay_buffers = collections.defaultdict(new_buffer)

# Metrics.
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()

# Added timesteps over lifetime.
self.num_added = 0

# Last added batch(es).
self.last_added_batches = collections.defaultdict(list)

def add_batch(self, batch: SampleBatchType) -> None:
"""Adds a batch to the appropriate policy's replay buffer.

Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
it is not a MultiAgentBatch. Subsequently adds the individual policy
batches to the storage.

Args:
batch: The batch to be added.
"""
# Make a copy so the replay buffer doesn't pin plasma memory.
batch = batch.copy()
batch = batch.as_multi_agent()

with self.add_batch_timer:
for policy_id, sample_batch in batch.policy_batches.items():
self.replay_buffers[policy_id].add_batch(sample_batch)
self.last_added_batches[policy_id].append(sample_batch)
self.num_added += batch.count

def replay(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> \
Optional[SampleBatchType]:
buffer = self.replay_buffers[policy_id]
# Return None, if:
# - Buffer empty or
# - `replay_ratio` < 1.0 (new samples required in returned batch)
# and no new samples to mix with replayed ones.
if len(buffer) == 0 or (len(self.last_added_batches[policy_id]) == 0
and self.replay_ratio < 1.0):
return None

# Mix buffer's last added batches with older replayed batches.
with self.replay_timer:
output_batches = self.last_added_batches[policy_id].copy()
self.last_added_batches[policy_id].clear()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you clear right after copy, why not:

output_batches = self.last_added_batches[policy_id]
self.last_added_batches[policy_id] = []

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to save a .copy() op

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think this is a good idea? this is the only comment I have left.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Indeed, it saves the copy. Done. :)


# No replay desired -> Return here.
if self.replay_ratio == 0.0:
return SampleBatch.concat_samples(output_batches)
# Only replay desired -> Return a (replayed) sample from the
# buffer.
elif self.replay_ratio == 1.0:
return buffer.replay()

# Replay ratio = old / [old + new]
# Replay proportion: old / new
num_new = len(output_batches)
replay_proportion = self.replay_proportion
gjoliver marked this conversation as resolved.
Show resolved Hide resolved
while random.random() < num_new * replay_proportion:
replay_proportion -= 1
output_batches.append(buffer.replay())
return SampleBatch.concat_samples(output_batches)

def get_host(self) -> str:
"""Returns the computer's network name.

Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()
11 changes: 8 additions & 3 deletions rllib/execution/buffers/multi_agent_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import platform
from typing import Any, Dict
from typing import Any, Dict, Optional

import numpy as np
import ray
Expand All @@ -13,7 +13,7 @@
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.util.iter import ParallelIteratorWorker


Expand Down Expand Up @@ -195,7 +195,7 @@ def add_batch(self, batch: SampleBatchType) -> None:
time_slice, weight=weight)
self.num_added += batch.count

def replay(self) -> SampleBatchType:
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
"""If this buffer was given a fake batch, return it, otherwise return
a MultiAgentBatch with samples.
"""
Expand All @@ -211,8 +211,13 @@ def replay(self) -> SampleBatchType:
# Lockstep mode: Sample from all policies at the same time an
# equal amount of steps.
if self.replay_mode == "lockstep":
assert policy_id is None, \
"`policy_id` specifier not allowed in `locksetp` mode!"
return self.replay_buffers[_ALL_POLICIES].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
elif policy_id is not None:
return self.replay_buffers[policy_id].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta)
else:
samples = {}
for policy_id, replay_buffer in self.replay_buffers.items():
Expand Down
22 changes: 16 additions & 6 deletions rllib/execution/buffers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,25 @@ def add(self, item: SampleBatchType, weight: float) -> None:

@DeveloperAPI
def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType:
"""Sample a batch of experiences.
"""Sample a batch of size `num_items` from this buffer.

If less than `num_items` records are in this buffer, some samples in
the results may be repeated to fulfil the batch size (`num_items`)
request.

Args:
num_items: Number of items to sample from this buffer.
beta: This is ignored (only used by prioritized replay buffers).
beta: The prioritized replay beta value. Only relevant if this
ReplayBuffer is a PrioritizedReplayBuffer.

Returns:
Concatenated batch of items.
"""
idxes = [
random.randint(0,
len(self._storage) - 1) for _ in range(num_items)
]
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None

idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
sample = self._encode_sample(idxes)
# Update our timesteps counters.
self._num_timesteps_sampled += len(sample)
Expand Down Expand Up @@ -282,6 +288,10 @@ def sample(self, num_items: int, beta: float) -> SampleBatchType:
"batch_indexes" fields denoting IS of each sampled
transition and original idxes in buffer of sampled experiences.
"""
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None

assert beta >= 0.0

idxes = self._sample_proportional(num_items)
Expand Down
152 changes: 152 additions & 0 deletions rllib/execution/parallel_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import logging
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set

import ray
from ray.actor import ActorHandle
from ray.rllib.utils.annotations import ExperimentalAPI

logger = logging.getLogger(__name__)


@ExperimentalAPI
def asynchronous_parallel_requests(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function looks really familiar. are we not replacing some existing logics with this util func call somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry, moved it into a new module for better clarity: This function may not only be used to collect SampleBatches from a RolloutWorker, but works generically on any set (and types!) of ray remote actors.

Copy link
Contributor Author

@sven1977 sven1977 Jan 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the old code (not used yet anywhere anyways) in replay_ops.py.

remote_requests_in_flight: DefaultDict[ActorHandle, Set[
ray.ObjectRef]],
actors: List[ActorHandle],
ray_wait_timeout_s: Optional[float] = None,
max_remote_requests_in_flight_per_actor: int = 2,
remote_fn: Optional[Callable[[ActorHandle, Any, Any], Any]] = None,
remote_args: Optional[List[List[Any]]] = None,
remote_kwargs: Optional[List[Dict[str, Any]]] = None,
) -> Dict[ActorHandle, Any]:
"""Runs parallel and asynchronous rollouts on all remote workers.

May use a timeout (if provided) on `ray.wait()` and returns only those
samples that could be gathered in the timeout window. Allows a maximum
of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
per remote actor.

Alternatively to calling `actor.sample.remote()`, the user can provide a
`remote_fn()`, which will be applied to the actor(s) instead.

Args:
remote_requests_in_flight: Dict mapping actor handles to a set of
their currently-in-flight pending requests (those we expect to
ray.get results for next). If you have an RLlib Trainer that calls
this function, you can use its `self.remote_requests_in_flight`
property here.
actors: The List of ActorHandles to perform the remote requests on.
ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
`ray.wait()` calls. If None (default), never time out (block
until at least one actor returns something).
max_remote_requests_in_flight_per_actor: Maximum number of remote
requests sent to each actor. 2 (default) is probably
sufficient to avoid idle times between two requests.
remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
`actor.sample.remote()` to generate the requests.
remote_args: If provided, use this list (per-actor) of lists (call
args) as *args to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_args=[[...] <- *args for A, [...] <- *args for B].
remote_kwargs: If provided, use this list (per-actor) of dicts
(kwargs) as **kwargs to be passed to the `remote_fn`.
E.g.: actors=[A, B],
remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].

Returns:
A dict mapping actor handles to the results received by sending requests
to these actors.
None, if no samples are ready.

Examples:
>>> # 2 remote rollout workers (num_workers=2):
>>> batches = asynchronous_parallel_sample(
... trainer.remote_requests_in_flight,
... actors=trainer.workers.remote_workers(),
... ray_wait_timeout_s=0.1,
... remote_fn=lambda w: time.sleep(1) # sleep 1sec
... )
>>> print(len(batches))
... 2
>>> # Expect a timeout to have happened.
>>> batches[0] is None and batches[1] is None
... True
"""

if remote_args is not None:
assert len(remote_args) == len(actors)
if remote_kwargs is not None:
assert len(remote_kwargs) == len(actors)

# For faster hash lookup.
actor_set = set(actors)

# Collect all currently pending remote requests into a single set of
# object refs.
pending_remotes = set()
# Also build a map to get the associated actor for each remote request.
remote_to_actor = {}
for actor, set_ in remote_requests_in_flight.items():
# Only consider those actors' pending requests that are in
# the given `actors` list.
if actor in actor_set:
pending_remotes |= set_
for r in set_:
remote_to_actor[r] = actor

# Add new requests, if possible (if
# `max_remote_requests_in_flight_per_actor` setting allows it).
for actor_idx, actor in enumerate(actors):
# Still room for another request to this actor.
if len(remote_requests_in_flight[actor]) < \
max_remote_requests_in_flight_per_actor:
if remote_fn is None:
req = actor.sample.remote()
else:
args = remote_args[actor_idx] if remote_args else []
kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
req = actor.apply.remote(remote_fn, *args, **kwargs)
# Add to our set to send to ray.wait().
pending_remotes.add(req)
# Keep our mappings properly updated.
remote_requests_in_flight[actor].add(req)
remote_to_actor[req] = actor

# There must always be pending remote requests.
assert len(pending_remotes) > 0
pending_remote_list = list(pending_remotes)

# No timeout: Block until at least one result is returned.
if ray_wait_timeout_s is None:
# First try to do a `ray.wait` w/o timeout for efficiency.
ready, _ = ray.wait(
pending_remote_list, num_returns=len(pending_remotes), timeout=0)
# Nothing returned and `timeout` is None -> Fall back to a
# blocking wait to make sure we can return something.
if not ready:
ready, _ = ray.wait(pending_remote_list, num_returns=1)
# Timeout: Do a `ray.wait() call` w/ timeout.
else:
ready, _ = ray.wait(
pending_remote_list,
num_returns=len(pending_remotes),
timeout=ray_wait_timeout_s)

# Return empty results if nothing ready after the timeout.
if not ready:
return {}

# Remove in-flight records for ready refs.
for obj_ref in ready:
remote_requests_in_flight[remote_to_actor[obj_ref]].remove(obj_ref)

# Do one ray.get().
results = ray.get(ready)
assert len(ready) == len(results)

# Return mapping from (ready) actors to their results.
ret = {}
for obj_ref, result in zip(ready, results):
ret[remote_to_actor[obj_ref]] = result

return ret
Loading