Skip to content

Commit

Permalink
[RLlib] APPO (new API stack) enhancements vol 05: Auto-sleep time AND…
Browse files Browse the repository at this point in the history
… thread-safety for MetricsLogger. (ray-project#48868)

Signed-off-by: Connor Sanders <[email protected]>
  • Loading branch information
sven1977 authored and jecsand838 committed Dec 4, 2024
1 parent 7e09904 commit 2a166ab
Show file tree
Hide file tree
Showing 17 changed files with 456 additions and 232 deletions.
23 changes: 19 additions & 4 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3548,11 +3548,26 @@ def _compile_iteration_results_new_api_stack(self, *, train_results, eval_result
),
}

# Compile all throughput stats.
throughputs = {}

def _reduce(p, s):
if isinstance(s, Stats):
ret = s.peek()
_throughput = s.peek(throughput=True)
if _throughput is not None:
_curr = throughputs
for k in p[:-1]:
_curr = _curr.setdefault(k, {})
_curr[p[-1] + "_throughput"] = _throughput
else:
ret = s
return ret

# Resolve all `Stats` leafs by peeking (get their reduced values).
return tree.map_structure(
lambda s: s.peek() if isinstance(s, Stats) else s,
results,
)
all_results = tree.map_structure_with_path(_reduce, results)
deep_update(all_results, throughputs, new_keys_allowed=True)
return all_results

