Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct finite horizon buffer calculation #398

Merged
merged 1 commit into from
Feb 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions stable_learning_control/algos/common/buffers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains several replay buffers that are used in multiple Pytorch and
TensorFlow algorithms.
"""

import numpy as np

from stable_learning_control.algos.common.helpers import discount_cumsum
Expand Down Expand Up @@ -113,15 +114,15 @@ def sample_batch(self, batch_size=32):


class FiniteHorizonReplayBuffer(ReplayBuffer):
"""A first-in-first-out (FIFO) experience replay buffer that also stores the
r"""A first-in-first-out (FIFO) experience replay buffer that also stores the
expected cumulative finite-horizon reward.

.. note::
The expected cumulative finite-horizon reward is calculated using the following
formula:

.. math::
L_{target}(s,a) = \\sum_{t}^{t+N} \\mathbb{E}_{c_{t}}
L_{target}(s,a) = \sum_{t}^{t+N} \mathbb{E}_{c_{t}}

Attributes:
horizon_length (int): The length of the finite-horizon.
Expand Down Expand Up @@ -193,8 +194,8 @@ def store(self, obs, act, rew, next_obs, done, truncated):
# Calculate the expected cumulative finite-horizon reward.
path_rew = np.pad(path_rew, (0, self.horizon_length), mode="edge")
horizon_rew = [
np.sum(path_rew[i : i + self.horizon_length + 1])
for i in range(len(path_rew) - self.horizon_length)
np.sum(path_rew[i : i + self.horizon_length])
for i in range(len(path_ptrs))
]

# Store the expected cumulative finite-horizon reward.
Expand Down Expand Up @@ -413,7 +414,6 @@ def finish_path(self, last_val=0):
the reward-to-go calculation to account for timesteps beyond the arbitrary
episode horizon (or epoch cutoff).
"""

# Calculate the advantage and rewards-to-go if buffer contains vals
if self._contains_vals:
# Get the current trajectory.
Expand Down