diff --git a/rllib/BUILD b/rllib/BUILD index 1f3ef0c79a47c..1617e9467ffd8 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -150,16 +150,6 @@ py_test( # -------------------------------------------------------------------- # APPO -py_test( - name = "learning_tests_cartpole_appo_no_vtrace", - main = "tests/run_regression_tests.py", - tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"], - size = "medium", # bazel may complain about it being too long sometimes - medium is on purpose as some frameworks take longer - srcs = ["tests/run_regression_tests.py"], - data = ["tuned_examples/appo/cartpole-appo.yaml"], - args = ["--dir=tuned_examples/appo"] -) - py_test( name = "learning_tests_cartpole_appo_w_rl_modules_and_learner", main = "tests/run_regression_tests.py", @@ -177,7 +167,7 @@ py_test( size = "medium", srcs = ["tests/run_regression_tests.py"], data = [ - "tuned_examples/appo/cartpole-appo-vtrace-separate-losses.py" + "tuned_examples/appo/cartpole-appo-separate-losses.py" ], args = ["--dir=tuned_examples/appo"] ) @@ -208,17 +198,17 @@ py_test( tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"], size = "medium", srcs = ["tests/run_regression_tests.py"], - data = ["tuned_examples/appo/cartpole-appo-vtrace-fake-gpus.yaml"], + data = ["tuned_examples/appo/cartpole-appo-fake-gpus.yaml"], args = ["--dir=tuned_examples/appo"] ) py_test( - name = "learning_tests_stateless_cartpole_appo_vtrace", + name = "learning_tests_stateless_cartpole_appo", main = "tests/run_regression_tests.py", tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"], size = "enormous", srcs = ["tests/run_regression_tests.py"], - data = ["tuned_examples/appo/stateless-cartpole-appo-vtrace.py"], + data = ["tuned_examples/appo/stateless_cartpole_appo.py"], args = ["--dir=tuned_examples/appo"] ) @@ -1453,6 +1443,13 @@ py_test( srcs = ["utils/exploration/tests/test_explorations.py"] ) +py_test( + name = "test_value_predictions", + tags = ["team:rllib", "utils"], + size = "small", + srcs = ["utils/postprocessing/tests/test_value_predictions.py"] +) + py_test( name = "test_random_encoder", tags = ["team:rllib", "utils"], @@ -1461,7 +1458,7 @@ py_test( ) py_test( - name = "utils/tests/test_torch_utils", + name = "test_torch_utils", tags = ["team:rllib", "utils", "gpu"], size = "medium", srcs = ["utils/tests/test_torch_utils.py"] diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index aafdda3503cc7..44e0e8494c968 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -25,6 +25,7 @@ Set, Tuple, Type, + TYPE_CHECKING, Union, ) @@ -46,7 +47,6 @@ collect_metrics, summarize_episodes, ) -from ray.rllib.evaluation.postprocessing_v2 import postprocess_episodes_to_sample_batch from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step @@ -129,6 +129,8 @@ from ray.util.timer import _Timer from ray.tune.registry import get_trainable_cls +if TYPE_CHECKING: + from ray.rllib.core.learner.learner_group import LearnerGroup try: from ray.rllib.extensions import AlgorithmBase @@ -449,6 +451,9 @@ def __init__( # Placeholder for a local replay buffer instance. self.local_replay_buffer = None + # Placeholder for our LearnerGroup responsible for updating the RLModule(s). + self.learner_group: Optional["LearnerGroup"] = None + # Create a default logger creator if no logger_creator is specified if logger_creator is None: # Default logdir prefix containing the agent's name and the @@ -1410,7 +1415,12 @@ def remote_fn(worker): worker.set_weights( weights=ray.get(weights_ref), weights_seq_no=weights_seq_no ) - episodes = worker.sample(explore=False) + # By episode: Run always only one episode per remote call. + # By timesteps: By default EnvRunner runs for the configured number of + # timesteps (based on `rollout_fragment_length` and `num_envs_per_worker`). + episodes = worker.sample( + explore=False, num_episodes=1 if unit == "episodes" else None + ) metrics = worker.get_metrics() return episodes, metrics, weights_seq_no @@ -1449,11 +1459,13 @@ def remote_fn(worker): rollout_metrics.extend(metrics) i += 1 - # Convert our list of Episodes to a single SampleBatch. - batch = postprocess_episodes_to_sample_batch(episodes) # Collect steps stats. - _agent_steps = batch.agent_steps() - _env_steps = batch.env_steps() + # TODO (sven): Solve for proper multi-agent env/agent steps counting. + # Once we have multi-agent support on EnvRunner stack, we can simply do: + # `len(episode)` for env steps and `episode.num_agent_steps()` for agent + # steps. + _agent_steps = sum(len(e) for e in episodes) + _env_steps = sum(len(e) for e in episodes) # Only complete episodes done by eval workers. if unit == "episodes": @@ -1467,6 +1479,7 @@ def remote_fn(worker): ) if self.reward_estimators: + batch = concat_samples([e.get_sample_batch() for e in episodes]) all_batches.append(batch) agent_steps_this_iter += _agent_steps diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 4700f13cab5b1..84daeae40ec95 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -363,6 +363,8 @@ def __init__(self, algo_class=None): self.grad_clip = None self.grad_clip_by = "global_norm" self.train_batch_size = 32 + # Simple logic for now: If None, use `train_batch_size`. + self.train_batch_size_per_learner = None # TODO (sven): Unsolved problem with RLModules sometimes requiring settings from # the main AlgorithmConfig. We should not require the user to provide those # settings in both, the AlgorithmConfig (as property) AND the model config @@ -871,6 +873,7 @@ def build_env_to_module_connector(self, env): return pipeline def build_module_to_env_connector(self, env): + from ray.rllib.connectors.module_to_env import ( DefaultModuleToEnv, ModuleToEnvPipeline, @@ -1333,11 +1336,11 @@ def environment( Tuple[value1, value2]: Clip at value1 and value2. normalize_actions: If True, RLlib will learn entirely inside a normalized action space (0.0 centered with small stddev; only affecting Box - components). We will unsquash actions (and clip, just in case) to the + components). RLlib will unsquash actions (and clip, just in case) to the bounds of the env's action space before sending actions back to the env. - clip_actions: If True, RLlib will clip actions according to the env's bounds - before sending them back to the env. - TODO: (sven) This option should be deprecated and always be False. + clip_actions: If True, the RLlib default ModuleToEnv connector will clip + actions according to the env's bounds (before sending them into the + `env.step()` call). disable_env_checking: If True, disable the environment pre-checking module. is_atari: This config can be used to explicitly specify whether the env is an Atari env or not. If not specified, RLlib will try to auto-detect @@ -1678,6 +1681,7 @@ def training( grad_clip: Optional[float] = NotProvided, grad_clip_by: Optional[str] = NotProvided, train_batch_size: Optional[int] = NotProvided, + train_batch_size_per_learner: Optional[int] = NotProvided, model: Optional[dict] = NotProvided, optimizer: Optional[dict] = NotProvided, max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided, @@ -1726,7 +1730,16 @@ def training( the shapes of these tensors are). grad_clip_by: See `grad_clip` for the effect of this setting on gradient clipping. Allowed values are `value`, `norm`, and `global_norm`. - train_batch_size: Training batch size, if applicable. + train_batch_size_per_learner: Train batch size per individual Learner + worker. This setting only applies to the new API stack. The number + of Learner workers can be set via `config.resources( + num_learner_workers=...)`. The total effective batch size is then + `num_learner_workers` x `train_batch_size_per_learner` and can + be accessed via the property `AlgorithmConfig.total_train_batch_size`. + train_batch_size: Training batch size, if applicable. When on the new API + stack, this setting should no longer be used. Instead, use + `train_batch_size_per_learner` (in combination with + `num_learner_workers`). model: Arguments passed into the policy model. See models/catalog.py for a full list of the available model options. TODO: Provide ModelConfig objects instead of dicts. @@ -1766,6 +1779,8 @@ def training( "or 'global_norm'!" ) self.grad_clip_by = grad_clip_by + if train_batch_size_per_learner is not NotProvided: + self.train_batch_size_per_learner = train_batch_size_per_learner if train_batch_size is not NotProvided: self.train_batch_size = train_batch_size if model is not NotProvided: @@ -2716,20 +2731,29 @@ def uses_new_env_runners(self): self.env_runner_cls, RolloutWorker ) + @property + def total_train_batch_size(self): + if self.train_batch_size_per_learner is not None: + return self.train_batch_size_per_learner * (self.num_learner_workers or 1) + else: + return self.train_batch_size + + # TODO: Make rollout_fragment_length as read-only property and replace the current + # self.rollout_fragment_length a private variable. def get_rollout_fragment_length(self, worker_index: int = 0) -> int: """Automatically infers a proper rollout_fragment_length setting if "auto". Uses the simple formula: - `rollout_fragment_length` = `train_batch_size` / + `rollout_fragment_length` = `total_train_batch_size` / (`num_envs_per_worker` * `num_rollout_workers`) If result is a fraction AND `worker_index` is provided, will make those workers add additional timesteps, such that the overall batch size (across - the workers) will add up to exactly the `train_batch_size`. + the workers) will add up to exactly the `total_train_batch_size`. Returns: The user-provided `rollout_fragment_length` or a computed one (if user - provided value is "auto"), making sure `train_batch_size` is reached + provided value is "auto"), making sure `total_train_batch_size` is reached exactly in each iteration. """ if self.rollout_fragment_length == "auto": @@ -2739,11 +2763,11 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int: # 4 workers, 3 envs per worker, 2500 train batch size: # -> 2500 / 12 -> 208.333 -> diff=4 (208 * 12 = 2496) # -> worker 1: 209, workers 2-4: 208 - rollout_fragment_length = self.train_batch_size / ( + rollout_fragment_length = self.total_train_batch_size / ( self.num_envs_per_worker * (self.num_rollout_workers or 1) ) if int(rollout_fragment_length) != rollout_fragment_length: - diff = self.train_batch_size - int( + diff = self.total_train_batch_size - int( rollout_fragment_length ) * self.num_envs_per_worker * (self.num_rollout_workers or 1) if (worker_index * self.num_envs_per_worker) <= diff: @@ -3095,12 +3119,12 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None: Raises: ValueError: If there is a mismatch between user provided - `rollout_fragment_length` and `train_batch_size`. + `rollout_fragment_length` and `total_train_batch_size`. """ if ( self.rollout_fragment_length != "auto" and not self.in_evaluation - and self.train_batch_size > 0 + and self.total_train_batch_size > 0 ): min_batch_size = ( max(self.num_rollout_workers, 1) @@ -3108,23 +3132,25 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None: * self.rollout_fragment_length ) batch_size = min_batch_size - while batch_size < self.train_batch_size: + while batch_size < self.total_train_batch_size: batch_size += min_batch_size - if ( - batch_size - self.train_batch_size > 0.1 * self.train_batch_size - or batch_size - min_batch_size - self.train_batch_size - > (0.1 * self.train_batch_size) + if batch_size - self.total_train_batch_size > ( + 0.1 * self.total_train_batch_size + ) or batch_size - min_batch_size - self.total_train_batch_size > ( + 0.1 * self.total_train_batch_size ): - suggested_rollout_fragment_length = self.train_batch_size // ( + suggested_rollout_fragment_length = self.total_train_batch_size // ( self.num_envs_per_worker * (self.num_rollout_workers or 1) ) raise ValueError( - f"Your desired `train_batch_size` ({self.train_batch_size}) or a " - "value 10% off of that cannot be achieved with your other " + "Your desired `total_train_batch_size` " + f"({self.total_train_batch_size}={self.num_learner_workers} " + f"learners x {self.train_batch_size_per_learner}) " + "or a value 10% off of that cannot be achieved with your other " f"settings (num_rollout_workers={self.num_rollout_workers}; " f"num_envs_per_worker={self.num_envs_per_worker}; " f"rollout_fragment_length={self.rollout_fragment_length})! " - "Try setting `rollout_fragment_length` to 'auto' OR " + "Try setting `rollout_fragment_length` to 'auto' OR to a value of " f"{suggested_rollout_fragment_length}." ) @@ -3580,8 +3606,7 @@ def _validate_evaluation_settings(self): """Checks, whether evaluation related settings make sense.""" if ( self.evaluation_interval - and self.env_runner_cls is not None - and not issubclass(self.env_runner_cls, RolloutWorker) + and self.uses_new_env_runners and not self.enable_async_evaluation ): raise ValueError( diff --git a/rllib/algorithms/appo/tests/test_appo.py b/rllib/algorithms/appo/tests/test_appo.py index 1a660cddd47df..5bd84f22efe69 100644 --- a/rllib/algorithms/appo/tests/test_appo.py +++ b/rllib/algorithms/appo/tests/test_appo.py @@ -26,19 +26,6 @@ def test_appo_compilation(self): num_iterations = 2 for _ in framework_iterator(config): - print("w/o v-trace") - config.vtrace = False - algo = config.build(env="CartPole-v1") - for i in range(num_iterations): - results = algo.train() - print(results) - check_train_results(results) - - check_compute_single_action(algo) - algo.stop() - - print("w/ v-trace") - config.vtrace = True algo = config.build(env="CartPole-v1") for i in range(num_iterations): results = algo.train() diff --git a/rllib/algorithms/appo/tf/appo_tf_learner.py b/rllib/algorithms/appo/tf/appo_tf_learner.py index 33bf2099035b4..420e3339b0103 100644 --- a/rllib/algorithms/appo/tf/appo_tf_learner.py +++ b/rllib/algorithms/appo/tf/appo_tf_learner.py @@ -8,9 +8,9 @@ OLD_ACTION_DIST_LOGITS_KEY, ) from ray.rllib.algorithms.appo.appo_learner import AppoLearner +from ray.rllib.algorithms.impala.tf.impala_tf_learner import ImpalaTfLearner from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2 from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY -from ray.rllib.core.learner.tf.tf_learner import TfLearner from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.nested_dict import NestedDict @@ -19,10 +19,10 @@ _, tf, _ = try_import_tf() -class APPOTfLearner(AppoLearner, TfLearner): +class APPOTfLearner(AppoLearner, ImpalaTfLearner): """Implements APPO loss / update logic on top of ImpalaTfLearner.""" - @override(TfLearner) + @override(ImpalaTfLearner) def compute_loss_for_module( self, *, @@ -72,12 +72,15 @@ def compute_loss_for_module( trajectory_len=rollout_frag_or_episode_len, recurrent_seq_len=recurrent_seq_len, ) - bootstrap_values_time_major = make_time_major( - batch[SampleBatch.VALUES_BOOTSTRAPPED], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=recurrent_seq_len, - ) - bootstrap_value = bootstrap_values_time_major[-1] + if self.config.uses_new_env_runners: + bootstrap_values = batch[SampleBatch.VALUES_BOOTSTRAPPED] + else: + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=recurrent_seq_len, + ) + bootstrap_values = bootstrap_values_time_major[-1] # The discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. @@ -100,7 +103,7 @@ def compute_loss_for_module( discounts=discounts_time_major, rewards=rewards_time_major, values=values_time_major, - bootstrap_value=bootstrap_value, + bootstrap_values=bootstrap_values, clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold, clip_rho_threshold=config.vtrace_clip_rho_threshold, ) diff --git a/rllib/algorithms/appo/torch/appo_torch_learner.py b/rllib/algorithms/appo/torch/appo_torch_learner.py index dd8c9ba44451d..bccde22807b03 100644 --- a/rllib/algorithms/appo/torch/appo_torch_learner.py +++ b/rllib/algorithms/appo/torch/appo_torch_learner.py @@ -8,6 +8,7 @@ OLD_ACTION_DIST_LOGITS_KEY, ) from ray.rllib.algorithms.appo.appo_learner import AppoLearner +from ray.rllib.algorithms.impala.torch.impala_torch_learner import ImpalaTorchLearner from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( make_time_major, vtrace_torch, @@ -30,10 +31,10 @@ torch, nn = try_import_torch() -class APPOTorchLearner(AppoLearner, TorchLearner): +class APPOTorchLearner(AppoLearner, ImpalaTorchLearner): """Implements APPO loss / update logic on top of ImpalaTorchLearner.""" - @override(TorchLearner) + @override(ImpalaTorchLearner) def compute_loss_for_module( self, *, @@ -86,12 +87,15 @@ def compute_loss_for_module( trajectory_len=rollout_frag_or_episode_len, recurrent_seq_len=recurrent_seq_len, ) - bootstrap_values_time_major = make_time_major( - batch[SampleBatch.VALUES_BOOTSTRAPPED], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=recurrent_seq_len, - ) - bootstrap_value = bootstrap_values_time_major[-1] + if self.config.uses_new_env_runners: + bootstrap_values = batch[SampleBatch.VALUES_BOOTSTRAPPED] + else: + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=recurrent_seq_len, + ) + bootstrap_values = bootstrap_values_time_major[-1] # The discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. @@ -111,7 +115,7 @@ def compute_loss_for_module( discounts=discounts_time_major, rewards=rewards_time_major, values=values_time_major, - bootstrap_value=bootstrap_value, + bootstrap_values=bootstrap_values, clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold, clip_rho_threshold=config.vtrace_clip_rho_threshold, ) diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index cc2c2ea6f0798..2d7a653fe210e 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -170,7 +170,6 @@ def __init__(self, algo_class=None): # Deprecated value. self.num_data_loader_buffers = DEPRECATED_VALUE - self.vtrace_drop_last_ts = DEPRECATED_VALUE @override(AlgorithmConfig) def training( @@ -205,8 +204,6 @@ def training( _separate_vf_optimizer: Optional[bool] = NotProvided, _lr_vf: Optional[float] = NotProvided, after_train_step: Optional[Callable[[dict], None]] = NotProvided, - # deprecated. - vtrace_drop_last_ts=None, **kwargs, ) -> "ImpalaConfig": """Sets the training related configuration. @@ -289,16 +286,6 @@ def training( Returns: This updated AlgorithmConfig object. """ - if vtrace_drop_last_ts is not None: - deprecation_warning( - old="vtrace_drop_last_ts", - help="The v-trace operations in RLlib have been enhanced and we are " - "now using proper value bootstrapping at the end of each " - "trajectory, such that no timesteps in our loss functions have to " - "be dropped anymore.", - error=True, - ) - # Pass kwargs onto super's `training()` method. super().training(**kwargs) @@ -370,6 +357,30 @@ def validate(self) -> None: # Call the super class' validation method first. super().validate() + # IMPALA and APPO need vtrace (A3C Policies no longer exist). + if not self.vtrace: + raise ValueError( + "IMPALA and APPO do NOT support vtrace=False anymore! Set " + "`config.training(vtrace=True)`." + ) + + # New stack w/ EnvRunners does NOT support aggregation workers yet or a mixin + # replay buffer. + if self.uses_new_env_runners: + if self.num_aggregation_workers > 0: + raise ValueError( + "Aggregation workers not supported on new API stack w/ new " + "EnvRunner API! Set `config.num_aggregation_workers = 0` or " + "disable the new API stack via " + "`config.experimental(_enable_new_api_stack=False)`." + ) + if self.replay_ratio != 0.0: + raise ValueError( + "The new API stack in combination with the new EnvRunner API " + "does NOT support a mixin replay buffer yet for " + f"{self} (set `config.replay_proportion` to 0.0)!" + ) + if self.num_data_loader_buffers != DEPRECATED_VALUE: deprecation_warning( "num_data_loader_buffers", "num_multi_gpu_tower_stacks", error=True @@ -416,17 +427,21 @@ def validate(self) -> None: "config.training(_tf_policy_handles_more_than_one_loss=True)." ) # Learner API specific checks. - if self._enable_new_api_stack: - if not ( + if ( + self._enable_new_api_stack + and self._minibatch_size != "auto" + and not ( (self.minibatch_size % self.rollout_fragment_length == 0) - and self.minibatch_size <= self.train_batch_size - ): - raise ValueError( - f"`minibatch_size` ({self._minibatch_size}) must either be 'auto' " - "or a multiple of `rollout_fragment_length` " - f"({self.rollout_fragment_length}) while at the same time smaller " - f"than or equal to `train_batch_size` ({self.train_batch_size})!" - ) + and self.minibatch_size <= self.total_train_batch_size + ) + ): + raise ValueError( + f"`minibatch_size` ({self._minibatch_size}) must either be 'auto' " + "or a multiple of `rollout_fragment_length` " + f"({self.rollout_fragment_length}) while at the same time smaller " + "than or equal to `total_train_batch_size` " + f"({self.total_train_batch_size})!" + ) @property def replay_ratio(self) -> float: @@ -441,7 +456,11 @@ def minibatch_size(self): # If 'auto', use the train_batch_size (meaning each SGD iter is a single pass # through the entire train batch). Otherwise, use user provided setting. return ( - self.train_batch_size + ( + self.train_batch_size_per_learner + if self.uses_new_env_runners + else self.train_batch_size + ) if self._minibatch_size == "auto" else self._minibatch_size ) @@ -554,42 +573,25 @@ def get_default_config(cls) -> AlgorithmConfig: def get_default_policy_class( cls, config: AlgorithmConfig ) -> Optional[Type[Policy]]: - if not config["vtrace"]: - raise ValueError("IMPALA with the learner API does not support non-VTrace ") - - if config["framework"] == "torch": - if config["vtrace"]: - from ray.rllib.algorithms.impala.impala_torch_policy import ( - ImpalaTorchPolicy, - ) - - return ImpalaTorchPolicy - else: - from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy + if config.framework_str == "torch": + from ray.rllib.algorithms.impala.impala_torch_policy import ( + ImpalaTorchPolicy, + ) - return A3CTorchPolicy - elif config["framework"] == "tf": - if config["vtrace"]: - from ray.rllib.algorithms.impala.impala_tf_policy import ( - ImpalaTF1Policy, - ) + return ImpalaTorchPolicy - return ImpalaTF1Policy - else: - from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy + elif config.framework_str == "tf": + from ray.rllib.algorithms.impala.impala_tf_policy import ( + ImpalaTF1Policy, + ) - return A3CTFPolicy + return ImpalaTF1Policy else: - if config["vtrace"]: - from ray.rllib.algorithms.impala.impala_tf_policy import ( - ImpalaTF2Policy, - ) - - return ImpalaTF2Policy - else: - from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy + from ray.rllib.algorithms.impala.impala_tf_policy import ( + ImpalaTF2Policy, + ) - return A3CTFPolicy + return ImpalaTF2Policy @override(Algorithm) def setup(self, config: AlgorithmConfig): @@ -653,7 +655,6 @@ def setup(self, config: AlgorithmConfig): # This variable is used to keep track of the statistics from the most recent # update of the learner group self._results = {} - self._timeout_s_sampler_manager = self.config.timeout_s_sampler_manager if not self.config._enable_new_api_stack: # Create and start the learner thread. @@ -845,7 +846,7 @@ def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]) -> None: def aggregate_into_larger_batch(): if ( sum(b.count for b in self.batch_being_built) - >= self.config.train_batch_size + >= self.config.total_train_batch_size ): batch_to_add = concat_samples(self.batch_being_built) self.batches_to_place_on_learner.append(batch_to_add) @@ -915,7 +916,7 @@ def get_samples_from_workers( sample_batches: List[ Tuple[int, ObjectRef] ] = self.workers.fetch_ready_async_reqs( - timeout_seconds=self._timeout_s_sampler_manager, + timeout_seconds=self.config.timeout_s_sampler_manager, return_obj_refs=return_object_refs, ) elif ( @@ -1270,10 +1271,10 @@ def _reduce_impala_results(results: List[ResultDict]) -> ResultDict: steps trained (on all modules). Args: - results: result dicts to reduce. + results: List of results dicts to be reduced. Returns: - A reduced result dict. + Final reduced results dict. """ result = tree.map_structure(lambda *x: np.mean(x), *results) agent_steps_trained = sum(r[ALL_MODULES][NUM_AGENT_STEPS_TRAINED] for r in results) diff --git a/rllib/algorithms/impala/impala_learner.py b/rllib/algorithms/impala/impala_learner.py index 549e22f8515da..1801d400b46be 100644 --- a/rllib/algorithms/impala/impala_learner.py +++ b/rllib/algorithms/impala/impala_learner.py @@ -1,12 +1,24 @@ +import abc from typing import Any, Dict +import numpy as np + from ray.rllib.algorithms.impala.impala import ( ImpalaConfig, LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, ) from ray.rllib.core.learner.learner import Learner +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.postprocessing.episodes import ( + add_one_ts_to_episodes_and_truncate, + remove_last_ts_from_data, + remove_last_ts_from_episodes_and_restore_truncateds, +) +from ray.rllib.utils.postprocessing.value_predictions import extract_bootstrapped_values +from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ModuleID @@ -29,6 +41,56 @@ def build(self) -> None: ) ) + @override(Learner) + def _preprocess_train_data( + self, + *, + batch, + episodes, + ): + batch = batch or {} + if not episodes: + return batch, episodes + + # Make all episodes one ts longer in order to just have a single batch + # (and distributed forward pass) for both vf predictions AND the bootstrap + # vf computations. + episode_lens = [len(e) for e in episodes] + orig_truncateds = add_one_ts_to_episodes_and_truncate(episodes) + episode_lens_p1 = [len(e) for e in episodes] + + # Call the learner connector (on the artificially elongated episodes) + # in order to get the batch to pass through the module for vf (and + # bootstrapped vf) computations. + batch_for_vf = self._learner_connector( + rl_module=self.module["default_policy"], # TODO: make multi-agent capable + data={}, + episodes=episodes, + ) + # Perform the value model's forward pass. + vf_preds = convert_to_numpy(self._compute_values(batch_for_vf)) + + # Remove all zero-padding again, if applicable, for the upcoming + # GAE computations. + vf_preds = unpad_data_if_necessary(episode_lens_p1, vf_preds) + # Generate the bootstrap value column (with only one entry per batch row). + batch[SampleBatch.VALUES_BOOTSTRAPPED] = extract_bootstrapped_values( + vf_preds=vf_preds, + episode_lengths=episode_lens, + T=self.config.get_rollout_fragment_length(), + ) + # Remove the extra timesteps again from vf_preds and value targets. Now that + # the GAE computation is done, we don't need this last timestep anymore in any + # of our data. + batch[SampleBatch.VF_PREDS] = remove_last_ts_from_data( + episode_lens_p1, vf_preds + ) + + # Remove the extra (artificial) timesteps again at the end of all episodes. + remove_last_ts_from_episodes_and_restore_truncateds(episodes, orig_truncateds) + + return batch, episodes + @override(Learner) def remove_module(self, module_id: str): super().remove_module(module_id) @@ -49,3 +111,17 @@ def additional_update_for_module( results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: new_entropy_coeff}) return results + + @abc.abstractmethod + def _compute_values(self, batch) -> np._typing.NDArray: + """Computes the values using the value function module given a batch of data. + + Args: + batch: The input batch to pass through our RLModule (value function + encoder and vf-head). + + Returns: + The batch (numpy) of value function outputs (already squeezed over the last + dimension (which should have shape (1,) b/c of the single value output + node). + """ diff --git a/rllib/algorithms/impala/tests/test_vtrace.py b/rllib/algorithms/impala/tests/test_vtrace.py index c52c3480acc05..6c9a9998b7114 100644 --- a/rllib/algorithms/impala/tests/test_vtrace.py +++ b/rllib/algorithms/impala/tests/test_vtrace.py @@ -46,13 +46,13 @@ def _ground_truth_vtrace_calculation( """Calculates the ground truth for V-trace in Python/Numpy. NOTE: - The discount, log_rhos, rewards, values, and bootstrap_value are all assumed to + The discount, log_rhos, rewards, values, and bootstrap_values are all assumed to come from trajectories of experience. Typically batches of trajectories could be thought of as having the shape [B, T] where B is the batch dimension, and T is the timestep dimension. Computing vtrace returns requires that the data is time major, meaning that it has the shape [T, B]. One can use a function like `make_time_major` to properly format their discount, log_rhos, rewards, values, - and bootstrap_value before calling _ground_truth_vtrace_calculation. + and bootstrap_values before calling _ground_truth_vtrace_calculation. Args: discounts: Array of shape [T*B] of discounts. T is the lenght of the trajectory @@ -64,7 +64,7 @@ def _ground_truth_vtrace_calculation( rewards: Array of shape [T*B] of rewards. values: Array of shape [T*B] of the value function estimated for every timestep in a batch. - bootstrap_value: Array of shape [T] of the value function estimated at the last + bootstrap_values: Array of shape [B] of the value function estimated at the last timestep for each trajectory in the batch. clip_rho_threshold: The threshold for clipping the importance weights. clip_pg_rho_threshold: The threshold for clipping the importance weights for diff --git a/rllib/algorithms/impala/tests/test_vtrace_v2.py b/rllib/algorithms/impala/tests/test_vtrace_v2.py index be381b931130e..06808c6d25c69 100644 --- a/rllib/algorithms/impala/tests/test_vtrace_v2.py +++ b/rllib/algorithms/impala/tests/test_vtrace_v2.py @@ -84,7 +84,7 @@ def setUpClass(cls): values = value_fn_space_w_time.sample() # this is supposed to be the value function at the last timestep of each # trajectory in the batch. In IMPALA its bootstrapped at training time - cls.bootstrap_value = np.array(value_fn_space.sample() + 1.0) + cls.bootstrap_values = np.array(value_fn_space.sample() + 1.0) # discount factor used at all of the timesteps discounts = [0.9 for _ in range(trajectory_len * batch_size)] @@ -117,7 +117,7 @@ def setUpClass(cls): log_rhos=log_rhos, rewards=cls.rewards_time_major, values=cls.values_time_major, - bootstrap_value=cls.bootstrap_value, + bootstrap_value=cls.bootstrap_values, clip_rho_threshold=cls.clip_rho_threshold, clip_pg_rho_threshold=cls.clip_pg_rho_threshold, ) @@ -133,7 +133,7 @@ def test_vtrace_tf2(self): discounts=tf.convert_to_tensor(self.discounts_time_major), rewards=tf.convert_to_tensor(self.rewards_time_major), values=tf.convert_to_tensor(self.values_time_major), - bootstrap_value=tf.convert_to_tensor(self.bootstrap_value), + bootstrap_values=tf.convert_to_tensor(self.bootstrap_values), clip_rho_threshold=self.clip_rho_threshold, clip_pg_rho_threshold=self.clip_pg_rho_threshold, ) @@ -150,7 +150,7 @@ def test_vtrace_torch(self): discounts=convert_to_torch_tensor(self.discounts_time_major), rewards=convert_to_torch_tensor(self.rewards_time_major), values=convert_to_torch_tensor(self.values_time_major), - bootstrap_value=convert_to_torch_tensor(self.bootstrap_value), + bootstrap_values=convert_to_torch_tensor(self.bootstrap_values), clip_rho_threshold=self.clip_rho_threshold, clip_pg_rho_threshold=self.clip_pg_rho_threshold, ) diff --git a/rllib/algorithms/impala/tf/impala_tf_learner.py b/rllib/algorithms/impala/tf/impala_tf_learner.py index 192849be746d5..0ccf67cc2dde4 100644 --- a/rllib/algorithms/impala/tf/impala_tf_learner.py +++ b/rllib/algorithms/impala/tf/impala_tf_learner.py @@ -1,11 +1,13 @@ from typing import Dict +import tree from ray.rllib.algorithms.impala.impala import ImpalaConfig from ray.rllib.algorithms.impala.impala_learner import ImpalaLearner from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2 from ray.rllib.core.learner.learner import ENTROPY_KEY from ray.rllib.core.learner.tf.tf_learner import TfLearner -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.core.models.base import CRITIC, ENCODER_OUT +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.nested_dict import NestedDict @@ -57,13 +59,15 @@ def compute_loss_for_module( trajectory_len=rollout_frag_or_episode_len, recurrent_seq_len=recurrent_seq_len, ) - bootstrap_values_time_major = make_time_major( - batch[SampleBatch.VALUES_BOOTSTRAPPED], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=recurrent_seq_len, - ) - bootstrap_value = bootstrap_values_time_major[-1] - rollout_frag_or_episode_len = config.get_rollout_fragment_length() + if self.config.uses_new_env_runners: + bootstrap_values = batch[SampleBatch.VALUES_BOOTSTRAPPED] + else: + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=recurrent_seq_len, + ) + bootstrap_values = bootstrap_values_time_major[-1] # the discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. @@ -86,7 +90,7 @@ def compute_loss_for_module( discounts=discounts_time_major, rewards=rewards_time_major, values=values_time_major, - bootstrap_value=bootstrap_value, + bootstrap_values=bootstrap_values, clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold, clip_rho_threshold=config.vtrace_clip_rho_threshold, ) @@ -129,3 +133,20 @@ def compute_loss_for_module( ) # Return the total loss. return total_loss + + @override(ImpalaLearner) + def _compute_values(self, batch): + infos = batch.pop(SampleBatch.INFOS, None) + batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch) + if infos is not None: + batch[SampleBatch.INFOS] = infos + + # TODO (sven): Make multi-agent capable. + module = self.module[DEFAULT_POLICY_ID].unwrapped() + + # Shared encoder. + encoder_outs = module.encoder(batch) + # Value head. + vf_out = module.vf(encoder_outs[ENCODER_OUT][CRITIC]) + # Squeeze out last dimension (single node value head). + return tf.squeeze(vf_out, -1) diff --git a/rllib/algorithms/impala/tf/vtrace_tf_v2.py b/rllib/algorithms/impala/tf/vtrace_tf_v2.py index 4aecc6c655490..7b82c92400a03 100644 --- a/rllib/algorithms/impala/tf/vtrace_tf_v2.py +++ b/rllib/algorithms/impala/tf/vtrace_tf_v2.py @@ -59,7 +59,7 @@ def vtrace_tf2( discounts: "tf.Tensor", rewards: "tf.Tensor", values: "tf.Tensor", - bootstrap_value: "tf.Tensor", + bootstrap_values: "tf.Tensor", clip_rho_threshold: Union[float, "tf.Tensor"] = 1.0, clip_pg_rho_threshold: Union[float, "tf.Tensor"] = 1.0, ): @@ -97,7 +97,7 @@ def vtrace_tf2( following the behaviour policy. values: A float32 tensor of shape [T, B] with the value function estimates wrt. the target policy. - bootstrap_value: A float32 of shape [B] with the value function estimate at + bootstrap_values: A float32 of shape [B] with the value function estimate at time T. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). @@ -116,7 +116,7 @@ def vtrace_tf2( cs = tf.minimum(1.0, rhos, name="cs") # Append bootstrapped value to get [v1, ..., v_t+1] values_t_plus_1 = tf.concat( - [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0 + [values[1:], tf.expand_dims(bootstrap_values, 0)], axis=0 ) deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) @@ -135,7 +135,7 @@ def scanfunc(acc, sequence_item): return delta_t + discount_t * c_t * acc with tf.device("/cpu:0"): - initial_values = tf.zeros_like(bootstrap_value) + initial_values = tf.zeros_like(bootstrap_values) vs_minus_v_xs = tf.nest.map_structure( tf.stop_gradient, tf.scan( @@ -153,7 +153,7 @@ def scanfunc(acc, sequence_item): vs = tf.add(vs_minus_v_xs, values) # Advantage for policy gradient. - vs_t_plus_1 = tf.concat([vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) + vs_t_plus_1 = tf.concat([vs[1:], tf.expand_dims(bootstrap_values, 0)], axis=0) if clip_pg_rho_threshold is not None: clipped_pg_rhos = tf.minimum(clip_pg_rho_threshold, rhos) else: diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index 78b86fc142eaa..0b220547f532a 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -8,7 +8,8 @@ ) from ray.rllib.core.learner.learner import ENTROPY_KEY from ray.rllib.core.learner.torch.torch_learner import TorchLearner -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.core.models.base import CRITIC, ENCODER_OUT +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.framework import try_import_torch @@ -66,12 +67,15 @@ def compute_loss_for_module( trajectory_len=rollout_frag_or_episode_len, recurrent_seq_len=recurrent_seq_len, ) - bootstrap_values_time_major = make_time_major( - batch[SampleBatch.VALUES_BOOTSTRAPPED], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=recurrent_seq_len, - ) - bootstrap_value = bootstrap_values_time_major[-1] + if self.config.uses_new_env_runners: + bootstrap_values = batch[SampleBatch.VALUES_BOOTSTRAPPED] + else: + bootstrap_values_time_major = make_time_major( + batch[SampleBatch.VALUES_BOOTSTRAPPED], + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=recurrent_seq_len, + ) + bootstrap_values = bootstrap_values_time_major[-1] # the discount factor that is used should be gamma except for timesteps where # the episode is terminated. In that case, the discount factor should be 0. @@ -95,7 +99,7 @@ def compute_loss_for_module( discounts=discounts_time_major, rewards=rewards_time_major, values=values_time_major, - bootstrap_value=bootstrap_value, + bootstrap_values=bootstrap_values, clip_rho_threshold=config.vtrace_clip_rho_threshold, clip_pg_rho_threshold=config.vtrace_clip_pg_rho_threshold, ) @@ -145,3 +149,21 @@ def compute_loss_for_module( ) # Return the total loss. return total_loss + + @override(ImpalaLearner) + def _compute_values(self, batch): + infos = batch.pop(SampleBatch.INFOS, None) + batch = convert_to_torch_tensor(batch, device=self._device) + # batch = tree.map_structure(lambda s: torch.from_numpy(s), batch) + if infos is not None: + batch[SampleBatch.INFOS] = infos + + # TODO (sven): Make multi-agent capable. + module = self.module[DEFAULT_POLICY_ID].unwrapped() + + # Shared encoder. + encoder_outs = module.encoder(batch) + # Value head. + vf_out = module.vf(encoder_outs[ENCODER_OUT][CRITIC]) + # Squeeze out last dimension (single node value head). + return vf_out.squeeze(-1) diff --git a/rllib/algorithms/impala/torch/vtrace_torch_v2.py b/rllib/algorithms/impala/torch/vtrace_torch_v2.py index 83ba8879d9558..8f8be2a635903 100644 --- a/rllib/algorithms/impala/torch/vtrace_torch_v2.py +++ b/rllib/algorithms/impala/torch/vtrace_torch_v2.py @@ -60,7 +60,7 @@ def vtrace_torch( discounts: "torch.Tensor", rewards: "torch.Tensor", values: "torch.Tensor", - bootstrap_value: "torch.Tensor", + bootstrap_values: "torch.Tensor", clip_rho_threshold: Union[float, "torch.Tensor"] = 1.0, clip_pg_rho_threshold: Union[float, "torch.Tensor"] = 1.0, ): @@ -98,7 +98,7 @@ def vtrace_torch( following the behaviour policy. values: A float32 tensor of shape [T, B] with the value function estimates wrt. the target policy. - bootstrap_value: A float32 of shape [B] with the value function estimate at + bootstrap_values: A float32 of shape [B] with the value function estimate at time T. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). @@ -117,7 +117,7 @@ def vtrace_torch( cs = torch.clamp(rhos, max=1.0) # Append bootstrapped value to get [v1, ..., v_t+1] values_t_plus_1 = torch.cat( - [values[1:], torch.unsqueeze(bootstrap_value, 0)], axis=0 + [values[1:], torch.unsqueeze(bootstrap_values, 0)], axis=0 ) deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) @@ -126,7 +126,7 @@ def vtrace_torch( discounts_cpu = discounts.to("cpu") cs_cpu = cs.to("cpu") deltas_cpu = deltas.to("cpu") - vs_minus_v_xs_cpu = [torch.zeros_like(bootstrap_value, device="cpu")] + vs_minus_v_xs_cpu = [torch.zeros_like(bootstrap_values, device="cpu")] for i in reversed(range(len(discounts_cpu))): discount_t, c_t, delta_t = discounts_cpu[i], cs_cpu[i], deltas_cpu[i] vs_minus_v_xs_cpu.append(delta_t + discount_t * c_t * vs_minus_v_xs_cpu[-1]) @@ -141,7 +141,7 @@ def vtrace_torch( vs = torch.add(vs_minus_v_xs, values) # Advantage for policy gradient. - vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], axis=0) + vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_values, 0)], axis=0) if clip_pg_rho_threshold is not None: clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) else: diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 498baa73a5bc1..96857e98b0b3c 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -18,8 +18,6 @@ from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.evaluation.postprocessing_v2 import postprocess_episodes_to_sample_batch -from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.rollout_ops import ( standardize_fields, synchronous_parallel_sample, @@ -40,7 +38,6 @@ ALL_MODULES, ) from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ResultDict from ray.util.debug import log_once @@ -122,6 +119,8 @@ def __init__(self, algo_class=None): self.kl_coeff = 0.2 self.kl_target = 0.01 self.sgd_minibatch_size = 128 + # Simple logic for now: If None, use `train_batch_size`. + self.mini_batch_size_per_learner = None self.num_sgd_iter = 30 self.shuffle_sequences = True self.vf_loss_coeff = 1.0 @@ -203,6 +202,7 @@ def training( use_kl_loss: Optional[bool] = NotProvided, kl_coeff: Optional[float] = NotProvided, kl_target: Optional[float] = NotProvided, + mini_batch_size_per_learner: Optional[int] = NotProvided, sgd_minibatch_size: Optional[int] = NotProvided, num_sgd_iter: Optional[int] = NotProvided, shuffle_sequences: Optional[bool] = NotProvided, @@ -231,8 +231,20 @@ def training( use_kl_loss: Whether to use the KL-term in the loss function. kl_coeff: Initial coefficient for KL divergence. kl_target: Target value for KL divergence. + mini_batch_size_per_learner: Only use if new API stack is enabled. + The mini batch size per Learner worker. This is the + batch size that each Learner worker's training batch (whose size is + `s`elf.train_batch_size_per_learner`) will be split into. For example, + if the train batch size per Learner worker is 4000 and the mini batch + size per Learner worker is 400, the train batch will be split into 10 + equal sized chunks (or "mini batches"). Each such mini batch will be + used for one SGD update. Overall, the train batch on each Learner + worker will be traversed `self.num_sgd_iter` times. In the above + example, if `self.num_sgd_iter` is 5, we will altogether perform 50 + (10x5) SGD updates per Learner update step. sgd_minibatch_size: Total SGD batch size across all devices for SGD. - This defines the minibatch size within each epoch. + This defines the minibatch size within each epoch. Deprecated on the + new API stack (use `mini_batch_size_per_learner` instead). num_sgd_iter: Number of SGD iterations in each outer loop (i.e., number of epochs to execute per train batch). shuffle_sequences: Whether to shuffle sequences in the batch when training @@ -267,6 +279,8 @@ def training( self.kl_coeff = kl_coeff if kl_target is not NotProvided: self.kl_target = kl_target + if mini_batch_size_per_learner is not NotProvided: + self.mini_batch_size_per_learner = mini_batch_size_per_learner if sgd_minibatch_size is not NotProvided: self.sgd_minibatch_size = sgd_minibatch_size if num_sgd_iter is not NotProvided: @@ -298,7 +312,8 @@ def validate(self) -> None: super().validate() # Synchronous sampling, on-policy/PPO algos -> Check mismatches between - # `rollout_fragment_length` and `train_batch_size` to avoid user confusion. + # `rollout_fragment_length` and `train_batch_size_per_learner` to avoid user + # confusion. # TODO (sven): Make rollout_fragment_length a property and create a private # attribute to store (possibly) user provided value (or "auto") in. Deprecate # `self.get_rollout_fragment_length()`. @@ -307,9 +322,10 @@ def validate(self) -> None: # SGD minibatch size must be smaller than train_batch_size (b/c # we subsample a batch of `sgd_minibatch_size` from the train-batch for # each `num_sgd_iter`). - # Note: Only check this if `train_batch_size` > 0 (DDPPO sets this - # to -1 to auto-calculate the actual batch size later). - if self.sgd_minibatch_size > self.train_batch_size: + if ( + not self._enable_new_api_stack + and self.sgd_minibatch_size > self.train_batch_size + ): raise ValueError( f"`sgd_minibatch_size` ({self.sgd_minibatch_size}) must be <= " f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch" @@ -317,6 +333,16 @@ def validate(self) -> None: f"is iterated over (used for updating the policy) {self.num_sgd_iter} " "times." ) + elif self._enable_new_api_stack: + mbs = self.mini_batch_size_per_learner or self.sgd_minibatch_size + tbs = self.train_batch_size_per_learner or self.train_batch_size + if mbs > tbs: + raise ValueError( + f"`mini_batch_size_per_learner` ({mbs}) must be <= " + f"`train_batch_size_per_learner` ({tbs}). In PPO, the train batch" + f" will be split into {mbs} chunks, each of which is iterated over " + f"(used for updating the policy) {self.num_sgd_iter} times." + ) # Episodes may only be truncated (and passed into PPO's # `postprocessing_fn`), iff generalized advantage estimation is used @@ -376,57 +402,122 @@ def get_default_policy_class( return PPOTF2Policy @override(Algorithm) - def training_step(self) -> ResultDict: - use_rollout_worker = self.config.env_runner_cls is None or issubclass( - self.config.env_runner_cls, RolloutWorker - ) + def training_step(self): + # New API stack (RLModule, Learner, EnvRunner, ConnectorV2). + if self.config.uses_new_env_runners: + return self._training_step_new_api_stack() + # Old and hybrid API stacks (Policy, RolloutWorker, Connector, maybe RLModule, + # maybe Learner). + else: + return self._training_step_old_and_hybrid_api_stacks() + def _training_step_new_api_stack(self) -> ResultDict: # Collect SampleBatches from sample workers until we have a full batch. with self._timers[SAMPLE_TIMER]: - # Old RolloutWorker based APIs (returning SampleBatch/MultiAgentBatch). - if use_rollout_worker: - if self.config.count_steps_by == "agent_steps": - train_batch = synchronous_parallel_sample( - worker_set=self.workers, - max_agent_steps=self.config.train_batch_size, - ) - else: - train_batch = synchronous_parallel_sample( - worker_set=self.workers, - max_env_steps=self.config.train_batch_size, - ) - # New Episode-returning EnvRunner API. + # TODO (sven): Make this also use `synchronous_parallel_sample`. + # Which needs to be enhanced to be able to handle episodes as well. + # Also, this would make this sampling with the EnvRunners fault + # tolerant, which it is NOT right now. + if self.workers.num_remote_workers() == 0: + episodes: List[SingleAgentEpisode] = [ + self.workers.local_worker().sample() + ] else: - if self.workers.num_remote_workers() <= 0: - episodes: List[SingleAgentEpisode] = [ - self.workers.local_worker().sample() - ] - else: - episodes: List[SingleAgentEpisode] = self.workers.foreach_worker( - lambda w: w.sample(), local_worker=False - ) - # Perform PPO postprocessing on a (flattened) list of Episodes. - postprocessed_episodes: List[ - SingleAgentEpisode - ] = self.postprocess_episodes(tree.flatten(episodes)) - # Convert list of postprocessed Episodes into a single sample batch. - train_batch: SampleBatch = postprocess_episodes_to_sample_batch( - postprocessed_episodes + episodes: List[SingleAgentEpisode] = self.workers.foreach_worker( + lambda w: w.sample(), local_worker=False + ) + episodes = tree.flatten(episodes) + # TODO (sven): single- vs multi-agent. + self._counters[NUM_AGENT_STEPS_SAMPLED] += sum(len(e) for e in episodes) + self._counters[NUM_ENV_STEPS_SAMPLED] += sum(len(e) for e in episodes) + + # Perform a train step on the collected batch. + train_results = self.learner_group.update_from_episodes( + episodes=episodes, + minibatch_size=( + self.config.mini_batch_size_per_learner + or self.config.sgd_minibatch_size + ), + num_iters=self.config.num_sgd_iter, + ) + + # The train results's loss keys are pids to their loss values. But we also + # return a total_loss key at the same level as the pid keys. So we need to + # subtract that to get the total set of pids to update. + # TODO (Kourosh): We should also not be using train_results as a message + # passing medium to infer which policies to update. We could use + # policies_to_train variable that is given by the user to infer this. + policies_to_update = set(train_results.keys()) - {ALL_MODULES} + + # Update weights - after learning on the local worker - on all remote + # workers. + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + if self.workers.num_remote_workers() > 0: + self.workers.sync_weights( + # Sync weights from learner_group to all rollout workers. + from_worker_or_learner_group=self.learner_group, + policies=policies_to_update, + global_vars=None, ) + else: + weights = self.learner_group.get_weights() + self.workers.local_worker().set_weights(weights) + + kl_dict = {} + if self.config.use_kl_loss: + for pid in policies_to_update: + kl = train_results[pid][LEARNER_RESULTS_KL_KEY] + kl_dict[pid] = kl + if np.isnan(kl): + logger.warning( + f"KL divergence for Module {pid} is non-finite, this will " + "likely destabilize your model and the training process. " + "Action(s) in a specific state have near-zero probability. " + "This can happen naturally in deterministic environments " + "where the optimal policy has zero mass for a specific " + "action. To fix this issue, consider setting `kl_coeff` to " + "0.0 or increasing `entropy_coeff` in your config." + ) + + # triggers a special update method on RLOptimizer to update the KL values. + additional_results = self.learner_group.additional_update( + module_ids_to_update=policies_to_update, + sampled_kl_values=kl_dict, + timestep=self._counters[NUM_AGENT_STEPS_SAMPLED], + ) + for pid, res in additional_results.items(): + train_results[pid].update(res) + + return train_results - train_batch = train_batch.as_multi_agent() - self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() - self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() + def _training_step_old_and_hybrid_api_stacks(self) -> ResultDict: + # Collect SampleBatches from sample workers until we have a full batch. + with self._timers[SAMPLE_TIMER]: + if self.config.count_steps_by == "agent_steps": + train_batch = synchronous_parallel_sample( + worker_set=self.workers, + max_agent_steps=self.config.total_train_batch_size, + ) + else: + train_batch = synchronous_parallel_sample( + worker_set=self.workers, + max_env_steps=self.config.total_train_batch_size, + ) + train_batch = train_batch.as_multi_agent() + self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() + self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() + # Standardize advantages. + train_batch = standardize_fields(train_batch, ["advantages"]) - # Standardize advantages. - train_batch = standardize_fields(train_batch, ["advantages"]) - # Train + # Perform a train step on the collected batch. if self.config._enable_new_api_stack: - # TODO (Kourosh) Clearly define what train_batch_size - # vs. sgd_minibatch_size and num_sgd_iter is in the config. + mini_batch_size_per_learner = ( + self.config.mini_batch_size_per_learner + or self.config.sgd_minibatch_size + ) train_results = self.learner_group.update_from_batch( batch=train_batch, - minibatch_size=self.config.sgd_minibatch_size, + minibatch_size=mini_batch_size_per_learner, num_iters=self.config.num_sgd_iter, ) @@ -446,18 +537,15 @@ def training_step(self) -> ResultDict: else: policies_to_update = list(train_results.keys()) - # TODO (Kourosh): num_grad_updates per each policy should be accessible via - # train_results. - if not use_rollout_worker: - global_vars = None - else: - global_vars = { - "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], - "num_grad_updates_per_policy": { - pid: self.workers.local_worker().policy_map[pid].num_grad_updates - for pid in policies_to_update - }, - } + global_vars = { + "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], + # TODO (sven): num_grad_updates per each policy should be + # accessible via `train_results` (and get rid of global_vars). + "num_grad_updates_per_policy": { + pid: self.workers.local_worker().policy_map[pid].num_grad_updates + for pid in policies_to_update + }, + } # Update weights - after learning on the local worker - on all remote # workers. @@ -550,24 +638,3 @@ def training_step(self) -> ResultDict: self.workers.local_worker().set_global_vars(global_vars) return train_results - - def postprocess_episodes( - self, episodes: List[SingleAgentEpisode] - ) -> List[SingleAgentEpisode]: - """Calculate advantages and value targets.""" - from ray.rllib.evaluation.postprocessing_v2 import compute_gae_for_episode - - # Bootstrap values. - postprocessed_episodes = [] - # TODO (simon): Remove somehow the double list. - # episodes = [episode for episode_list in episodes for episode in episode_list] - for episode in episodes: - # TODO (sven): Calling 'module' on the 'EnvRunner' only works - # for the 'SingleAgentEnvRunner' not for 'MultiAgentEnvRunner'. - postprocessed_episodes.append( - compute_gae_for_episode( - episode, self.config, self.workers.local_worker().module - ) - ) - - return postprocessed_episodes diff --git a/rllib/algorithms/ppo/ppo_learner.py b/rllib/algorithms/ppo/ppo_learner.py index 63d01906c1e24..050390bd271d8 100644 --- a/rllib/algorithms/ppo/ppo_learner.py +++ b/rllib/algorithms/ppo/ppo_learner.py @@ -1,12 +1,25 @@ +import abc from typing import Any, Dict +import numpy as np + from ray.rllib.algorithms.ppo.ppo import ( LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, PPOConfig, ) from ray.rllib.core.learner.learner import Learner +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets +from ray.rllib.utils.postprocessing.episodes import ( + add_one_ts_to_episodes_and_truncate, + remove_last_ts_from_data, + remove_last_ts_from_episodes_and_restore_truncateds, +) +from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ModuleID @@ -39,6 +52,78 @@ def build(self) -> None: ) ) + @override(Learner) + def _preprocess_train_data( + self, + *, + batch, + episodes, + ): + batch = batch or {} + if not episodes: + return batch, episodes + + # Make all episodes one ts longer in order to just have a single batch + # (and distributed forward pass) for both vf predictions AND the bootstrap + # vf computations. + orig_truncateds = add_one_ts_to_episodes_and_truncate(episodes) + episode_lens_p1 = [len(e) for e in episodes] + + # Call the learner connector (on the artificially elongated episodes) + # in order to get the batch to pass through the module for vf (and + # bootstrapped vf) computations. + batch_for_vf = self._learner_connector( + rl_module=self.module["default_policy"], # TODO: make multi-agent capable + data={}, + episodes=episodes, + ) + # Perform the value model's forward pass. + vf_preds = convert_to_numpy(self._compute_values(batch_for_vf)) + # Remove all zero-padding again, if applicable for the upcoming + # GAE computations. + vf_preds = unpad_data_if_necessary(episode_lens_p1, vf_preds) + # Compute value targets. + value_targets = compute_value_targets( + values=vf_preds, + rewards=unpad_data_if_necessary( + episode_lens_p1, batch_for_vf[SampleBatch.REWARDS] + ), + terminateds=unpad_data_if_necessary( + episode_lens_p1, batch_for_vf[SampleBatch.TERMINATEDS] + ), + truncateds=unpad_data_if_necessary( + episode_lens_p1, batch_for_vf[SampleBatch.TRUNCATEDS] + ), + gamma=self.config.gamma, + lambda_=self.config.lambda_, + ) + + # Remove the extra timesteps again from vf_preds and value targets. Now that + # the GAE computation is done, we don't need this last timestep anymore in any + # of our data. + ( + batch[SampleBatch.VF_PREDS], + batch[Postprocessing.VALUE_TARGETS], + ) = remove_last_ts_from_data( + episode_lens_p1, + vf_preds, + value_targets, + ) + advantages = batch[Postprocessing.VALUE_TARGETS] - batch[SampleBatch.VF_PREDS] + # Standardize advantages (used for more stable and better weighted + # policy gradient computations). + batch[Postprocessing.ADVANTAGES] = (advantages - advantages.mean()) / max( + 1e-4, advantages.std() + ) + + # Remove the extra (artificial) timesteps again at the end of all episodes. + remove_last_ts_from_episodes_and_restore_truncateds( + episodes, + orig_truncateds, + ) + + return batch, episodes + @override(Learner) def remove_module(self, module_id: str): super().remove_module(module_id) @@ -68,3 +153,17 @@ def additional_update_for_module( results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: new_entropy_coeff}) return results + + @abc.abstractmethod + def _compute_values(self, batch) -> np._typing.NDArray: + """Computes the values using the value function module given a batch of data. + + Args: + batch: The input batch to pass through our RLModule (value function + encoder and vf-head). + + Returns: + The batch (numpy) of value function outputs (already squeezed over the last + dimension (which should have shape (1,) b/c of the single value output + node). + """ diff --git a/rllib/algorithms/ppo/ppo_rl_module.py b/rllib/algorithms/ppo/ppo_rl_module.py index 1a95bd3429ffe..84084a4c6fb64 100644 --- a/rllib/algorithms/ppo/ppo_rl_module.py +++ b/rllib/algorithms/ppo/ppo_rl_module.py @@ -45,7 +45,7 @@ def get_initial_state(self) -> dict: @override(RLModule) def input_specs_inference(self) -> SpecDict: - return self.input_specs_exploration() + return [SampleBatch.OBS] @override(RLModule) def output_specs_inference(self) -> SpecDict: @@ -53,14 +53,11 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def input_specs_exploration(self): - return [SampleBatch.OBS] + return self.input_specs_inference() @override(RLModule) def output_specs_exploration(self) -> SpecDict: - return [ - SampleBatch.VF_PREDS, - SampleBatch.ACTION_DIST_INPUTS, - ] + return self.output_specs_inference() @override(RLModule) def input_specs_train(self) -> SpecDict: diff --git a/rllib/algorithms/ppo/tests/test_ppo.py b/rllib/algorithms/ppo/tests/test_ppo.py index 5064b9a111a97..0fb681c8cffb0 100644 --- a/rllib/algorithms/ppo/tests/test_ppo.py +++ b/rllib/algorithms/ppo/tests/test_ppo.py @@ -274,12 +274,11 @@ def test_ppo_exploration_setup(self): """Tests, whether PPO runs with different exploration setups.""" config = ( ppo.PPOConfig() - .experimental(_enable_new_api_stack=True) + # .experimental(_enable_new_api_stack=True) .environment( "FrozenLake-v1", env_config={"is_slippery": False, "map_name": "4x4"}, - ) - .rollouts( + ).rollouts( # Run locally. num_rollout_workers=0, ) @@ -287,7 +286,7 @@ def test_ppo_exploration_setup(self): obs = np.array(0) # Test against all frameworks. - for fw in framework_iterator(config): + for fw, sess in framework_iterator(config, session=True): # Default Agent should be setup with StochasticSampling. algo = config.build() # explore=False, always expect the same (deterministic) action. @@ -295,16 +294,6 @@ def test_ppo_exploration_setup(self): obs, explore=False, prev_action=np.array(2), prev_reward=np.array(1.0) ) - # Test whether this is really the argmax action over the logits. - # TODO (Kourosh): Only meaningful in the ModelV2 stack. - config.validate() - if not config._enable_new_api_stack and fw != "tf": - last_out = algo.get_policy().model.last_output() - if fw == "torch": - check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0]) - else: - check(a_, np.argmax(last_out.numpy(), 1)[0]) - for _ in range(50): a = algo.compute_single_action( obs, diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 517d03b3def22..47aa5b1d1832b 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -103,7 +103,7 @@ def test_ppo_compilation_and_schedule_mixins(self): # TODO (Kourosh) Bring back "FrozenLake-v1" for env in ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]: print("Env={}".format(env)) - for lstm in [True, False]: + for lstm in [False]: print("LSTM={}".format(lstm)) config.training(model=get_model_config(fw, lstm=lstm)) @@ -144,8 +144,7 @@ def test_ppo_exploration_setup(self): ) .rollouts( # Run locally. - num_rollout_workers=1, - enable_connectors=True, + num_rollout_workers=0, ) ) obs = np.array(0) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_learner.py b/rllib/algorithms/ppo/tf/ppo_tf_learner.py index 18cacb3dfd679..9ac6150b0c235 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_learner.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_learner.py @@ -1,6 +1,7 @@ import logging from typing import Any, Dict +import tree from ray.rllib.algorithms.ppo.ppo import ( LEARNER_RESULTS_KL_KEY, LEARNER_RESULTS_CURR_KL_COEFF_KEY, @@ -11,8 +12,9 @@ from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY from ray.rllib.core.learner.tf.tf_learner import TfLearner +from ray.rllib.core.models.base import ENCODER_OUT, CRITIC from ray.rllib.evaluation.postprocessing import Postprocessing -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.annotations import override @@ -174,3 +176,20 @@ def additional_update_for_module( results.update({LEARNER_RESULTS_CURR_KL_COEFF_KEY: curr_var.numpy()}) return results + + @override(PPOLearner) + def _compute_values(self, batch): + infos = batch.pop(SampleBatch.INFOS, None) + batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch) + if infos is not None: + batch[SampleBatch.INFOS] = infos + + # TODO (sven): Make multi-agent capable. + module = self.module[DEFAULT_POLICY_ID].unwrapped() + + # Shared encoder. + encoder_outs = module.encoder(batch) + # Value head. + vf_out = module.vf(encoder_outs[ENCODER_OUT][CRITIC]) + # Squeeze out last dimension (single node value head). + return tf.squeeze(vf_out, -1) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 2b30c810568da..4db309180f5de 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -40,6 +40,11 @@ def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: the policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ + # TODO (sven): Make this the only bahevior once PPO has been migrated + # to new API stack (including EnvRunners!). + if self.config.model_config_dict.get("uses_new_env_runners"): + return self._forward_inference(batch=batch) + output = {} # Shared encoder diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index bee86a97add3b..942f745c904be 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -9,15 +9,15 @@ PPOConfig, ) from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner -from ray.rllib.utils.torch_utils import sequence_mask from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.core.models.base import ENCODER_OUT, CRITIC from ray.rllib.evaluation.postprocessing import Postprocessing -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.utils.torch_utils import explained_variance +from ray.rllib.utils.torch_utils import convert_to_torch_tensor, explained_variance from ray.rllib.utils.typing import ModuleID, TensorType torch, nn = try_import_torch() @@ -41,25 +41,20 @@ def compute_loss_for_module( fwd_out: Dict[str, TensorType], ) -> TensorType: # TODO (Kourosh): batch type is NestedDict. - # TODO (Kourosh): We may or may not user module_id. For example if we have an - # agent based learning rate scheduler, we may want to use module_id to get the - # learning rate for that agent. - - # RNN case: Mask away 0-padded chunks at end of time axis. - if self.module[module_id].is_stateful(): - # In the RNN case, we expect incoming tensors to be padded to the maximum - # sequence length. We infer the max sequence length from the actions - # tensor. - maxlen = torch.max(batch[SampleBatch.SEQ_LENS]) - mask = sequence_mask(batch[SampleBatch.SEQ_LENS], maxlen=maxlen) - num_valid = torch.sum(mask) - - def possibly_masked_mean(t): - return torch.sum(t[mask]) / num_valid - - # non-RNN case: No masking. + + # Possibly apply masking to some sub loss terms and to the total loss term + # at the end. Masking could be used for RNN-based model (zero padded `batch`) + # and for PPO's batched value function (and bootstrap value) computations, + # for which we add an additional (artificial) timestep to each episode to + # simplify the actual computation. + if "loss_mask" in batch: + num_valid = torch.sum(batch["loss_mask"]) + + def possibly_masked_mean(data_): + return torch.sum(data_[batch["loss_mask"]]) / num_valid + else: - mask = None + possibly_masked_mean = torch.mean action_dist_class_train = ( @@ -171,3 +166,20 @@ def additional_update_for_module( results.update({LEARNER_RESULTS_CURR_KL_COEFF_KEY: curr_var.item()}) return results + + @override(PPOLearner) + def _compute_values(self, batch): + infos = batch.pop(SampleBatch.INFOS, None) + batch = convert_to_torch_tensor(batch, device=self._device) + if infos is not None: + batch[SampleBatch.INFOS] = infos + + # TODO (sven): Make multi-agent capable. + module = self.module[DEFAULT_POLICY_ID].unwrapped() + + # Shared encoder. + encoder_outs = module.encoder(batch) + # Value head. + vf_out = module.vf(encoder_outs[ENCODER_OUT][CRITIC]) + # Squeeze out last dimension (single node value head). + return vf_out.squeeze(-1) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 745f45bb603f6..eda01130646b3 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -40,6 +40,11 @@ def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: the policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ + # TODO (sven): Make this the only bahevior once PPO has been migrated + # to new API stack (including EnvRunners!). + if self.config.model_config_dict.get("uses_new_env_runners"): + return self._forward_inference(batch) + output = {} # Shared encoder diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index fb76293b2a188..dcb8214cc93a2 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -15,18 +15,25 @@ Sequence, Set, Tuple, + TYPE_CHECKING, Union, ) import ray -from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.connectors.learner.learner_connector_pipeline import ( + LearnerConnectorPipeline, +) from ray.rllib.core.learner.reduce_result_dict_fn import _reduce_mean_results from ray.rllib.core.rl_module.marl_module import ( MultiAgentRLModule, MultiAgentRLModuleSpec, ) from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch +from ray.rllib.policy.sample_batch import ( + DEFAULT_POLICY_ID, + MultiAgentBatch, + SampleBatch, +) from ray.rllib.utils.annotations import ( OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, @@ -60,6 +67,9 @@ ) from ray.util.annotations import PublicAPI +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + torch, _ = try_import_torch() tf1, tf, tfv = try_import_tf() @@ -205,7 +215,7 @@ def compute_loss(self, fwd_out, batch): def __init__( self, *, - config: AlgorithmConfig, + config: "AlgorithmConfig", module_spec: Optional[ Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec] ] = None, @@ -254,6 +264,8 @@ def __init__( # The actual MARLModule used by this Learner. self._module: Optional[MultiAgentRLModule] = None + # Our Learner connector pipeline. + self._learner_connector: Optional[LearnerConnectorPipeline] = None # These are set for properly applying optimizers and adding or removing modules. self._optimizer_parameters: Dict[Optimizer, List[ParamRef]] = {} self._named_optimizers: Dict[str, Optimizer] = {} @@ -272,6 +284,40 @@ def __init__( # the final results dict in the `self.compile_update_results()` method. self._metrics = defaultdict(dict) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def build(self) -> None: + """Builds the Learner. + + This method should be called before the learner is used. It is responsible for + setting up the RLModule, optimizers, and (optionally) their lr-schedulers. + """ + if self._is_built: + logger.debug("Learner already built. Skipping build.") + return + + # Build learner connector pipeline used on this Learner worker. + # TODO (sven): Support multi-agent cases. + if self.config.uses_new_env_runners and not self.config.is_multi_agent(): + module_spec = self._module_spec.as_multi_agent().module_specs[ + DEFAULT_POLICY_ID + ] + self._learner_connector = self.config.build_learner_connector( + input_observation_space=module_spec.observation_space, + input_action_space=module_spec.action_space, + ) + # Adjust module spec based on connector's (possibly transformed) spaces. + module_spec.observation_space = self._learner_connector.observation_space + module_spec.action_space = self._learner_connector.action_space + + # Build the module to be trained by this learner. + self._module = self._make_module() + + # Configure, construct, and register all optimizers needed to train + # `self.module`. + self.configure_optimizers() + + self._is_built = True + @property def distributed(self) -> bool: """Whether the learner is running in distributed mode.""" @@ -387,7 +433,7 @@ def configure_optimizers(self) -> None: @OverrideToImplementCustomLogic @abc.abstractmethod def configure_optimizers_for_module( - self, module_id: ModuleID, config: AlgorithmConfig = None, hps=None + self, module_id: ModuleID, config: "AlgorithmConfig" = None, hps=None ) -> None: """Configures an optimizer for the given module_id. @@ -479,7 +525,7 @@ def postprocess_gradients_for_module( self, *, module_id: ModuleID, - config: AlgorithmConfig = None, + config: Optional["AlgorithmConfig"] = None, module_gradients_dict: ParamDict, hps=None, ) -> ParamDict: @@ -823,25 +869,6 @@ def remove_module(self, module_id: ModuleID) -> None: self.module.remove_module(module_id) - @OverrideToImplementCustomLogic_CallToSuperRecommended - def build(self) -> None: - """Builds the Learner. - - This method should be called before the learner is used. It is responsible for - setting up the RLModule, optimizers, and (optionally) their lr-schedulers. - """ - if self._is_built: - logger.debug("Learner already built. Skipping build.") - return - self._is_built = True - - # Build the module to be trained by this learner. - self._module = self._make_module() - - # Configure, construct, and register all optimizers needed to train - # `self.module`. - self.configure_optimizers() - @OverrideToImplementCustomLogic def compute_loss( self, @@ -901,7 +928,7 @@ def compute_loss_for_module( self, *, module_id: ModuleID, - config: AlgorithmConfig = None, + config: Optional["AlgorithmConfig"] = None, batch: NestedDict, fwd_out: Dict[str, TensorType], ) -> TensorType: @@ -1053,7 +1080,7 @@ def additional_update_for_module( self, *, module_id: ModuleID, - config: AlgorithmConfig = None, + config: Optional["AlgorithmConfig"] = None, timestep: int, hps=None, **kwargs, @@ -1333,12 +1360,23 @@ def _update_from_batch_or_episodes( # We must do at least one pass on the batch for training. raise ValueError("`num_iters` must be >= 1") - # Call the train data preprocessor. - batch, episodes = self._preprocess_train_data(batch=batch, episodes=episodes) - - # TODO (sven): Insert a call to the Learner ConnectorV2 pipeline here, providing - # it both `batch` and `episode` for further custom processing before the - # actual `Learner._update()` call. + # Call the learner connector. + # TODO (sven): make multi-agent capable. + if self._learner_connector is not None: + # Call the train data preprocessor. + batch, episodes = self._preprocess_train_data( + batch=batch, episodes=episodes + ) + batch = self._learner_connector( + rl_module=self.module["default_policy"], + data=batch, + episodes=episodes, + ) + if episodes is not None: + batch = MultiAgentBatch( + policy_batches={DEFAULT_POLICY_ID: SampleBatch(batch)}, + env_steps=sum(len(e) for e in episodes), + ) if minibatch_size: batch_iter = MiniBatchCyclicIterator diff --git a/rllib/core/rl_module/marl_module.py b/rllib/core/rl_module/marl_module.py index e3c597bdebb5c..213e82fc8816f 100644 --- a/rllib/core/rl_module/marl_module.py +++ b/rllib/core/rl_module/marl_module.py @@ -14,10 +14,6 @@ Union, ) -from ray.util.annotations import PublicAPI -from ray.rllib.utils.annotations import override, ExperimentalAPI -from ray.rllib.utils.nested_dict import NestedDict - from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.core.rl_module.rl_module import ( @@ -28,10 +24,16 @@ ) # TODO (Kourosh): change this to module_id later to enforce consistency -from ray.rllib.utils.annotations import OverrideToImplementCustomLogic +from ray.rllib.utils.annotations import ( + ExperimentalAPI, + override, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.policy import validate_policy_id from ray.rllib.utils.serialization import serialize_type, deserialize_type from ray.rllib.utils.typing import T +from ray.util.annotations import PublicAPI ModuleID = str @@ -418,7 +420,6 @@ def _check_module_exists(self, module_id: ModuleID) -> None: class MultiAgentRLModuleSpec: """A utility spec class to make it constructing MARL modules easier. - Users can extend this class to modify the behavior of base class. For example to share neural networks across the modules, the build method can be overriden to create the shared module first and then pass it to custom module classes that would @@ -463,9 +464,7 @@ def get_marl_config(self) -> "MultiAgentRLModuleConfig": return MultiAgentRLModuleConfig(modules=self.module_specs) @OverrideToImplementCustomLogic - def build( - self, module_id: Optional[ModuleID] = None - ) -> Union[SingleAgentRLModuleSpec, "MultiAgentRLModule"]: + def build(self, module_id: Optional[ModuleID] = None) -> RLModule: """Builds either the multi-agent module or the single-agent module. If module_id is None, it builds the multi-agent module. Otherwise, it builds @@ -484,9 +483,11 @@ def build( """ self._check_before_build() + # ModuleID provided, return single-agent RLModule. if module_id: return self.module_specs[module_id].build() + # Return MultiAgentRLModule. module_config = self.get_marl_config() module = self.marl_module_class(module_config) return module @@ -585,6 +586,10 @@ def update(self, other: "MultiAgentRLModuleSpec", overwrite=False) -> None: else: self.module_specs.update(other.module_specs) + def as_multi_agent(self) -> "MultiAgentRLModuleSpec": + """Returns self to match `SingleAgentRLModuleSpec.as_multi_agent()`.""" + return self + @ExperimentalAPI @dataclass diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 2c9047da4224d..a2d62b5f6a6ab 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -162,6 +162,15 @@ def update(self, other) -> None: self.catalog_class = other.catalog_class or self.catalog_class self.load_state_path = other.load_state_path or self.load_state_path + def as_multi_agent(self) -> "MultiAgentRLModuleSpec": + """Returns a MultiAgentRLModuleSpec (`self` under DEFAULT_POLICY_ID key).""" + from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec + + return MultiAgentRLModuleSpec( + module_specs={DEFAULT_POLICY_ID: self}, + load_state_path=self.load_state_path, + ) + @ExperimentalAPI @dataclass diff --git a/rllib/evaluation/postprocessing_v2.py b/rllib/evaluation/postprocessing_v2.py index 98a5058330f12..1ad383c1c37e1 100644 --- a/rllib/evaluation/postprocessing_v2.py +++ b/rllib/evaluation/postprocessing_v2.py @@ -1,207 +1,35 @@ -from typing import List - import numpy as np -import tree # pip install dm_tree - -from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.core.models.base import STATE_IN -from ray.rllib.core.rl_module.rl_module import RLModule -from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.evaluation.postprocessing import discount_cumsum -from ray.rllib.policy.sample_batch import concat_samples, SampleBatch -from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.torch_utils import convert_to_torch_tensor -from ray.rllib.utils.typing import TensorType - -_, tf, _ = try_import_tf() - -@DeveloperAPI -class Postprocessing: - """Constant definitions for postprocessing.""" - ADVANTAGES = "advantages" - VALUE_TARGETS = "value_targets" - - -@DeveloperAPI -def postprocess_episodes_to_sample_batch( - episodes: List[SingleAgentEpisode], -) -> SampleBatch: - """Converts the results from sampling with an `EnvRunner` to one `SampleBatch'. - - Once the `SampleBatch` will be deprecated this function will be - deprecated, too. - """ - batches = [] - - for episode_or_list in episodes: - # Without postprocessing (explore=True), we could have - # a list. - if isinstance(episode_or_list, list): - for episode in episode_or_list: - batches.append(episode.get_sample_batch()) - # During exploration we have an episode. - else: - batches.append(episode_or_list.get_sample_batch()) - - batch = concat_samples(batches) - # TODO (sven): During evalaution we do not have infos at all. - # On the other side, if we leave in infos in training, conversion - # to tensors throws an exception. - if SampleBatch.INFOS in batch.keys(): - del batch[SampleBatch.INFOS] - # Return the SampleBatch. - return batch - - -@DeveloperAPI -def compute_gae_for_episode( - episode: SingleAgentEpisode, - config: AlgorithmConfig, - module: RLModule, +def compute_value_targets( + values, + rewards, + terminateds, + truncateds, + gamma: float, + lambda_: float, ): - """Adds GAE to a trajectory.""" - # TODO (simon): All of this can be batched over multiple episodes. - # This should increase performance. - # TODO (sven): Shall do postprocessing in the training_step or - # in the env_runner? Here we could batch over episodes as we have - # them now in the training_step. - episode = compute_bootstrap_value(episode, module) - - vf_preds = episode.get_extra_model_outputs(SampleBatch.VF_PREDS) - rewards = episode.get_rewards() - - # TODO (simon): In case of recurrent models sequeeze out time dimension. - - episode = compute_advantages( - episode, - last_r=episode.extra_model_outputs[SampleBatch.VALUES_BOOTSTRAPPED][-1], - gamma=config["gamma"], - lambda_=config["lambda"], - use_gae=config["use_gae"], - use_critic=config.get("use_critic", True), - vf_preds=vf_preds, - rewards=rewards, - ) + """Computes value function (vf) targets given vf predictions and rewards. - # TODO (simon): Add dimension in case of recurrent model. - return episode - - -def compute_bootstrap_value( - episode: SingleAgentEpisode, module: RLModule -) -> SingleAgentEpisode: - if episode.is_terminated: - last_r = 0.0 - else: - # TODO (simon): This has to be made multi-agent ready. - # TODO (sven, simon): We have to change this as soon as the - # Connector API is ready. Episodes do not have states anymore. - initial_states = module.get_initial_state() - state = { - k: initial_states[k] if episode.states is None else episode.states[k] - for k in initial_states.keys() - } - - input_dict = { - STATE_IN: tree.map_structure( - lambda s: convert_to_torch_tensor(s) - if module.framework == "torch" - else tf.convert_to_tensor(s), - state, - ), - SampleBatch.OBS: convert_to_torch_tensor( - np.expand_dims(episode.observations[-1], axis=0) - ) - if module.framework == "torch" - else tf.convert_to_tensor(np.expand_dims(episode.observations[-1], axis=0)), - } - - # TODO (simon): Torch might need the correct device. - - # TODO (sven): If we want to get rid of the policy in the future - # what should we do for adding the time dimension? - # TODO (simon): Add support for recurrent models. - - input_dict = NestedDict(input_dict) - fwd_out = module.forward_exploration(input_dict) - # TODO (simon): Remove time dimension in case of recurrent model. - last_r = fwd_out[SampleBatch.VF_PREDS][-1] - - vf_preds = episode.extra_model_outputs[SampleBatch.VF_PREDS] - # TODO (simon): Squeeze out the time dimension in case of recurrent model. - episode.extra_model_outputs[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate( - [ - vf_preds[1:], - np.array([convert_to_numpy(last_r)], dtype=np.float32), - ], - axis=0, - ) - - # TODO (simon): Unsqueeze in case of recurrent model. - - return episode - - -def compute_advantages( - episode: SingleAgentEpisode, - last_r: float, - gamma: float = 0.9, - lambda_: float = 1.0, - use_critic: bool = True, - use_gae: bool = True, - rewards: TensorType = None, - vf_preds: TensorType = None, -): - assert ( - SampleBatch.VF_PREDS in episode.extra_model_outputs or not use_critic - ), "use_critic=True but values not found" - assert use_critic or not use_gae, "Can't use gae without using a value function." - # TODO (simon): Check if we need conversion here. - last_r = convert_to_numpy(last_r) - - if rewards is None: - rewards = episode.get_rewards() - if vf_preds is None: - vf_preds = episode.get_extra_model_outs(SampleBatch.VF_PREDS) + Note that advantages can then easily be computeed via the formula: + advantages = targets - vf_predictions + """ + # Force-set all values at terminals (not at truncations!) to 0.0. + orig_values = flat_values = values * (1.0 - terminateds) - if use_gae: - vpred_t = np.concatenate([vf_preds, np.array([last_r])]) - delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] - # This formula for the advantage comes from: - # Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 - episode.extra_model_outputs[Postprocessing.ADVANTAGES] = discount_cumsum( - delta_t, gamma * lambda_ - ) - episode.extra_model_outputs[Postprocessing.VALUE_TARGETS] = ( - episode.extra_model_outputs[Postprocessing.ADVANTAGES] + vf_preds - ).astype(np.float32) - else: - rewards_plus_v = np.concatenate([rewards, np.array([last_r])]) - discounted_returns = discount_cumsum(rewards_plus_v, gamma)[:-1].astype( - np.float32 - ) + flat_values = np.append(flat_values, 0.0) + intermediates = rewards + gamma * (1 - lambda_) * flat_values[1:] + continues = 1.0 - terminateds - if use_critic: - episode.extra_model_outputs[Postprocessing.ADVANTAGES] = ( - discounted_returns - vf_preds - ) - episode.extra_model_outputs[ - Postprocessing.VALUE_TARGETS - ] = discounted_returns - else: - episode.extra_model_outputs[Postprocessing.ADVANTAGES] = discounted_returns - episode.extra_model_outputs[Postprocessing.VALUE_TARGETS] = np.zeros_like( - episode.extra_model_outputs[Postprocessing.ADVANTAGES] - ) + Rs = [] + last = flat_values[-1] + for t in reversed(range(intermediates.shape[0])): + last = intermediates[t] + continues[t] * gamma * lambda_ * last + Rs.append(last) + if truncateds[t]: + last = orig_values[t] - # TODO (sven, simon): Maybe change to `BufferWithInfiniteLookback` - episode.extra_model_outputs[ - Postprocessing.ADVANTAGES - ] = episode.extra_model_outputs[Postprocessing.ADVANTAGES].astype(np.float32) + # Reverse back to correct (time) direction. + value_targets = np.stack(list(reversed(Rs)), axis=0) - return episode + return value_targets.astype(np.float32) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index ce1be13994cd8..f53dc1954efec 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -1169,34 +1169,15 @@ def map_(path, value): if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith( "state_in_" ): - if path[0] != SampleBatch.INFOS: - return value[start_padded:stop_padded] - else: - if ( - (isinstance(value, np.ndarray) and value.size > 0) - or ( - torch - and torch.is_tensor(value) - and len(list(value.shape)) > 0 - ) - or (tf and tf.is_tensor(value) and tf.size(value) > 0) - ): - return value[start_unpadded:stop_unpadded] - else: - # Since infos should be stored as lists and not arrays, - # we return the values here and slice them separately - # TODO(Artur): Clean this hack up. - return value + return value[start_padded:stop_padded] else: return value[start_seq_len:stop_seq_len] + infos = self.pop(SampleBatch.INFOS, None) data = tree.map_structure_with_path(map_, self) - - # Since we don't slice in the above map_ function, we do it here. - if isinstance(data.get(SampleBatch.INFOS), list): - data[SampleBatch.INFOS] = data[SampleBatch.INFOS][ - start_unpadded:stop_unpadded - ] + if infos is not None and isinstance(infos, (list, np.ndarray)): + self[SampleBatch.INFOS] = infos + data[SampleBatch.INFOS] = infos[start_unpadded:stop_unpadded] return SampleBatch( data, @@ -1207,21 +1188,11 @@ def map_(path, value): _num_grad_updates=self.num_grad_updates, ) else: - - def map_(value): - if ( - isinstance(value, np.ndarray) - or (torch and torch.is_tensor(value)) - or (tf and tf.is_tensor(value)) - ): - return value[start:stop] - else: - # Since infos should be stored as lists and not arrays, - # we return the values here and slice them separately - # TODO(Artur): Clean this hack up. - return value - - data = tree.map_structure(map_, self) + infos = self.pop(SampleBatch.INFOS, None) + data = tree.map_structure(lambda s: s[start:stop], self) + if infos is not None and isinstance(infos, (list, np.ndarray)): + self[SampleBatch.INFOS] = infos + data[SampleBatch.INFOS] = infos[start:stop] return SampleBatch( data, diff --git a/rllib/tests/test_supported_spaces.py b/rllib/tests/test_supported_spaces.py index 72f5eeb4ffd37..8451931231cf8 100644 --- a/rllib/tests/test_supported_spaces.py +++ b/rllib/tests/test_supported_spaces.py @@ -44,11 +44,8 @@ def tearDownClass(cls) -> None: def test_appo(self): config = ( - APPOConfig() - .resources(num_gpus=0) - .training(vtrace=False, model={"fcnet_hiddens": [10]}) + APPOConfig().resources(num_gpus=0).training(model={"fcnet_hiddens": [10]}) ) - config.training(vtrace=True) check_supported_spaces("APPO", config) diff --git a/rllib/tuned_examples/appo/cartpole-appo-vtrace-fake-gpus.yaml b/rllib/tuned_examples/appo/cartpole-appo-fake-gpus.yaml similarity index 100% rename from rllib/tuned_examples/appo/cartpole-appo-vtrace-fake-gpus.yaml rename to rllib/tuned_examples/appo/cartpole-appo-fake-gpus.yaml diff --git a/rllib/tuned_examples/appo/cartpole-appo-vtrace-separate-losses.py b/rllib/tuned_examples/appo/cartpole-appo-separate-losses.py similarity index 100% rename from rllib/tuned_examples/appo/cartpole-appo-vtrace-separate-losses.py rename to rllib/tuned_examples/appo/cartpole-appo-separate-losses.py diff --git a/rllib/tuned_examples/appo/cartpole-appo-vtrace.yaml b/rllib/tuned_examples/appo/cartpole-appo-vtrace.yaml deleted file mode 100644 index 1c4a9755a2145..0000000000000 --- a/rllib/tuned_examples/appo/cartpole-appo-vtrace.yaml +++ /dev/null @@ -1,20 +0,0 @@ -cartpole-appo-vtrace: - env: CartPole-v1 - run: APPO - stop: - sampler_results/episode_reward_mean: 180 - timesteps_total: 200000 - config: - # Works for both torch and tf. - framework: torch - num_envs_per_worker: 5 - num_workers: 4 - num_gpus: 0 - observation_filter: MeanStdFilter - num_sgd_iter: 1 - vf_loss_coeff: 0.01 - vtrace: true - model: - fcnet_hiddens: [32] - fcnet_activation: linear - vf_share_layers: true diff --git a/rllib/tuned_examples/appo/cartpole-appo-w-rl-modules-and-learner.yaml b/rllib/tuned_examples/appo/cartpole-appo-w-rl-modules-and-learner.yaml index d2fd76b037af9..f706d696a550c 100644 --- a/rllib/tuned_examples/appo/cartpole-appo-w-rl-modules-and-learner.yaml +++ b/rllib/tuned_examples/appo/cartpole-appo-w-rl-modules-and-learner.yaml @@ -16,7 +16,6 @@ cartpole-appo-w-rl-modules-and-learner: observation_filter: MeanStdFilter num_sgd_iter: 6 vf_loss_coeff: 0.01 - vtrace: false model: fcnet_hiddens: [32] fcnet_activation: linear diff --git a/rllib/tuned_examples/appo/cartpole-appo.yaml b/rllib/tuned_examples/appo/cartpole-appo.yaml index 7ad2cc89be117..b6785c0d3eb04 100644 --- a/rllib/tuned_examples/appo/cartpole-appo.yaml +++ b/rllib/tuned_examples/appo/cartpole-appo.yaml @@ -2,20 +2,19 @@ cartpole-appo: env: CartPole-v1 run: APPO stop: - sampler_results/episode_reward_mean: 150 + sampler_results/episode_reward_mean: 180 timesteps_total: 200000 config: # Works for both torch and tf. framework: torch num_envs_per_worker: 5 - num_workers: 1 + num_workers: 4 num_gpus: 0 observation_filter: MeanStdFilter - num_sgd_iter: 6 + num_sgd_iter: 1 vf_loss_coeff: 0.01 - vtrace: false + vtrace: true model: fcnet_hiddens: [32] fcnet_activation: linear vf_share_layers: true - enable_connectors: true diff --git a/rllib/tuned_examples/appo/stateless-cartpole-appo-vtrace.py b/rllib/tuned_examples/appo/stateless_cartpole_appo.py similarity index 100% rename from rllib/tuned_examples/appo/stateless-cartpole-appo-vtrace.py rename to rllib/tuned_examples/appo/stateless_cartpole_appo.py diff --git a/rllib/utils/minibatch_utils.py b/rllib/utils/minibatch_utils.py index ffee2c3c3ad56..15e0f133df885 100644 --- a/rllib/utils/minibatch_utils.py +++ b/rllib/utils/minibatch_utils.py @@ -69,6 +69,15 @@ def __iter__(self): "the same number of samples for each module_id." ) s = self._start[module_id] # start + # TODO (sven): Fix this bug for LSTMs: + # In an RNN-setting, the Learner connector already has zero-padded + # and added a timerank to the batch. Thus, n_step would still be based + # on the BxT dimension, rather than the new B dimension (excluding T), + # which then leads to minibatches way too large. + # However, changing this already would break APPO/IMPALA w/o LSTMs as + # these setups require sequencing, BUT their batches are not yet time- + # ranked (this is done only in their loss functions via the + # `make_time_major` utility). n_steps = self._minibatch_size samples_to_concat = [] diff --git a/rllib/utils/postprocessing/__init__.py b/rllib/utils/postprocessing/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/utils/postprocessing/episodes.py b/rllib/utils/postprocessing/episodes.py new file mode 100644 index 0000000000000..b0209571af601 --- /dev/null +++ b/rllib/utils/postprocessing/episodes.py @@ -0,0 +1,142 @@ +from typing import List, Tuple + +import numpy as np + +from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + +def add_one_ts_to_episodes_and_truncate(episodes: List[SingleAgentEpisode]): + """Adds an artificial timestep to an episode at the end. + + In detail: The last observations, infos, actions, and all `extra_model_outputs` + will be duplicated and appended to each episode's data. An extra 0.0 reward + will be appended to the episode's rewards. The episode's timestep will be + increased by 1. Also, adds the truncated=True flag to each episode if the + episode is not already done (terminated or truncated). + + Useful for value function bootstrapping, where it is required to compute a + forward pass for the very last timestep within the episode, + i.e. using the following input dict: { + obs=[final obs], + state=[final state output], + prev. reward=[final reward], + etc.. + } + + Args: + episodes: The list of SingleAgentEpisode objects to extend by one timestep + and add a truncation flag if necessary. + + Returns: + A list of the original episodes' truncated values (so the episodes can be + properly restored later into their original states). + """ + orig_truncateds = [] + for episode in episodes: + # Make sure the episode is already in numpy format. + assert episode.is_finalized + orig_truncateds.append(episode.is_truncated) + + # Add timestep. + episode.t += 1 + # Use the episode API that allows appending (possibly complex) structs + # to the data. + episode.observations.append(episode.observations[-1]) + episode.infos.append(episode.infos[-1]) + episode.actions.append(episode.actions[-1]) + episode.rewards.append(0.0) + for v in list(episode.extra_model_outputs.values()): + v.append(v[-1]) + # Artificially make this episode truncated for the upcoming GAE + # computations. + if not episode.is_done: + episode.is_truncated = True + # Validate to make sure, everything is in order. + episode.validate() + + return orig_truncateds + + +def remove_last_ts_from_data( + episode_lens: List[int], + *data: Tuple[np._typing.NDArray], +) -> Tuple[np._typing.NDArray]: + """Removes the last timesteps from each given data item. + + Each item in data is a concatenated sequence of episodes data. + For example if `episode_lens` is [2, 4], then data is a shape=(6,) + ndarray. The returned corresponding value will have shape (4,), meaning + both episodes have been shortened by exactly one timestep to 1 and 3. + + ..testcode:: + + from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner + import numpy as np + + unpadded = PPOLearner._remove_last_ts_from_data( + [5, 3], + np.array([0, 1, 2, 3, 4, 0, 1, 2]), + ) + assert (unpadded[0] == [0, 1, 2, 3, 0, 1]).all() + + unpadded = PPOLearner._remove_last_ts_from_data( + [4, 2, 3], + np.array([0, 1, 2, 3, 0, 1, 0, 1, 2]), + np.array([4, 5, 6, 7, 2, 3, 3, 4, 5]), + ) + assert (unpadded[0] == [0, 1, 2, 0, 0, 1]).all() + assert (unpadded[1] == [4, 5, 6, 2, 3, 4]).all() + + Args: + episode_lens: A list of current episode lengths. The returned + data will have the same lengths minus 1 timestep. + data: A tuple of data items (np.ndarrays) representing concatenated episodes + to be shortened by one timestep per episode. + Note that only arrays with `shape=(n,)` are supported! The + returned data will have `shape=(n-len(episode_lens),)` (each + episode gets shortened by one timestep). + + Returns: + A tuple of new data items shortened by one timestep. + """ + # Figure out the new slices to apply to each data item based on + # the given episode_lens. + slices = [] + sum = 0 + for len_ in episode_lens: + slices.append(slice(sum, sum + len_ - 1)) + sum += len_ + # Compiling return data by slicing off one timestep at the end of + # each episode. + ret = [] + for d in data: + ret.append(np.concatenate([d[s] for s in slices])) + return tuple(ret) + + +def remove_last_ts_from_episodes_and_restore_truncateds( + episodes: List[SingleAgentEpisode], + orig_truncateds: List[bool], +) -> None: + """Reverts the effects of `_add_ts_to_episodes_and_truncate`. + + Args: + episodes: The list of SingleAgentEpisode objects to extend by one timestep + and add a truncation flag if necessary. + orig_truncateds: A list of the original episodes' truncated values to be + applied to the `episodes`. + """ + + # Fix all episodes. + for episode, orig_truncated in zip(episodes, orig_truncateds): + # Reduce timesteps by 1. + episode.t -= 1 + # Remove all extra timestep data from the episode's buffers. + episode.observations.pop() + episode.infos.pop() + episode.actions.pop() + episode.rewards.pop() + for v in episode.extra_model_outputs.values(): + v.pop() + # Fix the truncateds flag again. + episode.is_truncated = orig_truncated diff --git a/rllib/utils/postprocessing/tests/__init__.py b/rllib/utils/postprocessing/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/utils/postprocessing/tests/test_value_predictions.py b/rllib/utils/postprocessing/tests/test_value_predictions.py new file mode 100644 index 0000000000000..89d077a1ac0dc --- /dev/null +++ b/rllib/utils/postprocessing/tests/test_value_predictions.py @@ -0,0 +1,47 @@ +import unittest + +from ray.rllib.utils.postprocessing.value_predictions import extract_bootstrapped_values +from ray.rllib.utils.test_utils import check + + +class TestPostprocessing(unittest.TestCase): + def test_extract_bootstrapped_values(self): + """Tests, whether the extract_bootstrapped_values utility works properly.""" + + # Fake vf_preds sequence. + # Spaces = denote (elongated-by-one-artificial-ts) episode boundaries. + # digits = timesteps within the actual episode. + # [lower case letters] = bootstrap values at episode truncations. + # '-' = bootstrap values at episode terminals (these values are simply zero). + sequence = "012345678a 01234A 0- 0123456b 01c 012- 012345e 012-" + sequence = sequence.replace(" ", "") + sequence = list(sequence) + # The actual, non-elongated, episode lengths. + episode_lengths = [9, 5, 1, 7, 2, 3, 6, 3] + T = 4 + result = extract_bootstrapped_values( + vf_preds=sequence, + episode_lengths=episode_lengths, + T=T, + ) + check(result, [4, 8, 3, 1, 5, "c", 1, 5, "-"]) + + # Another example. + sequence = "0123a 012345b 01234567- 012- 012- 012- 012345- 0123456c" + sequence = sequence.replace(" ", "") + sequence = list(sequence) + episode_lengths = [4, 6, 8, 3, 3, 3, 6, 7] + T = 5 + result = extract_bootstrapped_values( + vf_preds=sequence, + episode_lengths=episode_lengths, + T=T, + ) + check(result, [1, "b", 5, 2, 1, 3, 2, "c"]) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/utils/postprocessing/value_predictions.py b/rllib/utils/postprocessing/value_predictions.py new file mode 100644 index 0000000000000..f8eafcdb0a6b8 --- /dev/null +++ b/rllib/utils/postprocessing/value_predictions.py @@ -0,0 +1,103 @@ +import numpy as np + + +def compute_value_targets( + values, + rewards, + terminateds, + truncateds, + gamma: float, + lambda_: float, +): + """Computes value function (vf) targets given vf predictions and rewards. + + Note that advantages can then easily be computeed via the formula: + advantages = targets - vf_predictions + """ + # Force-set all values at terminals (not at truncations!) to 0.0. + orig_values = flat_values = values * (1.0 - terminateds) + + flat_values = np.append(flat_values, 0.0) + intermediates = rewards + gamma * (1 - lambda_) * flat_values[1:] + continues = 1.0 - terminateds + + Rs = [] + last = flat_values[-1] + for t in reversed(range(intermediates.shape[0])): + last = intermediates[t] + continues[t] * gamma * lambda_ * last + Rs.append(last) + if truncateds[t]: + last = orig_values[t] + + # Reverse back to correct (time) direction. + value_targets = np.stack(list(reversed(Rs)), axis=0) + + return value_targets.astype(np.float32) + + +def extract_bootstrapped_values(vf_preds, episode_lengths, T): + """Returns a bootstrapped value batch given value predictions. + + Note that the incoming value predictions must have happened over (artificially) + elongated episodes (by 1 timestep at the end). This way, we can either extract the + `vf_preds` at these extra timesteps (as "bootstrap values") or skip over them + entirely if they lie in the middle of the T-slices. + + For example, given an episodes structure like this: + 01234a 0123456b 01c 012- 0123e 012- + where each episode is separated by a space and goes from 0 to n and ends in an + artificially elongated timestep (denoted by 'a', 'b', 'c', '-', or 'e'), where '-' + means that the episode was terminated and the bootstrap value at the end should be + zero and 'a', 'b', 'c', etc.. represent truncated episode ends with computed vf + estimates. + The output for the above sequence (and T=4) should then be: + 4 3 b 2 3 - + + Args: + vf_preds: The computed value function predictions over the artificially + elongated episodes (by one timestep at the end). + episode_lengths: The original (correct) episode lengths, NOT counting the + artificially added timestep at the end. + T: The size of the time dimension by which to slice the data. Note that the + sum of all episode lengths (`sum(episode_lengths)`) must be dividable by T. + + Returns: + The batch of bootstrapped values. + """ + bootstrapped_values = [] + if sum(episode_lengths) % T != 0: + raise ValueError( + "Can only extract bootstrapped values if the sum of episode lengths " + f"({sum(episode_lengths)}) is dividable by the given T ({T})!" + ) + + # Loop over all episode lengths and collect bootstrap values. + i = -1 + while i < len(episode_lengths) - 1: + i += 1 + eps_len = episode_lengths[i] + # We can make another T-stride inside this episode -> + # - Use a vf prediction within the episode as bootstrapped value. + # - "Fix" the episode_lengths array and continue within the same episode. + if T < eps_len: + bootstrapped_values.append(vf_preds[T]) + vf_preds = vf_preds[T:] + episode_lengths[i] -= T + i -= 1 + # We can make another T-stride inside this episode, but will then be at the end + # of it -> + # - Use the value function prediction at the artificially added timestep + # as bootstrapped value. + # - Skip the additional timestep at the end and ,ove on with next episode. + elif T == eps_len: + bootstrapped_values.append(vf_preds[T]) + vf_preds = vf_preds[T + 1 :] + # The episode fits entirely into the T-stride -> + # - Move on to next episode ("fix" its length by make it seemingly longer). + else: + # Skip bootstrap value of current episode (not needed). + vf_preds = vf_preds[1:] + # Make next episode seem longer. + episode_lengths[i + 1] += eps_len + + return np.array(bootstrapped_values) diff --git a/rllib/connectors/utils/zero_padding.py b/rllib/utils/postprocessing/zero_padding.py similarity index 73% rename from rllib/connectors/utils/zero_padding.py rename to rllib/utils/postprocessing/zero_padding.py index e34c0eab85cc5..41328f20488bb 100644 --- a/rllib/connectors/utils/zero_padding.py +++ b/rllib/utils/postprocessing/zero_padding.py @@ -133,3 +133,70 @@ def split_and_pad_single_record( # Send everything through `split_and_pad` to perform the actual splitting into # sub-chunks of max len=T and zero-padding. return split_and_pad(episodes_data, T) + + +def unpad_data_if_necessary(episode_lens, data): + """Removes right-side zero-padding from data based on `episode_lens`. + + ..testcode:: + + from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner + import numpy as np + + unpadded = PPOLearner._unpad_data_if_necessary( + episode_lens=[4, 2], + data=np.array([ + [2, 4, 5, 3, 0, 0, 0, 0], + [-1, 3, 0, 0, 0, 0, 0, 0], + ]), + ) + assert (unpadded == [2, 4, 5, 3, -1, 3]).all() + + unpadded = PPOLearner._unpad_data_if_necessary( + episode_lens=[1, 5], + data=np.array([ + [2, 0, 0, 0, 0], + [-1, -2, -3, -4, -5], + ]), + ) + assert (unpadded == [2, -1, -2, -3, -4, -5]).all() + + Args: + episode_lens: A list of actual episode lengths. + data: A 2D np.ndarray with right-side zero-padded rows. + + Returns: + A 1D np.ndarray resulting from concatenation of the un-padded + input data along the 0-axis. + """ + # If data des NOT have time dimension, return right away. + if len(data.shape) == 1: + return data + + # Assert we only have B and T dimensions (meaning this function only operates + # on single-float data, such as value function predictions). + assert len(data.shape) == 2 + + new_data = [] + row_idx = 0 + + T = data.shape[1] + for len_ in episode_lens: + # Calculate how many full rows this array occupies and how many elements are + # in the last, potentially partial row. + num_rows, col_idx = divmod(len_, T) + + # If the array spans multiple full rows, fully include these rows. + for i in range(num_rows): + new_data.append(data[row_idx]) + row_idx += 1 + + # If there are elements in the last, potentially partial row, add this + # partial row as well. + if col_idx > 0: + new_data.append(data[row_idx, :col_idx]) + + # Move to the next row for the next array (skip the zero-padding zone). + row_idx += 1 + + return np.concatenate(new_data)