Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] APPO (new API stack) enhancements vol 05: Auto-sleep time AND thread-safety for MetricsLogger. #48868

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
2b09969
wip
sven1977 Oct 10, 2024
5089e12
wip
sven1977 Oct 11, 2024
490e254
wip
sven1977 Oct 17, 2024
0c8fb9e
wip
sven1977 Oct 18, 2024
8d46658
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 18, 2024
12185ca
wip
sven1977 Oct 18, 2024
6e4652b
ppo reporting everything ok now.
sven1977 Oct 18, 2024
ef549b8
fix episodes/episodes-lifetime in env runners.
sven1977 Oct 18, 2024
1171ccf
wip
sven1977 Oct 18, 2024
85c48e8
wip
sven1977 Oct 19, 2024
bd5a884
wip
sven1977 Oct 21, 2024
937ff49
wip
sven1977 Oct 21, 2024
4673c96
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 21, 2024
a686fc7
wip
sven1977 Oct 21, 2024
a6fcc37
wip
sven1977 Oct 21, 2024
b6ef29e
wip
sven1977 Oct 22, 2024
666ba01
wip
sven1977 Oct 22, 2024
70939e7
wip
sven1977 Oct 22, 2024
18fbb91
wip
sven1977 Oct 22, 2024
dbf2d07
fix
sven1977 Oct 22, 2024
75f761f
fix
sven1977 Oct 23, 2024
6c1aa7a
Merge branch 'master' of https://github.com/ray-project/ray into fix_…
sven1977 Oct 23, 2024
b164f31
Merge branch 'master' of https://github.com/ray-project/ray into impa…
sven1977 Oct 23, 2024
b65d49c
fix
sven1977 Oct 23, 2024
2342bc9
wip
sven1977 Oct 23, 2024
3707a13
Merge branch 'master' of https://github.com/ray-project/ray into add_…
sven1977 Oct 24, 2024
4d1a4ec
Merge branch 'add_off_policyness_metric_to_new_api_stack' into impala…
sven1977 Oct 24, 2024
ec7159c
wip
sven1977 Oct 24, 2024
7327a4e
wip
sven1977 Oct 24, 2024
64c09e4
wip
sven1977 Oct 24, 2024
d0969d6
wip
sven1977 Oct 24, 2024
29fa4ed
Merge branch 'fix_accumulation_of_results_in_algorithm' into impala_a…
sven1977 Oct 24, 2024
2754c9e
wip
sven1977 Oct 24, 2024
fd45de5
wip
sven1977 Oct 24, 2024
4301721
Merge branch 'master' of https://github.com/ray-project/ray into impa…
sven1977 Oct 24, 2024
4c98e7e
wip
sven1977 Oct 24, 2024
fcfff7b
merge
sven1977 Oct 28, 2024
d2ee136
merge
sven1977 Nov 12, 2024
6e48131
wip
sven1977 Nov 12, 2024
cc6b753
wip
sven1977 Nov 12, 2024
f9af97a
wip
sven1977 Nov 12, 2024
68515d2
fix
sven1977 Nov 13, 2024
5cbbb96
Merge branch 'master' of https://github.com/ray-project/ray into impa…
sven1977 Nov 13, 2024
339ff6e
wip
sven1977 Nov 13, 2024
5a2ce74
learns Pong in some time (>>10min). Not great, but does learn.
sven1977 Nov 13, 2024
bc374d0
test copying dummy batches to circumvent spending time on learner con…
sven1977 Nov 14, 2024
e0844c3
wip
sven1977 Nov 14, 2024
9e2a755
learning Pong in 700sec (R>20.0) on 31 EnvRunners and 1 local L4 GPU.
sven1977 Nov 14, 2024
f330ff3
wip
sven1977 Nov 15, 2024
6a59c81
wip
sven1977 Nov 15, 2024
4866467
merge
sven1977 Nov 15, 2024
901fbc8
wip
sven1977 Nov 15, 2024
1a53331
wip
sven1977 Nov 15, 2024
1aacee7
wip
sven1977 Nov 15, 2024
308ebc9
wip
sven1977 Nov 15, 2024
1540d1f
Learns Pong-v5 in <8min (20.0+) with 31 ER and 1 local GPU.
sven1977 Nov 15, 2024
3a12aba
wip
sven1977 Nov 15, 2024
df18c93
various fixes and enhancements:
sven1977 Nov 17, 2024
329b9de
Merge branch 'master' of https://github.com/ray-project/ray into impa…
sven1977 Nov 18, 2024
3ddb9bb
deadlock and deepcopy fix
sven1977 Nov 18, 2024
abf8f67
fix problem with tensor found in reduced stats -> have to unlock tens…
sven1977 Nov 18, 2024
d213174
wip
sven1977 Nov 18, 2024
f8d6f7b
wip
sven1977 Nov 18, 2024
2b3a468
wip
sven1977 Nov 19, 2024
c0ee159
merge
sven1977 Nov 21, 2024
9c1b7c1
Merge branch 'master' of https://github.com/ray-project/ray into impa…
sven1977 Nov 21, 2024
cc56e55
Merge branch 'master' of https://github.com/ray-project/ray into impa…
sven1977 Nov 21, 2024
d121103
wip
sven1977 Nov 21, 2024
bd87e61
merge
sven1977 Nov 22, 2024
a4f0dff
some bug fixes related to metrics (was not thread safe) and throughpu…
sven1977 Nov 22, 2024
8eeb1fe
Merge branch 'master' of https://github.com/ray-project/ray into impa…
sven1977 Nov 22, 2024
39dc133
wip
sven1977 Nov 22, 2024
540d488
LEARNING!!!
sven1977 Nov 22, 2024
200b200
wip
sven1977 Nov 22, 2024
321d8dc
wip
sven1977 Nov 22, 2024
93a6f7f
wip
sven1977 Nov 23, 2024
3ec5a54
wip
sven1977 Nov 24, 2024
ed24c39
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Nov 27, 2024
4128993
wip
sven1977 Nov 27, 2024
0996f0e
wip
sven1977 Nov 27, 2024
122bb07
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Dec 2, 2024
56b4854
wip
sven1977 Dec 2, 2024
a87ec84
LINT
sven1977 Dec 2, 2024
942a372
test: old APPO loss variable names and no "x lambda".
sven1977 Dec 2, 2024
66bd406
wip
sven1977 Dec 2, 2024
7ecac19
wip
sven1977 Dec 2, 2024
8e12918
LINT; learns Pong in ~12min (maybe less) on 1 local A10 GPU and 31 En…
sven1977 Dec 2, 2024
eef1302
wip
sven1977 Dec 3, 2024
180305f
Merge branch 'master' of https://github.com/ray-project/ray into appo…
sven1977 Dec 3, 2024
00a8fef
wip
sven1977 Dec 3, 2024
ecb4b37
wip
sven1977 Dec 3, 2024
a9eb5da
LINT
sven1977 Dec 3, 2024
d40b0c2
fix
sven1977 Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3548,11 +3548,26 @@ def _compile_iteration_results_new_api_stack(self, *, train_results, eval_result
),
}