def __repr__(self):
if self.config.enable_rl_module_and_learner:
Expand Down
17 changes: 12 additions & 5 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,9 +1813,11 @@ def env_runners(
fill up, causing spilling of objects to disk. This can cause any
asynchronous requests to become very slow, making your experiment run
slowly as well. You can inspect the object store during your experiment
via a call to `ray memory` on your head node, and by using the Ray
through a call to `ray memory` on your head node, and by using the Ray
dashboard. If you're seeing that the object store is filling up,
turn down the number of remote requests in flight or enable compression.
turn down the number of remote requests in flight or enable compression
or increase the object store memory through, for example:
`ray.init(object_store_memory=10 * 1024 * 1024 * 1024) # =10 GB`
sample_collector: For the old API stack only. The SampleCollector class to
be used to collect and retrieve environment-, model-, and sampler data.
Override the SampleCollector base class to implement your own
Expand Down Expand Up @@ -2144,9 +2146,14 @@ def learners(
CUDA devices. For example if `os.environ["CUDA_VISIBLE_DEVICES"] = "1"`
and `local_gpu_idx=0`, RLlib uses the GPU with ID=1 on the node.
max_requests_in_flight_per_learner: Max number of in-flight requests
to each Learner (actor)). See the
`ray.rllib.utils.actor_manager.FaultTolerantActorManager` class for more
details.
to each Learner (actor). You normally do not have to tune this setting
(default is 3), however, for asynchronous algorithms, this determines
the "queue" size for incoming batches (or lists of episodes) into each
Learner worker, thus also determining, how much off-policy'ness would be
acceptable. The off-policy'ness is the difference between the numbers of
updates a policy has undergone on the Learner vs the EnvRunners.
See the `ray.rllib.utils.actor_manager.FaultTolerantActorManager` class
for more details.
Returns:
This updated AlgorithmConfig object.
Expand Down
3 changes: 2 additions & 1 deletion rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
TARGET_ACTION_DIST_LOGITS_KEY = "target_action_dist_logits"
OLD_ACTION_DIST_KEY = "old_action_dist"
OLD_ACTION_DIST_LOGITS_KEY = "old_action_dist_logits"


class APPOConfig(IMPALAConfig):
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict
from ray.rllib.utils.metrics import (
ALL_MODULES,
LAST_TARGET_UPDATE_TS,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
Expand Down Expand Up @@ -89,7 +88,7 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:

# TODO (sven): Maybe we should have a `after_gradient_based_update`
# method per module?
curr_timestep = self.metrics.peek((ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME))
curr_timestep = timesteps.get(NUM_ENV_STEPS_TRAINED_LIFETIME, 0)
for module_id, module in self.module._rl_modules.items():
config = self.config.get_config_for_module(module_id)

Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/appo/appo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Tuple

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.algorithms.appo.appo import TARGET_ACTION_DIST_LOGITS_KEY
from ray.rllib.algorithms.appo.appo import OLD_ACTION_DIST_LOGITS_KEY
from ray.rllib.core.learner.utils import make_target_network
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
Expand Down Expand Up @@ -32,7 +32,7 @@ def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
old_pi_inputs_encoded = self._old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self._old_pi(old_pi_inputs_encoded)
return {TARGET_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}
return {OLD_ACTION_DIST_LOGITS_KEY: old_action_dist_logits}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(PPORLModule)
Expand Down
126 changes: 43 additions & 83 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
"""Asynchronous Proximal Policy Optimization (APPO)
The algorithm is described in [1] (under the name of "IMPACT"):
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#appo
[1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
Luo et al. 2020
https://arxiv.org/pdf/1912.00167
"""
from typing import Dict

from ray.rllib.algorithms.appo.appo import (
APPOConfig,
LEARNER_RESULTS_CURR_KL_COEFF_KEY,
LEARNER_RESULTS_KL_KEY,
TARGET_ACTION_DIST_LOGITS_KEY,
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.appo.appo_learner import APPOLearner
from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner
Expand Down Expand Up @@ -71,49 +60,45 @@ def compute_loss_for_module(
)

action_dist_cls_train = module.get_train_action_dist_cls()

# Policy being trained (current).
current_action_dist = action_dist_cls_train.from_logits(
target_policy_dist = action_dist_cls_train.from_logits(
fwd_out[Columns.ACTION_DIST_INPUTS]
)
current_actions_logp = current_action_dist.logp(batch[Columns.ACTIONS])
current_actions_logp_time_major = make_time_major(
current_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,

old_target_policy_dist = action_dist_cls_train.from_logits(
module.forward_target(batch)[OLD_ACTION_DIST_LOGITS_KEY]
)
old_target_policy_actions_logp = old_target_policy_dist.logp(
batch[Columns.ACTIONS]
)
behaviour_actions_logp = batch[Columns.ACTION_LOGP]
target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS])

# Target policy.
target_action_dist = action_dist_cls_train.from_logits(
module.forward_target(batch)[TARGET_ACTION_DIST_LOGITS_KEY]
behaviour_actions_logp_time_major = make_time_major(
behaviour_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
target_actions_logp = target_action_dist.logp(batch[Columns.ACTIONS])
target_actions_logp_time_major = make_time_major(
target_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

# EnvRunner's policy (behavior).
behavior_actions_logp = batch[Columns.ACTION_LOGP]
behavior_actions_logp_time_major = make_time_major(
behavior_actions_logp,
old_actions_logp_time_major = make_time_major(
old_target_policy_actions_logp,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

rewards_time_major = make_time_major(
batch[Columns.REWARDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)

assert Columns.VALUES_BOOTSTRAPPED not in batch
values_time_major = make_time_major(
values,
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
)
assert Columns.VALUES_BOOTSTRAPPED not in batch
# Use as bootstrap values the vf-preds in the next "batch row", except
# for the very last row (which doesn't have a next row), for which the
# bootstrap value does not matter b/c it has a +1ts value at its end
Expand All @@ -127,86 +112,61 @@ def compute_loss_for_module(
dim=0,
)

# The discount factor that is used should be `gamma * lambda_`, except for
# termination timesteps, in which case the discount factor should be 0.
# The discount factor that is used should be gamma except for timesteps where
# the episode is terminated. In that case, the discount factor should be 0.
discounts_time_major = (
(
1.0
- make_time_major(
batch[Columns.TERMINATEDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
).float()
# See [1] 3.1: Discounts must contain the GAE lambda_ parameter as well.
)
* config.gamma
* config.lambda_
)
1.0
- make_time_major(
batch[Columns.TERMINATEDS],
trajectory_len=rollout_frag_or_episode_len,
recurrent_seq_len=recurrent_seq_len,
).float()
) * config.gamma

# Note that vtrace will compute the main loop on the CPU for better performance.
vtrace_adjusted_target_values, pg_advantages = vtrace_torch(
# See [1] 3.1: For AˆV-GAE, the ratios used are: min(c¯, π(target)/π(i))
# π(target)
target_action_log_probs=target_actions_logp_time_major,
# π(i)
behaviour_action_log_probs=behavior_actions_logp_time_major,
# See [1] 3.1: Discounts must contain the GAE lambda_ parameter as well.
target_action_log_probs=old_actions_logp_time_major,
behaviour_action_log_probs=behaviour_actions_logp_time_major,
discounts=discounts_time_major,
rewards=rewards_time_major,
values=values_time_major,
bootstrap_values=bootstrap_values,
# c¯
clip_rho_threshold=config.vtrace_clip_rho_threshold,
# c¯ (but we allow users to distinguish between c¯ used for
# value estimates and c¯ used for the advantages.
clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold,
clip_rho_threshold=config.vtrace_clip_rho_threshold,
)
pg_advantages = pg_advantages * loss_mask_time_major

# The policy gradient loss.
# As described in [1], use a logp-ratio of:
# min(π(i) / π(target), ρ) * (π / π(i)), where ..
# - π are the action probs from the current (learner) policy
# - π(i) are the action probs from the ith EnvRunner
# - π(target) are the action probs from the target network
# - ρ is the "target-worker clipping" (2.0 in the paper)
target_worker_is_ratio = torch.clip(
torch.exp(
behavior_actions_logp_time_major - target_actions_logp_time_major
),
# The policy gradients loss.
is_ratio = torch.clip(
torch.exp(behaviour_actions_logp_time_major - old_actions_logp_time_major),
0.0,
config.target_worker_clipping,
2.0,
)
target_worker_logp_ratio = target_worker_is_ratio * torch.exp(
current_actions_logp_time_major - behavior_actions_logp_time_major
logp_ratio = is_ratio * torch.exp(
target_actions_logp_time_major - behaviour_actions_logp_time_major
)

surrogate_loss = torch.minimum(
pg_advantages * target_worker_logp_ratio,
pg_advantages * logp_ratio,
pg_advantages
* torch.clip(
target_worker_logp_ratio,
1 - config.clip_param,
1 + config.clip_param,
),
* torch.clip(logp_ratio, 1 - config.clip_param, 1 + config.clip_param),
)
mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)

# Compute KL-loss (if required): KL divergence between current action dist.
# and target action dict.
if config.use_kl_loss:
action_kl = target_action_dist.kl(current_action_dist) * loss_mask
action_kl = old_target_policy_dist.kl(target_policy_dist) * loss_mask
mean_kl_loss = torch.sum(action_kl) / size_loss_mask
else:
mean_kl_loss = 0.0
mean_pi_loss = -(torch.sum(surrogate_loss) / size_loss_mask)

# Compute value function loss.
# The baseline loss.
delta = values_time_major - vtrace_adjusted_target_values
vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0) * loss_mask_time_major)
mean_vf_loss = vf_loss / size_loss_mask

# Compute entropy loss.
# The entropy loss.
mean_entropy_loss = (
-torch.sum(current_action_dist.entropy() * loss_mask) / size_loss_mask
-torch.sum(target_policy_dist.entropy() * loss_mask) / size_loss_mask
)

# The summed weighted loss.
Expand Down
Loading

0 comments on commit 2a166ab

Please sign in to comment.