Skip to content

Commit

Permalink
Merge branch 'feature/buffer-np' of https://github.com/Eclectic-Sheep…
Browse files Browse the repository at this point in the history
…/sheeprl into feature/buffer-np
  • Loading branch information
belerico committed Dec 19, 2023
2 parents 63b1f7e + ba04da0 commit ab3ee64
Show file tree
Hide file tree
Showing 20 changed files with 49 additions and 54 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,15 @@ For each algorithm, losses are kept in a separate module, so that their implemen

## :card_index_dividers: Buffer

For the buffer implementation, we choose to use a wrapper around a [TensorDict](https://pytorch.org/rl/tensordict/reference/generated/tensordict.TensorDict.html).
For the buffer implementation, we choose to use a wrapper around a dictionary of Numpy arrays.

TensorDict comes in handy since we can easily add custom fields to the buffer as if we are working with dictionaries, but we can also easily perform operations on them as if we are working with tensors.
To enable a simple way to work with numpy memory-mapped arrays, we implemented the `sheeprl.utils.memmap.MemmapArray`, a container that handles the memory-mapped arrays.

This flexibility makes it very simple to implement, with the classes `ReplayBuffer`, `SequentialReplayBuffer`, `EpisodeBuffer`, and `AsyncReplayBuffer`, all the buffers needed for on-policy and off-policy algorithms.
This flexibility makes it very simple to implement, with the classes `ReplayBuffer`, `SequentialReplayBuffer`, `EpisodeBuffer`, and `EnvIndependentReplayBuffer`, all the buffers needed for on-policy and off-policy algorithms.

### :mag: Technical details

The tensor's shape in the TensorDict is `(T, B, *)`, where `T` is the number of timesteps, `B` is the number of parallel environments, and `*` is the shape of the data.
The shape of the Numpy arrays in the dictionary is `(T, B, *)`, where `T` is the number of timesteps, `B` is the number of parallel environments, and `*` is the shape of the data.

For the `ReplayBuffer` to be used as a RolloutBuffer, the proper `buffer_size` must be specified. For example, for PPO, the `buffer_size` must be `[T, B]`, where `T` is the number of timesteps and `B` is the number of parallel environments.

Expand Down
2 changes: 1 addition & 1 deletion howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
)

# Define the optimizer
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all")

# Load the state from the checkpoint
if cfg.checkpoint.resume_from:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ minerl = ["setuptools==66.0.0", "minerl==0.4.4"]
diambra = ["diambra==0.0.16", "diambra-arena==2.2.2"]
crafter = ["crafter==1.8.1"]
mlflow = ["mlflow==2.8.0"]
pokemon = ["hnswlib", "mediapy", "pyboy", "scikit-image"]

[tool.ruff]
line-length = 120
Expand Down
9 changes: 3 additions & 6 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,6 @@ def train(
module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients
)
critic_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/critic", critic_grads.mean().detach())
aggregator.update("Loss/value_loss", value_loss.detach())

# Log metrics
if aggregator and not aggregator.disabled:
Expand Down Expand Up @@ -584,7 +581,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

for update in range(start_step, num_updates + 1):
Expand Down Expand Up @@ -660,7 +657,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["dones"] = dones[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis]
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Reset and save the observation coming from the automatic reset
dones_idxes = dones.nonzero()[0].tolist()
Expand All @@ -672,7 +669,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data["dones"] = np.zeros((1, reset_envs, 1))
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
rb.add(reset_data, dones_idxes, validate_args=False)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
Expand Down
13 changes: 5 additions & 8 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def train(

# Given how the environment interaction works, we assume that the first element in a sequence
# is the first one, as if the environment has been reset
data["is_first"][0, :] = torch.full_like(data["is_first"][0, :], 1.0)
data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])

# Dynamic Learning
stoch_state_size = stochastic_size * discrete_size
Expand Down Expand Up @@ -280,7 +280,7 @@ def train(
predicted_target_values = target_critic(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)
if cfg.algo.world_model.use_continues and world_model.continue_model:
continues = logits_to_probs(world_model.continue_model(imagined_trajectories))
continues = logits_to_probs(world_model.continue_model(imagined_trajectories), is_binary=True)
true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma
continues = torch.cat((true_done, continues[1:]))
else:
Expand Down Expand Up @@ -348,9 +348,6 @@ def train(
module=actor, optimizer=actor_optimizer, max_norm=cfg.algo.actor.clip_gradients, error_if_nonfinite=False
)
actor_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/actor", actor_grads.mean().detach())
aggregator.update("Loss/policy_loss", policy_loss.detach())

# Predict the values distribution only for the first H (horizon)
# imagined states (to match the dimension with the lambda values),
Expand Down Expand Up @@ -625,7 +622,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["is_first"] = np.ones_like(step_data["dones"])
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

per_rank_gradient_steps = 0
Expand Down Expand Up @@ -703,7 +700,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1))
step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1))
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Reset and save the observation coming from the automatic reset
dones_idxes = dones.nonzero()[0].tolist()
Expand All @@ -716,7 +713,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
reset_data["is_first"] = np.ones_like(reset_data["dones"])
rb.add(reset_data, dones_idxes, validate_args=False)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def train(
device = fabric.device
batch_obs = {k: data[k] / 255.0 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})
data["is_first"][0, :] = torch.full_like(data["is_first"][0, :], 1.0)
data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])

