From feeb440f6f92d77324a8ffa3124ab29c9be24195 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 20 Jan 2025 11:31:53 +0100 Subject: [PATCH] [RLlib] Add metrics to buffers. (#49822) --- rllib/algorithms/algorithm.py | 10 + rllib/algorithms/dqn/dqn.py | 7 + rllib/utils/metrics/__init__.py | 34 ++ .../replay_buffers/episode_replay_buffer.py | 388 +++++++++++++++++- .../prioritized_episode_buffer.py | 112 ++++- 5 files changed, 542 insertions(+), 9 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index beba22d4f719..ad9797e3812c 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -3962,6 +3962,16 @@ def _create_local_replay_buffer_if_necessary( ): return + # Add parameters, if necessary. + if config["replay_buffer_config"]["type"] in [ + "EpisodeReplayBuffer", + "PrioritizedEpisodeReplayBuffer", + ]: + # TODO (simon): If all episode buffers have metrics, check for sublassing. + config["replay_buffer_config"][ + "metrics_num_episodes_for_smoothing" + ] = self.config.metrics_num_episodes_for_smoothing + return from_config(ReplayBuffer, config["replay_buffer_config"]) @OldAPIStack diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 24e9afe26dde..b328b664e0d3 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -51,6 +51,7 @@ NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_TARGET_UPDATES, REPLAY_BUFFER_ADD_DATA_TIMER, + REPLAY_BUFFER_RESULTS, REPLAY_BUFFER_SAMPLE_TIMER, REPLAY_BUFFER_UPDATE_PRIOS_TIMER, SAMPLE_TIMER, @@ -689,6 +690,12 @@ def _training_step_new_api_stack(self): sample_episodes=True, ) + # Get the replay buffer metrics. + replay_buffer_results = self.local_replay_buffer.get_metrics() + self.metrics.merge_and_log_n_dicts( + [replay_buffer_results], key=REPLAY_BUFFER_RESULTS + ) + # Perform an update on the buffer-sampled train batch. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)): learner_results = self.learner_group.update_from_episodes( diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 82d2ad63862b..50dfe780b250 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -36,6 +36,40 @@ ENV_TO_MODULE_SUM_EPISODES_LENGTH_IN = "env_to_module_sum_episodes_length_in" ENV_TO_MODULE_SUM_EPISODES_LENGTH_OUT = "env_to_module_sum_episodes_length_out" +# Counters for adding and evicting in replay buffers. +ACTUAL_N_STEP = "actual_n_step" +AGENT_ACTUAL_N_STEP = "agent_actual_n_step" +AGENT_STEP_UTILIZATION = "agent_step_utilization" +ENV_STEP_UTILIZATION = "env_step_utilization" +NUM_AGENT_EPISODES_STORED = "num_agent_episodes" +NUM_AGENT_EPISODES_ADDED = "num_agent_episodes_added" +NUM_AGENT_EPISODES_ADDED_LIFETIME = "num_agent_episodes_added_lifetime" +NUM_AGENT_EPISODES_EVICTED = "num_agent_episodes_evicted" +NUM_AGENT_EPISODES_EVICTED_LIFETIME = "num_agent_episodes_evicted_lifetime" +NUM_AGENT_EPISODES_PER_SAMPLE = "num_agent_episodes_per_sample" +NUM_AGENT_RESAMPLES = "num_agent_resamples" +NUM_AGENT_STEPS_ADDED = "num_agent_steps_added" +NUM_AGENT_STEPS_ADDED_LIFETIME = "num_agent_steps_added_lifetime" +NUM_AGENT_STEPS_EVICTED = "num_agent_steps_evicted" +NUM_AGENT_STEPS_EVICTED_LIFETIME = "num_agent_steps_evicted_lifetime" +NUM_AGENT_STEPS_PER_SAMPLE = "num_agent_steps_per_sample" +NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME = "num_agent_steps_per_sample_lifetime" +NUM_AGENT_STEPS_STORED = "num_agent_steps" +NUM_ENV_STEPS_STORED = "num_env_steps" +NUM_ENV_STEPS_ADDED = "num_env_steps_added" +NUM_ENV_STEPS_ADDED_LIFETIME = "num_env_steps_added_lifetime" +NUM_ENV_STEPS_EVICTED = "num_env_steps_evicted" +NUM_ENV_STEPS_EVICTED_LIFETIME = "num_env_steps_evicted_lifetime" +NUM_ENV_STEPS_PER_SAMPLE = "num_env_steps_per_sample" +NUM_ENV_STEPS_PER_SAMPLE_LIFETIME = "num_env_steps_per_sample_lifetime" +NUM_EPISODES_STORED = "num_episodes" +NUM_EPISODES_ADDED = "num_episodes_added" +NUM_EPISODES_ADDED_LIFETIME = "num_episodes_added_lifetime" +NUM_EPISODES_EVICTED = "num_episodes_evicted" +NUM_EPISODES_EVICTED_LIFETIME = "num_episodes_evicted_lifetime" +NUM_EPISODES_PER_SAMPLE = "num_episodes_per_sample" +NUM_RESAMPLES = "num_resamples" + EPISODE_DURATION_SEC_MEAN = "episode_duration_sec_mean" EPISODE_LEN_MEAN = "episode_len_mean" EPISODE_LEN_MAX = "episode_len_max" diff --git a/rllib/utils/replay_buffers/episode_replay_buffer.py b/rllib/utils/replay_buffers/episode_replay_buffer.py index d524cd013aa5..52197e5de0e0 100644 --- a/rllib/utils/replay_buffers/episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/episode_replay_buffer.py @@ -1,16 +1,58 @@ from collections import deque import copy +import hashlib from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import scipy +from ray.rllib.core import DEFAULT_AGENT_ID from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer -from ray.rllib.utils.annotations import override -from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface -from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.metrics import ( + ACTUAL_N_STEP, + AGENT_ACTUAL_N_STEP, + AGENT_STEP_UTILIZATION, + ENV_STEP_UTILIZATION, + NUM_AGENT_EPISODES_STORED, + NUM_AGENT_EPISODES_ADDED, + NUM_AGENT_EPISODES_ADDED_LIFETIME, + NUM_AGENT_EPISODES_EVICTED, + NUM_AGENT_EPISODES_EVICTED_LIFETIME, + NUM_AGENT_EPISODES_PER_SAMPLE, + NUM_AGENT_STEPS_STORED, + NUM_AGENT_STEPS_ADDED, + NUM_AGENT_STEPS_ADDED_LIFETIME, + NUM_AGENT_STEPS_EVICTED, + NUM_AGENT_STEPS_EVICTED_LIFETIME, + NUM_AGENT_STEPS_PER_SAMPLE, + NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_SAMPLED_LIFETIME, + NUM_ENV_STEPS_STORED, + NUM_ENV_STEPS_ADDED, + NUM_ENV_STEPS_ADDED_LIFETIME, + NUM_ENV_STEPS_EVICTED, + NUM_ENV_STEPS_EVICTED_LIFETIME, + NUM_ENV_STEPS_PER_SAMPLE, + NUM_ENV_STEPS_PER_SAMPLE_LIFETIME, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_EPISODES_STORED, + NUM_EPISODES_ADDED, + NUM_EPISODES_ADDED_LIFETIME, + NUM_EPISODES_EVICTED, + NUM_EPISODES_EVICTED_LIFETIME, + NUM_EPISODES_PER_SAMPLE, +) +from ray.rllib.utils.metrics.metrics_logger import MetricsLogger +from ray.rllib.utils.replay_buffers.base import ReplayBufferInterface +from ray.rllib.utils.typing import SampleBatchType, ResultDict class EpisodeReplayBuffer(ReplayBufferInterface): @@ -65,6 +107,7 @@ def __init__( *, batch_size_B: int = 16, batch_length_T: int = 64, + metrics_num_episodes_for_smoothing: int = 100, ): """Initializes an EpisodeReplayBuffer instance. @@ -112,6 +155,10 @@ def __init__( self.rng = np.random.default_rng(seed=None) + # Initialize the metrics. + self.metrics = MetricsLogger() + self._metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing + @override(ReplayBufferInterface) def __len__(self) -> int: return self.get_num_timesteps() @@ -124,6 +171,12 @@ def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]) """ episodes = force_list(episodes) + # Set up some counters for metrics. + num_env_steps_added = 0 + num_episodes_added = 0 + num_episodes_evicted = 0 + num_env_steps_evicted = 0 + for eps in episodes: # Make sure we don't change what's coming in from the user. # TODO (sven): It'd probably be better to make sure in the EnvRunner to not @@ -134,8 +187,12 @@ def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]) # actually preferred). eps = copy.deepcopy(eps) - self._num_timesteps += len(eps) - self._num_timesteps_added += len(eps) + eps_len = len(eps) + # TODO (simon): Check, if we can deprecate these two + # variables and instead peek into the metrics. + self._num_timesteps += eps_len + self._num_timesteps_added += eps_len + num_env_steps_added += eps_len # Ongoing episode, concat to existing record. if eps.id_ in self.episode_id_to_index: @@ -146,6 +203,7 @@ def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]) existing_eps.concat_episode(eps) # New episode. Add to end of our episodes deque. else: + num_episodes_added += 1 self.episodes.append(eps) eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted self.episode_id_to_index[eps.id_] = eps_idx @@ -157,6 +215,8 @@ def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]) # Eject oldest episode. evicted_eps = self.episodes.popleft() evicted_eps_len = len(evicted_eps) + num_episodes_evicted += 1 + num_env_steps_evicted += evicted_eps_len # Correct our size. self._num_timesteps -= evicted_eps_len @@ -206,6 +266,168 @@ def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]) # Increase episode evicted counter. self._num_episodes_evicted += 1 + self._update_add_metrics( + num_env_steps_added, + num_episodes_added, + num_episodes_evicted, + num_env_steps_evicted, + ) + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _update_add_metrics( + self, + num_timesteps_added: int, + num_episodes_added: int, + num_episodes_evicted: int, + num_env_steps_evicted: int, + **kwargs, + ) -> None: + """Updates the replay buffer's adding metrics. + + Args: + num_timesteps_added: The total number of environment steps added to the + buffer in the `EpisodeReplayBuffer.add` call. + num_episodes_added: The total number of episodes added to the + buffer in the `EpisodeReplayBuffer.add` call. + num_episodes_evicted: The total number of environment steps evicted from + the buffer in the `EpisodeReplayBuffer.add` call. Note, this + does not include the number of episodes evicted before ever + added to the buffer (i.e. can happen in case a lot of episodes + were added and the buffer's capacity is not large enough). + num_env_steps_evicted: he total number of environment steps evicted from + the buffer in the `EpisodeReplayBuffer.add` call. Note, this + does not include the number of steps evicted before ever + added to the buffer (i.e. can happen in case a lot of episodes + were added and the buffer's capacity is not large enough). + """ + # Get the actual number of agent steps residing in the buffer. + # TODO (simon): Write the same counters and getters as for the + # multi-agent buffers. + self.metrics.log_value( + (NUM_AGENT_STEPS_STORED, DEFAULT_AGENT_ID), + self.get_num_timesteps(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Number of timesteps added. + self.metrics.log_value( + (NUM_AGENT_STEPS_ADDED, DEFAULT_AGENT_ID), + num_timesteps_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_ADDED_LIFETIME, DEFAULT_AGENT_ID), + num_timesteps_added, + reduce="sum", + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_EVICTED, DEFAULT_AGENT_ID), + num_env_steps_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_EVICTED_LIFETIME, DEFAULT_AGENT_ID), + num_env_steps_evicted, + reduce="sum", + ) + # Whole buffer step metrics. + self.metrics.log_value( + NUM_ENV_STEPS_STORED, + self.get_num_timesteps(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + self.metrics.log_value( + NUM_ENV_STEPS_ADDED, + num_timesteps_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_ADDED_LIFETIME, + num_timesteps_added, + reduce="sum", + ) + self.metrics.log_value( + NUM_ENV_STEPS_EVICTED, + num_env_steps_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_EVICTED_LIFETIME, + num_env_steps_evicted, + reduce="sum", + ) + + # Episode metrics. + + # Number of episodes in the buffer. + self.metrics.log_value( + (NUM_AGENT_EPISODES_STORED, DEFAULT_AGENT_ID), + self.get_num_episodes(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Number of new episodes added. Note, this metric could + # be zero. + self.metrics.log_value( + (NUM_AGENT_EPISODES_ADDED, DEFAULT_AGENT_ID), + num_episodes_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_ADDED_LIFETIME, DEFAULT_AGENT_ID), + num_episodes_added, + reduce="sum", + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_EVICTED, DEFAULT_AGENT_ID), + num_episodes_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_EVICTED_LIFETIME, DEFAULT_AGENT_ID), + num_episodes_evicted, + reduce="sum", + ) + + # Whole buffer episode metrics. + self.metrics.log_value( + NUM_EPISODES_STORED, + self.get_num_episodes(), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Number of new episodes added. Note, this metric could + # be zero. + self.metrics.log_value( + NUM_EPISODES_ADDED, + num_episodes_added, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_EPISODES_ADDED_LIFETIME, + num_episodes_added, + reduce="sum", + ) + self.metrics.log_value( + NUM_EPISODES_EVICTED, + num_episodes_evicted, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_EPISODES_EVICTED_LIFETIME, + num_episodes_evicted, + reduce="sum", + ) + @override(ReplayBufferInterface) def sample( self, @@ -349,6 +571,11 @@ def _sample_batch( is_terminated = [[False] * batch_length_T for _ in range(batch_size_B)] is_truncated = [[False] * batch_length_T for _ in range(batch_size_B)] + # Record all the env step buffer indices that are contained in the sample. + sampled_env_step_idxs = set() + # Record all the episode buffer indices that are contained in the sample. + sampled_episode_idxs = set() + B = 0 T = 0 while B < batch_size_B: @@ -413,10 +640,24 @@ def _sample_batch( # Start filling the next row. B += 1 T = 0 + # Add the episode buffer index to the set of episode indexes. + sampled_episode_idxs.add(episode_idx) + # Record a has for the episode ID and timestep inside of the episode. + sampled_env_step_idxs.add( + hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest() + ) # Update our sampled counter. self.sampled_timesteps += batch_size_B * batch_length_T + # Update the sample metrics. + self._update_sample_metrics( + num_env_steps_sampled=batch_size_B * batch_length_T, + num_episodes_per_sample=len(sampled_episode_idxs), + num_env_steps_per_sample=len(sampled_env_step_idxs), + sampled_n_step=None, + ) + # TODO: Return SampleBatch instead of this simpler dict. ret = { "obs": np.array(observations), @@ -532,6 +773,12 @@ def _sample_episodes( self._last_sampled_indices = [] sampled_episodes = [] + # Record all the env step buffer indices that are contained in the sample. + sampled_env_step_idxs = set() + # Record all the episode buffer indices that are contained in the sample. + sampled_episode_idxs = set() + # Record all n-steps that have been used. + sampled_n_steps = [] B = 0 while B < batch_size_B: @@ -619,7 +866,10 @@ def _sample_episodes( ), len_lookback_buffer=lookback, ) - + # Record a has for the episode ID and timestep inside of the episode. + sampled_env_step_idxs.add( + hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest() + ) # Remove reference to sampled episode. del episode @@ -636,15 +886,135 @@ def _sample_episodes( # Append the sampled episode. sampled_episodes.append(sampled_episode) + sampled_episode_idxs.add(episode_idx) + sampled_n_steps.append(actual_n_step) # Increment counter. - B += (actual_length - episode_ts + 1) or 1 + B += (actual_length - episode_ts - (actual_n_step - 1) + 1) or 1 # Update the metric. self.sampled_timesteps += batch_size_B + # Update the sample metrics. + self._update_sample_metrics( + batch_size_B, + len(sampled_episode_idxs), + len(sampled_env_step_idxs), + sum(sampled_n_steps) / batch_size_B, + ) + return sampled_episodes + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _update_sample_metrics( + self, + num_env_steps_sampled: int, + num_episodes_per_sample: int, + num_env_steps_per_sample: int, + sampled_n_step: Optional[float], + **kwargs: Dict[str, Any], + ) -> None: + """Updates the replay buffer's sample metrics. + + Args: + num_env_steps_sampled: The number of environment steps sampled + this iteration in the `sample` method. + num_episodes_per_sample: The number of unique episodes in the + sample. + num_env_steps_per_sample: The number of unique environment steps + in the sample. + sampled_n_step: The mean n-step used in the sample. Note, this + is constant, if the n-step is not sampled. + """ + if sampled_n_step: + self.metrics.log_value( + ACTUAL_N_STEP, + sampled_n_step, + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + self.metrics.log_value( + (AGENT_ACTUAL_N_STEP, DEFAULT_AGENT_ID), + sampled_n_step, + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + self.metrics.log_value( + (NUM_AGENT_EPISODES_PER_SAMPLE, DEFAULT_AGENT_ID), + num_episodes_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_PER_SAMPLE, DEFAULT_AGENT_ID), + num_env_steps_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, DEFAULT_AGENT_ID), + num_env_steps_per_sample, + reduce="sum", + ) + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED, DEFAULT_AGENT_ID), + num_env_steps_sampled, + reduce="sum", + clear_on_reduce=True, + ) + # TODO (simon): Check, if we can then deprecate + # self.sampled_timesteps. + self.metrics.log_value( + (NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID), + num_env_steps_sampled, + reduce="sum", + ) + self.metrics.log_value( + (AGENT_STEP_UTILIZATION, DEFAULT_AGENT_ID), + self.metrics.peek((NUM_AGENT_STEPS_PER_SAMPLE_LIFETIME, DEFAULT_AGENT_ID)) + / self.metrics.peek((NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID)), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + # Whole buffer sampled env steps metrics. + self.metrics.log_value( + NUM_EPISODES_PER_SAMPLE, + num_episodes_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_PER_SAMPLE, + num_env_steps_per_sample, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_PER_SAMPLE_LIFETIME, + num_env_steps_per_sample, + reduce="sum", + ) + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED, + num_env_steps_sampled, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_ENV_STEPS_SAMPLED_LIFETIME, + num_env_steps_sampled, + reduce="sum", + ) + self.metrics.log_value( + ENV_STEP_UTILIZATION, + self.metrics.peek(NUM_ENV_STEPS_PER_SAMPLE_LIFETIME) + / self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME), + reduce="mean", + window=self._metrics_num_episodes_for_smoothing, + ) + + # TODO (simon): Check, if we can instead peek into the metrics + # and deprecate all variables. def get_num_episodes(self) -> int: """Returns number of episodes (completed or truncated) stored in the buffer.""" return len(self.episodes) @@ -665,6 +1035,10 @@ def get_added_timesteps(self) -> int: """Returns number of timesteps that have been added in buffer's lifetime.""" return self._num_timesteps_added + def get_metrics(self) -> ResultDict: + """Returns the metrics of the buffer and reduces them.""" + return self.metrics.reduce() + @override(ReplayBufferInterface) def get_state(self) -> Dict[str, Any]: """Gets a pickable state of the buffer. diff --git a/rllib/utils/replay_buffers/prioritized_episode_buffer.py b/rllib/utils/replay_buffers/prioritized_episode_buffer.py index 02982d51ef6a..f6ca7e548c48 100644 --- a/rllib/utils/replay_buffers/prioritized_episode_buffer.py +++ b/rllib/utils/replay_buffers/prioritized_episode_buffer.py @@ -1,4 +1,5 @@ import copy +import hashlib import numpy as np import scipy @@ -6,11 +7,19 @@ from numpy.typing import NDArray from typing import Any, Dict, List, Optional, Tuple, Union +from ray.rllib.core import DEFAULT_AGENT_ID from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.metrics import ( + NUM_AGENT_RESAMPLES, + NUM_RESAMPLES, +) from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer -from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import ModuleID, SampleBatchType @@ -118,6 +127,7 @@ def __init__( batch_size_B: int = 16, batch_length_T: int = 1, alpha: float = 1.0, + metrics_num_episodes_for_smoothing: int = 100, **kwargs, ): """Initializes a `PrioritizedEpisodeReplayBuffer` object @@ -132,7 +142,10 @@ def __init__( prioritization, `alpha=0.0` means no prioritization. """ super().__init__( - capacity=capacity, batch_size_B=batch_size_B, batch_length_T=batch_length_T + capacity=capacity, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + metrics_num_episodes_for_smoothing=metrics_num_episodes_for_smoothing, ) # `alpha` should be non-negative. @@ -196,6 +209,12 @@ def add( episodes = force_list(episodes) + # Set up some counters for metrics. + num_env_steps_added = 0 + num_episodes_added = 0 + num_episodes_evicted = 0 + num_env_steps_evicted = 0 + # Add first the timesteps of new episodes to have info about how many # episodes should be evicted to stay below capacity. new_episode_ids = [] @@ -215,6 +234,8 @@ def add( eps_evicted.append(self.episodes.popleft()) eps_evicted_ids.append(eps_evicted[-1].id_) eps_evicted_idxs.append(self.episode_id_to_index.pop(eps_evicted_ids[-1])) + num_episodes_evicted += 1 + num_env_steps_evicted += len(eps_evicted[-1]) # If this episode has a new chunk in the new episodes added, # we subtract it again. # TODO (sven, simon): Should we just treat such an episode chunk @@ -282,6 +303,7 @@ def add( existing_eps.concat_episode(eps) # Otherwise, create a new entry. else: + num_episodes_added += 1 self.episodes.append(eps) eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted self.episode_id_to_index[eps.id_] = eps_idx @@ -295,9 +317,18 @@ def add( for i in range(len(eps)) ] ) + num_env_steps_added += len(eps) # Increase index to the new length of `self._indices`. j = len(self._indices) + # Increase metrics. + self._update_add_metrics( + num_env_steps_added, + num_episodes_added, + num_episodes_evicted, + num_env_steps_evicted, + ) + @override(EpisodeReplayBuffer) def sample( self, @@ -391,6 +422,14 @@ def sample( self._last_sampled_indices = [] sampled_episodes = [] + # Record the sampled episode buffer indices to check the number of + # episodes per sample. + sampled_episode_idxs = set() + # Record sampled env step hashes to check the number of different + # env steps per sample. + sampled_env_steps_idxs = set() + num_resamples = 0 + sampled_n_steps = [] # Sample proportionally from replay buffer's segments using the weights. total_segment_sum = self._sum_segment.sum() @@ -429,6 +468,7 @@ def sample( # Skip, if we are too far to the end and `episode_ts` + n_step would go # beyond the episode's end. if episode_ts + actual_n_step > len(episode): + num_resamples += 1 continue # Note, this will be the reward after executing action @@ -492,20 +532,88 @@ def sample( len_lookback_buffer=0, t_started=episode_ts, ) + # Record here the episode time step via a hash code. + sampled_env_steps_idxs.add( + hashlib.sha256(f"{episode.id_}-{episode_ts}".encode()).hexdigest() + ) + # Convert to numpy arrays, if required. if to_numpy: sampled_episode.to_numpy() sampled_episodes.append(sampled_episode) + # Add the episode buffer index to the sampled indices. + sampled_episode_idxs.add(episode_idx) + # Record the actual n-step for this sample. + sampled_n_steps.append(actual_n_step) + # Increment counter. B += 1 # Keep track of sampled indices for updating priorities later. self._last_sampled_indices.append(idx) + # Add to the sampled timesteps counter of the buffer. self.sampled_timesteps += batch_size_B + # Update the sample metrics. + self._update_sample_metrics( + batch_size_B, + len(sampled_episode_idxs), + len(sampled_env_steps_idxs), + sum(sampled_n_steps) / batch_size_B, + num_resamples, + ) + return sampled_episodes + @override(EpisodeReplayBuffer) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _update_sample_metrics( + self, + num_env_steps_sampled: int, + num_episodes_per_sample: int, + num_env_steps_per_sample: int, + sampled_n_step: Optional[float], + num_resamples: int, + **kwargs: Dict[str, Any], + ) -> None: + """Updates the replay buffer's sample metrics. + + Args: + num_env_steps_sampled: The number of environment steps sampled + this iteration in the `sample` method. + num_episodes_per_sample: The number of unique episodes in the + sample. + num_env_steps_per_sample: The number of unique environment steps + in the sample. + sampled_n_step: The mean n-step used in the sample. Note, this + is constant, if the n-step is not sampled. + num_resamples: The total number of times environment steps needed to + be resampled. Resampling happens, if the sampled time step is + to near to the episode's end to cover the complete n-step. + """ + # Call the super's method to increase all regular sample metrics. + super()._update_sample_metrics( + num_env_steps_sampled, + num_episodes_per_sample, + num_env_steps_per_sample, + sampled_n_step, + ) + + # Add the metrics for resamples. + self.metrics.log_value( + (NUM_AGENT_RESAMPLES, DEFAULT_AGENT_ID), + num_resamples, + reduce="sum", + clear_on_reduce=True, + ) + self.metrics.log_value( + NUM_RESAMPLES, + num_resamples, + reduce="sum", + clear_on_reduce=True, + ) + @override(EpisodeReplayBuffer) def get_state(self) -> Dict[str, Any]: """Gets the state of a `PrioritizedEpisodeReplayBuffer`.