Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] New ConnectorV3 API #05: PPO runs in single-agent mode in this API stack #42272

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Jan 9, 2024

EnvRunners support new ConnectorV3 API; PPO runs in single-agent mode in this API stack
This PR:

  • Adds a new config key: train_batch_size_per_learner to better distinguish between total effective batch size and batch size per (GPU) learner worker.
  • Makes large changes to the PPO algorithm when run with the new API stack + EnvRunners:
    • Forwards episode data directly from EnvRunner(s) to Learner worker(s) w/o having to form a MultiAgentBatch first.
    • Removes need for PPO's forward_exploration to perform a value-function pass. This is an essential improvement in code quality as we now have full separation between the sampling- and the learning worlds. The EnvRunner (sampling world) is no longer concerned with having to think about what the PPOLearner (learning world) might need and only needs to compute actions for the next env step.
    • All vf-computations, GAE, and advantages computations have been moved to the Learner side and these operations are now performed in a batched fashion (on all provided episodes at once). Having the episodes still intact on the Learner side helps reducing the complexity of these computations.

Benchmark results:
Learns Pong in ~5min via examples/connectors/connector_v2_frame_stacking.py example script:

Args: --num-gpus=8 --num-env-runners=95 --framework=torch

on commit: 790a537

Trial status: 1 RUNNING
Current time: 2024-01-12 12:41:55. Total running time: 7min 0s
Logical resource usage: 96.0/96 CPUs, 8.0/8 GPUs (0.0/1.0 accelerator_type:V100)
╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Trial name            status       iter     total time (s)       ts     reward     episode_reward_max     episode_reward_min     episode_len_mean     episodes_this_iter │
├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ PPO_env_0b2b7_00000   RUNNING       226             362.71   904000      19.62                     21                      9              1728.96                      0 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…runner_support_connectors_04_learner_api_changes
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…runner_support_connectors_04_learner_api_changes
Signed-off-by: sven1977 <[email protected]>
…runner_support_connectors_04_learner_api_changes
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…runner_support_connectors_04_learner_api_changes
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 added do-not-merge Do not merge this PR! rllib-newstack rllib-oldstack-cleanup Issues related to cleaning up classes, utilities on the old API stack labels Jan 9, 2024
@@ -550,24 +638,3 @@ def training_step(self) -> ResultDict:
self.workers.local_worker().set_global_vars(global_vars)

return train_results

def postprocess_episodes(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed here. Episodes are sent directly to Learner(s) as-is.

@@ -39,6 +47,78 @@ def build(self) -> None:
)
)

@override(Learner)
def _preprocess_train_data(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Only called on the new API stack + EnvRunners.

if not episodes:
return batch, episodes

# Make all episodes one ts longer in order to just have a single batch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New way to do GAE:

  • elongate all episodes by one artificial ts.
  • perform vf-predictions AND bootstrap value predictions in one single batch (b/c we have the extra timestep!)
    • use the learner connector to make sure this forward pass is done using the correct (custom?) batch format.
  • remove extra timesteps from episodes (and computed advantages)

SampleBatch.VF_PREDS,
SampleBatch.ACTION_DIST_INPUTS,
]
return self.output_specs_inference()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified

@@ -40,6 +40,11 @@ def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]:
the policy distribution to be used for computing KL divergence between the old
policy and the new policy during training.
"""
# TODO (sven): Make this the only bahevior once PPO has been migrated
# to new API stack (including EnvRunners!).
if self.config.model_config_dict.get("uses_new_env_runners"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporary hack to make sure RLModule knows, when it still has to compute vf-preds via forward_exploration (old and hybrid API stacks).

@@ -272,6 +281,40 @@ def __init__(
# the final results dict in the `self.compile_update_results()` method.
self._metrics = defaultdict(dict)

@OverrideToImplementCustomLogic_CallToSuperRecommended
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved here for better ordering of methods (used to be all the way at the bottom of class).


# Build learner connector pipeline used on this Learner worker.
# TODO (sven): Support multi-agent cases.
if self.config.uses_new_env_runners and not self.config.is_multi_agent():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now: Only on new API stack + EnvRunner + single-agent: use Learner connector (w/o this PPO on new stack would not learn).

@@ -87,7 +86,13 @@ def __iter__(self):
def get_len(b):
return len(b[SampleBatch.SEQ_LENS])

n_steps = int(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug fix. When slicing on a BxT batch, we should slice properly along B-axis (with the correct slice size!).

return value

data = tree.map_structure(map_, self)
infos = self.pop(SampleBatch.INFOS, None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplifications.

# we return the values here and slice them separately
# TODO(Artur): Clean this hack up.
return value
return value[start_padded:stop_padded]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplifications.

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stmp

Signed-off-by: sven1977 <[email protected]>
…runner_support_connectors_05_ppo_w_connectorv2s
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…runner_support_connectors_05_ppo_w_connectorv2s
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 merged commit e03dd6e into ray-project:master Jan 19, 2024
9 checks passed
@Mark2000
Copy link
Contributor

Mark2000 commented May 2, 2024

@sven1977 Could you speak more to why GAE support was dropped for APPO in this release?

        # IMPALA and APPO need vtrace (A3C Policies no longer exist).
         if not self.vtrace:
             raise ValueError(
                 "IMPALA and APPO do NOT support vtrace=False anymore! Set "
                 "`config.training(vtrace=True)`."
             )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rllib RLlib related issues rllib-newstack rllib-oldstack-cleanup Issues related to cleaning up classes, utilities on the old API stack
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants