From 9e80963158b84fbed260a65d68ed9cc3f0025411 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Tue, 6 Feb 2024 22:29:41 +0100 Subject: [PATCH] fix: correct finite horizon buffer calculation This commit addresses a bug in the `FiniteHorizonReplayBuffer` that resulted in an erroneous inclusion of an extra horizon step in the finite horizon reward computation. --- stable_learning_control/algos/common/buffers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stable_learning_control/algos/common/buffers.py b/stable_learning_control/algos/common/buffers.py index f69992bbb..59e163e47 100644 --- a/stable_learning_control/algos/common/buffers.py +++ b/stable_learning_control/algos/common/buffers.py @@ -1,6 +1,7 @@ """This module contains several replay buffers that are used in multiple Pytorch and TensorFlow algorithms. """ + import numpy as np from stable_learning_control.algos.common.helpers import discount_cumsum @@ -113,7 +114,7 @@ def sample_batch(self, batch_size=32): class FiniteHorizonReplayBuffer(ReplayBuffer): - """A first-in-first-out (FIFO) experience replay buffer that also stores the + r"""A first-in-first-out (FIFO) experience replay buffer that also stores the expected cumulative finite-horizon reward. .. note:: @@ -121,7 +122,7 @@ class FiniteHorizonReplayBuffer(ReplayBuffer): formula: .. math:: - L_{target}(s,a) = \\sum_{t}^{t+N} \\mathbb{E}_{c_{t}} + L_{target}(s,a) = \sum_{t}^{t+N} \mathbb{E}_{c_{t}} Attributes: horizon_length (int): The length of the finite-horizon. @@ -193,8 +194,8 @@ def store(self, obs, act, rew, next_obs, done, truncated): # Calculate the expected cumulative finite-horizon reward. path_rew = np.pad(path_rew, (0, self.horizon_length), mode="edge") horizon_rew = [ - np.sum(path_rew[i : i + self.horizon_length + 1]) - for i in range(len(path_rew) - self.horizon_length) + np.sum(path_rew[i : i + self.horizon_length]) + for i in range(len(path_ptrs)) ] # Store the expected cumulative finite-horizon reward. @@ -413,7 +414,6 @@ def finish_path(self, last_val=0): the reward-to-go calculation to account for timesteps beyond the arbitrary episode horizon (or epoch cutoff). """ - # Calculate the advantage and rewards-to-go if buffer contains vals if self._contains_vals: # Get the current trajectory.