diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 38e469d7559fe..e50b2285a4c11 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -368,7 +368,7 @@ def __init__(self, algo_class: Optional[type] = None): self.grad_clip = None self.grad_clip_by = "global_norm" # Simple logic for now: If None, use `train_batch_size`. - self.train_batch_size_per_learner = None + self._train_batch_size_per_learner = None self.train_batch_size = 32 # @OldAPIStack # These setting have been adopted from the original PPO batch settings: @@ -2324,7 +2324,7 @@ def training( ) self.grad_clip_by = grad_clip_by if train_batch_size_per_learner is not NotProvided: - self.train_batch_size_per_learner = train_batch_size_per_learner + self._train_batch_size_per_learner = train_batch_size_per_learner if train_batch_size is not NotProvided: self.train_batch_size = train_batch_size if num_epochs is not NotProvided: @@ -3763,14 +3763,26 @@ def rl_module_spec(self): return default_rl_module_spec @property - def total_train_batch_size(self): - if ( - self.train_batch_size_per_learner is not None - and self.enable_rl_module_and_learner - ): - return self.train_batch_size_per_learner * (self.num_learners or 1) - else: - return self.train_batch_size + def train_batch_size_per_learner(self) -> int: + # If not set explicitly, try to infer the value. + if self._train_batch_size_per_learner is None: + return self.train_batch_size // (self.num_learners or 1) + return self._train_batch_size_per_learner + + @train_batch_size_per_learner.setter + def train_batch_size_per_learner(self, value: int) -> None: + self._train_batch_size_per_learner = value + + @property + def total_train_batch_size(self) -> int: + """Returns the effective total train batch size. + + New API stack: `train_batch_size_per_learner` * [effective num Learners]. + + @OldAPIStack: User never touches `train_batch_size_per_learner` or + `num_learners`) -> `train_batch_size`. + """ + return self.train_batch_size_per_learner * (self.num_learners or 1) # TODO: Make rollout_fragment_length as read-only property and replace the current # self.rollout_fragment_length a private variable. @@ -3905,18 +3917,11 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None: asking the user to set rollout_fragment_length to `auto` or to a matching value. - Also, only checks this if `train_batch_size` > 0 (DDPPO sets this - to -1 to auto-calculate the actual batch size later). - Raises: ValueError: If there is a mismatch between user provided `rollout_fragment_length` and `total_train_batch_size`. """ - if ( - self.rollout_fragment_length != "auto" - and not self.in_evaluation - and self.total_train_batch_size > 0 - ): + if self.rollout_fragment_length != "auto" and not self.in_evaluation: min_batch_size = ( max(self.num_env_runners, 1) * self.num_envs_per_env_runner diff --git a/rllib/algorithms/appo/appo_learner.py b/rllib/algorithms/appo/appo_learner.py index db59871b131c5..235bc823209f8 100644 --- a/rllib/algorithms/appo/appo_learner.py +++ b/rllib/algorithms/appo/appo_learner.py @@ -99,8 +99,7 @@ def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None: config.target_network_update_freq * config.circular_buffer_num_batches * config.circular_buffer_iterations_per_batch - * config.total_train_batch_size - / (config.num_learners or 1) + * config.train_batch_size_per_learner ) ): for (