# Compile all throughput stats.
throughputs = {}

def _reduce(p, s):
if isinstance(s, Stats):
ret = s.peek()
_throughput = s.peek(throughput=True)
if _throughput is not None:
_curr = throughputs
for k in p[:-1]:
_curr = _curr.setdefault(k, {})
_curr[p[-1] + "_throughput"] = _throughput
else:
ret = s
return ret

# Resolve all `Stats` leafs by peeking (get their reduced values).
return tree.map_structure(
lambda s: s.peek() if isinstance(s, Stats) else s,
results,
)
all_results = tree.map_structure_with_path(_reduce, results)
deep_update(all_results, throughputs, new_keys_allowed=True)
return all_results

def __repr__(self):
if self.config.enable_rl_module_and_learner:
Expand Down
17 changes: 12 additions & 5 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,9 +1813,11 @@ def env_runners(
fill up, causing spilling of objects to disk. This can cause any
asynchronous requests to become very slow, making your experiment run
slowly as well. You can inspect the object store during your experiment
via a call to `ray memory` on your head node, and by using the Ray
through a call to `ray memory` on your head node, and by using the Ray
dashboard. If you're seeing that the object store is filling up,
turn down the number of remote requests in flight or enable compression.
turn down the number of remote requests in flight or enable compression
or increase the object store memory through, for example:
`ray.init(object_store_memory=10 * 1024 * 1024 * 1024) # =10 GB`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

sample_collector: For the old API stack only. The SampleCollector class to
be used to collect and retrieve environment-, model-, and sampler data.
Override the SampleCollector base class to implement your own
Expand Down Expand Up @@ -2144,9 +2146,14 @@ def learners(
CUDA devices. For example if `os.environ["CUDA_VISIBLE_DEVICES"] = "1"`
and `local_gpu_idx=0`, RLlib uses the GPU with ID=1 on the node.
max_requests_in_flight_per_learner: Max number of in-flight requests
to each Learner (actor)). See the
`ray.rllib.utils.actor_manager.FaultTolerantActorManager` class for more
details.
to each Learner (actor). You normally do not have to tune this setting
(default is 3), however, for asynchronous algorithms, this determines
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great description!

the "queue" size for incoming batches (or lists of episodes) into each
Learner worker, thus also determining, how much off-policy'ness would be
acceptable. The off-policy'ness is the difference between the numbers of
updates a policy has undergone on the Learner vs the EnvRunners.
See the `ray.rllib.utils.actor_manager.FaultTolerantActorManager` class
for more details.

Returns:
This updated AlgorithmConfig object.
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
TARGET_ACTION_DIST_LOGITS_KEY = "target_action_dist_logits"
OLD_ACTION_DIST_KEY = "old_action_dist"
OLD_ACTION_DIST_LOGITS_KEY = "old_action_dist_logits"


class APPOConfig(IMPALAConfig):
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
from ray.rllib.utils.metrics import (
ALL_MODULES,
LAST_TARGET_UPDATE_TS,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
Expand Down Expand Up @@ -89,7 +88,7 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:

# TODO (sven): Maybe we should have a `after_gradient_based_update`
# method per module?
curr_timestep = self.metrics.peek((ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME))
curr_timestep = timesteps.get(NUM_ENV_STEPS_TRAINED_LIFETIME, 0)
for module_id, module in self.module._rl_modules.items():
config = self.config.get_config_for_module(module_id)

Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Tuple

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.algorithms.appo.appo import TARGET_ACTION_DIST_LOGITS_KEY
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
Expand Down Expand Up @@ -32,7 +32,7 @@ def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {TARGET_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(PPORLModule)
Expand Down
126 changes: 43 additions & 83 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
"""Asynchronous Proximal Policy Optimization (APPO)

The algorithm is described in [1] (under the name of "IMPACT"):

Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#appo

[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
Luo et al. 2020
https://arxiv.org/pdf/1912.00167
"""
from typing import Dict

from ray.rllib.algorithms.appo.appo import (
APPOConfig,
LEARNER_RESULTS_CURR_KL_COEFF_KEY,
LEARNER_RESULTS_KL_KEY,
TARGET_ACTION_DIST_LOGITS_KEY,
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_learner import APPOLearner
from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner
Expand Down Expand Up @@ -71,49 +60,45 @@ def compute_loss_for_module(
)

action_dist_cls_train = module.get_train_action_dist_cls()

# Policy being trained (current).
current_action_dist = action_dist_cls_train.from_logits(
target_policy_dist = action_dist_cls_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
current_actions_logp = current_action_dist.logp(batch[Columns.ACTIONS])
current_actions_logp_time_major = make_time_major(
current_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,

old_target_policy_dist = action_dist_cls_train.from_logits(
module.forward_target(batch)[OLD_ACTION_DIST_LOGITS_KEY]
)
old_target_policy_actions_logp = old_target_policy_dist.logp(
batch[Columns.ACTIONS]
)
behaviour_actions_logp = batch[Columns.ACTION_LOGP]
target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])

# Target policy.
target_action_dist = action_dist_cls_train.from_logits(
module.forward_target(batch)[TARGET_ACTION_DIST_LOGITS_KEY]
behaviour_actions_logp_time_major = make_time_major(
behaviour_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
target_actions_logp = target_action_dist.logp(batch[Columns.ACTIONS])
target_actions_logp_time_major = make_time_major(
target_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

# EnvRunner's policy (behavior).
behavior_actions_logp = batch[Columns.ACTION_LOGP]
behavior_actions_logp_time_major = make_time_major(
behavior_actions_logp,
old_actions_logp_time_major = make_time_major(
old_target_policy_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

rewards_time_major = make_time_major(
batch[Columns.REWARDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

assert Columns.VALUES_BOOTSTRAPPED not in batch
values_time_major = make_time_major(
values,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
assert Columns.VALUES_BOOTSTRAPPED not in batch
# Use as bootstrap values the vf-preds in the next "batch row", except
# for the very last row (which doesn't have a next row), for which the
# bootstrap value does not matter b/c it has a +1ts value at its end
Expand All @@ -127,86 +112,61 @@ def compute_loss_for_module(
dim=0,
)

# The discount factor that is used should be `gamma * lambda_`, except for
# termination timesteps, in which case the discount factor should be 0.
# The discount factor that is used should be gamma except for timesteps where
# the episode is terminated. In that case, the discount factor should be 0.
discounts_time_major = (
(
1.0
- make_time_major(
batch[Columns.TERMINATEDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
).float()
# See [1] 3.1: Discounts must contain the GAE lambda_ parameter as well.
)
* config.gamma
* config.lambda_
)
1.0
- make_time_major(
batch[Columns.TERMINATEDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
).float()
) * config.gamma

# Note that vtrace will compute the main loop on the CPU for better performance.
vtrace_adjusted_target_values, pg_advantages = vtrace_torch(
# See [1] 3.1: For AˆV-GAE, the ratios used are: min(c¯, π(target)/π(i))
# π(target)
target_action_log_probs=target_actions_logp_time_major,
# π(i)
behaviour_action_log_probs=behavior_actions_logp_time_major,
# See [1] 3.1: Discounts must contain the GAE lambda_ parameter as well.
target_action_log_probs=old_actions_logp_time_major,
behaviour_action_log_probs=behaviour_actions_logp_time_major,
discounts=discounts_time_major,
rewards=rewards_time_major,
values=values_time_major,
bootstrap_values=bootstrap_values,
# c¯
clip_rho_threshold=config.vtrace_clip_rho_threshold,
# c¯ (but we allow users to distinguish between c¯ used for
# value estimates and c¯ used for the advantages.
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
clip_rho_threshold=config.vtrace_clip_rho_threshold,
)
pg_advantages = pg_advantages * loss_mask_time_major

# The policy gradient loss.
# As described in [1], use a logp-ratio of:
# min(π(i) / π(target), ρ) * (π / π(i)), where ..
# - π are the action probs from the current (learner) policy
# - π(i) are the action probs from the ith EnvRunner
# - π(target) are the action probs from the target network
# - ρ is the "target-worker clipping" (2.0 in the paper)
target_worker_is_ratio = torch.clip(
torch.exp(
behavior_actions_logp_time_major - target_actions_logp_time_major
),
# The policy gradients loss.
is_ratio = torch.clip(
torch.exp(behaviour_actions_logp_time_major - old_actions_logp_time_major),
0.0,
config.target_worker_clipping,
2.0,
)
target_worker_logp_ratio = target_worker_is_ratio * torch.exp(
current_actions_logp_time_major - behavior_actions_logp_time_major
logp_ratio = is_ratio * torch.exp(
target_actions_logp_time_major - behaviour_actions_logp_time_major
)

surrogate_loss = torch.minimum(
pg_advantages * target_worker_logp_ratio,
pg_advantages * logp_ratio,
pg_advantages
* torch.clip(
target_worker_logp_ratio,
1 - config.clip_param,
1 + config.clip_param,
),
* torch.clip(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
)
mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)

# Compute KL-loss (if required): KL divergence between current action dist.
# and target action dict.
if config.use_kl_loss:
action_kl = target_action_dist.kl(current_action_dist) * loss_mask
action_kl = old_target_policy_dist.kl(target_policy_dist) * loss_mask
mean_kl_loss = torch.sum(action_kl) / size_loss_mask
else:
mean_kl_loss = 0.0
mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)

# Compute value function loss.
# The baseline loss.
delta = values_time_major - vtrace_adjusted_target_values
vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0) * loss_mask_time_major)
mean_vf_loss = vf_loss / size_loss_mask

# Compute entropy loss.
# The entropy loss.
mean_entropy_loss = (
-torch.sum(current_action_dist.entropy() * loss_mask) / size_loss_mask
-torch.sum(target_policy_dist.entropy() * loss_mask) / size_loss_mask
)

# The summed weighted loss.
Expand Down
Loading
Loading