Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix train_batch_size_per_learner problems. #49715

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.grad_clip = None
self.grad_clip_by = "global_norm"
# Simple logic for now: If None, use `train_batch_size`.
self.train_batch_size_per_learner = None
self._train_batch_size_per_learner = None
self.train_batch_size = 32 # @OldAPIStack

# These setting have been adopted from the original PPO batch settings:
Expand Down Expand Up @@ -2324,7 +2324,7 @@ def training(
)
self.grad_clip_by = grad_clip_by
if train_batch_size_per_learner is not NotProvided:
self.train_batch_size_per_learner = train_batch_size_per_learner
self._train_batch_size_per_learner = train_batch_size_per_learner
if train_batch_size is not NotProvided:
self.train_batch_size = train_batch_size
if num_epochs is not NotProvided:
Expand Down Expand Up @@ -3763,14 +3763,26 @@ def rl_module_spec(self):
return default_rl_module_spec

@property
def total_train_batch_size(self):
if (
self.train_batch_size_per_learner is not None
and self.enable_rl_module_and_learner
):
return self.train_batch_size_per_learner * (self.num_learners or 1)
else:
return self.train_batch_size
def train_batch_size_per_learner(self) -> int:
# If not set explicitly, try to infer the value.
if self._train_batch_size_per_learner is None:
return self.train_batch_size // (self.num_learners or 1)
return self._train_batch_size_per_learner

@train_batch_size_per_learner.setter
def train_batch_size_per_learner(self, value: int) -> None:
self._train_batch_size_per_learner = value

@property
def total_train_batch_size(self) -> int:
"""Returns the effective total train batch size.

New API stack: `train_batch_size_per_learner` * [effective num Learners].

@OldAPIStack: User never touches `train_batch_size_per_learner` or
`num_learners`) -> `train_batch_size`.
"""
return self.train_batch_size_per_learner * (self.num_learners or 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we return here not the private attribute self.train_batch_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B/c self.train_batch_size is old API stack. So we should no longer reference it anywhere in the new API stack logic.


# TODO: Make rollout_fragment_length as read-only property and replace the current
# self.rollout_fragment_length a private variable.
Expand Down Expand Up @@ -3905,18 +3917,11 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
asking the user to set rollout_fragment_length to `auto` or to a matching
value.

Also, only checks this if `train_batch_size` > 0 (DDPPO sets this
to -1 to auto-calculate the actual batch size later).

Raises:
ValueError: If there is a mismatch between user provided
`rollout_fragment_length` and `total_train_batch_size`.
"""
if (
self.rollout_fragment_length != "auto"
and not self.in_evaluation
and self.total_train_batch_size > 0
):
if self.rollout_fragment_length != "auto" and not self.in_evaluation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, is "auto" now the only configuration possible?

min_batch_size = (
max(self.num_env_runners, 1)
* self.num_envs_per_env_runner
Expand Down
3 changes: 1 addition & 2 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
config.target_network_update_freq
* config.circular_buffer_num_batches
* config.circular_buffer_iterations_per_batch
* config.total_train_batch_size
/ (config.num_learners or 1)
* config.train_batch_size_per_learner
)
):
for (
Expand Down
Loading