Skip to content

Commit

Permalink
refactor(buffers): remove unused 'rew_dim' argument (#327)
Browse files Browse the repository at this point in the history
This commit removes the `rew_dim` argument in the `ReplayBuffer` and
`TrajectoryBuffer` since it is
uncommon for Gymansium environments to have multi-dimensional reward
spaces. This code was added
because of a prototype we did some time ago.

BREAKING CHANGE: the `ReplayBuffer` and `TrajectoryBuffer` classes don't
take a `rew_dim` argument anymore.
  • Loading branch information
rickstaa authored Aug 11, 2023
1 parent e0a0b9d commit a69a7f6
Showing 6 changed files with 9 additions and 26 deletions.
4 changes: 1 addition & 3 deletions sandbox/test_traj_buffer.py
Original file line number Diff line number Diff line change
@@ -7,12 +7,11 @@

if __name__ == "__main__":
# Create dummy environment.
env = gym.make("CartPole-v1")
env = gym.make("stable_gym:CartPoleCost-v1")

# Dummy algorithm settings.
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
rew_dim = env.cost_range.shape[0]
buffer_size = int(1e6)
epochs = 10
local_steps_per_epoch = 100
@@ -21,7 +20,6 @@
buffer = TrajectoryBuffer(
obs_dim=obs_dim,
act_dim=act_dim,
rew_dim=rew_dim,
size=buffer_size,
preempt=True,
incomplete=True,
23 changes: 8 additions & 15 deletions stable_learning_control/algos/common/buffers.py
Original file line number Diff line number Diff line change
@@ -21,13 +21,12 @@ class ReplayBuffer:
ptr (int): The current buffer index.
"""

def __init__(self, obs_dim, act_dim, rew_dim, size):
def __init__(self, obs_dim, act_dim, size):
"""Initialise the ReplayBuffer object.
Args:
obs_dim (tuple): The size of the observation space.
act_dim (tuple): The size of the action space.
rew_dim (tuple): The size of the reward space.
size (int): The replay buffer size.
"""
# Preallocate memory for experience buffer (s, s', a, r, d)
@@ -41,9 +40,7 @@ def __init__(self, obs_dim, act_dim, rew_dim, size):
self.act_buf = atleast_2d(
np.zeros(combine_shapes(int(size), act_dim), dtype=np.float32).squeeze()
)
self.rew_buf = np.zeros(
combine_shapes(int(size), rew_dim), dtype=np.float32
).squeeze()
self.rew_buf = np.zeros(int(size), dtype=np.float32)
self.done_buf = np.zeros(int(size), dtype=np.float32)
self.ptr, self.size, self._max_size = 0, 0, int(size)

@@ -78,15 +75,15 @@ def store(self, obs, act, rew, next_obs, done):
self.rew_buf[self.ptr] = rew
except ValueError as e:
error_msg = (
f"{e.args[0].capitalize()} please make sure you set the "
"ReplayBuffer 'rew_dim' equal to your environment 'reward_space'."
f"{e.args[0].capitalize()} please make sure your 'rew' ReplayBuffer "
"element is of dimension 1."
)
raise ValueError(error_msg)
try:
self.done_buf[self.ptr] = done
except ValueError as e:
error_msg = (
f"{e.args[0].capitalize()} please make sure your 'done' ReplayBuffer"
f"{e.args[0].capitalize()} please make sure your 'done' ReplayBuffer "
"element is of dimension 1."
)
raise ValueError(error_msg)
@@ -141,7 +138,6 @@ def __init__(
self,
obs_dim,
act_dim,
rew_dim,
size,
preempt=False,
min_trajectory_size=3,
@@ -154,7 +150,6 @@ def __init__(
Args:
obs_dim (tuple): The size of the observation space.
act_dim (tuple): The size of the action space.
rew_dim (tuple): The size of the reward space.
size (int): The replay buffer size.
preempt (bool, optional): Whether the buffer can be retrieved before it is
full. Defaults to ``False``.
@@ -186,9 +181,7 @@ def __init__(
self.act_buf = atleast_2d(
np.zeros(combine_shapes(size, act_dim), dtype=np.float32).squeeze()
)
self.rew_buf = np.zeros(
combine_shapes(int(size), rew_dim), dtype=np.float32
).squeeze()
self.rew_buf = np.zeros(int(size), dtype=np.float32)
self.done_buf = np.zeros(int(size), dtype=np.float32)

# Optional buffers.
@@ -245,8 +238,8 @@ def store(self, obs, act, rew, next_obs, done, val=None, logp=None):
self.rew_buf[self.ptr] = rew
except ValueError as e:
error_msg = (
f"{e.args[0].capitalize()} please make sure you set the "
"TrajectoryBuffer 'rew_dim' equal to your environment 'reward_space'."
f"{e.args[0].capitalize()} please make sure your 'rew' "
"TrajectoryBuffer element is of dimension 1."
)
raise ValueError(error_msg)
try:
2 changes: 0 additions & 2 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
@@ -1037,7 +1037,6 @@ def lac(
test_env = gym.wrappers.FlattenObservation(test_env)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape
rew_dim = 1

# Setup logger.
logger_kwargs["quiet"] = (
@@ -1151,7 +1150,6 @@ def lac(
replay_buffer = ReplayBuffer(
obs_dim=obs_dim,
act_dim=act_dim,
rew_dim=rew_dim,
size=replay_size,
device=policy.device,
)
2 changes: 0 additions & 2 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
@@ -922,7 +922,6 @@ def sac(
test_env = gym.wrappers.FlattenObservation(test_env)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape
rew_dim = 1

# Setup logger.
logger_kwargs["quiet"] = (
@@ -1034,7 +1033,6 @@ def sac(
replay_buffer = ReplayBuffer(
obs_dim=obs_dim,
act_dim=act_dim,
rew_dim=rew_dim,
size=replay_size,
device=policy.device,
)
2 changes: 0 additions & 2 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
@@ -970,7 +970,6 @@ def lac(
test_env = gym.wrappers.FlattenObservation(test_env)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape
rew_dim = 1

# Setup logger.
logger_kwargs["quiet"] = (
@@ -1093,7 +1092,6 @@ def lac(
replay_buffer = ReplayBuffer(
obs_dim=obs_dim,
act_dim=act_dim,
rew_dim=rew_dim,
size=replay_size,
)

2 changes: 0 additions & 2 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
@@ -855,7 +855,6 @@ def sac(
test_env = gym.wrappers.FlattenObservation(test_env)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape
rew_dim = 1

# Setup logger.
logger_kwargs["quiet"] = (
@@ -976,7 +975,6 @@ def sac(
replay_buffer = ReplayBuffer(
obs_dim=obs_dim,
act_dim=act_dim,
rew_dim=rew_dim,
size=replay_size,
)

0 comments on commit a69a7f6

Please sign in to comment.