# Given how the environment interaction works, we remove the last actions
# and add the first one as the zero action
Expand Down Expand Up @@ -656,7 +656,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = step_data["rewards"][:, dones_idxes]
reset_data["is_first"] = np.zeros_like(reset_data["dones"])
rb.add(reset_data, dones_idxes)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)

# Reset already inserted step data
step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"])
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1
Expand Down Expand Up @@ -318,7 +318,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1)
step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1).astype(np.float32)
step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32)
rb.add(step_data)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# next_obs becomes the new obs
obs = next_obs
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

for update in range(start_step, num_updates + 1):
Expand Down Expand Up @@ -697,7 +697,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["dones"] = dones[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis]
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Reset and save the observation coming from the automatic reset
dones_idxes = dones.nonzero()[0].tolist()
Expand All @@ -709,7 +709,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data["dones"] = np.zeros((1, reset_envs, 1))
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
rb.add(reset_data, dones_idxes, validate_args=False)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()
player.init_states()

Expand Down Expand Up @@ -323,7 +323,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
step_data["dones"] = dones[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis]
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Reset and save the observation coming from the automatic reset
dones_idxes = dones.nonzero()[0].tolist()
Expand All @@ -335,7 +335,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
reset_data["dones"] = np.zeros((1, reset_envs, 1))
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
rb.add(reset_data, dones_idxes, validate_args=False)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
Expand Down
8 changes: 4 additions & 4 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def train(
data = {k: data[k] for k in data.keys()}
batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})
data["is_first"][0, :] = torch.full_like(data["is_first"][0, :], 1.0)
data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])

# Dynamic Learning
recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device)
Expand Down Expand Up @@ -767,7 +767,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["is_first"] = np.ones_like(step_data["dones"])
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

per_rank_gradient_steps = 0
Expand Down Expand Up @@ -845,7 +845,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1))
step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1))
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Reset and save the observation coming from the automatic reset
dones_idxes = dones.nonzero()[0].tolist()
Expand All @@ -858,7 +858,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
reset_data["is_first"] = np.ones_like(reset_data["dones"])
rb.add(reset_data, dones_idxes, validate_args=False)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["is_first"] = np.ones_like(step_data["dones"])
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

per_rank_gradient_steps = 0
Expand Down Expand Up @@ -344,7 +344,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1))
step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1))
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Reset and save the observation coming from the automatic reset
dones_idxes = dones.nonzero()[0].tolist()
Expand All @@ -357,7 +357,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
reset_data["is_first"] = np.ones_like(reset_data["dones"])
rb.add(reset_data, dones_idxes, validate_args=False)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = step_data["rewards"][:, dones_idxes]
reset_data["is_first"] = np.zeros_like(reset_data["dones"])
rb.add(reset_data, dones_idxes)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)

# Reset already inserted step data
step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"])
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = step_data["rewards"][:, dones_idxes]
reset_data["is_first"] = np.zeros_like(reset_data["dones"])
rb.add(reset_data, dones_idxes)
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)

# Reset already inserted step data
step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"])
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def train(
cfg: Dict[str, Any],
):
"""Train the agent on the data collected from the environment."""
indexes = list(range(data[next(iter(data.keys()))].shape[0]))
indexes = list(range(next(iter(data.values())).shape[0]))
if cfg.buffer.share_data:
sampler = DistributedSampler(
indexes,
Expand Down Expand Up @@ -196,7 +196,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Create a metric aggregator to log the metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
if cfg.buffer.size < cfg.algo.rollout_steps:
Expand Down Expand Up @@ -320,7 +320,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

# Append data to buffer
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Update the observation and dones
next_obs = {}
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def player(
step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

# Append data to buffer
rb.add(step_data, validate_args=False)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Update the observation and dones
next_obs = {}
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Create a metric aggregator to log the metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
rb = ReplayBuffer(
Expand Down Expand Up @@ -341,7 +341,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
step_data["advantages"] = np.zeros_like(rewards)

# Append data to buffer
rb.add(step_data)
rb.add(step_data, validate_args=cfg.buffer.validate_args)

# Update actions
prev_actions = (1 - dones) * actions
Expand Down
Loading

0 comments on commit ab3ee64

Please sign in to comment.