Skip to content

Commit

Permalink
fix: correct finite horizon buffer calculation (#398)
Browse files Browse the repository at this point in the history
This commit addresses a bug in the `FiniteHorizonReplayBuffer` that
resulted in an erroneous inclusion of an extra horizon step in the
finite horizon reward computation.
  • Loading branch information
rickstaa authored Feb 6, 2024
1 parent 3aaf7ab commit 779201c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions stable_learning_control/algos/common/buffers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains several replay buffers that are used in multiple Pytorch and
TensorFlow algorithms.
"""

import numpy as np

from stable_learning_control.algos.common.helpers import discount_cumsum
Expand Down Expand Up @@ -113,15 +114,15 @@ def sample_batch(self, batch_size=32):


class FiniteHorizonReplayBuffer(ReplayBuffer):
"""A first-in-first-out (FIFO) experience replay buffer that also stores the
r"""A first-in-first-out (FIFO) experience replay buffer that also stores the
expected cumulative finite-horizon reward.
.. note::
The expected cumulative finite-horizon reward is calculated using the following
formula:
.. math::
L_{target}(s,a) = \\sum_{t}^{t+N} \\mathbb{E}_{c_{t}}
L_{target}(s,a) = \sum_{t}^{t+N} \mathbb{E}_{c_{t}}
Attributes:
horizon_length (int): The length of the finite-horizon.
Expand Down Expand Up @@ -193,8 +194,8 @@ def store(self, obs, act, rew, next_obs, done, truncated):
# Calculate the expected cumulative finite-horizon reward.
path_rew = np.pad(path_rew, (0, self.horizon_length), mode="edge")
horizon_rew = [
np.sum(path_rew[i : i + self.horizon_length + 1])
for i in range(len(path_rew) - self.horizon_length)
np.sum(path_rew[i : i + self.horizon_length])
for i in range(len(path_ptrs))
]

# Store the expected cumulative finite-horizon reward.
Expand Down Expand Up @@ -413,7 +414,6 @@ def finish_path(self, last_val=0):
the reward-to-go calculation to account for timesteps beyond the arbitrary
episode horizon (or epoch cutoff).
"""

# Calculate the advantage and rewards-to-go if buffer contains vals
if self._contains_vals:
# Get the current trajectory.
Expand Down

0 comments on commit 779201c

Please sign in to comment.