diff --git a/.gitignore b/.gitignore index 3ef1d5a0..de3b08df 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,5 @@ pytest_* .pypirc mlruns mlartifacts -examples/models \ No newline at end of file +examples/models +session_* \ No newline at end of file diff --git a/README.md b/README.md index ca24de82..27bae407 100644 --- a/README.md +++ b/README.md @@ -358,15 +358,15 @@ For each algorithm, losses are kept in a separate module, so that their implemen ## :card_index_dividers: Buffer -For the buffer implementation, we choose to use a wrapper around a [TensorDict](https://pytorch.org/rl/tensordict/reference/generated/tensordict.TensorDict.html). +For the buffer implementation, we choose to use a wrapper around a dictionary of Numpy arrays. -TensorDict comes in handy since we can easily add custom fields to the buffer as if we are working with dictionaries, but we can also easily perform operations on them as if we are working with tensors. +To enable a simple way to work with numpy memory-mapped arrays, we implemented the `sheeprl.utils.memmap.MemmapArray`, a container that handles the memory-mapped arrays. -This flexibility makes it very simple to implement, with the classes `ReplayBuffer`, `SequentialReplayBuffer`, `EpisodeBuffer`, and `AsyncReplayBuffer`, all the buffers needed for on-policy and off-policy algorithms. +This flexibility makes it very simple to implement, with the classes `ReplayBuffer`, `SequentialReplayBuffer`, `EpisodeBuffer`, and `EnvIndependentReplayBuffer`, all the buffers needed for on-policy and off-policy algorithms. ### :mag: Technical details -The tensor's shape in the TensorDict is `(T, B, *)`, where `T` is the number of timesteps, `B` is the number of parallel environments, and `*` is the shape of the data. +The shape of the Numpy arrays in the dictionary is `(T, B, *)`, where `T` is the number of timesteps, `B` is the number of parallel environments, and `*` is the shape of the data. For the `ReplayBuffer` to be used as a RolloutBuffer, the proper `buffer_size` must be specified. For example, for PPO, the `buffer_size` must be `[T, B]`, where `T` is the number of timesteps and `B` is the number of parallel environments. diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 98b63d0c..df9d97b2 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -9,11 +9,57 @@ algos ... └── sota ├── __init__.py + ├── agent.py ├── loss.py ├── sota.py └── utils.py ``` +## The agent +The agent is the core of the algorithm and it is defined in the `agent.py` file. It must contain at least single function called `build_agent` that returns a `torch.nn.Module` wrapped with Fabric: + +```python +from __future__ import annotations + +from typing import Any, Dict, Sequence + +import gymnasium +from lightning import Fabric +from lightning.fabric.wrappers import _FabricModule +import torch +from torch import Tensor + + +class SOTAAgent(torch.nn.Module): + def __init__(self, ...): + ... + + def forward(self, obs: Dict[str, torch.Tensor]) -> Tensor: + ... + + +def build_agent( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + cfg: Dict[str, Any], + observation_space: gymnasium.spaces.Dict, + state: Dict[str, Any] | None = None, +) -> _FabricModule: + + # Define the agent here + agent = SOTAAgent(...) + + # Load the state from the checkpoint + if state: + agent.load_state_dict(state) + + # Setup the agent with Fabric + agent = fabric.setup_model(agent) + + return agent +``` + ## Loss functions All the loss functions to be optimized by the agent during the training should be defined under the `loss.py` file, even though is not strictly necessary: @@ -40,20 +86,13 @@ from datetime import datetime import gymnasium as gym import hydra import torch -from gymnasium.vector import SyncVectorEnv from lightning.fabric import Fabric -from lightning.fabric.fabric import _is_using_cli -from omegaconf import DictConfig, OmegaConf -from tensordict import TensorDict, make_tensordict -from tensordict.tensordict import TensorDictBase -from torch.optim import Adam from torchmetrics import MeanMetric, SumMetric -from sheeprl.algos.ppo.agent import build_agent +from sheeprl.algos.sota.agent import build_agent from sheeprl.algos.sota.loss import loss1, loss2 -from sheeprl.algos.sota.utils import test +from sheeprl.algos.sota.utils import normalize_obs, test from sheeprl.data import ReplayBuffer -from sheeprl.models.models import MLP from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.env import make_env @@ -67,7 +106,7 @@ def train( fabric: Fabric, agent: torch.nn.Module, optimizer: torch.optim.Optimizer, - data: TensorDictBase, + data: Dict[str, torch.Tensor], aggregator: MetricAggregator, cfg: Dict[str, Any], ): @@ -99,6 +138,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -108,13 +148,34 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): cfg, cfg.seed + rank * cfg.env.num_envs + i, rank * cfg.env.num_envs, - logger.log_dir if rank == 0 else None, + log_dir if rank == 0 else None, "train", vector_env_idx=i, - ), + ) for i in range(cfg.env.num_envs) ] ) + observation_space = envs.single_observation_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + if cfg.metric.log_level > 0: + fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder + + is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) + is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + envs.single_action_space.shape + if is_continuous + else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) + ) # Create the agent model: this should be a torch.nn.Module to be accelerated with Fabric # Given that the environment has been created with the `make_env` method, the agent @@ -129,29 +190,54 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): state["agent"] if cfg.checkpoint.resume_from else None, ) - # the optimizer and set up it with Fabric + # Define the optimizer optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all") + # Load the state from the checkpoint + if cfg.checkpoint.resume_from: + optimizer.load_state_dict(state["optimizer"]) + + # Setup agent and optimizer with Fabric + optimizer = fabric.setup_optimizers(optimizer) + # Create a metric aggregator to log the metrics aggregator = None if not MetricAggregator.disabled: aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device) + if fabric.is_global_zero: + save_configs(cfg, log_dir) + # Local data - rb = ReplayBuffer(cfg.algo.rollout_steps, cfg.env.num_envs, device=device, memmap=cfg.buffer.memmap) - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) + rb = ReplayBuffer( + cfg.buffer.size, + cfg.env.num_envs, + memmap=cfg.buffer.memmap, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + obs_keys=obs_keys, + ) # Global variables - last_log = 0 last_train = 0 train_step = 0 - policy_step = 0 - last_checkpoint = 0 + start_step = ( + # + 1 because the checkpoint is at the end of the update step + # (when resuming from a checkpoint, the update at the checkpoint + # is ended and you have to start with the next one) + (state["update"] // fabric.world_size) + 1 + if cfg.checkpoint.resume_from + else 1 + ) + policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 + last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + if cfg.checkpoint.resume_from: + cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # Warning for log and checkpoint every - if cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " @@ -167,20 +253,14 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] - next_obs = {} - for k in o.keys(): - if k in obs_keys: - torch_obs = torch.from_numpy(o[k]).to(fabric.device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs - next_obs[k] = torch_obs - next_done = torch.zeros(cfg.env.num_envs, 1, dtype=torch.float32).to(fabric.device) # [N_envs, 1] - - for update in range(1, num_updates + 1): + step_data = {} + next_obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] + for k in obs_keys: + if k in cfg.algo.cnn_keys.encoder: + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + step_data[k] = next_obs[k][np.newaxis] + + for update in range(start_step, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs * world_size @@ -189,41 +269,43 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment - # This calls the `forward` method of the PyTorch module, escaping from Fabric - # because we don't want this to be a synchronization point - action = agent.module(next_obs) + normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) + torch_obs = { + k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys + } + actions = agent.module(torch_obs) + if is_continuous: + real_actions = torch.cat(actions, -1).cpu().numpy() + else: + real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - o, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) - - with device: - rewards = torch.tensor(reward).view(cfg.env.num_envs, -1) # [N_envs, 1] - done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) # [N_envs, 1] - done = done.view(cfg.env.num_envs, -1).float() + obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data - step_data["dones"] = next_done - step_data["actions"] = action - step_data["rewards"] = rewards + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] + if cfg.buffer.memmap: + step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) # Append data to buffer - rb.add(step_data.unsqueeze(0)) - - # Update the observation and done - obs = {} - for k in o.keys(): - if k in obs_keys: - torch_obs = torch.from_numpy(o[k]).to(fabric.device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - next_obs = obs - next_done = done - - if "final_info" in info: + rb.add(step_data, validate_args=False) + + # Update the observation and dones + next_obs = {} + for k in obs_keys: + _obs = obs[k] + if k in cfg.algo.cnn_keys.encoder: + _obs = _obs.reshape(cfg.env.num_envs, -1, *_obs.shape[-2:]) + step_data[k] = _obs[np.newaxis] + next_obs[k] = _obs + + if cfg.metric.log_level > 0 and "final_info" in info: for i, agent_ep_info in enumerate(info["final_info"]): if agent_ep_info is not None: ep_rew = agent_ep_info["episode"]["r"] @@ -234,8 +316,8 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): aggregator.update("Game/ep_len_avg", ep_len) fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") - # Flatten the batch - local_data = rb.buffer.view(-1) + # Transform the data into PyTorch Tensors + local_data = rb.to_tensor(dtype=None, device=device) # Train the agent train(fabric, agent, optimizer, local_data, aggregator, cfg) @@ -298,7 +380,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): register_model(fabric, log_models, cfg, models_to_log) ``` -where `log_models` has to be defined in the `sheeprl.algo.sota.utils` module, for example like this: +where `log_models`, `test` and `normalize_obs` have to be defined in the `sheeprl.algo.sota.utils` module, for example like this: ```python from __future__ import annotations @@ -307,14 +389,69 @@ import warnings from typing import TYPE_CHECKING, Any, Dict import torch +from lightning import Fabric from lightning.fabric.wrappers import _FabricModule +from sheeprl.algos.sota.agent import SOTAAgent from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE from sheeprl.utils.utils import unwrap_fabric if TYPE_CHECKING: from mlflow.models.model import ModelInfo + +@torch.no_grad() +def test(agent: SOTAAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): + env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() + agent.eval() + done = False + cumulative_rew = 0 + o = env.reset(seed=cfg.seed)[0] + obs = {} + for k in o.keys(): + if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: + torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) + if k in cfg.algo.cnn_keys.encoder: + torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 + if k in cfg.algo.mlp_keys.encoder: + torch_obs = torch_obs.float() + obs[k] = torch_obs + + while not done: + # Act greedly through the environment + if agent.is_continuous: + actions = torch.cat(agent.get_greedy_actions(obs), dim=-1) + else: + actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1) + + # Single environment step + o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) + done = done or truncated + cumulative_rew += reward + obs = {} + for k in o.keys(): + if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder: + torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0) + if k in cfg.algo.cnn_keys.encoder: + torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 + if k in cfg.algo.mlp_keys.encoder: + torch_obs = torch_obs.float() + obs[k] = torch_obs + + if cfg.dry_run: + done = True + fabric.print("Test - Reward:", cumulative_rew) + if cfg.metric.log_level > 0: + fabric.log_dict({"Test/cumulative_reward": cumulative_rew}, 0) + env.close() + + +def normalize_obs( + obs: Dict[str, np.ndarray | Tensor], cnn_keys: Sequence[str], obs_keys: Sequence[str] +) -> Dict[str, np.ndarray | Tensor]: + return {k: obs[k] / 255 - 0.5 if k in cnn_keys else obs[k] for k in obs_keys} + + def log_models( cfg: Dict[str, Any], models_to_log: Dict[str, torch.nn.Module | _FabricModule], diff --git a/pyproject.toml b/pyproject.toml index 2261b875..55705142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ "gymnasium==0.29.*", "pygame >=2.1.3", "moviepy>=1.0.3", - "tensordict==0.2.*", "tensorboard>=2.10", "python-dotenv>=1.0.0", "lightning==2.1.*", @@ -139,7 +138,7 @@ markers = ["benchmark: mark test as a benchmark"] # Pytest coverage [tool.coverage.run] -omit = ["./sheeprl/envs/*", "./sheeprl/available_agents.py"] +omit = ["./sheeprl/envs/*", "./sheeprl/available_agents.py", "sheeprl/utils/mlflow.py"] [tool.prettier] tab_width = 2 diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index 441618bc..9334d4f2 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -1,11 +1,11 @@ import os -ROOT_DIR = os.path.dirname(__file__) - import decorator from dotenv import load_dotenv load_dotenv() +ROOT_DIR = os.path.dirname(__file__) + from sheeprl.utils.imports import _IS_TORCH_GREATER_EQUAL_2_0 diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index 708baf52..c30d3801 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -7,8 +7,6 @@ import numpy as np import torch from lightning.fabric import Fabric -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import SumMetric @@ -28,7 +26,7 @@ def train( fabric: Fabric, agent: torch.nn.Module, optimizer: torch.optim.Optimizer, - data: TensorDictBase, + data: Dict[str, torch.Tensor], aggregator: MetricAggregator, cfg: Dict[str, Any], ): @@ -38,7 +36,7 @@ def train( # If we are in the distributed setting, we need to use a DistributedSampler, which # will shuffle the data at each epoch and will ensure that each process will get # a different part of the data - indexes = list(range(data.shape[0])) + indexes = list(range(next(iter(data.values())).shape[0])) if cfg.buffer.share_data: sampler = DistributedSampler( indexes, @@ -63,8 +61,8 @@ def train( # we do not do that, instead we take the overall sum (or mean, depending on the loss reduction). # This is achieved by accumulating the gradients and calling the backward method only at the end. for i, batch_idxes in enumerate(sampler): - batch = data[batch_idxes] - obs = {k: batch[k] for k in cfg.algo.mlp_keys.encoder} + batch = {k: v[batch_idxes] for k, v in data.items()} + obs = {k: v for k, v in batch.items() if k in cfg.algo.mlp_keys.encoder} # is_accumulating is True for every i except for the last one is_accumulating = i < len(sampler) - 1 @@ -173,7 +171,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): observation_space, state["agent"] if cfg.checkpoint.resume_from else None, ) - fabric.print(agent.module) # the optimizer and set up it with Fabric optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all") @@ -184,8 +181,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device) # Local data - rb = ReplayBuffer(cfg.algo.rollout_steps, cfg.env.num_envs, device=device, memmap=cfg.buffer.memmap) - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) + rb = ReplayBuffer( + cfg.buffer.size, + cfg.env.num_envs, + memmap=cfg.buffer.memmap, + memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + obs_keys=obs_keys, + ) # Global variables last_log = 0 @@ -213,13 +215,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] - next_obs = {} - for k in o.keys(): - if k in obs_keys: - torch_obs = torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device) - step_data[k] = torch_obs - next_obs[k] = torch_obs + step_data = {} + next_obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] + for k in obs_keys: + step_data[k] = next_obs[k][np.newaxis] for update in range(1, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): @@ -232,40 +231,39 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sample an action given the observation received by the environment # This calls the `forward` method of the PyTorch module, escaping from Fabric # because we don't want this to be a synchronization point - actions, _, values = agent.module(next_obs) + torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + actions, _, values = agent.module(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, -1) + real_actions = torch.cat([act.argmax(dim=-1) for act in actions], axis=-1).cpu().numpy() + actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - o, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) dones = np.logical_or(done, truncated) - dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) + dones = dones.reshape(cfg.env.num_envs, -1) + rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data - step_data["dones"] = dones - step_data["values"] = values - step_data["actions"] = actions - step_data["rewards"] = rewards + step_data["dones"] = dones[np.newaxis] + step_data["values"] = values.cpu().numpy()[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] if cfg.buffer.memmap: - step_data["returns"] = torch.zeros_like(rewards, dtype=torch.float32) - step_data["advantages"] = torch.zeros_like(rewards, dtype=torch.float32) + step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) # Append data to buffer - rb.add(step_data.unsqueeze(0)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) - # Update the observation and done - obs = {} - for k in o.keys(): - if k in obs_keys: - torch_obs = torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device) - step_data[k] = torch_obs - obs[k] = torch_obs - next_obs = obs + # Update the observation and dones + next_obs = {} + for k in obs_keys: + _obs = obs[k] + step_data[k] = _obs[np.newaxis] + next_obs[k] = _obs if cfg.metric.log_level > 0 and "final_info" in info: for i, agent_ep_info in enumerate(info["final_info"]): @@ -278,13 +276,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator.update("Game/ep_len_avg", ep_len) fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + # Transform the data into PyTorch Tensors + local_data = rb.to_tensor(dtype=None, device=device) + # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): - next_values = agent.module.get_value(next_obs) + torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + next_values = agent.module.get_value(torch_obs) returns, advantages = gae( - rb["rewards"].to(torch.float64), - rb["values"], - rb["dones"], + local_data["rewards"].to(torch.float64), + local_data["values"], + local_data["dones"], next_values, cfg.algo.rollout_steps, cfg.algo.gamma, @@ -292,11 +294,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Add returns and advantages to the buffer - rb["returns"] = returns.float() - rb["advantages"] = advantages.float() - - # Flatten the batch - local_data = rb.buffer.view(-1) + local_data["returns"] = returns.float() + local_data["advantages"] = advantages.float() # Train the agent train(fabric, agent, optimizer, local_data, aggregator, cfg) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 2793e347..d67f3f71 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -12,18 +12,16 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase +from torch import Tensor from torch.distributions import Bernoulli, Independent, Normal from torch.distributions.utils import logits_to_probs -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_agent from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.data.buffers import AsyncReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator @@ -44,7 +42,7 @@ def train( world_optimizer: _FabricOptimizer, actor_optimizer: _FabricOptimizer, critic_optimizer: _FabricOptimizer, - data: TensorDictBase, + data: Dict[str, Tensor], aggregator: MetricAggregator | None, cfg: Dict[str, Any], ) -> None: @@ -97,7 +95,7 @@ def train( world_optimizer (_FabricOptimizer): the world optimizer. actor_optimizer (_FabricOptimizer): the actor optimizer. critic_optimizer (_FabricOptimizer): the critic optimizer. - data (TensorDictBase): the batch of data to use for training. + data (Dict[str, Tensor]): the batch of data to use for training. aggregator (MetricAggregator, optional): the aggregator to print the metrics. cfg (DictConfig): the configs. """ @@ -107,7 +105,6 @@ def train( recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size device = fabric.device - data = {k: data[k] for k in data.keys()} batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) @@ -369,6 +366,7 @@ def train( ) critic_optimizer.step() + # Log metrics if aggregator and not aggregator.disabled: aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) @@ -415,6 +413,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -439,7 +438,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") if ( @@ -510,23 +509,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2 - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, cfg.env.num_envs, - device=fabric.device if cfg.buffer.memmap else "cpu", + obs_keys=obs_keys, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], EnvIndependentReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=fabric.device if cfg.buffer.memmap else "cpu") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -546,6 +543,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 + expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size @@ -575,18 +573,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1) - step_data["actions"] = torch.zeros(cfg.env.num_envs, sum(actions_dim)) - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1) - rb.add(step_data[None, ...]) + if k in cfg.algo.cnn_keys.encoder: + obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() for update in range(start_step, num_updates + 1): @@ -605,30 +601,32 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not is_continuous: actions = np.concatenate( [ - F.one_hot(torch.tensor(act), act_dim).numpy() + F.one_hot(torch.as_tensor(act), act_dim).numpy() for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) else: with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): + normalized_obs = {} + for k in obs_keys: + torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = torch_obs / 255 - 0.5 + normalized_obs[k] = torch_obs + mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) - actions = torch.cat(actions, -1).cpu().numpy() + real_actions = actions = player.get_exploration_action(normalized_obs, mask) + actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -641,64 +639,59 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } - for k in obs_keys: # [N_envs, N_obs] - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() - actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() + for k in obs_keys: + if k in cfg.algo.cnn_keys.encoder: + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) + step_data[k] = real_next_obs[k][np.newaxis] # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones - step_data["actions"] = actions - step_data["rewards"] = clip_rewards_fn(rewards) - rb.add(step_data[None, ...]) + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") + reset_data = {} for k in obs_keys: - reset_data[k] = next_obs[k][dones_idxes] - reset_data["dones"] = torch.zeros(reset_envs, 1) - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) - reset_data["rewards"] = torch.zeros(reset_envs, 1) - rb.add(reset_data[None, ...], dones_idxes) + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) # Reset internal agent states - player.init_states(dones_idxes) + player.init_states(reset_envs=dones_idxes) updates_before_training -= 1 # Train the agent if update > learning_starts and updates_before_training <= 0: - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_gradient_steps, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(cfg.algo.per_rank_gradient_steps): + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=1, + dtype=None, + device=device, + ) # [N_samples, Seq_len, Batch_size, ...] + batch = {k: v[0].float() for k, v in sample.items()} train( fabric, world_model, @@ -707,7 +700,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, ) diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index 8772ffe6..f69891fa 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -19,6 +19,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/dreamer_v2/README.md b/sheeprl/algos/dreamer_v2/README.md index 68095441..5aab1e86 100644 --- a/sheeprl/algos/dreamer_v2/README.md +++ b/sheeprl/algos/dreamer_v2/README.md @@ -118,8 +118,10 @@ env=atari \ env.id=AssaultNoFrameskip-v0 \ env.capture_video=True \ env.action_repeat=4 \ -clip_rewards=True \ -total_steps=200000000 \ +env.clip_rewards=True \ +env.max_episode_steps=27000 \ +env.num_envs=1 \ +algo.total_steps=200000000 \ algo.learning_starts=200000 \ algo.per_rank_pretrain_steps=1 \ algo.train_every=4 \ @@ -134,14 +136,14 @@ algo.world_model.use_continues=True \ algo.world_model.representation_model.hidden_size=600 \ algo.world_model.transition_model.hidden_size=600 \ algo.world_model.recurrent_model.recurrent_state_size=600 \ +algo.world_model.kl_free_nats=0.0 \ +algo.per_rank_batch_size=50 \ +algo.cnn_keys.encoder=[rgb] \ buffer.size=2000000 \ buffer.memmap=True \ -algo.world_model.kl_free_nats=0.0 \ -env.max_episode_steps=27000 \ -per_rank_batch_size=50 \ -checkpoint.every=100000 \ buffer.type=episode \ buffer.prioritize_ends=True +checkpoint.every=100000 \ ``` ## DMC environments @@ -158,11 +160,14 @@ PYOPENGL_PLATFORM="" MUJOCO_GL=osmesa python sheeprl.py \ exp=dreamer_v2 \ fabric.devices=1 \ env=dmc \ -env.id=dmc_walker_walk \ +env.id=walker_walk \ env.capture_video=True \ env.action_repeat=2 \ -clip_rewards=False \ -total_steps=5000000 \ +env.clip_rewards=False \ +env.max_episode_steps=1000 \ +env.num_envs=1 \ +algo.cnn_keys.encoder=[rgb] \ +algo.total_steps=5000000 \ algo.learning_starts=1000 \ algo.per_rank_pretrain_steps=100 \ algo.train_every=5 \ @@ -176,14 +181,14 @@ algo.world_model.use_continues=False \ algo.world_model.representation_model.hidden_size=200 \ algo.world_model.transition_model.hidden_size=200 \ algo.world_model.recurrent_model.recurrent_state_size=200 \ +algo.per_rank_batch_size=50 \ +algo.world_model.kl_free_nats=1.0 \ +algo.actor.objective_mix=0.0 \ buffer.size=5000000 \ buffer.memmap=True \ -algo.world_model.kl_free_nats=1.0 \ -env.max_episode_steps=1000 \ -per_rank_batch_size=50 \ -checkpoint.every=100000 \ buffer.type=episode \ buffer.prioritize_ends=False +checkpoint.every=100000 \ ``` ## Recommendations diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 50180316..01df22d8 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -15,19 +15,16 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase from torch import Tensor from torch.distributions import Bernoulli, Distribution, Independent, Normal from torch.distributions.utils import logits_to_probs from torch.optim import Optimizer -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_agent from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test -from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -50,7 +47,7 @@ def train( world_optimizer: Optimizer, actor_optimizer: Optimizer, critic_optimizer: Optimizer, - data: TensorDictBase, + data: Dict[str, Tensor], aggregator: MetricAggregator | None, cfg: Dict[str, Any], actions_dim: Sequence[int], @@ -98,7 +95,7 @@ def train( world_optimizer (Optimizer): the world optimizer. actor_optimizer (Optimizer): the actor optimizer. critic_optimizer (Optimizer): the critic optimizer. - data (TensorDictBase): the batch of data to use for training. + data (Dict[str, Tensor]): the batch of data to use for training. aggregator (MetricAggregator, optional): the aggregator to print the metrics. cfg (DictConfig): the configs. actions_dim (Sequence[int]): the actions dimension. @@ -121,7 +118,6 @@ def train( stochastic_size = cfg.algo.world_model.stochastic_size discrete_size = cfg.algo.world_model.discrete_size device = fabric.device - data = {k: data[k] for k in data.keys()} batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) @@ -284,7 +280,7 @@ def train( predicted_target_values = target_critic(imagined_trajectories) predicted_rewards = world_model.reward_model(imagined_trajectories) if cfg.algo.world_model.use_continues and world_model.continue_model: - continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True) + continues = logits_to_probs(world_model.continue_model(imagined_trajectories), is_binary=True) true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma continues = torch.cat((true_done, continues[1:])) else: @@ -366,11 +362,14 @@ def train( critic_optimizer.zero_grad(set_to_none=True) value_loss = -torch.mean(discount[:-1, ..., 0] * qv.log_prob(lambda_values.detach())) fabric.backward(value_loss) + critic_grads = None if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0: critic_grads = fabric.clip_gradients( module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False ) critic_optimizer.step() + + # Log metrics if aggregator and not aggregator.disabled: aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) @@ -437,6 +436,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -461,7 +461,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -525,7 +525,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer ) - local_vars = locals() if fabric.is_global_zero: save_configs(cfg, log_dir) @@ -538,19 +537,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2 buffer_type = cfg.buffer.type.lower() if buffer_type == "sequential": - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, - cfg.env.num_envs, - device="cpu", + n_envs=cfg.env.num_envs, + obs_keys=obs_keys, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) elif buffer_type == "episode": rb = EpisodeBuffer( buffer_size, - sequence_length=cfg.algo.per_rank_sequence_length, - device="cpu", + minimum_episode_length=1 if cfg.dry_run else cfg.algo.per_rank_sequence_length, + n_envs=cfg.env.num_envs, + obs_keys=obs_keys, + prioritize_ends=cfg.buffer.prioritize_ends, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) @@ -559,12 +560,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], (AsyncReplayBuffer, EpisodeBuffer)): + elif isinstance(state["rb"], (EnvIndependentReplayBuffer, EpisodeBuffer)): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -584,6 +583,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 + expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size @@ -613,25 +613,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - episode_steps = [[] for _ in range(cfg.env.num_envs)] - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - # Images stay uint8 to save space - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1) - step_data["actions"] = torch.zeros(cfg.env.num_envs, sum(actions_dim)) - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1) - step_data["is_first"] = torch.ones_like(step_data["dones"]) - if buffer_type == "sequential": - rb.add(step_data[None, ...]) - else: - for i, env_ep in enumerate(episode_steps): - env_ep.append(step_data[i : i + 1][None, ...]) + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + if cfg.dry_run: + step_data["dones"] = step_data["dones"] + 1 + step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["dones"]) + rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() per_rank_gradient_steps = 0 @@ -651,32 +643,34 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not is_continuous: actions = np.concatenate( [ - F.one_hot(torch.tensor(act), act_dim).numpy() + F.one_hot(torch.as_tensor(act), act_dim).numpy() for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) else: with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): + normalized_obs = {} + for k in obs_keys: + torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = torch_obs / 255 - 0.5 + normalized_obs[k] = torch_obs + mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) - actions = torch.cat(actions, -1).cpu().numpy() + real_actions = actions = player.get_exploration_action(normalized_obs, mask) + actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -691,60 +685,39 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs: Dict[str, Tensor] = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } - for k in real_next_obs.keys(): # [N_envs, N_obs] - if k in obs_keys: - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() - actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() + for k in obs_keys: # [N_envs, N_obs] + step_data[k] = real_next_obs[k][np.newaxis] # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones - step_data["actions"] = actions - step_data["rewards"] = clip_rewards_fn(rewards) - if buffer_type == "sequential": - rb.add(step_data[None, ...]) - else: - for i, env_ep in enumerate(episode_steps): - env_ep.append(step_data[i : i + 1][None, ...]) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") - for k in next_obs.keys(): - reset_data[k] = next_obs[k][dones_idxes] - reset_data["dones"] = torch.zeros(reset_envs, 1) - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) - reset_data["rewards"] = torch.zeros(reset_envs, 1) - reset_data["is_first"] = torch.ones_like(reset_data["dones"]) - if buffer_type == "episode": - for i, d in enumerate(dones_idxes): - if len(episode_steps[d]) >= cfg.algo.per_rank_sequence_length: - rb.add(torch.cat(episode_steps[d], dim=0)) - episode_steps[d] = [reset_data[i : i + 1][None, ...]] - else: - rb.add(reset_data[None, ...], dones_idxes) + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + reset_data["is_first"] = np.ones_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) # Reset internal agent states player.init_states(dones_idxes) @@ -752,28 +725,22 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: - if buffer_type == "sequential": - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - ).to(device) - else: - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - prioritize_ends=cfg.buffer.prioritize_ends, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) + n_samples = ( + cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps + ) + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=n_samples, + dtype=None, + device=fabric.device, + ) with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} train( fabric, world_model, @@ -783,7 +750,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, actions_dim, diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index 78ebee5b..2b3b6bab 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -19,6 +19,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index dcf606b3..799674fc 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union import gymnasium as gym -import numpy as np import torch import torch.nn as nn from lightning import Fabric @@ -150,7 +149,7 @@ def test( if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() # Single environment step next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 4513e528..20cbbb69 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -1,6 +1,7 @@ """Dreamer-V3 implementation from [https://arxiv.org/abs/2301.04104](https://arxiv.org/abs/2301.04104) Adapted from the original implementation from https://github.com/danijar/dreamerv3 """ +from __future__ import annotations import copy import os @@ -15,18 +16,15 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase from torch import Tensor from torch.distributions import Bernoulli, Distribution, Independent from torch.optim import Optimizer -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_agent from sheeprl.algos.dreamer_v3.loss import reconstruction_loss from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test -from sheeprl.data.buffers import AsyncReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.envs.wrappers import RestartOnException from sheeprl.utils.distribution import ( MSEDistribution, @@ -55,8 +53,8 @@ def train( world_optimizer: Optimizer, actor_optimizer: Optimizer, critic_optimizer: Optimizer, - data: TensorDictBase, - aggregator: MetricAggregator, + data: Dict[str, Tensor], + aggregator: MetricAggregator | None, cfg: Dict[str, Any], is_continuous: bool, actions_dim: Sequence[int], @@ -73,7 +71,7 @@ def train( world_optimizer (Optimizer): the world optimizer. actor_optimizer (Optimizer): the actor optimizer. critic_optimizer (Optimizer): the critic optimizer. - data (TensorDictBase): the batch of data to use for training. + data (Dict[str, Tensor]): the batch of data to use for training. aggregator (MetricAggregator, optional): the aggregator to print the metrics. cfg (DictConfig): the configs. is_continuous (bool): whether or not the environment is continuous. @@ -97,7 +95,6 @@ def train( stochastic_size = cfg.algo.world_model.stochastic_size discrete_size = cfg.algo.world_model.discrete_size device = fabric.device - data = {k: data[k] for k in data.keys()} batch_obs = {k: data[k] / 255.0 for k in cfg.algo.cnn_keys.encoder} batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :]) @@ -177,6 +174,7 @@ def train( validate_args=validate_args, ) fabric.backward(rec_loss) + world_model_grads = None if cfg.algo.world_model.clip_gradients is not None and cfg.algo.world_model.clip_gradients > 0: world_model_grads = fabric.clip_gradients( module=world_model, @@ -287,6 +285,7 @@ def train( entropy = torch.zeros_like(objective) policy_loss = -torch.mean(discount[:-1].detach() * (objective + entropy.unsqueeze(dim=-1)[:-1])) fabric.backward(policy_loss) + actor_grads = None if cfg.algo.actor.clip_gradients is not None and cfg.algo.actor.clip_gradients > 0: actor_grads = fabric.clip_gradients( module=actor, optimizer=actor_optimizer, max_norm=cfg.algo.actor.clip_gradients, error_if_nonfinite=False @@ -306,14 +305,15 @@ def train( value_loss = torch.mean(value_loss * discount[:-1].squeeze(-1)) fabric.backward(value_loss) + critic_grads = None if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0: critic_grads = fabric.clip_gradients( module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients, error_if_nonfinite=False ) critic_optimizer.step() + # Log metrics if aggregator and not aggregator.disabled: - aggregator.update("Grads/world_model", world_model_grads.mean().detach()) aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) @@ -342,10 +342,14 @@ def train( .mean() .detach(), ) - aggregator.update("Grads/actor", actor_grads.mean().detach()) aggregator.update("Loss/policy_loss", policy_loss.detach()) - aggregator.update("Grads/critic", critic_grads.mean().detach()) aggregator.update("Loss/value_loss", value_loss.detach()) + if world_model_grads: + aggregator.update("Grads/world_model", world_model_grads.mean().detach()) + if actor_grads: + aggregator.update("Grads/actor", actor_grads.mean().detach()) + if critic_grads: + aggregator.update("Grads/critic", critic_grads.mean().detach()) # Reset everything actor_optimizer.zero_grad(set_to_none=True) @@ -376,6 +380,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -403,7 +408,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -484,22 +489,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2 - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, - cfg.env.num_envs, - device="cpu", + n_envs=cfg.env.num_envs, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], EnvIndependentReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables @@ -549,18 +552,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - # Images stay uint8 to save space - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1).float() - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1).float() - step_data["is_first"] = torch.ones_like(step_data["dones"]).float() + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() per_rank_gradient_steps = 0 @@ -580,7 +578,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not is_continuous: actions = np.concatenate( [ - F.one_hot(torch.tensor(act), act_dim).numpy() + F.one_hot(torch.as_tensor(act), act_dim).numpy() for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, @@ -589,10 +587,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): + preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255.0 - else: - preprocessed_obs[k] = v[None, ...].to(device) + preprocessed_obs[k] = preprocessed_obs[k] / 255.0 mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None @@ -601,24 +598,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) - step_data["actions"] = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rb.add(step_data[None, ...]) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) - step_data["is_first"] = torch.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["dones"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = torch.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) - rb.buffer[i]["is_first"][last_inserted_idx] = torch.zeros_like( + rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) + rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] ) - step_data["is_first"][i] = torch.ones_like(step_data["is_first"][i]) + step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -631,70 +630,61 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs: Dict[str, Tensor] = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } - for k in real_next_obs.keys(): # [N_envs, N_obs] - if k in obs_keys: - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = next_obs[k] - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() + for k in obs_keys: + step_data[k] = next_obs[k][np.newaxis] # next_obs becomes the new obs obs = next_obs - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() - step_data["dones"] = dones + rewards = rewards.reshape((1, cfg.env.num_envs, -1)) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") + reset_data = {} for k in obs_keys: - reset_data[k] = real_next_obs[k][dones_idxes] - if k in cfg.algo.mlp_keys.encoder: - reset_data[k] = reset_data[k].float() - reset_data["dones"] = torch.ones(reset_envs, 1).float() - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)).float() - reset_data["rewards"] = step_data["rewards"][dones_idxes].float() - reset_data["is_first"] = torch.zeros_like(reset_data["dones"]).float() - rb.add(reset_data[None, ...], dones_idxes) + reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = step_data["rewards"][:, dones_idxes] + reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data - step_data["rewards"][dones_idxes] = torch.zeros_like(reset_data["rewards"]).float() - step_data["dones"][dones_idxes] = torch.zeros_like(step_data["dones"][dones_idxes]).float() - step_data["is_first"][dones_idxes] = torch.ones_like(step_data["is_first"][dones_idxes]).float() + step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) + step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) updates_before_training -= 1 # Train the agent if update >= learning_starts and updates_before_training <= 0: - local_data = rb.sample( + local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) + n_samples=( + cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps + ), + dtype=None, + device=fabric.device, + ) with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + batch = {k: v[i].float() for k, v in local_data.items()} train( fabric, world_model, @@ -704,7 +694,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, is_continuous, diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index 94775a45..17b32ddf 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -19,6 +19,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index 7fb9a886..27a068a5 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -122,7 +122,7 @@ def test( if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() # Single environment step next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 4f4d95c9..6bdf3868 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -11,7 +11,6 @@ import torch import torch.nn.functional as F from lightning.fabric import Fabric -from tensordict import TensorDict, make_tensordict from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler @@ -41,14 +40,16 @@ def train( ): # Sample a minibatch in a distributed way: Line 5 - Algorithm 2 # We sample one time to reduce the communications between processes - sample = rb.sample( + sample = rb.sample_tensors( cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs ) - critic_data = fabric.all_gather(sample.to_dict()) - critic_data = make_tensordict(critic_data).view(-1) + critic_data = fabric.all_gather(sample) + flatten_dim = 3 if fabric.world_size > 1 else 2 + critic_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in critic_data.items()} + critic_idxes = range(len(critic_data[next(iter(critic_data.keys()))])) if fabric.world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( - range(len(critic_data)), + critic_idxes, num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, @@ -59,29 +60,27 @@ def train( sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) else: - critic_sampler = BatchSampler( - sampler=range(len(critic_data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False - ) + critic_sampler = BatchSampler(sampler=critic_idxes, batch_size=cfg.algo.per_rank_batch_size, drop_last=False) # Sample a different minibatch in a distributed way to update actor and alpha parameter - sample = rb.sample(cfg.algo.per_rank_batch_size) - actor_data = fabric.all_gather(sample.to_dict()) - actor_data = make_tensordict(actor_data).view(-1) + sample = rb.sample_tensors(cfg.algo.per_rank_batch_size) + actor_data = fabric.all_gather(sample) + actor_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in actor_data.items()} if fabric.world_size > 1: actor_sampler: DistributedSampler = DistributedSampler( - range(len(actor_data)), + range(len(actor_data[next(iter(actor_data.keys()))])), num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, seed=cfg.seed, drop_last=False, ) - actor_data = actor_data[next(iter(actor_sampler))] + actor_data = {k: actor_data[k][next(iter(actor_sampler))] for k in actor_data.keys()} with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): # Update the soft-critic for batch_idxes in critic_sampler: - critic_batch_data = critic_data[batch_idxes] + critic_batch_data = {k: critic_data[k][batch_idxes] for k in critic_data.keys()} next_target_qf_value = agent.get_next_target_q_values( critic_batch_data["next_observations"], critic_batch_data["rewards"], @@ -157,6 +156,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -233,7 +233,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables last_train = 0 @@ -273,12 +272,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "policy_steps_per_update value." ) - with device: - # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = torch.cat( - [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] + step_data = {} + # Get the first environment observation and start the optimization + o = envs.reset(seed=cfg.seed)[0] + obs = np.concatenate([o[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32) for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * fabric.world_size @@ -288,9 +285,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment - actions, _ = agent.actor.module(obs) + actions, _ = agent.actor.module(torch.from_numpy(obs).to(device)) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) + next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -311,24 +308,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k, v in final_obs.items(): real_next_obs[k][idx] = v - with device: - next_obs = torch.cat( - [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] - real_next_obs = torch.cat( - [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] - actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) - rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] - dones = torch.tensor(dones, dtype=torch.float32).view(cfg.env.num_envs, -1) - - step_data["dones"] = dones - step_data["actions"] = actions - step_data["observations"] = obs + next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32) + real_next_obs = np.concatenate([real_next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype( + np.float32 + ) + + step_data["observations"] = obs[np.newaxis] if not cfg.buffer.sample_next_obs: - step_data["next_observations"] = real_next_obs - step_data["rewards"] = rewards - rb.add(step_data.unsqueeze(0)) + step_data["next_observations"] = real_next_obs[np.newaxis] + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) + step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # next_obs becomes the new obs obs = next_obs diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py index a80f0ef0..2fa3e0c9 100644 --- a/sheeprl/algos/droq/evaluate.py +++ b/sheeprl/algos/droq/evaluate.py @@ -19,6 +19,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index e6d2de92..d170e436 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -20,6 +20,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 7385eaa1..6968fcbf 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -12,11 +12,9 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase +from torch import Tensor from torch.distributions import Bernoulli, Independent, Normal from torch.distributions.utils import logits_to_probs -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel @@ -24,7 +22,7 @@ from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv1.agent import build_agent -from sheeprl.data.buffers import AsyncReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator @@ -44,7 +42,7 @@ def train( world_optimizer: _FabricOptimizer, actor_task_optimizer: _FabricOptimizer, critic_task_optimizer: _FabricOptimizer, - data: TensorDictBase, + data: Dict[str, Tensor], aggregator: MetricAggregator | None, cfg: Dict[str, Any], ensembles: _FabricModule, @@ -87,7 +85,7 @@ def train( world_optimizer (_FabricOptimizer): the world optimizer. actor_task_optimizer (_FabricOptimizer): the actor optimizer for solving the task. critic_task_optimizer (_FabricOptimizer): the critic optimizer for solving the task. - data (TensorDictBase): the batch of data to use for training. + data (Dict[str, Tensor]): the batch of data to use for training. aggregator (MetricAggregator, optional): the aggregator to print the metrics. cfg (DictConfig): the configs. ensembles (_FabricModule): the ensemble models. @@ -103,7 +101,6 @@ def train( recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size stochastic_size = cfg.algo.world_model.stochastic_size device = fabric.device - data = {k: data[k] for k in data.keys()} batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) @@ -415,6 +412,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv envs = vectorized_env( @@ -438,7 +436,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -543,23 +541,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 - rb = AsyncReplayBuffer( + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2 + rb = EnvIndependentReplayBuffer( buffer_size, cfg.env.num_envs, - device="cpu", + obs_keys=obs_keys, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], EnvIndependentReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") + step_data = {} expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables @@ -615,18 +613,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1) - step_data["actions"] = torch.zeros(cfg.env.num_envs, sum(actions_dim)) - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1) - rb.add(step_data[None, ...]) + if k in cfg.algo.cnn_keys.encoder: + obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() for update in range(start_step, num_updates + 1): @@ -645,31 +641,32 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not is_continuous: actions = np.concatenate( [ - F.one_hot(torch.tensor(act), act_dim).numpy() + F.one_hot(torch.as_tensor(act), act_dim).numpy() for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) else: with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): + normalized_obs = {} + for k in obs_keys: + torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = torch_obs / 255 - 0.5 + normalized_obs[k] = torch_obs + mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) - actions = torch.cat(actions, -1).cpu().numpy() + real_actions = actions = player.get_exploration_action(normalized_obs, mask) + actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -682,65 +679,59 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): - if k == "rgb": - real_next_obs[idx] = v + real_next_obs[k][idx] = v - next_obs = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } - for k in obs_keys: # [N_envs, N_obs] - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() - actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() + for k in obs_keys: + if k in cfg.algo.cnn_keys.encoder: + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) + step_data[k] = real_next_obs[k][np.newaxis] # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones - step_data["actions"] = actions - step_data["rewards"] = clip_rewards_fn(rewards) - rb.add(step_data[None, ...]) + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") - for k in next_obs.keys(): - reset_data[k] = next_obs[k][dones_idxes] - reset_data["dones"] = torch.zeros(reset_envs, 1) - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) - reset_data["rewards"] = torch.zeros(reset_envs, 1) - rb.add(reset_data[None, ...], dones_idxes) + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) # Reset internal agent states - player.init_states(dones_idxes) + player.init_states(reset_envs=dones_idxes) updates_before_training -= 1 # Train the agent if update >= learning_starts and updates_before_training <= 0: - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_gradient_steps, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(cfg.algo.per_rank_gradient_steps): + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=1, + dtype=None, + device=device, + ) # [N_samples, Seq_len, Batch_size, ...] + batch = {k: v[0].float() for k, v in sample.items()} train( fabric, world_model, @@ -749,7 +740,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, ensembles=ensembles, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 2f18c669..adb53c7e 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -11,15 +11,13 @@ import numpy as np import torch from lightning.fabric import Fabric -from tensordict import TensorDict -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v1.agent import PlayerDV1 from sheeprl.algos.dreamer_v1.dreamer_v1 import train from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv1.agent import build_agent -from sheeprl.data.buffers import AsyncReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator @@ -82,6 +80,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -106,7 +105,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -185,22 +184,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, cfg.env.num_envs, - device="cpu", + obs_keys=obs_keys, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) if resume_from_checkpoint or (cfg.buffer.load_from_exploration and exploration_cfg.buffer.checkpoint): if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], EnvIndependentReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables @@ -256,18 +254,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1) - step_data["actions"] = torch.zeros(cfg.env.num_envs, sum(actions_dim)) - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1) - rb.add(step_data[None, ...]) + if k in cfg.algo.cnn_keys.encoder: + obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) + player.init_states() player.init_states() for update in range(start_step, num_updates + 1): @@ -277,24 +274,25 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): + normalized_obs = {} + for k in obs_keys: + torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = torch_obs / 255 - 0.5 + normalized_obs[k] = torch_obs + mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) - actions = torch.cat(actions, -1).cpu().numpy() + real_actions = actions = player.get_exploration_action(normalized_obs, mask) + actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -307,51 +305,43 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): - if k == "rgb": - real_next_obs[idx] = v - - next_obs = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } - for k in obs_keys: # [N_envs, N_obs] - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() - actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() + real_next_obs[k][idx] = v + + for k in obs_keys: + if k in cfg.algo.cnn_keys.encoder: + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + real_next_obs[k] = real_next_obs[k].reshape(cfg.env.num_envs, -1, *real_next_obs[k].shape[-2:]) + step_data[k] = real_next_obs[k][np.newaxis] # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones - step_data["actions"] = actions - step_data["rewards"] = clip_rewards_fn(rewards) - rb.add(step_data[None, ...]) + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") - for k in next_obs.keys(): - reset_data[k] = next_obs[k][dones_idxes] - reset_data["dones"] = torch.zeros(reset_envs, 1) - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) - reset_data["rewards"] = torch.zeros(reset_envs, 1) - rb.add(reset_data[None, ...], dones_idxes) + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) # Reset internal agent states - player.init_states(dones_idxes) + player.init_states(reset_envs=dones_idxes) updates_before_training -= 1 @@ -360,15 +350,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if player.actor_type == "exploration": player.actor = actor_task.module player.actor_type = "task" - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_gradient_steps, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(cfg.algo.per_rank_gradient_steps): + sample = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=1, + dtype=None, + device=device, + ) # [N_samples, Seq_len, Batch_size, ...] + batch = {k: v[0].float() for k, v in sample.items()} train( fabric, world_model, @@ -377,7 +368,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, ) diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index 28757330..66c1173e 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -20,6 +20,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index f609034f..f6fc1f5a 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -12,19 +12,16 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase from torch import Tensor, nn from torch.distributions import Bernoulli, Distribution, Independent, Normal from torch.distributions.utils import logits_to_probs -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test from sheeprl.algos.p2e_dv2.agent import build_agent -from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -46,7 +43,7 @@ def train( world_optimizer: _FabricOptimizer, actor_task_optimizer: _FabricOptimizer, critic_task_optimizer: _FabricOptimizer, - data: TensorDictBase, + data: Dict[str, Tensor], aggregator: MetricAggregator | None, cfg: Dict[str, Any], ensembles: _FabricModule, @@ -93,7 +90,7 @@ def train( world_optimizer (_FabricOptimizer): the world optimizer. actor_task_optimizer (_FabricOptimizer): the actor optimizer for solving the task. critic_task_optimizer (_FabricOptimizer): the critic optimizer for solving the task. - data (TensorDictBase): the batch of data to use for training. + data (Dict[str, Tensor]): the batch of data to use for training. aggregator (MetricAggregator, optional): the aggregator to print the metrics. cfg (DictConfig): the configs. ensembles (_FabricModule): the ensemble models. @@ -117,7 +114,7 @@ def train( data = {k: data[k] for k in data.keys()} batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder} batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder}) - data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :]) + data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :]) # Dynamic Learning recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device) @@ -535,6 +532,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -559,7 +557,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -666,7 +664,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): critic_exploration_optimizer, ) - local_vars = locals() if fabric.is_global_zero: save_configs(cfg, log_dir) @@ -679,19 +676,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 buffer_type = cfg.buffer.type.lower() if buffer_type == "sequential": - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, cfg.env.num_envs, - device="cpu", + obs_keys=obs_keys, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) elif buffer_type == "episode": rb = EpisodeBuffer( buffer_size, - sequence_length=cfg.algo.per_rank_sequence_length, - device="cpu", + minimum_episode_length=1 if cfg.dry_run else cfg.algo.per_rank_sequence_length, + n_envs=cfg.env.num_envs, + obs_keys=obs_keys, + prioritize_ends=cfg.buffer.prioritize_ends, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) @@ -700,11 +699,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], (EnvIndependentReplayBuffer, EpisodeBuffer)): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables @@ -760,25 +758,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - episode_steps = [[] for _ in range(cfg.env.num_envs)] - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - # Images stay uint8 to save space - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1) - step_data["actions"] = torch.zeros(cfg.env.num_envs, sum(actions_dim)) - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1) - step_data["is_first"] = torch.ones_like(step_data["dones"]) - if buffer_type == "sequential": - rb.add(step_data[None, ...]) - else: - for i, env_ep in enumerate(episode_steps): - env_ep.append(step_data[i : i + 1][None, ...]) + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + if cfg.dry_run: + step_data["dones"] = step_data["dones"] + 1 + step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["dones"]) + rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() per_rank_gradient_steps = 0 @@ -798,32 +788,34 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not is_continuous: actions = np.concatenate( [ - F.one_hot(torch.tensor(act), act_dim).numpy() + F.one_hot(torch.as_tensor(act), act_dim).numpy() for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, ) else: with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): + normalized_obs = {} + for k in obs_keys: + torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = torch_obs / 255 - 0.5 + normalized_obs[k] = torch_obs + mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) - actions = torch.cat(actions, -1).cpu().numpy() + real_actions = actions = player.get_exploration_action(normalized_obs, mask) + actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -838,59 +830,39 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs: Dict[str, Tensor] = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } for k in obs_keys: # [N_envs, N_obs] - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() - actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() - - # next_obs becomes the new obs + step_data[k] = real_next_obs[k][np.newaxis] + + # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones - step_data["actions"] = actions - step_data["rewards"] = clip_rewards_fn(rewards) - if buffer_type == "sequential": - rb.add(step_data[None, ...]) - else: - for i, env_ep in enumerate(episode_steps): - env_ep.append(step_data[i : i + 1][None, ...]) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") - for k in next_obs.keys(): - reset_data[k] = next_obs[k][dones_idxes] - reset_data["dones"] = torch.zeros(reset_envs, 1) - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) - reset_data["rewards"] = torch.zeros(reset_envs, 1) - reset_data["is_first"] = torch.ones_like(reset_data["dones"]) - if buffer_type == "episode": - for i, d in enumerate(dones_idxes): - if len(episode_steps[d]) >= cfg.algo.per_rank_sequence_length: - rb.add(torch.cat(episode_steps[d], dim=0)) - episode_steps[d] = [reset_data[i : i + 1][None, ...]] - else: - rb.add(reset_data[None, ...], dones_idxes) + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + reset_data["is_first"] = np.ones_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) # Reset internal agent states player.init_states(dones_idxes) @@ -898,26 +870,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: - if buffer_type == "sequential": - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - ).to(device) - else: - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - prioritize_ends=cfg.buffer.prioritize_ends, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) + n_samples = ( + cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps + ) + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=n_samples, + dtype=None, + device=fabric.device, + ) # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(cp.data) @@ -925,6 +890,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): critic_exploration.module.parameters(), target_critic_exploration.parameters() ): tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} train( fabric, world_model, @@ -934,7 +900,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, ensembles=ensembles, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index edb2ec99..a3322410 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -11,16 +11,13 @@ import numpy as np import torch from lightning.fabric import Fabric -from tensordict import TensorDict -from torch import Tensor -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v2.agent import PlayerDV2 from sheeprl.algos.dreamer_v2.dreamer_v2 import train from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv2.agent import build_agent -from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator @@ -86,6 +83,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -110,7 +108,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -193,19 +191,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 buffer_type = cfg.buffer.type.lower() if buffer_type == "sequential": - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, - cfg.env.num_envs, - device="cpu", + n_envs=cfg.env.num_envs, + obs_keys=obs_keys, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) elif buffer_type == "episode": rb = EpisodeBuffer( buffer_size, - sequence_length=cfg.algo.per_rank_sequence_length, - device="cpu", + minimum_episode_length=1 if cfg.dry_run else cfg.algo.per_rank_sequence_length, + n_envs=cfg.env.num_envs, + obs_keys=obs_keys, + prioritize_ends=cfg.buffer.prioritize_ends, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) @@ -214,11 +214,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if resume_from_checkpoint or (cfg.buffer.load_from_exploration and exploration_cfg.buffer.checkpoint): if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], (EnvIndependentReplayBuffer, EpisodeBuffer)): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables @@ -274,25 +273,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - episode_steps = [[] for _ in range(cfg.env.num_envs)] - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - # Images stay uint8 to save space - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1) - step_data["actions"] = torch.zeros(cfg.env.num_envs, sum(actions_dim)) - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1) - step_data["is_first"] = torch.ones_like(step_data["dones"]) - if buffer_type == "sequential": - rb.add(step_data[None, ...]) - else: - for i, env_ep in enumerate(episode_steps): - env_ep.append(step_data[i : i + 1][None, ...]) + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + if cfg.dry_run: + step_data["dones"] = step_data["dones"] + 1 + step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["dones"]) + rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() per_rank_gradient_steps = 0 @@ -303,25 +294,27 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): + normalized_obs = {} + for k in obs_keys: + torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 - else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + torch_obs = torch_obs / 255 - 0.5 + normalized_obs[k] = torch_obs + mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) - actions = torch.cat(actions, -1).cpu().numpy() + real_actions = actions = player.get_exploration_action(normalized_obs, mask) + actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -336,59 +329,39 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs: Dict[str, Tensor] = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } for k in obs_keys: # [N_envs, N_obs] - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = torch.from_numpy(real_next_obs[k]).view(cfg.env.num_envs, *real_next_obs[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() - actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() - - # next_obs becomes the new obs + step_data[k] = real_next_obs[k][np.newaxis] + + # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones - step_data["actions"] = actions - step_data["rewards"] = clip_rewards_fn(rewards) - if buffer_type == "sequential": - rb.add(step_data[None, ...]) - else: - for i, env_ep in enumerate(episode_steps): - env_ep.append(step_data[i : i + 1][None, ...]) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Reset and save the observation coming from the automatic reset - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") - for k in next_obs.keys(): - reset_data[k] = next_obs[k][dones_idxes] - reset_data["dones"] = torch.zeros(reset_envs, 1) - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)) - reset_data["rewards"] = torch.zeros(reset_envs, 1) - reset_data["is_first"] = torch.ones_like(reset_data["dones"]) - if buffer_type == "episode": - for i, d in enumerate(dones_idxes): - if len(episode_steps[d]) >= cfg.algo.per_rank_sequence_length: - rb.add(torch.cat(episode_steps[d], dim=0)) - episode_steps[d] = [reset_data[i : i + 1][None, ...]] - else: - rb.add(reset_data[None, ...], dones_idxes) + reset_data = {} + for k in obs_keys: + reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = np.zeros((1, reset_envs, 1)) + reset_data["is_first"] = np.ones_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][d] = torch.zeros_like(step_data["dones"][d]) + step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) # Reset internal agent states player.init_states(dones_idxes) @@ -399,29 +372,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if player.actor_type == "exploration": player.actor = actor_task.module player.actor_type = "task" - if buffer_type == "sequential": - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - ).to(device) - else: - local_data = rb.sample( - cfg.algo.per_rank_batch_size, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - prioritize_ends=cfg.buffer.prioritize_ends, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) + n_samples = ( + cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps + ) + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=n_samples, + dtype=None, + device=fabric.device, + ) # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} train( fabric, world_model, @@ -431,7 +398,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, actions_dim=actions_dim, diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index b99c2d28..7aadb93f 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -20,6 +20,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 619747e5..33d71072 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -11,18 +11,15 @@ from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer from omegaconf import DictConfig -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase from torch import Tensor, nn from torch.distributions import Bernoulli, Distribution, Independent -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel from sheeprl.algos.dreamer_v3.loss import reconstruction_loss from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test from sheeprl.algos.p2e_dv3.agent import build_agent -from sheeprl.data.buffers import AsyncReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.distribution import ( MSEDistribution, OneHotCategoricalValidateArgs, @@ -49,7 +46,7 @@ def train( world_optimizer: _FabricOptimizer, actor_task_optimizer: _FabricOptimizer, critic_task_optimizer: _FabricOptimizer, - data: TensorDictBase, + data: Dict[str, Tensor], aggregator: MetricAggregator, cfg: DictConfig, ensembles: _FabricModule, @@ -96,7 +93,7 @@ def train( world_optimizer (_FabricOptimizer): the world optimizer. actor_task_optimizer (_FabricOptimizer): the actor optimizer for solving the task. critic_task_optimizer (_FabricOptimizer): the critic optimizer for solving the task. - data (TensorDictBase): the batch of data to use for training. + data (Dict[str, Tensor]): the batch of data to use for training. aggregator (MetricAggregator): the aggregator to print the metrics. cfg (DictConfig): the configs. ensembles (_FabricModule): the ensemble models. @@ -569,6 +566,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -593,7 +591,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -756,22 +754,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, - cfg.env.num_envs, - device="cpu", + n_envs=cfg.env.num_envs, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], EnvIndependentReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables @@ -827,18 +823,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - # Images stay uint8 to save space - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1).float() - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1).float() - step_data["is_first"] = torch.ones_like(step_data["dones"]).float() + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() per_rank_gradient_steps = 0 @@ -858,7 +849,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not is_continuous: actions = np.concatenate( [ - F.one_hot(torch.tensor(act), act_dim).numpy() + F.one_hot(torch.as_tensor(act), act_dim).numpy() for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) ], axis=-1, @@ -867,10 +858,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): + preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255.0 - else: - preprocessed_obs[k] = v[None, ...].to(device) + preprocessed_obs[k] = preprocessed_obs[k] / 255.0 mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None @@ -879,24 +869,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) - step_data["actions"] = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rb.add(step_data[None, ...]) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) - step_data["is_first"] = torch.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["dones"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = torch.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) - rb.buffer[i]["is_first"][last_inserted_idx] = torch.zeros_like( + rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) + rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] ) - step_data["is_first"][i] = torch.ones_like(step_data["is_first"][i]) + step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -909,68 +901,57 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs: Dict[str, Tensor] = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } - for k in real_next_obs.keys(): # [N_envs, N_obs] - if k in obs_keys: - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = next_obs[k] - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() + for k in obs_keys: + step_data[k] = next_obs[k][np.newaxis] # next_obs becomes the new obs obs = next_obs - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() - step_data["dones"] = dones + rewards = rewards.reshape((1, cfg.env.num_envs, -1)) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") - for k in real_next_obs.keys(): - if k in obs_keys: - reset_data[k] = real_next_obs[k][dones_idxes] - if k in cfg.algo.mlp_keys.encoder: - reset_data[k] = reset_data[k].float() - reset_data["dones"] = torch.ones(reset_envs, 1).float() - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)).float() - reset_data["rewards"] = step_data["rewards"][dones_idxes].float() - reset_data["is_first"] = torch.zeros_like(reset_data["dones"]).float() - rb.add(reset_data[None, ...], dones_idxes) + reset_data = {} + for k in obs_keys: + reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = step_data["rewards"][:, dones_idxes] + reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data - step_data["rewards"][dones_idxes] = torch.zeros_like(reset_data["rewards"]).float() - step_data["dones"][dones_idxes] = torch.zeros_like(step_data["dones"][dones_idxes]).float() - step_data["is_first"][dones_idxes] = torch.ones_like(step_data["is_first"][dones_idxes]).float() + step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) + step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) updates_before_training -= 1 # Train the agent if update >= learning_starts and updates_before_training <= 0: - local_data = rb.sample( + local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) + n_samples=( + cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps + ), + dtype=None, + device=fabric.device, + ) # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): @@ -981,6 +962,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): critics_exploration[k]["target_module"].parameters(), ): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + batch = {k: v[i].float() for k, v in local_data.items()} train( fabric, world_model, @@ -990,7 +972,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, ensembles=ensembles, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 3cbe71a8..c2eba822 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -9,16 +9,13 @@ import numpy as np import torch from lightning.fabric import Fabric -from tensordict import TensorDict -from torch import Tensor -from torch.utils.data import BatchSampler from torchmetrics import SumMetric from sheeprl.algos.dreamer_v3.agent import PlayerDV3 from sheeprl.algos.dreamer_v3.dreamer_v3 import train from sheeprl.algos.dreamer_v3.utils import Moments, test from sheeprl.algos.p2e_dv3.agent import build_agent -from sheeprl.data.buffers import AsyncReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator @@ -80,6 +77,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -104,7 +102,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r + clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") @@ -201,22 +199,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 - rb = AsyncReplayBuffer( + rb = EnvIndependentReplayBuffer( buffer_size, - cfg.env.num_envs, - device="cpu", + n_envs=cfg.env.num_envs, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), - sequential=True, + buffer_cls=SequentialReplayBuffer, ) if resume_from_checkpoint or (cfg.buffer.load_from_exploration and exploration_cfg.buffer.checkpoint): if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] - elif isinstance(state["rb"], AsyncReplayBuffer): + elif isinstance(state["rb"], EnvIndependentReplayBuffer): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables @@ -272,18 +268,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = {k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask")} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: - torch_obs = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - if k in cfg.algo.mlp_keys.encoder: - # Images stay uint8 to save space - torch_obs = torch_obs.float() - step_data[k] = torch_obs - obs[k] = torch_obs - step_data["dones"] = torch.zeros(cfg.env.num_envs, 1).float() - step_data["rewards"] = torch.zeros(cfg.env.num_envs, 1).float() - step_data["is_first"] = torch.ones_like(step_data["dones"]).float() + step_data[k] = obs[k][np.newaxis] + step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() per_rank_gradient_steps = 0 @@ -296,10 +287,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): + preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255.0 - else: - preprocessed_obs[k] = v[None, ...].to(device) + preprocessed_obs[k] = preprocessed_obs[k] / 255.0 mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None @@ -308,24 +298,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = ( + torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() + ) - step_data["actions"] = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rb.add(step_data[None, ...]) + step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated).astype(np.uint8) - step_data["is_first"] = torch.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["dones"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = torch.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) - rb.buffer[i]["is_first"][last_inserted_idx] = torch.zeros_like( + rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like(rb.buffer[i]["dones"][last_inserted_idx]) + rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] ) - step_data["is_first"][i] = torch.ones_like(step_data["is_first"][i]) + step_data["is_first"][i] = np.ones_like(step_data["is_first"][i]) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -338,51 +330,39 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs: Dict[str, Tensor] = { - k: torch.from_numpy(v).view(cfg.env.num_envs, *v.shape[1:]) for k, v in o.items() if k.startswith("mask") - } - for k in real_next_obs.keys(): # [N_envs, N_obs] - if k in obs_keys: - next_obs[k] = torch.from_numpy(o[k]).view(cfg.env.num_envs, *o[k].shape[1:]) - step_data[k] = next_obs[k] - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() - step_data[k] = step_data[k].float() + for k in obs_keys: + step_data[k] = next_obs[k][np.newaxis] # next_obs becomes the new obs obs = next_obs - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float() - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float() - step_data["dones"] = dones + rewards = rewards.reshape((1, cfg.env.num_envs, -1)) + step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) - dones_idxes = dones.nonzero(as_tuple=True)[0].tolist() + dones_idxes = dones.nonzero()[0].tolist() reset_envs = len(dones_idxes) if reset_envs > 0: - reset_data = TensorDict({}, batch_size=[reset_envs], device="cpu") - for k in real_next_obs.keys(): - if k in obs_keys: - reset_data[k] = real_next_obs[k][dones_idxes] - if k in cfg.algo.mlp_keys.encoder: - reset_data[k] = reset_data[k].float() - reset_data["dones"] = torch.ones(reset_envs, 1).float() - reset_data["actions"] = torch.zeros(reset_envs, np.sum(actions_dim)).float() - reset_data["rewards"] = step_data["rewards"][dones_idxes].float() - reset_data["is_first"] = torch.zeros_like(reset_data["dones"]).float() - rb.add(reset_data[None, ...], dones_idxes) + reset_data = {} + for k in obs_keys: + reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] + reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) + reset_data["rewards"] = step_data["rewards"][:, dones_idxes] + reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data - step_data["rewards"][dones_idxes] = torch.zeros_like(reset_data["rewards"]).float() - step_data["dones"][dones_idxes] = torch.zeros_like(step_data["dones"][dones_idxes]).float() - step_data["is_first"][dones_idxes] = torch.ones_like(step_data["is_first"][dones_idxes]).float() + step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) + step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) updates_before_training -= 1 @@ -392,21 +372,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if player.actor_type == "exploration": player.actor = actor_task.module player.actor_type = "task" - local_data = rb.sample( + local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=cfg.algo.per_rank_pretrain_steps - if update == learning_starts - else cfg.algo.per_rank_gradient_steps, - ).to(device) - distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) + n_samples=( + cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps + ), + dtype=None, + device=fabric.device, + ) # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - for i in distributed_sampler: + for i in range(next(iter(local_data.values())).shape[0]): tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + batch = {k: v[i].float() for k, v in local_data.items()} train( fabric, world_model, @@ -416,7 +398,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer, - local_data[i].view(cfg.algo.per_rank_sequence_length, cfg.algo.per_rank_batch_size), + batch, aggregator, cfg, is_continuous=is_continuous, diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index f780098a..d2dd7eb9 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from math import prod from typing import Any, Dict, List, Optional, Sequence, Tuple diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py index 1d629938..6d66c01a 100644 --- a/sheeprl/algos/ppo/evaluate.py +++ b/sheeprl/algos/ppo/evaluate.py @@ -12,13 +12,14 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(algorithms=["ppo"]) +@register_evaluation(algorithms="ppo") def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, @@ -53,6 +54,6 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): # This is just for showcase -@register_evaluation(algorithms=["ppo_decoupled"]) +@register_evaluation(algorithms="ppo_decoupled") def evaluate_ppo_decoupled(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): evaluate_ppo(fabric, cfg, state) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index a1160a7b..88c0aa0a 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -11,8 +11,6 @@ import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule -from tensordict import TensorDict, make_tensordict -from tensordict.tensordict import TensorDictBase from torch import nn from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import SumMetric @@ -20,7 +18,7 @@ from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo.utils import normalize_obs, test -from sheeprl.data import ReplayBuffer +from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator @@ -33,13 +31,12 @@ def train( fabric: Fabric, agent: Union[nn.Module, _FabricModule], optimizer: torch.optim.Optimizer, - data: TensorDictBase, + data: Dict[str, torch.Tensor], aggregator: MetricAggregator | None, cfg: Dict[str, Any], ): """Train the agent on the data collected from the environment.""" - data = {k: data[k] for k in data.keys()} - indexes = list(range(data[next(iter(data.keys()))].shape[0])) + indexes = list(range(next(iter(data.values())).shape[0])) if cfg.buffer.share_data: sampler = DistributedSampler( indexes, @@ -136,6 +133,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -210,12 +208,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = ReplayBuffer( cfg.buffer.size, cfg.env.num_envs, - device=device, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), obs_keys=obs_keys, ) - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables last_train = 0 @@ -261,16 +257,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): scheduler.load_state_dict(state["scheduler"]) # Get the first environment observation and start the optimization - obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] - next_obs = {} + step_data = {} + next_obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] for k in obs_keys: - torch_obs = torch.as_tensor(obs[k]).to(fabric.device) if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs - next_obs[k] = torch_obs + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + step_data[k] = next_obs[k][np.newaxis] for update in range(start_step, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): @@ -282,12 +274,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - actions, logprobs, _, values = agent.module(normalized_obs) + torch_obs = { + k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys + } + actions, logprobs, _, values = agent.module(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, -1) + real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + actions = torch.cat(actions, -1).cpu().numpy() # Single environment step obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) @@ -306,38 +301,36 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 + torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:]) + torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): vals = agent.module.get_value(real_next_obs).cpu().numpy() rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated) - dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) + dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data - step_data["dones"] = dones - step_data["values"] = values - step_data["actions"] = actions - step_data["logprobs"] = logprobs - step_data["rewards"] = rewards + step_data["dones"] = dones[np.newaxis] + step_data["values"] = values.cpu().numpy()[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["logprobs"] = logprobs.cpu().numpy()[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] if cfg.buffer.memmap: - step_data["returns"] = torch.zeros_like(rewards, dtype=torch.float32) - step_data["advantages"] = torch.zeros_like(rewards, dtype=torch.float32) + step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) # Append data to buffer - rb.add(step_data.unsqueeze(0)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Update the observation and dones next_obs = {} for k in obs_keys: + _obs = obs[k] if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch.as_tensor(obs[k], device=device) - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.algo.mlp_keys.encoder: - torch_obs = torch.as_tensor(obs[k], device=device, dtype=torch.float32) - step_data[k] = torch_obs - next_obs[k] = torch_obs + _obs = _obs.reshape(cfg.env.num_envs, -1, *_obs.shape[-2:]) + step_data[k] = _obs[np.newaxis] + next_obs[k] = _obs if cfg.metric.log_level > 0 and "final_info" in info: for i, agent_ep_info in enumerate(info["final_info"]): @@ -350,33 +343,35 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator.update("Game/ep_len_avg", ep_len) fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + # Transform the data into PyTorch Tensors + local_data = rb.to_tensor(dtype=None, device=device) + # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - next_values = agent.module.get_value(normalized_obs) + torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + next_values = agent.module.get_value(torch_obs) returns, advantages = gae( - rb["rewards"].to(torch.float64), - rb["values"], - rb["dones"], + local_data["rewards"].to(torch.float64), + local_data["values"], + local_data["dones"], next_values, cfg.algo.rollout_steps, cfg.algo.gamma, cfg.algo.gae_lambda, ) - # Add returns and advantages to the buffer - rb["returns"] = returns.float() - rb["advantages"] = advantages.float() - - # Flatten the batch - local_data = rb.buffer.view(-1) + local_data["returns"] = returns.float() + local_data["advantages"] = advantages.float() if cfg.buffer.share_data and fabric.world_size > 1: # Gather all the tensors from all the world and reshape them - gathered_data = fabric.all_gather(local_data.to_dict()) # Fabric does not work with TensorDict - gathered_data = make_tensordict(gathered_data).view(-1) + gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(local_data) + # Flatten the first three dimensions: [World_Size, Buffer_Size, Num_Envs] + gathered_data = {k: v.flatten(start_dim=0, end_dim=2).float() for k, v in gathered_data.items()} else: - gathered_data = local_data + # Flatten the first two dimensions: [Buffer_Size, Num_Envs] + gathered_data = {k: v.flatten(start_dim=0, end_dim=1).float() for k, v in local_data.items()} with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, gathered_data, aggregator, cfg) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index c71fff37..786be075 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -12,8 +12,6 @@ from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy -from tensordict import TensorDict -from tensordict.tensordict import TensorDictBase, make_tensordict from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler from torchmetrics import SumMetric @@ -21,7 +19,7 @@ from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo.utils import normalize_obs, test -from sheeprl.data import ReplayBuffer +from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir from sheeprl.utils.metric import MetricAggregator @@ -124,13 +122,12 @@ def player( # Local data rb = ReplayBuffer( - cfg.algo.rollout_steps, + cfg.buffer.size, cfg.env.num_envs, - device=device, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), + obs_keys=obs_keys, ) - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables start_step = ( @@ -175,20 +172,16 @@ def player( ] # Broadcast num_updates to all the world - update_t = torch.tensor([num_updates], device=device, dtype=torch.float32) + update_t = torch.as_tensor([num_updates], device=device, dtype=torch.float32) world_collective.broadcast(update_t, src=0) # Get the first environment observation and start the optimization - obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] - next_obs = {} + step_data = {} + next_obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] for k in obs_keys: - torch_obs = torch.as_tensor(obs[k]).to(fabric.device) if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs - next_obs[k] = torch_obs + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + step_data[k] = next_obs[k][np.newaxis] params = {"update": start_step, "last_log": last_log, "last_checkpoint": last_checkpoint} world_collective.scatter_object_list([None], [params] * world_collective.world_size, src=0) @@ -202,15 +195,18 @@ def player( with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - actions, logprobs, _, values = agent(normalized_obs) + torch_obs = { + k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys + } + actions, logprobs, _, values = agent(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, -1) + real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - obs, rewards, dones, truncated, info = envs.step(real_actions) + obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -226,38 +222,36 @@ def player( for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.algo.cnn_keys.encoder: - torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 + torch_v = torch_v.view(cfg.env.num_envs, -1, *v.shape[-2:]) + torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): vals = agent.get_value(real_next_obs).cpu().numpy() rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated) - dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(cfg.env.num_envs, -1) + dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data - step_data["dones"] = dones - step_data["values"] = values - step_data["actions"] = actions - step_data["logprobs"] = logprobs - step_data["rewards"] = rewards + step_data["dones"] = dones[np.newaxis] + step_data["values"] = values.cpu().numpy()[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["logprobs"] = logprobs.cpu().numpy()[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] if cfg.buffer.memmap: - step_data["returns"] = torch.zeros_like(rewards, dtype=torch.float32) - step_data["advantages"] = torch.zeros_like(rewards, dtype=torch.float32) + step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) + step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape)) # Append data to buffer - rb.add(step_data.unsqueeze(0)) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Update the observation and dones next_obs = {} for k in obs_keys: + _obs = obs[k] if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch.as_tensor(obs[k], device=device) - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.algo.mlp_keys.encoder: - torch_obs = torch.as_tensor(obs[k], device=device, dtype=torch.float32) - step_data[k] = torch_obs - next_obs[k] = torch_obs + _obs = _obs.reshape(cfg.env.num_envs, -1, *_obs.shape[-2:]) + step_data[k] = _obs[np.newaxis] + next_obs[k] = _obs if cfg.metric.log_level > 0 and "final_info" in info: for i, agent_ep_info in enumerate(info["final_info"]): @@ -269,13 +263,17 @@ def player( aggregator.update("Game/ep_len_avg", ep_len) fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + # Transform the data into PyTorch Tensors + local_data = rb.to_tensor(dtype=None, device=device) + # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) - next_values = agent.get_value(normalized_obs) + torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} + next_values = agent.get_value(torch_obs) returns, advantages = gae( - rb["rewards"].to(torch.float64), - rb["values"], - rb["dones"], + local_data["rewards"].to(torch.float64), + local_data["values"], + local_data["dones"], next_values, cfg.algo.rollout_steps, cfg.algo.gamma, @@ -283,17 +281,17 @@ def player( ) # Add returns and advantages to the buffer - rb["returns"] = returns.float() - rb["advantages"] = advantages.float() - rb["rewards"] = rb["rewards"].float() - - # Flatten the batch - local_data = rb.buffer.view(-1) + local_data["returns"] = returns.float() + local_data["advantages"] = advantages.float() + local_data["rewards"] = local_data["rewards"].float() # Send data to the training agents # Split data in an even way, when possible - perm = torch.randperm(local_data.shape[0], device=device) - chunks = local_data[perm].split(chunks_sizes) + perm = torch.randperm(local_data[next(iter(local_data.keys()))].shape[0], device=device) + # chunks = {k1: [k1_chunk_1, k1_chunk_2, ...], k2: [k2_chunk_1, k2_chunk_2, ...]} + chunks = {k: v[perm].flatten(0, 1).split(chunks_sizes) for k, v in local_data.items()} + # chunks = [{k1: k1_chunk_1, k2: k2_chunk_1}, {k1: k1_chunk_2, k2: k2_chunk_2}, ...] + chunks = [{k: v[i] for k, v in chunks.items()} for i in range(len(chunks[next(iter(chunks.keys()))]))] world_collective.scatter_object_list([None], [None] + chunks, src=0) # Wait the trainers to finish @@ -452,7 +450,7 @@ def trainer( data = [None] world_collective.scatter_object_list(data, [None for _ in range(world_collective.world_size)], src=0) data = data[0] - if not isinstance(data, TensorDictBase) and data == -1: + if not isinstance(data, dict) and data == -1: # Last Checkpoint if cfg.checkpoint.save_last: state = { @@ -473,8 +471,6 @@ def trainer( state=state, ) return - data = make_tensordict(data, device=device) - data = {k: data[k] for k in data.keys()} train_step += group_world_size diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py index 919a6c86..0c5a0ed1 100644 --- a/sheeprl/algos/ppo_recurrent/evaluate.py +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -19,6 +19,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index f59b90eb..e1dbf2a6 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import itertools import os import warnings from contextlib import nullcontext @@ -12,8 +11,7 @@ import numpy as np import torch from lightning.fabric import Fabric -from tensordict import TensorDict, pad_sequence -from tensordict.tensordict import TensorDictBase +from torch import Tensor from torch.distributed.algorithms.join import Join from torch.utils.data.sampler import BatchSampler, RandomSampler from torchmetrics import SumMetric @@ -35,11 +33,10 @@ def train( fabric: Fabric, agent: RecurrentPPOAgent, optimizer: torch.optim.Optimizer, - data: TensorDictBase, + data: Dict[str, Tensor], aggregator: MetricAggregator | None, cfg: Dict[str, Any], ): - data = {k: data[k] for k in data.keys()} num_sequences = data[next(iter(data.keys()))].shape[1] if cfg.algo.per_rank_num_batches > 0: batch_size = num_sequences // cfg.algo.per_rank_num_batches @@ -144,6 +141,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -216,7 +214,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) - step_data = TensorDict({}, batch_size=[1, cfg.env.num_envs], device=device) # Check that `rollout_steps` = k * `per_rank_sequence_length` if cfg.algo.rollout_steps % cfg.algo.per_rank_sequence_length != 0: @@ -266,20 +263,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): scheduler.load_state_dict(state["scheduler"]) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] - obs = {} + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] for k in obs_keys: - torch_obs = torch.as_tensor(o[k], device=fabric.device) if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - step_data[k] = torch_obs[None] # [Seq_len, Batch_size, D] --> [1, num_envs, D] - obs[k] = torch_obs[None] + obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) + obs[k] = obs[k][np.newaxis] + step_data[k] = obs[k] # Get the resetted recurrent states from the agent prev_states = agent.initial_states - prev_actions = torch.zeros(1, cfg.env.num_envs, sum(actions_dim), device=fabric.device) + prev_actions = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) + torch_prev_actions = torch.zeros(1, cfg.env.num_envs, sum(actions_dim), device=device, dtype=torch.float32) for update in range(start_step, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): @@ -292,14 +287,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sample an action given the observation received by the environment # [Seq_len, Batch_size, D] --> [1, num_envs, D] normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) + torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} actions, logprobs, _, values, states = agent.module( - normalized_obs, prev_actions=prev_actions, prev_states=prev_states + torch_obs, prev_actions=torch_prev_actions, prev_states=prev_states ) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, dim=-1) + real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy() + torch_actions = torch.cat(actions, dim=-1) + actions = torch_actions.cpu().numpy() # Single environment step next_obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) @@ -324,47 +321,44 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): feat = agent.module.feature_extractor(real_next_obs) rnn_out, _ = agent.module.rnn( - torch.cat((feat, actions[:, truncated_envs, :]), dim=-1), + torch.cat((feat, torch_actions[:, truncated_envs, :]), dim=-1), tuple(s[:, truncated_envs, ...] for s in states), ) vals = agent.module.get_values(rnn_out).view(rewards[truncated_envs].shape).cpu().numpy() rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated) - dones = torch.as_tensor(dones, dtype=torch.float32, device=device).view(1, cfg.env.num_envs, -1) - rewards = torch.as_tensor(rewards, dtype=torch.float32, device=device).view(1, cfg.env.num_envs, -1) - - step_data["dones"] = dones - step_data["values"] = values - step_data["actions"] = actions - step_data["rewards"] = rewards - step_data["logprobs"] = logprobs - step_data["prev_hx"] = prev_states[0] - step_data["prev_cx"] = prev_states[1] - step_data["prev_actions"] = prev_actions + dones = np.logical_or(dones, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) + rewards = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + + step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1) + step_data["values"] = values.cpu().numpy().reshape(1, cfg.env.num_envs, -1) + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) + step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1) + step_data["logprobs"] = logprobs.cpu().numpy() + step_data["prev_hx"] = prev_states[0].cpu().numpy().reshape(1, cfg.env.num_envs, -1) + step_data["prev_cx"] = prev_states[1].cpu().numpy().reshape(1, cfg.env.num_envs, -1) + step_data["prev_actions"] = prev_actions.reshape(1, cfg.env.num_envs, -1) if cfg.buffer.memmap: - step_data["returns"] = torch.zeros_like(rewards) - step_data["advantages"] = torch.zeros_like(rewards) + step_data["returns"] = np.zeros_like(rewards) + step_data["advantages"] = np.zeros_like(rewards) # Append data to buffer - rb.add(step_data) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # Update actions prev_actions = (1 - dones) * actions + torch_prev_actions = torch.from_numpy(prev_actions).to(device).float() # Update the observation - obs = {} + obs = next_obs for k in obs_keys: + obs[k] = obs[k][np.newaxis] if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch.as_tensor(next_obs[k], device=device) - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - elif k in cfg.algo.mlp_keys.encoder: - torch_obs = torch.as_tensor(next_obs[k], device=device, dtype=torch.float32) - step_data[k] = torch_obs[None] - obs[k] = torch_obs[None] + obs[k] = obs[k].reshape(1, cfg.env.num_envs, -1, *obs[k].shape[-2:]) + step_data[k] = obs[k] # Reset the states if the episode is done if cfg.algo.reset_recurrent_state_on_done: - prev_states = tuple([(1 - dones) * s for s in states]) + prev_states = tuple([(1 - torch.as_tensor(dones, device=device)) * s for s in states]) else: prev_states = states @@ -378,16 +372,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator.update("Game/ep_len_avg", ep_len) fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") + # Transform the data into PyTorch Tensors + local_data = rb.to_tensor(dtype=None, device=device) + # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) - feat = agent.module.feature_extractor(normalized_obs) - rnn_out, _ = agent.module.rnn(torch.cat((feat, actions), dim=-1), states) + torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} + feat = agent.module.feature_extractor(torch_obs) + rnn_out, _ = agent.module.rnn(torch.cat((feat, torch_actions), dim=-1), states) next_values = agent.module.get_values(rnn_out) returns, advantages = gae( - rb["rewards"].to(torch.float64), - rb["values"], - rb["dones"], + local_data["rewards"].to(torch.float64), + local_data["values"], + local_data["dones"], next_values, cfg.algo.rollout_steps, cfg.algo.gamma, @@ -395,17 +393,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Add returns and advantages to the buffer - rb["returns"] = returns.float() - rb["advantages"] = advantages.float() - - # Get the training data as a TensorDict - local_data = rb.buffer + local_data["rewards"] = local_data["rewards"].float() + local_data["returns"] = returns.float() + local_data["advantages"] = advantages.float() # Train the agent # 1. Split data into episodes (for every environment) - episodes: List[TensorDictBase] = [] + episodes: List[Dict[str, Tensor]] = [] + lengths = [] for env_id in range(cfg.env.num_envs): - env_data = local_data[:, env_id] # [N_steps, *] + env_data = {k: v[:, env_id].float() for k, v in local_data.items()} # [N_steps, *] episode_ends = env_data["dones"].nonzero(as_tuple=True)[0] episode_ends = episode_ends.tolist() episode_ends.append(cfg.algo.rollout_steps) @@ -414,18 +411,33 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): stop = ep_end_idx # Include the done, since when we encounter a done it means that # the episode has ended - episode = env_data[start : stop + 1] + episode = {k: v[start : stop + 1] for k, v in env_data.items() if len(v[start : stop + 1]) > 0} if len(episode) > 0: episodes.append(episode) + lengths.append(episode[next(iter(episode.keys()))].shape[0]) start = stop + 1 # 2. Split every episode into sequences of length `per_rank_sequence_length` if cfg.algo.per_rank_sequence_length is not None and cfg.algo.per_rank_sequence_length > 0: - sequences = list( - itertools.chain.from_iterable([ep.split(cfg.algo.per_rank_sequence_length) for ep in episodes]) - ) + lengths = [] + sl = cfg.algo.per_rank_sequence_length + sequences = {k: [] for k in episodes[0].keys()} + for ep in episodes: + for k in sequences.keys(): + seq = torch.split(ep[k], sl) + sequences[k].extend(seq) + # Regardless of the key, the shapes are the same + lengths.extend([s.shape[0] for s in seq]) + else: sequences = episodes - padded_sequences = pad_sequence(sequences, batch_first=False, return_mask=True) # [Seq_len, Num_seq, *] + + padded_sequences = { + k: torch.nn.utils.rnn.pad_sequence(v, batch_first=False, padding_value=0) for k, v in sequences.items() + } + max_len = max(lengths) + lengths = torch.as_tensor(lengths) + mask = (torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)).T + padded_sequences["mask"] = mask.to(device).bool() with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index cc922383..34125f76 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -31,12 +31,12 @@ def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_di with fabric.device: o = env.reset(seed=cfg.seed)[0] next_obs = { - k: torch.tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1, *o[k].shape[-2:]) / 255 + k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1, *o[k].shape[-2:]) / 255 for k in cfg.algo.cnn_keys.encoder } next_obs.update( { - k: torch.tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1) + k: torch.as_tensor(o[k], dtype=torch.float32, device=fabric.device).view(1, 1, -1) for k in cfg.algo.mlp_keys.encoder } ) diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py index 68f05e46..11f70741 100644 --- a/sheeprl/algos/sac/evaluate.py +++ b/sheeprl/algos/sac/evaluate.py @@ -19,6 +19,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 5173bd4c..a64069e0 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -11,7 +11,6 @@ import torch from lightning.fabric import Fabric from lightning.fabric.plugins.collectives.collective import CollectibleGroup -from tensordict import TensorDict, make_tensordict from torch import Tensor from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler @@ -110,6 +109,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -176,7 +176,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = ReplayBuffer( buffer_size, cfg.env.num_envs, - device=device, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) @@ -187,7 +186,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables last_train = 0 @@ -228,12 +226,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "policy_steps_per_update value." ) - with device: - # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = torch.cat( - [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] + step_data = {} + # Get the first environment observation and start the optimization + obs = envs.reset(seed=cfg.seed)[0] + obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -246,10 +242,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: # Sample an action given the observation received by the environment with torch.no_grad(): - actions, _ = agent.actor.module(obs) + torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) + actions, _ = agent.actor.module(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) + next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) + dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -266,27 +265,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - with device: - next_obs = torch.cat( - [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] - real_next_obs = torch.cat( - [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] - actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) - rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) - dones = torch.tensor(dones, dtype=torch.float32).view(cfg.env.num_envs, -1) - - step_data["dones"] = dones - step_data["actions"] = actions - step_data["observations"] = obs + real_next_obs[idx] = np.concatenate([v for v in final_obs.values()], axis=-1) + + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["observations"] = obs[np.newaxis] if not cfg.buffer.sample_next_obs: - step_data["next_observations"] = real_next_obs - step_data["rewards"] = rewards - rb.add(step_data.unsqueeze(0)) + step_data["next_observations"] = real_next_obs[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) # next_obs becomes the new obs obs = next_obs @@ -296,15 +283,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): training_steps = learning_starts if update == learning_starts else 1 # We sample one time to reduce the communications between processes - sample = rb.sample( - training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, + sample = rb.sample_tensors( + batch_size=training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, - ) # [G*B, 1] - gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] - gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] + dtype=None, + device=device, + ) # [G*B] + gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(sample) # [World, G*B] + for k, v in gathered_data.items(): + gathered_data[k] = v.float() # [G*B*World] + if fabric.world_size > 1: + gathered_data[k] = gathered_data[k].flatten(start_dim=0, end_dim=2) + else: + gathered_data[k] = gathered_data[k].flatten(start_dim=0, end_dim=1) + idxes_to_sample = list(range(next(iter(gathered_data.values())).shape[0])) if world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), + idxes_to_sample, num_replicas=world_size, rank=fabric.global_rank, shuffle=True, @@ -316,19 +311,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) else: sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False + sampler=idxes_to_sample, batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) # Start training with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): for batch_idxes in sampler: + batch = {k: v[batch_idxes] for k, v in gathered_data.items()} train( fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, - {k: gathered_data[k][batch_idxes] for k in gathered_data.keys()}, + batch, aggregator, update, cfg, diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index f88c9655..707e85dd 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -13,8 +13,6 @@ from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy -from tensordict import TensorDict, make_tensordict -from tensordict.tensordict import TensorDictBase from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric @@ -119,7 +117,6 @@ def player( rb = ReplayBuffer( buffer_size, cfg.env.num_envs, - device=device, memmap=cfg.buffer.memmap, memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"), ) @@ -131,7 +128,6 @@ def player( "The replay buffer in the configs must be of type " f"`sheeprl.data.buffers.ReplayBuffer`, got {type(state['rb'])}." ) - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables first_info_sent = False @@ -168,12 +164,10 @@ def player( "policy_steps_per_update value." ) - with device: - # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] - obs = torch.cat( - [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] + step_data = {} + # Get the first environment observation and start the optimization + obs = envs.reset(seed=cfg.seed)[0] + obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs @@ -186,10 +180,13 @@ def player( else: # Sample an action given the observation received by the environment with torch.no_grad(): - actions, _ = actor(obs) + torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) + actions, _ = actor(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) + next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) + dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -206,27 +203,15 @@ def player( if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: - for k, v in final_obs.items(): - real_next_obs[k][idx] = v - - with device: - next_obs = torch.cat( - [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] - real_next_obs = torch.cat( - [torch.tensor(real_next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 - ) # [N_envs, N_obs] - actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) - rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] - dones = torch.tensor(dones, dtype=torch.float32).view(cfg.env.num_envs, -1) - - step_data["dones"] = dones - step_data["actions"] = actions - step_data["observations"] = obs + real_next_obs[idx] = np.concatenate([v for v in final_obs.values()], axis=-1) + + step_data["dones"] = dones[np.newaxis] + step_data["actions"] = actions[np.newaxis] + step_data["observations"] = obs[np.newaxis] if not cfg.buffer.sample_next_obs: - step_data["next_observations"] = real_next_obs - step_data["rewards"] = rewards - rb.add(step_data.unsqueeze(0)) + step_data["next_observations"] = real_next_obs[np.newaxis] + step_data["rewards"] = rewards[np.newaxis] + rb.add(step_data, validate_args=cfg.buffer.validate_args) # next_obs becomes the new obs obs = next_obs @@ -242,13 +227,22 @@ def player( # Sample data to be sent to the trainers training_steps = learning_starts if update == learning_starts else 1 - chunks = rb.sample( - training_steps + sample = rb.sample_tensors( + batch_size=training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size * (fabric.world_size - 1), sample_next_obs=cfg.buffer.sample_next_obs, - ).split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size) + dtype=None, + device=device, + ) + # chunks = {k1: [k1_chunk_1, k1_chunk_2, ...], k2: [k2_chunk_1, k2_chunk_2, ...]} + chunks = { + k: v.float().split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size) + for k, v in sample.items() + } + # chunks = [{k1: k1_chunk_1, k2: k2_chunk_1}, {k1: k1_chunk_2, k2: k2_chunk_2}, ...] + chunks = [{k: v[i] for k, v in chunks.items()} for i in range(len(chunks[next(iter(chunks.keys()))]))] world_collective.scatter_object_list([None], [None] + chunks, src=0) # Wait the trainers to finish @@ -426,7 +420,7 @@ def trainer( data = [None] world_collective.scatter_object_list(data, [None for _ in range(world_collective.world_size)], src=0) data = data[0] - if not isinstance(data, TensorDictBase) and data == -1: + if not isinstance(data, dict) and data == -1: # Last Checkpoint if cfg.checkpoint.save_last: state = { @@ -452,8 +446,6 @@ def trainer( torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), src=1 ) return - data = make_tensordict(data, device=device) - data = {k: data[k] for k in data.keys()} sampler = BatchSampler( range(len(data[next(iter(data.keys()))])), batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index 19905825..5d16e45f 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -35,7 +35,7 @@ def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): with fabric.device: o = env.reset(seed=cfg.seed)[0] next_obs = torch.cat( - [torch.tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 + [torch.as_tensor(o[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ).unsqueeze( 0 ) # [N_envs, N_obs] @@ -49,7 +49,7 @@ def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): cumulative_rew += reward with fabric.device: next_obs = torch.cat( - [torch.tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 + [torch.as_tensor(next_obs[k], dtype=torch.float32) for k in cfg.algo.mlp_keys.encoder], dim=-1 ) if cfg.dry_run: diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index df91720a..727e0cb3 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -13,6 +13,7 @@ from sheeprl.algos.sac_ae.utils import weight_init from sheeprl.models.models import CNN, MLP, DeCNN, MultiDecoder, MultiEncoder +from sheeprl.utils.model import cnn_forward LOG_STD_MAX = 2 LOG_STD_MIN = -10 @@ -71,7 +72,7 @@ def conv_output_shape(self) -> Size: def forward(self, obs: Dict[str, Tensor], *, detach_encoder_features: bool = False, **kwargs) -> Tensor: x = torch.cat([obs[k] for k in self.keys], dim=-3) - x = self.model(x).flatten(1) + x = cnn_forward(self.model, x, x.shape[-3:], (-1,)) if detach_encoder_features: x = x.detach() x = self.fc(x) diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py index 9489f4ab..a47c0f28 100644 --- a/sheeprl/algos/sac_ae/evaluate.py +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -19,6 +19,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") env = make_env( cfg, diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index c3e805e8..79eaa7f8 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -14,7 +14,6 @@ from lightning.fabric import Fabric from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.wrappers import _FabricModule -from tensordict import TensorDict, make_tensordict from torch import Tensor from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler @@ -54,7 +53,6 @@ def train( critic_target_network_frequency = cfg.algo.critic.target_network_frequency // policy_steps_per_update + 1 actor_network_frequency = cfg.algo.actor.network_frequency // policy_steps_per_update + 1 decoder_update_freq = cfg.algo.decoder.update_freq // policy_steps_per_update + 1 - data = {k: v.to(fabric.device) for k, v in data.items()} normalized_obs = {} normalized_next_obs = {} for k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder: @@ -74,6 +72,8 @@ def train( qf_optimizer.zero_grad(set_to_none=True) fabric.backward(qf_loss) qf_optimizer.step() + if aggregator and not aggregator.disabled: + aggregator.update("Loss/value_loss", qf_loss) # Update the target networks with EMA if update % critic_target_network_frequency == 0: @@ -81,7 +81,6 @@ def train( agent.critic_encoder_target_ema() # Update the actor - actor_loss = alpha_loss = None if update % actor_network_frequency == 0: actions, logprobs = agent.get_actions_and_log_probs(normalized_obs, detach_encoder_features=True) qf_values = agent.get_q_values(normalized_obs, actions, detach_encoder_features=True) @@ -98,8 +97,11 @@ def train( agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad, group=group) alpha_optimizer.step() + if aggregator and not aggregator.disabled: + aggregator.update("Loss/policy_loss", actor_loss) + aggregator.update("Loss/alpha_loss", alpha_loss) + # Update the decoder - reconstruction_loss = None if update % decoder_update_freq == 0: hidden = encoder(normalized_obs) reconstruction = decoder(hidden) @@ -115,14 +117,7 @@ def train( fabric.backward(reconstruction_loss) encoder_optimizer.step() decoder_optimizer.step() - - if aggregator and not aggregator.disabled: - aggregator.update("Loss/value_loss", qf_loss) - if actor_loss: - aggregator.update("Loss/policy_loss", actor_loss) - if alpha_loss: - aggregator.update("Loss/alpha_loss", alpha_loss) - if reconstruction_loss: + if aggregator and not aggregator.disabled: aggregator.update("Loss/reconstruction_loss", reconstruction_loss) @@ -156,6 +151,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + fabric.print(f"Log dir: {log_dir}") # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -176,6 +172,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if not isinstance(envs.single_action_space, gym.spaces.Box): + raise RuntimeError( + f"Unexpected action space, should be of type continuous (of type Box), got: {observation_space}" + ) if ( len(set(cfg.algo.cnn_keys.encoder).intersection(set(cfg.algo.cnn_keys.decoder))) == 0 @@ -197,6 +197,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) fabric.print("Decoder CNN keys:", cfg.algo.cnn_keys.decoder) fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) + obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder # Define the agent and the optimizer and setup them with Fabric agent, encoder, decoder = build_agent( @@ -258,7 +259,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") - step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=fabric.device if cfg.buffer.memmap else "cpu") # Global variables last_train = 0 @@ -300,16 +300,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Get the first environment observation and start the optimization - o = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] - obs = {} - for k in o.keys(): - if k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder: - torch_obs = torch.from_numpy(o[k]).to(fabric.device) - if k in cfg.algo.cnn_keys.encoder: - torch_obs = torch_obs.view(cfg.env.num_envs, -1, *torch_obs.shape[-2:]) - if k in cfg.algo.mlp_keys.encoder: - torch_obs = torch_obs.float() - obs[k] = torch_obs + step_data = {} + obs = envs.reset(seed=cfg.seed)[0] # [N_envs, N_obs] + for k in obs_keys: + if k in cfg.algo.cnn_keys.encoder: + obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * fabric.world_size @@ -322,9 +317,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: with torch.no_grad(): normalized_obs = {k: v / 255 if k in cfg.algo.cnn_keys.encoder else v for k, v in obs.items()} - actions, _ = agent.actor.module(normalized_obs) + torch_obs = {k: torch.from_numpy(v).to(device).float() for k, v in normalized_obs.items()} + actions, _ = agent.actor.module(torch_obs) actions = actions.cpu().numpy() - o, rewards, dones, truncated, infos = envs.step(actions) + next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -338,38 +334,29 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation - real_next_obs = copy.deepcopy(o) + real_next_obs = copy.deepcopy(next_obs) if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: for k, v in final_obs.items(): real_next_obs[k][idx] = v - next_obs = {} for k in real_next_obs.keys(): - next_obs[k] = torch.from_numpy(o[k]).to(fabric.device) if k in cfg.algo.cnn_keys.encoder: - next_obs[k] = next_obs[k].view(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) - if k in cfg.algo.mlp_keys.encoder: - next_obs[k] = next_obs[k].float() + next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) + step_data[k] = obs[k][np.newaxis] - step_data[k] = obs[k] if not cfg.buffer.sample_next_obs: - step_data[f"next_{k}"] = torch.from_numpy(real_next_obs[k]).to(fabric.device) + step_data[f"next_{k}"] = real_next_obs[k][np.newaxis] if k in cfg.algo.cnn_keys.encoder: - step_data[f"next_{k}"] = step_data[f"next_{k}"].view( - cfg.env.num_envs, -1, *step_data[f"next_{k}"].shape[-2:] + step_data[f"next_{k}"] = step_data[f"next_{k}"].reshape( + 1, cfg.env.num_envs, -1, *step_data[f"next_{k}"].shape[-2:] ) - if k in cfg.algo.mlp_keys.encoder: - step_data[f"next_{k}"] = step_data[f"next_{k}"].float() - actions = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float().to(fabric.device) - rewards = torch.from_numpy(rewards).view(cfg.env.num_envs, -1).float().to(fabric.device) - dones = torch.from_numpy(dones).view(cfg.env.num_envs, -1).float().to(fabric.device) - step_data["dones"] = dones - step_data["actions"] = actions - step_data["rewards"] = rewards - rb.add(step_data.unsqueeze(0)) + step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + rb.add(step_data, validate_args=cfg.buffer.validate_args) # next_obs becomes the new obs obs = next_obs @@ -379,15 +366,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): training_steps = learning_starts if update == learning_starts - 1 else 1 # We sample one time to reduce the communications between processes - sample = rb.sample( + sample = rb.sample_tensors( training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, ) # [G*B, 1] - gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] - gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] + gathered_data = fabric.all_gather(sample) # [G*B, World, 1] + flatten_dim = 3 if fabric.world_size > 1 else 2 + gathered_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in gathered_data.items()} # [G*B*World] + len_data = len(gathered_data[next(iter(gathered_data.keys()))]) if fabric.world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), + range(len_data), num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, @@ -399,7 +388,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) else: sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False + sampler=range(len_data), batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) # Start training @@ -415,7 +404,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): alpha_optimizer, encoder_optimizer, decoder_optimizer, - {k: gathered_data[k][batch_idxes] for k in gathered_data.keys()}, + {k: v[batch_idxes] for k, v in gathered_data.items()}, aggregator, update, cfg, diff --git a/sheeprl/configs/buffer/default.yaml b/sheeprl/configs/buffer/default.yaml index 582c793a..9c329aab 100644 --- a/sheeprl/configs/buffer/default.yaml +++ b/sheeprl/configs/buffer/default.yaml @@ -1,2 +1,3 @@ size: ??? -memmap: True \ No newline at end of file +memmap: True +validate_args: False \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v2_crafter.yaml b/sheeprl/configs/exp/dreamer_v2_crafter.yaml new file mode 100644 index 00000000..db7e5249 --- /dev/null +++ b/sheeprl/configs/exp/dreamer_v2_crafter.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +defaults: + - dreamer_v2 + - override /env: crafter + - _self_ + +# Experiment +seed: 0 +total_steps: 1000000 + +# Environment +env: + id: reward + num_envs: 1 + reward_as_observation: True + +# Checkpoint +checkpoint: + every: 100000 + +# Buffer +buffer: + size: 2000000 + type: episode + checkpoint: True + prioritize_ends: True + +# The CNN and MLP keys of the decoder are the same as those of the encoder by default +cnn_keys: + encoder: + - rgb + decoder: + - rgb +mlp_keys: + encoder: + - reward + decoder: [] + +# Algorithm +algo: + gamma: 0.999 + train_every: 5 + layer_norm: True + learning_starts: 10000 + per_rank_pretrain_steps: 1 + world_model: + kl_free_nats: 0.0 + use_continues: True + recurrent_model: + recurrent_state_size: 1024 + transition_model: + hidden_size: 1024 + representation_model: + hidden_size: 1024 + optimizer: + lr: 1e-4 + actor: + ent_coef: 3e-3 + optimizer: + lr: 1e-4 + critic: + optimizer: + lr: 1e-4 \ No newline at end of file diff --git a/sheeprl/data/__init__.py b/sheeprl/data/__init__.py index f134d62f..a76e4786 100644 --- a/sheeprl/data/__init__.py +++ b/sheeprl/data/__init__.py @@ -1,4 +1,4 @@ -from sheeprl.data.buffers import AsyncReplayBuffer as AsyncReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer as EnvIndependentReplayBuffer from sheeprl.data.buffers import EpisodeBuffer as EpisodeBuffer from sheeprl.data.buffers import ReplayBuffer as ReplayBuffer from sheeprl.data.buffers import SequentialReplayBuffer as SequentialReplayBuffer diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 219ed003..73672e39 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -1,35 +1,51 @@ +from __future__ import annotations + +import logging import os import shutil import typing import uuid -import warnings +from itertools import compress from pathlib import Path -from typing import List, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Type import numpy as np import torch -from tensordict import MemmapTensor, TensorDict -from tensordict.tensordict import TensorDictBase -from torch import Size, Tensor, device +from torch import Tensor + +from sheeprl.utils.memmap import MemmapArray +from sheeprl.utils.utils import NUMPY_TO_TORCH_DTYPE_DICT class ReplayBuffer: + batch_axis: int = 1 + def __init__( self, buffer_size: int, n_envs: int = 1, - device: Union[device, str] = "cpu", - memmap: bool = False, - memmap_dir: Optional[Union[str, os.PathLike]] = None, obs_keys: Sequence[str] = ("observations",), + memmap: bool = False, + memmap_dir: str | os.PathLike | None = None, + memmap_mode: str = "r+", + **kwargs, ): - """A replay buffer which internally uses a TensorDict. + """A standard replay buffer implementation. Internally this is represented by a + dictionary mapping string to numpy arrays. The first dimension of the arrays is the + buffer size, while the second dimension is the number of environments. Args: - buffer_size (int): The buffer size. - n_envs (int, optional): The number of environments. Defaults to 1. - device (Union[torch.device, str], optional): The device where the buffer is created. Defaults to "cpu". - memmap (bool, optional): Whether to memory-mapping the buffer. + buffer_size (int): the buffer size. + n_envs (int, optional): the number of environments. Defaults to 1. + obs_keys (Sequence[str], optional): names of the observation keys. Those are used + to sample the next-observation. Defaults to ("observations",). + memmap (bool, optional): whether to memory-map the numpy arrays saved in the buffer. Defaults to False. + memmap_dir (str | os.PathLike | None, optional): the memory-mapped files directory. + Defaults to None. + memmap_mode (str, optional): memory-map mode. + Possible values are: "r+", "w+", "c", "copyonwrite", "readwrite", "write". + Defaults to "r+". + kwargs: additional keyword arguments. """ if buffer_size <= 0: raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}") @@ -37,30 +53,33 @@ def __init__( raise ValueError(f"The number of environments must be greater than zero, got: {n_envs}") self._buffer_size = buffer_size self._n_envs = n_envs - if isinstance(device, str): - device = torch.device(device=device) - self._device = device + self._obs_keys = obs_keys self._memmap = memmap self._memmap_dir = memmap_dir + self._memmap_mode = memmap_mode + self._buf: Dict[str, np.ndarray | MemmapArray] = {} if self._memmap: - if memmap_dir is None: - warnings.warn( - "The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" - " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", - UserWarning, + if self._memmap_mode not in ("r+", "w+", "c", "copyonwrite", "readwrite", "write"): + raise ValueError( + 'Accepted values for memmap_mode are "r+", "readwrite", "w+", "write", "c" or ' + '"copyonwrite". PyTorch does not support tensors backed by read-only ' + 'NumPy arrays, so "r" and "readonly" are not supported.' + ) + if self._memmap_dir is None: + raise ValueError( + "The buffer is set to be memory-mapped but the 'memmap_dir' attribute is None. " + "Set the 'memmap_dir' to a known directory.", ) else: self._memmap_dir = Path(self._memmap_dir) self._memmap_dir.mkdir(parents=True, exist_ok=True) - self._buf = None - else: - self._buf = TensorDict({}, batch_size=[buffer_size, n_envs], device=device) self._pos = 0 self._full = False - self.obs_keys = obs_keys + self._memmap_specs = {} + self._rng: np.random.Generator = np.random.default_rng() @property - def buffer(self) -> Optional[TensorDictBase]: + def buffer(self) -> Dict[str, np.ndarray]: return self._buf @property @@ -76,110 +95,165 @@ def n_envs(self) -> int: return self._n_envs @property - def shape(self) -> Optional[Size]: - if self.buffer is None: - return None - return self.buffer.shape + def empty(self) -> bool: + return (self.buffer is not None and len(self.buffer) == 0) or self.buffer is None @property - def device(self) -> device: - return self._device + def is_memmap(self) -> bool: + return self._memmap def __len__(self) -> int: return self.buffer_size + @torch.no_grad() + def to_tensor( + self, + dtype: Optional[torch.dtype] = None, + clone: bool = False, + device: str | torch.dtype = "cpu", + from_numpy: bool = False, + ) -> Dict[str, Tensor]: + """Converts the replay buffer to a dictionary mapping string to torch.Tensor. + + Args: + dtype (Optional[torch.dtype], optional): the torch dtype to convert the arrays to. + If None, then the dtypes of the numpy arrays is maintained. + Defaults to None. + clone (bool, optional): whether to clone the converted tensors. + Defaults to False. + device (str | torch.dtype, optional): the torch device to move the tensors to. + Defaults to "cpu". + from_numpy (bool, optional): whether to convert the numpy arrays to torch tensors + with the 'torch.from_numpy' function. Defaults to False. + + Returns: + Dict[str, Tensor]: the converted buffer. + """ + buf = {} + for k, v in self.buffer.items(): + buf[k] = get_tensor(v, dtype=dtype, clone=clone, device=device, from_numpy=from_numpy) + return buf + @typing.overload - def add(self, data: "ReplayBuffer") -> None: + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: TensorDictBase) -> None: + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... - def add(self, data: Union["ReplayBuffer", TensorDictBase]) -> None: - """Add data to the buffer. + def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None: + """Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten. + If data is a dictionary, then the keys must be strings and the values must be numpy arrays of shape + [sequence_length, n_envs, ...]. Args: - data: data to add. + data (ReplayBuffer | Dict[str, np.ndarray]): the data to add to the replay buffer. + validate_args (bool, optional): whether to validate the arguments. Defaults to False. Raises: - RuntimeError: the number of dimensions (the batch_size of the TensorDictBase) must be 2: - one for the number of environments and one for the sequence length. + ValueError: if the data is not a dictionary containing numpy arrays. + ValueError: if the data is not a dictionary containing numpy arrays. + RuntimeError: if the data does not have at least 2 dimensions. + RuntimeError: if the data is not congruent in the first 2 dimensions. """ if isinstance(data, ReplayBuffer): data = data.buffer - elif not isinstance(data, TensorDictBase): - raise TypeError("`data` must be a TensorDictBase or a sheeprl.data.ReplayBuffer") - if data is None: - raise RuntimeError("The `data` replay buffer must be not None") - if len(data.shape) != 2: - raise RuntimeError( - "`data` must have 2 batch dimensions: [sequence_length, n_envs]. " - "`sequence_length` and `n_envs` should be 1. Shape is: {}".format(data.shape) - ) - data = data.to(self.device) - data_len = data.shape[0] + if validate_args: + if not isinstance(data, dict): + raise ValueError( + f"'data' must be a dictionary containing Numpy arrays, but 'data' is of type '{type(data)}'" + ) + elif isinstance(data, dict): + for k, v in data.items(): + if not isinstance(v, np.ndarray): + raise ValueError( + f"'data' must be a dictionary containing Numpy arrays. Found key '{k}' " + f"containing a value of type '{type(v)}'" + ) + last_key = next(iter(data.keys())) + last_batch_shape = next(iter(data.values())).shape[:2] + for i, (k, v) in enumerate(data.items()): + if len(v.shape) < 2: + raise RuntimeError( + "'data' must have at least 2 dimensions: [sequence_length, n_envs, ...]. " + f"Shape of '{k}' is {v.shape}" + ) + if i > 0: + current_key = k + current_batch_shape = v.shape[:2] + if current_batch_shape != last_batch_shape: + raise RuntimeError( + "Every array in 'data' must be congruent in the first 2 dimensions: " + f"found key '{last_key}' with shape '{last_batch_shape}' " + f"and '{current_key}' with shape '{current_batch_shape}'" + ) + last_key = current_key + last_batch_shape = current_batch_shape + data_len = next(iter(data.values())).shape[0] next_pos = (self._pos + data_len) % self._buffer_size - if next_pos < self._pos or (data_len >= self._buffer_size and not self._full): - idxes = torch.tensor( - list(range(self._pos, self._buffer_size)) + list(range(0, next_pos)), device=self.device - ) + if next_pos <= self._pos or (data_len > self._buffer_size and not self._full): + idxes = np.array(list(range(self._pos, self._buffer_size)) + list(range(0, next_pos))) else: - idxes = torch.tensor(range(self._pos, next_pos), device=self.device) + idxes = np.array(range(self._pos, next_pos)) if data_len > self._buffer_size: - data_to_store = data[-self._buffer_size - next_pos :] + data_to_store = {k: v[-self._buffer_size - next_pos :] for k, v in data.items()} else: data_to_store = data - if self._memmap and self._buf is None: - self._buf = TensorDict( - { - k: MemmapTensor( - (self._buffer_size, self._n_envs, *v.shape[2:]), - dtype=v.dtype, - device=v.device, - filename=None if self._memmap_dir is None else self._memmap_dir / f"{k}.memmap", - ) - for k, v in data_to_store.items() - }, - batch_size=[self._buffer_size, self._n_envs], - device=self.device, - ) - self._buf.memmap_(prefix=self._memmap_dir) - self._buf[idxes, :] = data_to_store + if self._memmap and self.empty: + for k, v in data_to_store.items(): + self.buffer[k] = MemmapArray( + filename=Path(self._memmap_dir / f"{k}.memmap"), + dtype=v.dtype, + shape=(self._buffer_size, self._n_envs, *v.shape[2:]), + mode=self._memmap_mode, + ) + self.buffer[k][idxes] = data_to_store[k] + elif self.empty: + for k, v in data_to_store.items(): + self.buffer[k] = np.empty(shape=(self._buffer_size, self._n_envs, *v.shape[2:]), dtype=v.dtype) + self.buffer[k][idxes] = data_to_store[k] + else: + for k, v in data_to_store.items(): + self.buffer[k][idxes] = data_to_store[k] if self._pos + data_len >= self._buffer_size: self._full = True self._pos = next_pos - def sample(self, batch_size: int, sample_next_obs: bool = False, clone: bool = False, **kwargs) -> TensorDictBase: - """Sample elements from the replay buffer. - - Custom sampling when using memory efficient variant, - as we should not sample the element with index `self.pos` + def sample( + self, batch_size: int, sample_next_obs: bool = False, clone: bool = False, n_samples: int = 1, **kwargs + ) -> Dict[str, np.ndarray]: + """Sample elements from the replay buffer. If the replay buffer is not full, then the samples are taken + from the first 'self.pos' elements. Otherwise, the samples are taken from all the elements. + When 'sample_next_obs' is True we sample until 'self.pos - 1' to avoid sampling the last observation, + which would be invalid. See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 Args: batch_size (int): Number of element to sample - sample_next_obs (bool): whether to sample the next observations from the 'observations' key. + sample_next_obs (bool): whether to sample the next observations from the 'self.obs_keys' keys. Defaults to False. - clone (bool): whether to clone the sampled TensorDict + clone (bool): whether to clone the sampled numpy arrays. Defaults to False. + n_samples (int): the number of samples to perform. Defaults to 1. Returns: - TensorDictBase: the sampled TensorDictBase with a `batch_size` of [batch_size, 1] + Dict[str, np.ndarray]: the sampled dictionary with a shape of [n_samples, batch_size, ...]. """ - if batch_size <= 0: - raise ValueError("Batch size must be greater than 0") + if batch_size <= 0 or n_samples <= 0: + raise ValueError(f"'batch_size' ({batch_size}) and 'n_samples' ({n_samples}) must be both greater than 0") if not self._full and self._pos == 0: raise ValueError( - "No sample has been added to the buffer. Please add at least one sample calling `self.add()`" + "No sample has been added to the buffer. Please add at least one sample calling 'self.add()'" ) if self._full: first_range_end = self._pos - 1 if sample_next_obs else self._pos second_range_end = self.buffer_size if first_range_end >= 0 else self.buffer_size + first_range_end - valid_idxes = torch.tensor( - list(range(0, first_range_end)) + list(range(self._pos, second_range_end)), - device=self.device, + valid_idxes = np.array( + list(range(0, first_range_end)) + list(range(self._pos, second_range_end)), dtype=np.intp ) - batch_idxes = valid_idxes[torch.randint(0, len(valid_idxes), size=(batch_size,), device=self.device)] + batch_idxes = valid_idxes[ + self._rng.integers(0, len(valid_idxes), size=(batch_size * n_samples,), dtype=np.intp) + ] else: max_pos_to_sample = self._pos - 1 if sample_next_obs else self._pos if max_pos_to_sample == 0: @@ -187,156 +261,483 @@ def sample(self, batch_size: int, sample_next_obs: bool = False, clone: bool = F "You want to sample the next observations, but one sample has been added to the buffer. " "Make sure that at least two samples are added." ) - batch_idxes = torch.randint(0, max_pos_to_sample, size=(batch_size,), device=self.device) - sample = self._get_samples(batch_idxes, sample_next_obs=sample_next_obs).unsqueeze(-1) - if clone: - return sample.clone() - return sample - - def _get_samples(self, batch_idxes: Tensor, sample_next_obs: bool = False) -> TensorDictBase: - env_idxes = torch.randint(0, self.n_envs, size=(len(batch_idxes),)) - if self._buf is None: + batch_idxes = self._rng.integers(0, max_pos_to_sample, size=(batch_size * n_samples,), dtype=np.intp) + return { + k: v.reshape(n_samples, batch_size, *v.shape[1:]) + for k, v in self._get_samples(batch_idxes=batch_idxes, sample_next_obs=sample_next_obs, clone=clone).items() + } + + def _get_samples( + self, batch_idxes: np.ndarray, sample_next_obs: bool = False, clone: bool = False + ) -> Dict[str, np.ndarray]: + if self.empty: raise RuntimeError("The buffer has not been initialized. Try to add some data first.") - buf = self._buf[batch_idxes, env_idxes] + samples: Dict[str, np.ndarray] = {} + env_idxes = self._rng.integers(0, self.n_envs, size=(len(batch_idxes),), dtype=np.intp) + flattened_idxes = (batch_idxes * self.n_envs + env_idxes).flat if sample_next_obs: - for k in self.obs_keys: - buf[f"next_{k}"] = self._buf[k][(batch_idxes + 1) % self._buffer_size, env_idxes] - return buf + flattened_next_idxes = (((batch_idxes + 1) % self._buffer_size) * self.n_envs + env_idxes).flat + for k, v in self.buffer.items(): + samples[k] = np.take(np.reshape(v, (-1, *v.shape[2:])), flattened_idxes, axis=0) + if clone: + samples[k] = samples[k].copy() + if k in self._obs_keys and sample_next_obs: + samples[f"next_{k}"] = np.take(np.reshape(v, (-1, *v.shape[2:])), flattened_next_idxes, axis=0) + if clone: + samples[f"next_{k}"] = samples[f"next_{k}"].copy() + return samples - def __getitem__(self, key: str) -> torch.Tensor: + @torch.no_grad() + def sample_tensors( + self, + batch_size: int, + clone: bool = False, + sample_next_obs: bool = False, + dtype: Optional[torch.dtype] = None, + device: str | torch.dtype = "cpu", + from_numpy: bool = False, + **kwargs, + ) -> Dict[str, Tensor]: + """Sample elements from the replay buffer and convert them to torch tensors. + + Args: + batch_size (int): Number of elements to sample. + clone (bool): whether to clone the sampled numpy arrays. Defaults to False. + sample_next_obs (bool): whether to sample the next observations from the 'self.obs_keys' keys. + Defaults to False. + dtype (Optional[torch.dtype], optional): the torch dtype to convert the arrays to. If None, + then the dtypes of the numpy arrays is maintained. Defaults to None. + device (str | torch.dtype, optional): the torch device to move the tensors to. Defaults to "cpu". + from_numpy (bool, optional): whether to convert the numpy arrays to torch tensors + with the 'torch.from_numpy' function. If False, then the numpy arrays are converted + with the 'torch.as_tensor' function. Defaults to False. + kwargs: additional keyword arguments to be passed to the 'self.sample' method. + + Returns: + Dict[str, Tensor]: the sampled dictionary, containing the sampled array, + one for every key, with a shape of [n_samples, batch_size, ...] + """ + samples = self.sample(batch_size=batch_size, sample_next_obs=sample_next_obs, clone=clone, **kwargs) + return { + k: get_tensor(v, dtype=dtype, clone=clone, device=device, from_numpy=from_numpy) for k, v in samples.items() + } + + def __getitem__(self, key: str) -> np.ndarray | np.memmap | MemmapArray: if not isinstance(key, str): - raise TypeError("`key` must be a string") - if self._buf is None: + raise TypeError("'key' must be a string") + if self.empty: raise RuntimeError("The buffer has not been initialized. Try to add some data first.") - return self._buf.get(key) + return self.buffer.get(key) - def __setitem__(self, key: str, t: Tensor) -> None: - if self._buf is None: + def __setitem__(self, key: str, value: np.ndarray | np.memmap | MemmapArray) -> None: + if not isinstance(value, (np.ndarray, MemmapArray)): + raise ValueError( + "The value to be set must be an instance of 'np.ndarray', 'np.memmap' " + f"or '{MemmapArray.__module__}.{MemmapArray.__qualname__}', " + f"got {type(value)}" + ) + if self.empty: raise RuntimeError("The buffer has not been initialized. Try to add some data first.") - self._buf.set(key, t, inplace=True) + if value.shape[:2] != (self._buffer_size, self._n_envs): + raise RuntimeError( + "'value' must have at least two dimensions of dimension [buffer_size, n_envs, ...]. " + f"Shape of 'value' is {value.shape}" + ) + if self._memmap: + if isinstance(value, np.ndarray): + filename = Path(self._memmap_dir / f"{key}.memmap") + elif isinstance(value, MemmapArray): + filename = value.filename + value_to_add = MemmapArray.from_array(value, filename=filename, mode=self._memmap_mode) + else: + if isinstance(value, np.ndarray): + value_to_add = np.copy(value) + elif isinstance(value, MemmapArray): + value_to_add = np.copy(value.array) + self.buffer.update({key: value_to_add}) class SequentialReplayBuffer(ReplayBuffer): - """A replay buffer which internally uses a TensorDict and returns sequential samples. - - Args: - buffer_size (int): The buffer size. - n_envs (int, optional): The number of environments. Defaults to 1. - device (Union[torch.device, str], optional): The device where the buffer is created. Defaults to "cpu". - """ + batch_axis: int = 2 def __init__( self, buffer_size: int, n_envs: int = 1, - device: Union[device, str] = "cpu", + obs_keys: Sequence[str] = ("observations",), memmap: bool = False, - memmap_dir: Optional[Union[str, os.PathLike]] = None, + memmap_dir: str | os.PathLike | None = None, + memmap_mode: str = "r+", + **kwargs, ): - super().__init__(buffer_size, n_envs, device, memmap, memmap_dir) + """A sequential replay buffer implementation. Internally this is represented by a + dictionary mapping string to numpy arrays. The first dimension of the arrays is the + buffer length, while the second dimension is the number of environments. The sequentiality comes + from the fact that the samples are sampled as sequences of consecutive elements. + + Args: + buffer_size (int): the buffer size. + n_envs (int, optional): the number of environments. Defaults to 1. + obs_keys (Sequence[str], optional): names of the observation keys. Those are used + to sample the next-observation. Defaults to ("observations",). + memmap (bool, optional): whether to memory-map the numpy arrays saved in the buffer. Defaults to False. + memmap_dir (str | os.PathLike | None, optional): the memory-mapped files directory. + Defaults to None. + memmap_mode (str, optional): memory-map mode. Possible values are: "r+", "w+", "c", "copyonwrite", + "readwrite", "write". Defaults to "r+". + kwargs: additional keyword arguments. + """ + super().__init__(buffer_size, n_envs, obs_keys, memmap, memmap_dir, memmap_mode, **kwargs) def sample( self, batch_size: int, sample_next_obs: bool = False, clone: bool = False, - sequence_length: int = 1, n_samples: int = 1, - ) -> TensorDictBase: - """Sample elements from the sequential replay buffer, - each one is a sequence of a consecutive items. - - Custom sampling when using memory efficient variant, - as the first element of the sequence cannot be in a position - greater than (pos - sequence_length) % buffer_size. - See comments in the code for more information. + sequence_length: int = 1, + **kwargs, + ) -> Dict[str, np.ndarray]: + """Sample elements from the replay buffer in a sequential manner, without considering the episode + boundaries. Args: batch_size (int): Number of element to sample sample_next_obs (bool): whether to sample the next observations from the 'observations' key. Defaults to False. - clone (bool): whether to clone the sampled TensorDict. - sequence_length (int): the length of the sequence of each element. Defaults to 1. + clone (bool): whether to clone the sampled tensors. n_samples (int): the number of samples to perform. Defaults to 1. + sequence_length (int): the length of the sequence of each element. Defaults to 1. Returns: - TensorDictBase: the sampled TensorDictBase with a `batch_size` of [n_samples, sequence_length, batch_size] + Dict[str, np.ndarray]: the sampled dictionary with a shape of + [n_samples, sequence_length, batch_size, ...]. """ # the batch_size can be fused with the number of samples to have single batch size batch_dim = batch_size * n_samples - # Controls - if batch_dim <= 0: - raise ValueError("Batch size must be greater than 0") - if not self._full and self._pos == 0: + # Sanity checks + if batch_size <= 0 or n_samples <= 0: + raise ValueError(f"'batch_size' ({batch_size}) and 'n_samples' ({n_samples}) must be both greater than 0") + if not self.full and self._pos == 0: raise ValueError( - "No sample has been added to the buffer. Please add at least one sample calling `self.add()`" + "No sample has been added to the buffer. Please add at least one sample calling 'self.add()'" ) if self._buf is None: raise RuntimeError("The buffer has not been initialized. Try to add some data first.") - if not self._full and self._pos - sequence_length + 1 < 1: - raise ValueError(f"too long sequence length ({sequence_length})") - if self.full and sequence_length > self._buf.shape[0]: - raise ValueError(f"too long sequence length ({sequence_length})") + if not self.full and self._pos - sequence_length + 1 < 1: + raise ValueError(f"Cannot sample a sequence of length {sequence_length}. Data added so far: {self._pos}") + if self.full and sequence_length > self.__len__(): + raise ValueError( + f"The sequence length ({sequence_length}) is greater than the buffer size ({self.__len__()})" + ) - # Do not sample the element with index `self.pos` as the transitions is invalid - if self._full: - # when the buffer is full, it is necessary to avoid the starting index between (self.pos - sequence_length) - # and self.pos, so it is possible to sample the starting index between (0, self.pos - sequence_length) and - # between (self.pos, self.buffer_size) + # Do not sample the element with index 'self.pos' as the transitions is invalid + if self.full: + # when the buffer is full, it is necessary to avoid the starting index + # to be between (self.pos - sequence_length) + # and self.pos, so it is possible to sample + # the starting index between (0, self.pos - sequence_length) and (self.pos, self.buffer_size) first_range_end = self._pos - sequence_length + 1 # end of the second range, if the first range is empty, then the second range ends # in (buffer_size + (self._pos - sequence_length + 1)), otherwise the sequence will contain # invalid values second_range_end = self.buffer_size if first_range_end >= 0 else self.buffer_size + first_range_end - valid_idxes = torch.tensor( - list(range(0, first_range_end)) + list(range(self._pos, second_range_end)), - device=self.device, + valid_idxes = np.array( + list(range(0, first_range_end)) + list(range(self._pos, second_range_end)), dtype=np.intp ) # start_idxes are the indices of the first elements of the sequences - start_idxes = valid_idxes[torch.randint(0, len(valid_idxes), size=(batch_dim,), device=self.device)] + start_idxes = valid_idxes[self._rng.integers(0, len(valid_idxes), size=(batch_dim,), dtype=np.intp)] else: # when the buffer is not full, we need to start the sequence so that it does not go out of bounds - start_idxes = torch.randint(0, self._pos - sequence_length + 1, size=(batch_dim,), device=self.device) + start_idxes = self._rng.integers(0, self._pos - sequence_length + 1, size=(batch_dim,), dtype=np.intp) # chunk_length contains the relative indices of the sequence (0, 1, ..., sequence_length-1) - chunk_length = torch.arange(sequence_length, device=self.device).reshape(1, -1) + chunk_length = np.arange(sequence_length, dtype=np.intp).reshape(1, -1) idxes = (start_idxes.reshape(-1, 1) + chunk_length) % self.buffer_size # (n_samples, sequence_length, batch_size) - sample = self._get_samples(idxes).reshape(n_samples, batch_size, sequence_length).permute(0, -1, -2) - if clone: - return sample.clone() - return sample + return self._get_samples( + idxes, batch_size, n_samples, sequence_length, sample_next_obs=sample_next_obs, clone=clone + ) + + def _get_samples( + self, + batch_idxes: np.ndarray, + batch_size: int, + n_samples: int, + sequence_length: int, + sample_next_obs: bool = False, + clone: bool = False, + ) -> Dict[str, np.ndarray]: + batch_shape = (batch_size * n_samples, sequence_length) # [Batch_size * N_samples, Seq_len] + flattened_batch_idxes = np.ravel(batch_idxes) + + # Each sequence must come from the same environment + if self._n_envs == 1: + env_idxes = np.zeros((np.prod(batch_shape),), dtype=np.intp) + else: + env_idxes = self._rng.integers(0, self.n_envs, size=(batch_shape[0],), dtype=np.intp) + env_idxes = np.reshape(env_idxes, (-1, 1)) + env_idxes = np.tile(env_idxes, (1, sequence_length)) + env_idxes = np.ravel(env_idxes) + + # Flatten indexes + flattened_idxes = (flattened_batch_idxes * self._n_envs + env_idxes).flat + + # Get samples + samples: Dict[str, np.ndarray] = {} + for k, v in self.buffer.items(): + # Retrieve the items by flattening the indices + # (b1_s1, b1_s2, b1_s3, ..., bn_s1, bn_s2, bn_s3, ...) + # where bm_sk is the k-th elements in the sequence of the m-th batch + flattened_v = np.take(np.reshape(v, (-1, *v.shape[2:])), flattened_idxes, axis=0) + # Properly reshape the items: + # [ + # [b1_s1, b1_s2, ...], + # [b2_s1, b2_s2, ...], + # ..., + # [bn_s1, bn_s2, ...] + # ] + batched_v = np.reshape(flattened_v, (n_samples, batch_size, sequence_length) + flattened_v.shape[1:]) + # Reshape back to # [N_samples, Seq_len, Batch_size] + samples[k] = np.swapaxes( + batched_v, + axis1=1, + axis2=2, + ) + if clone: + samples[k] = samples[k].copy() + if sample_next_obs: + flattened_next_v = v[(flattened_batch_idxes + 1) % self._buffer_size, env_idxes] + batched_next_v = np.reshape( + flattened_next_v, (n_samples, batch_size, sequence_length) + flattened_next_v.shape[1:] + ) + samples[f"next_{k}"] = np.swapaxes( + batched_next_v, + axis1=1, + axis2=2, + ) + if clone: + samples[f"next_{k}"] = samples[f"next_{k}"].copy() + return samples + + +class EnvIndependentReplayBuffer: + def __init__( + self, + buffer_size: int, + n_envs: int = 1, + obs_keys: Sequence[str] = ("observations",), + memmap: bool = False, + memmap_dir: str | os.PathLike | None = None, + memmap_mode: str = "r+", + buffer_cls: Type[ReplayBuffer] = ReplayBuffer, + **kwargs, + ): + """A replay buffer implementation that is composed of multiple independent replay buffers. + + Args: + buffer_size (int): the buffer size. + n_envs (int, optional): the number of environments. Defaults to 1. + obs_keys (Sequence[str], optional): names of the observation keys. Those are used + to sample the next-observation. Defaults to ("observations",). + memmap (bool, optional): whether to memory-map the numpy arrays saved in the buffer. Defaults to False. + memmap_dir (str | os.PathLike | None, optional): the memory-mapped files directory. + Defaults to None. + memmap_mode (str, optional): memory-map mode. Possible values are: "r+", "w+", "c", "copyonwrite", + "readwrite", "write". Defaults to "r+". + buffer_cls (Type[ReplayBuffer], optional): the replay buffer class to use. Defaults to ReplayBuffer. + kwargs: additional keyword arguments. + """ + if buffer_size <= 0: + raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}") + if n_envs <= 0: + raise ValueError(f"The number of environments must be greater than zero, got: {n_envs}") + if memmap: + if memmap_mode not in ("r+", "w+", "c", "copyonwrite", "readwrite", "write"): + raise ValueError( + 'Accepted values for memmap_mode are "r+", "readwrite", "w+", "write", "c" or ' + '"copyonwrite". PyTorch does not support tensors backed by read-only ' + 'NumPy arrays, so "r" and "readonly" are not supported.' + ) + if memmap_dir is None: + raise ValueError( + "The buffer is set to be memory-mapped but the 'memmap_dir' attribute is None. " + "Set the 'memmap_dir' to a known directory.", + ) + else: + memmap_dir = Path(memmap_dir) + memmap_dir.mkdir(parents=True, exist_ok=True) + self._buf: Sequence[ReplayBuffer] = [ + buffer_cls( + buffer_size=buffer_size, + n_envs=1, + obs_keys=obs_keys, + memmap=memmap, + memmap_dir=memmap_dir / f"env_{i}" if memmap else None, + memmap_mode=memmap_mode, + **kwargs, + ) + for i in range(n_envs) + ] + self._buffer_size = buffer_size + self._n_envs = n_envs + self._rng: np.random.Generator = np.random.default_rng() + self._concat_along_axis = buffer_cls.batch_axis + + @property + def buffer(self) -> Sequence[ReplayBuffer]: + return tuple(self._buf) + + @property + def buffer_size(self) -> int: + return self._buffer_size + + @property + def full(self) -> Sequence[bool]: + return tuple([b.full for b in self.buffer]) + + @property + def n_envs(self) -> int: + return self._n_envs + + @property + def empty(self) -> Sequence[bool]: + return tuple([b.empty for b in self.buffer]) + + @property + def is_memmap(self) -> Sequence[bool]: + return tuple([b.is_memmap for b in self.buffer]) + + def __len__(self) -> int: + return self.buffer_size + + @typing.overload + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: + ... + + @typing.overload + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: + ... + + def add( + self, + data: "ReplayBuffer" | Dict[str, np.ndarray], + indices: Optional[Sequence[int]] = None, + validate_args: bool = False, + ) -> None: + """Add data to the replay buffers specified by the 'indices'. If 'indices' is None, then the data is added + one for every environment. The length of indices must be equal to the second dimension of the arrays in 'data', + which is the number of environments. If data is a dictionary, then the keys must be strings + and the values must be numpy arrays of shape [sequence_length, n_envs, ...]. + + + Args: + data (Union[ReplayBuffer, Dict[str, np.ndarray]]): the data to add to the replay buffers. + indices (Optional[Sequence[int]], optional): the indices of the replay buffers to add the data to. + Defaults to None. + validate_args (bool, optional): whether to validate the arguments. Defaults to False. + """ + if indices is None: + indices = tuple(range(self.n_envs)) + elif len(indices) != next(iter(data.values())).shape[1]: + raise ValueError( + f"The length of 'indices' ({len(indices)}) must be equal to the second dimension of the " + f"arrays in 'data' ({next(iter(data.values())).shape[1]})" + ) + for env_data_idx, env_idx in enumerate(indices): + env_data = {k: v[:, env_data_idx : env_data_idx + 1] for k, v in data.items()} + self._buf[env_idx].add(env_data, validate_args=validate_args) + + def sample( + self, + batch_size: int, + sample_next_obs: bool = False, + clone: bool = False, + n_samples: int = 1, + **kwargs, + ) -> Dict[str, np.ndarray]: + """Samples data from the buffer. The returned samples are sampled given the 'buffer_cls' class + used to initialize the buffer. The samples are concatenated along the 'buffer_cls.batch_axis' axis. + + Args: + batch_size (int): The number of samples to draw from the buffer. + sample_next_obs (bool): Whether to sample the next observation or the current observation. + clone (bool): Whether to clone the data or return a reference to the original data. + n_samples (int): The number of samples to draw for each batch element. + **kwargs: Additional keyword arguments to pass to the underlying buffer's `sample` method. - def _get_samples(self, batch_idxes: Tensor, sample_next_obs: bool = False) -> TensorDictBase: - """Retrieves the items and return the TensorDict of sampled items. + Returns: + Dict[str, np.ndarray]: the sampled dictionary with a shape of + [n_samples, sequence_length, batch_size, ...] if 'buffer_cls' is a 'SequentialReplayBuffer', + otherwise [n_samples, batch_size, ...] if 'buffer_cls' is a 'ReplayBuffer'. + """ + if batch_size <= 0 or n_samples <= 0: + raise ValueError(f"'batch_size' ({batch_size}) and 'n_samples' ({n_samples}) must be both greater than 0") + if self._buf is None: + raise RuntimeError("The buffer has not been initialized. Try to add some data first.") + + bs_per_buf = np.bincount(self._rng.integers(0, self._n_envs, (batch_size,))) + per_buf_samples = [ + b.sample( + batch_size=bs, + sample_next_obs=sample_next_obs, + clone=clone, + n_samples=n_samples, + **kwargs, + ) + for b, bs in zip(self._buf, bs_per_buf) + if bs > 0 + ] + samples = {} + for k in per_buf_samples[0].keys(): + samples[k] = np.concatenate([s[k] for s in per_buf_samples], axis=self._concat_along_axis) + return samples + + @torch.no_grad() + def sample_tensors( + self, + batch_size: int, + sample_next_obs: bool = False, + clone: bool = False, + n_samples: int = 1, + dtype: Optional[torch.dtype] = None, + device: str | torch.dtype = "cpu", + from_numpy: bool = False, + **kwargs, + ) -> Dict[str, Tensor]: + """Sample elements from the replay buffer and convert them to torch tensors. Args: - batch_idxes (Tensor): the indices to retrieve of dimension (batch_dim, sequence_length). + batch_size (int): Number of elements to sample. sample_next_obs (bool): whether to sample the next observations from the 'observations' key. Defaults to False. + clone (bool): whether to clone the sampled tensors. + n_samples (int): the number of samples per batch_size to retrieve. Defaults to 1. + dtype (Optional[torch.dtype], optional): the torch dtype to convert the arrays to. If None, + then the dtypes of the numpy arrays is maintained. Defaults to None. + device (str | torch.dtype, optional): the torch device to move the tensors to. Defaults to "cpu". + from_numpy (bool, optional): whether to convert the numpy arrays to torch tensors + with the 'torch.from_numpy' function. If False, then the numpy arrays are converted + with the 'torch.as_tensor' function. Defaults to False. + kwargs: additional keyword arguments to be passed to the 'self.sample' method. Returns: - TensorDictBase: the sampled TensorDictBase with a `batch_size` of [batch_dim, sequence_length] + Dict[str, Tensor]: the sampled dictionary, containing the sampled array, + one for every key, with a shape of [n_samples, sequence_length, batch_size, ...] if 'buffer_cls' is a + 'SequentialReplayBuffer', otherwise [n_samples, batch_size, ...] if 'buffer_cls' is a 'ReplayBuffer'. """ - unflatten_shape = batch_idxes.shape - # each sequence must come from the same environment - env_idxes = ( - torch.randint(0, self.n_envs, size=(unflatten_shape[0],)).view(-1, 1).repeat(1, unflatten_shape[1]).view(-1) + samples = self.sample( + batch_size=batch_size, + sample_next_obs=sample_next_obs, + clone=clone, + n_samples=n_samples, + **kwargs, ) - # retrieve the items by flattening the indices - # (b1_s1, b1_s2, b1_s3, ..., bn_s1, bn_s2, bn_s3, ...) - # where bm_sk is the k-th elements in the sequence of the m-th batch - sample = self._buf[batch_idxes.flatten(), env_idxes] - # properly reshape the items: - # [ - # [b1_s1, b1_s2, ...], - # [b2_s1, b2_s2, ...], - # ..., - # [bn_s1, bn_s2, ...] - # ] - return sample.view(*unflatten_shape) + return { + k: get_tensor(v, dtype=dtype, clone=clone, device=device, from_numpy=from_numpy) for k, v in samples.items() + } class EpisodeBuffer: @@ -346,62 +747,103 @@ class EpisodeBuffer: buffer_size (int): The capacity of the buffer. sequence_length (int): The length of the sequences of the samples (an episode cannot be shorter than the episode length). - device (Union[torch.device, str]): The device where the buffer is created. Defaults to "cpu". + n_envs (int): The number of environments. + Default to 1. + obs_keys (Sequence[str]): The observations keys to store in the buffer. + Default to ("observations",). + prioritize_ends (bool): Whether to prioritize the ends of the episodes when sampling. + Default to False. memmap (bool): Whether to memory-mapping the buffer. + Default to False. + memmap_dir (str | os.PathLike, optional): The directory for the memmap. + Default to None. + memmap_mode (str, optional): memory-map mode. + Possible values are: "r+", "w+", "c", "copyonwrite", "readwrite", "write". + Defaults to "r+". """ + batch_axis: int = 2 + def __init__( self, buffer_size: int, - sequence_length: int, - device: Union[device, str] = "cpu", + minimum_episode_length: int, + n_envs: int = 1, + obs_keys: Sequence[str] = ("observations",), + prioritize_ends: bool = False, memmap: bool = False, - memmap_dir: Optional[Union[str, os.PathLike]] = None, + memmap_dir: str | os.PathLike | None = None, + memmap_mode: str = "r+", ) -> None: if buffer_size <= 0: raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}") - if sequence_length <= 0: - raise ValueError(f"The sequence length must be greater than zero, got: {sequence_length}") - if buffer_size < sequence_length: + if minimum_episode_length <= 0: + raise ValueError(f"The sequence length must be greater than zero, got: {minimum_episode_length}") + if buffer_size < minimum_episode_length: raise ValueError( "The sequence length must be lower than the buffer size, " - f"got: bs = {buffer_size} and sl = {sequence_length}" + f"got: bs = {buffer_size} and sl = {minimum_episode_length}" ) + self._n_envs = n_envs + self._obs_keys = obs_keys self._buffer_size = buffer_size - self._sequence_length = sequence_length - self._buf: List[TensorDictBase] = [] - self._cum_lengths: List[int] = [] - if isinstance(device, str): - device = torch.device(device=device) - self._device = device + self._minimum_episode_length = minimum_episode_length + self._prioritize_ends = prioritize_ends + + # One list for each environment that contains open episodes: + # one open episode per environment + self._open_episodes = [[] for _ in range(n_envs)] + # Contain the cumulative length of the episodes in the buffer + self._cum_lengths: Sequence[int] = [] + # List of stored episodes + self._buf: Sequence[Dict[str, np.ndarray | MemmapArray]] = [] + self._memmap = memmap self._memmap_dir = memmap_dir - if memmap_dir is None: - warnings.warn( - "The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" - " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", - UserWarning, - ) - else: - self._memmap_dir = Path(self._memmap_dir) - self._memmap_dir.mkdir(parents=True, exist_ok=True) - self._chunk_length = torch.arange(sequence_length, device=self.device).reshape(1, -1) + self._memmap_mode = memmap_mode + if self._memmap: + if self._memmap_mode not in ("r+", "w+", "c", "copyonwrite", "readwrite", "write"): + raise ValueError( + 'Accepted values for memmap_mode are "r+", "readwrite", "w+", "write", "c" or ' + '"copyonwrite". PyTorch does not support tensors backed by read-only ' + 'NumPy arrays, so "r" and "readonly" are not supported.' + ) + if self._memmap_dir is None: + raise ValueError( + "The buffer is set to be memory-mapped but the `memmap_dir` attribute is None. " + "Set the `memmap_dir` to a known directory.", + ) + else: + self._memmap_dir = Path(self._memmap_dir) + self._memmap_dir.mkdir(parents=True, exist_ok=True) @property - def buffer(self) -> Optional[List[TensorDictBase]]: + def prioritize_ends(self) -> bool: + return self._prioritize_ends + + @prioritize_ends.setter + def prioritize_ends(self, prioritize_ends: bool) -> None: + self._prioritize_ends = prioritize_ends + + @property + def buffer(self) -> Sequence[Dict[str, np.ndarray | MemmapArray]]: return self._buf @property - def buffer_size(self) -> int: - return self._buffer_size + def obs_keys(self) -> Sequence[str]: + return self._obs_keys @property - def sequence_length(self) -> int: - return self._sequence_length + def n_envs(self) -> int: + return self._n_envs @property - def device(self) -> device: - return self._device + def buffer_size(self) -> int: + return self._buffer_size + + @property + def minimum_episode_length(self) -> int: + return self._minimum_episode_length @property def is_memmap(self) -> bool: @@ -409,282 +851,321 @@ def is_memmap(self) -> bool: @property def full(self) -> bool: - return self._cum_lengths[-1] + self._sequence_length > self._buffer_size if len(self._buf) > 0 else False - - def __getitem__(self, key: int) -> torch.Tensor: - if not isinstance(key, int): - raise TypeError("`key` must be an integer") - return self._buf[key] + return self._cum_lengths[-1] + self._minimum_episode_length > self._buffer_size if len(self._buf) > 0 else False def __len__(self) -> int: return self._cum_lengths[-1] if len(self._buf) > 0 else 0 - def add(self, episode: TensorDictBase) -> None: - """Add an episode to the buffer. + @typing.overload + def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None: + ... - Args: - episode (TensorDictBase): data to add. + @typing.overload + def add( + self, + data: Dict[str, np.ndarray], + env_idxes: Sequence[int] | None = None, + validate_args: bool = False, + ) -> None: + ... - Raises: - RuntimeError: - - The episode must contain exactly one done at the end of the episode. - - The length of the episode must be at least sequence lenght. - - The length of the episode cannot be greater than the buffer size. + def add( + self, + data: "ReplayBuffer" | Dict[str, np.ndarray], + env_idxes: Sequence[int] | None = None, + validate_args: bool = False, + ) -> None: + """Add data to the replay buffer in episodes. If data is a dictionary, then the keys must be strings + and the values must be numpy arrays of shape [sequence_length, n_envs, ...]. + + Args: + data (ReplayBuffer | Dict[str, np.ndarray]]): data to add. + env_idxes (Sequence[int], optional): the indices of the environments in which to add the data. + Default to None. + validate_args (bool): whether to validate the arguments or not. + Default to None. """ - if len(torch.nonzero(episode["dones"])) != 1: - raise RuntimeError( - f"The episode must contain exactly one done, got: {len(torch.nonzero(episode['dones']))}" - ) - if episode["dones"][-1] != 1: - raise RuntimeError(f"The last step must contain a done, got: {episode['dones'][-1]}") - if episode.shape[0] < self._sequence_length: + if isinstance(data, ReplayBuffer): + data = data.buffer + if validate_args: + if data is None: + raise ValueError("The `data` replay buffer must be not None") + if not isinstance(data, dict): + raise ValueError( + f"`data` must be a dictionary containing Numpy arrays, but `data` is of type `{type(data)}`" + ) + elif isinstance(data, dict): + for k, v in data.items(): + if not isinstance(v, np.ndarray): + raise ValueError( + f"`data` must be a dictionary containing Numpy arrays. Found key `{k}` " + f"containing a value of type `{type(v)}`" + ) + last_key = next(iter(data.keys())) + last_batch_shape = next(iter(data.values())).shape[:2] + for i, (k, v) in enumerate(data.items()): + if len(v.shape) < 2: + raise RuntimeError( + "`data` must have at least 2: [sequence_length, n_envs, ...]. " f"Shape of `{k}` is {v.shape}" + ) + if i > 0: + current_key = k + current_batch_shape = v.shape[:2] + if current_batch_shape != last_batch_shape: + raise RuntimeError( + "Every array in `data` must be congruent in the first 2 dimensions: " + f"found key `{last_key}` with shape `{last_batch_shape}` " + f"and `{current_key}` with shape `{current_batch_shape}`" + ) + last_key = current_key + last_batch_shape = current_batch_shape + + if "dones" not in data: + raise RuntimeError(f"The episode must contain the `dones` key, got: {data.keys()}") + + if env_idxes is not None and (np.array(env_idxes) >= self._n_envs).any(): + raise ValueError( + f"The indices of the environment must be integers in [0, {self._n_envs}), given {env_idxes}" + ) + + # For each environment + if env_idxes is None: + env_idxes = range(self._n_envs) + for i, env in enumerate(env_idxes): + # Take the data from a single environment + env_data = {k: v[:, i] for k, v in data.items()} + done = env_data["dones"] + # Take episode ends + episode_ends = done.nonzero()[0].tolist() + # If there is not any done, then add the data to the respective open episode + if len(episode_ends) == 0: + self._open_episodes[env].append(env_data) + else: + # In case there is at leas one done, then split the environment data into episodes + episode_ends.append(len(done)) + start = 0 + # For each episode in the received data + for ep_end_idx in episode_ends: + stop = ep_end_idx + # Take the episode from the data + episode = {k: env_data[k][start : stop + 1] for k in env_data.keys()} + # If the episode length is greater than zero, then add it to the open episode + # of the corresponding environment. + if len(episode["dones"]) > 0: + self._open_episodes[env].append(episode) + start = stop + 1 + # If the open episode is not empty and the last element is a done, then save the episode + # in the buffer and clear the open episode + if len(self._open_episodes[env]) > 0 and self._open_episodes[env][-1]["dones"][-1] == 1: + self._save_episode(self._open_episodes[env]) + self._open_episodes[env] = [] + + def _save_episode(self, episode_chunks: Sequence[Dict[str, np.ndarray | MemmapArray]]) -> None: + if len(episode_chunks) == 0: + raise RuntimeError("Invalid episode, an empty sequence is given. You must pass a non-empty sequence.") + # Concatenate all the chunks of the episode + episode = {k: [] for k in episode_chunks[0].keys()} + for chunk in episode_chunks: + for k in chunk.keys(): + episode[k].append(chunk[k]) + episode = {k: np.concatenate(v, axis=0) for k, v in episode.items()} + + # Control the validity of the episode + ep_len = episode["dones"].shape[0] + if len(episode["dones"].nonzero()[0]) != 1 or episode["dones"][-1] != 1: + raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(episode['dones']))}") + if ep_len < self._minimum_episode_length: raise RuntimeError( - f"Episode too short (at least {self._sequence_length} steps), got: {episode.shape[0]} steps" + f"Episode too short (at least {self._minimum_episode_length} steps), got: {ep_len} steps" ) - if episode.shape[0] > self._buffer_size: - raise RuntimeError(f"Episode too long (at most {self._buffer_size} steps), got: {episode.shape[0]} steps") + if ep_len > self._buffer_size: + raise RuntimeError(f"Episode too long (at most {self._buffer_size} steps), got: {ep_len} steps") - ep_len = episode.shape[0] + # If the buffer is full, then remove the oldest episodes if self.full or len(self) + ep_len > self._buffer_size: + # Compute the index of the last episode to remove cum_lengths = np.array(self._cum_lengths) mask = (len(self) - cum_lengths + ep_len) <= self._buffer_size last_to_remove = mask.argmax() # Remove all memmaped episodes if self._memmap and self._memmap_dir is not None: for _ in range(last_to_remove + 1): - filename = self._buf[0][self._buf[0].sorted_keys[0]].filename - for k in self._buf[0].sorted_keys: - f = self._buf[0][k].file - if f is not None: - f.close() + dirname = os.path.dirname(self._buf[0][next(iter(self._buf[0].keys()))].filename) + for v in self._buf[0].values(): + del v del self._buf[0] - shutil.rmtree(os.path.dirname(filename)) + try: + shutil.rmtree(dirname) + except Exception as e: + logging.error(e) else: self._buf = self._buf[last_to_remove + 1 :] + # Update the cum_lengths lists cum_lengths = cum_lengths[last_to_remove + 1 :] - cum_lengths[last_to_remove] self._cum_lengths = cum_lengths.tolist() self._cum_lengths.append(len(self) + ep_len) + episode_to_store = episode if self._memmap: - episode_dir = None - if self._memmap_dir is not None: - episode_dir = self._memmap_dir / f"episode_{str(uuid.uuid4())}" - episode_dir.mkdir(parents=True, exist_ok=True) + episode_dir = self._memmap_dir / f"episode_{str(uuid.uuid4())}" + episode_dir.mkdir(parents=True, exist_ok=True) + episode_to_store = {} for k, v in episode.items(): - episode[k] = MemmapTensor.from_tensor( - v, - filename=None if episode_dir is None else episode_dir / f"{k}.memmap", - transfer_ownership=False, + path = Path(episode_dir / f"{k}.memmap") + filename = str(path) + episode_to_store[k] = MemmapArray( + filename=str(filename), + dtype=v.dtype, + shape=v.shape, + mode=self._memmap_mode, ) - episode.memmap_(prefix=episode_dir) - episode.to(self.device) - self._buf.append(episode) + episode_to_store[k][:] = episode[k] + self._buf.append(episode_to_store) def sample( self, batch_size: int, + sample_next_obs: bool = False, n_samples: int = 1, - prioritize_ends: bool = False, clone: bool = False, - ) -> TensorDictBase: + sequence_length: int = 1, + **kwargs, + ) -> Dict[str, np.ndarray]: """Sample trajectories from the replay buffer. Args: batch_size (int): Number of element in the batch. - n_samples (bool): The number of samples to be retrieved. + sample_next_obs (bool): Whether to sample the next obs. + Default to False. + n_samples (bool): The number of samples per batch_size to be retrieved. Defaults to 1. - prioritize_ends (bool): Whether to clone prioritize the ends of the episodes. - Defaults to False. + clone (bool): Whether to clone the samples. + Default to False. + sequence_length (int): The length of the sequences to sample. + Default to 1. Returns: - TensorDictBase: the sampled TensorDictBase with a `batch_size` of [batch_size, 1] + Dict[str, np.ndarray]: the sampled dictionary with a shape of + [n_samples, sequence_length, batch_size, ...]. """ if batch_size <= 0: raise ValueError(f"Batch size must be greater than 0, got: {batch_size}") if n_samples <= 0: raise ValueError(f"The number of samples must be greater than 0, got: {n_samples}") - if len(self) == 0: + if sample_next_obs: + valid_episode_idxes = np.array(self._cum_lengths) - np.array([0] + self._cum_lengths[:-1]) > sequence_length + else: + valid_episode_idxes = ( + np.array(self._cum_lengths) - np.array([0] + self._cum_lengths[:-1]) >= sequence_length + ) + valid_episodes = list(compress(self._buf, valid_episode_idxes)) + if len(valid_episodes) == 0: raise RuntimeError( - "No sample has been added to the buffer. Please add at least one sample calling `self.add()`" + "No valid episodes has been added to the buffer. Please add at least one episode of length greater " + f"than or equal to {sequence_length} calling `self.add()`" ) - nsample_per_eps = torch.bincount(torch.randint(0, len(self._buf), (batch_size * n_samples,))) - samples = [] + chunk_length = np.arange(sequence_length, dtype=np.intp).reshape(1, -1) + nsample_per_eps = np.bincount(np.random.randint(0, len(valid_episodes), (batch_size * n_samples,))).astype( + np.intp + ) + samples_per_eps = {k: [] for k in valid_episodes[0].keys()} + if sample_next_obs: + samples_per_eps.update({f"next_{k}": [] for k in self._obs_keys}) for i, n in enumerate(nsample_per_eps): - ep_len = self._buf[i].shape[0] - upper = ep_len - self._sequence_length + 1 - if prioritize_ends: - upper += self._sequence_length - start_idxes = torch.min( - torch.randint(0, upper, size=(n,)).reshape(-1, 1), torch.tensor(ep_len - self._sequence_length) - ) - indices = start_idxes + self._chunk_length - samples.append(self._buf[i][indices]) - samples = torch.cat(samples, 0).reshape(n_samples, batch_size, self._sequence_length).permute(0, -1, -2) - if clone: - return samples.clone() - return samples - - -class AsyncReplayBuffer: - def __init__( - self, - buffer_size: int, - n_envs: int = 1, - device: Union[device, str] = "cpu", - memmap: bool = False, - memmap_dir: Optional[Union[str, os.PathLike]] = None, - sequential: bool = False, - ): - """An async replay buffer which internally uses a TensorDict. This replay buffer - saves a experiences independently for every environment. When new data has to be added, it expects - the TensorDict or the ReplayBuffer to be added to have a 2D shape dimension as [T, B], where T` - represents the sequence length, while `B` is the number of environments to be batched. - - Args: - buffer_size (int): The buffer size. - n_envs (int, optional): The number of environments. Defaults to 1. - device (Union[torch.device, str], optional): The device where the buffer is created. Defaults to "cpu". - memmap (bool, optional): Whether to memory-mapping the buffer. - """ - if buffer_size <= 0: - raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}") - if n_envs <= 0: - raise ValueError(f"The number of environments must be greater than zero, got: {n_envs}") - self._buffer_size = buffer_size - self._n_envs = n_envs - if isinstance(device, str): - device = torch.device(device=device) - self._device = device - self._memmap = memmap - self._memmap_dir = memmap_dir - self._sequential = sequential - self._buf: Optional[Sequence[ReplayBuffer]] = None - if self._memmap_dir is not None: - self._memmap_dir = Path(self._memmap_dir) - if self._memmap: - if memmap_dir is None: - warnings.warn( - "The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" - " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", - UserWarning, + if n > 0: + ep_len = valid_episodes[i]["dones"].shape[0] + if sample_next_obs: + ep_len -= 1 + # Define the maximum index that can be sampled in the episodes + upper = ep_len - sequence_length + 1 + # If you want to prioritize ends, then all the indices of the episode + # can be sampled as starting index + if self._prioritize_ends: + upper += sequence_length + # Sample the starting indices and upper bound with `ep_len - sequence_length` + start_idxes = np.minimum( + np.random.randint(0, upper, size=(n,)).reshape(-1, 1), ep_len - sequence_length, dtype=np.intp ) - else: - self._memmap_dir.mkdir(parents=True, exist_ok=True) - - @property - def buffer(self) -> Optional[Sequence[ReplayBuffer]]: - return tuple(self._buf) - - @property - def buffer_size(self) -> int: - return self._buffer_size - - @property - def full(self) -> Optional[Sequence[bool]]: - if self.buffer is None: - return None - return tuple([b.full for b in self.buffer]) - - @property - def n_envs(self) -> int: - return self._n_envs - - @property - def shape(self) -> Optional[Sequence[Size]]: - if self.buffer is None: - return None - return tuple([b.shape for b in self.buffer]) - - @property - def device(self) -> Optional[Sequence[device]]: - if self.buffer is None: - return None - return self._device - - def __len__(self) -> int: - return self.buffer_size - - def add(self, data: TensorDictBase, indices: Optional[Sequence[int]] = None) -> None: - """Add data to the buffer. - - Args: - data: data to add. - indices (Sequence[int], optional): the indices where to add the data. - If None, then data will be added on every indices. - Defaults to None. - - Raises: - RuntimeError: the number of dimensions (the batch_size of the TensorDictBase) must be 2: - one for the number of environments and one for the sequence length. - """ - if not isinstance(data, TensorDictBase): - raise TypeError("`data` must be a TensorDictBase") - if data is None: - raise RuntimeError("The `data` parameter must be not None") - if len(data.shape) != 2: - raise RuntimeError( - "`data` must have 2 batch dimensions: [sequence_length, n_envs]. " - "`sequence_length` and `n_envs` should be 1. Shape is: {}".format(data.shape) - ) - if self._buf is None: - buf_cls = SequentialReplayBuffer if self._sequential else ReplayBuffer - self._buf = [ - buf_cls( - self.buffer_size, - n_envs=1, - device=self._device, - memmap=self._memmap, - memmap_dir=self._memmap_dir / f"env_{i}" if self._memmap_dir is not None else None, + # Compute the indices of the sequences + indices = start_idxes + chunk_length + # Retrieve the data + for k in valid_episodes[0].keys(): + samples_per_eps[k].append( + np.take(valid_episodes[i][k], indices.flat, axis=0).reshape( + n, sequence_length, *valid_episodes[i][k].shape[1:] + ) + ) + if sample_next_obs and k in self._obs_keys: + samples_per_eps[f"next_{k}"].append(valid_episodes[i][k][indices + 1]) + # Concatenate all the trajectories on the batch dimension and properly reshape them + samples = {} + for k, v in samples_per_eps.items(): + if len(v) > 0: + samples[k] = np.moveaxis( + np.concatenate(v, axis=0).reshape(n_samples, batch_size, sequence_length, *v[0].shape[2:]), + 2, + 1, ) - for i in range(self._n_envs) - ] - if indices is None: - indices = tuple(range(self.n_envs)) - for env_data_idx, env_idx in enumerate(indices): - self._buf[env_idx].add(data[:, env_data_idx : env_data_idx + 1]) + if clone: + samples[k] = samples[k].copy() + return samples - def sample( + @torch.no_grad() + def sample_tensors( self, batch_size: int, sample_next_obs: bool = False, + n_samples: int = 1, clone: bool = False, sequence_length: int = 1, - n_samples: int = 1, - ) -> TensorDictBase: - """Sample elements from the sequential replay buffer, - each one is a sequence of a consecutive items. - - Custom sampling when using memory efficient variant, - as the first element of the sequence cannot be in a position - greater than (pos - sequence_length) % buffer_size. - See comments in the code for more information. + dtype: Optional[torch.dtype] = None, + device: str | torch.dtype = "cpu", + from_numpy: bool = False, + **kwargs, + ) -> Dict[str, Tensor]: + """Sample elements from the replay buffer and convert them to torch tensors. Args: - batch_size (int): Number of element to sample + batch_size (int): Number of elements to sample. sample_next_obs (bool): whether to sample the next observations from the 'observations' key. Defaults to False. - clone (bool): whether to clone the sampled TensorDict. + clone (bool): whether to clone the sampled tensors. + n_samples (int): the number of samples per batch_size. Defaults to 1. sequence_length (int): the length of the sequence of each element. Defaults to 1. - n_samples (int): the number of samples to perform. Defaults to 1. - - Returns: - TensorDictBase: the sampled TensorDictBase with a `batch_size` of [n_samples, sequence_length, batch_size] + dtype (Optional[torch.dtype], optional): the torch dtype to convert the arrays to. If None, + then the dtypes of the numpy arrays is maintained. Defaults to None. + device (str | torch.dtype, optional): the torch device to move the tensors to. Defaults to "cpu". + from_numpy (bool, optional): whether to convert the numpy arrays to torch tensors + with the 'torch.from_numpy' function. If False, then the numpy arrays are converted + with the 'torch.as_tensor' function. Defaults to False. + kwargs: additional keyword arguments to be passed to the 'self.sample' method. """ - if batch_size <= 0 or n_samples <= 0: - raise ValueError(f"`batch_size` ({batch_size}) and `n_samples` ({n_samples}) must be both greater than 0") - if self._buf is None: - raise RuntimeError("The buffer has not been initialized. Try to add some data first.") - - bs_per_buf = torch.bincount(torch.randint(0, self._n_envs, (batch_size,))) - samples = [ - b.sample( - batch_size=bs, - sample_next_obs=sample_next_obs, - clone=clone, - n_samples=n_samples, - sequence_length=sequence_length, - ) - for b, bs in zip(self._buf, bs_per_buf) - if bs > 0 - ] - return torch.cat(samples, dim=2 if self._sequential else 0) + samples = self.sample(batch_size, sample_next_obs, n_samples, clone, sequence_length) + return { + k: get_tensor(v, dtype=dtype, clone=clone, device=device, from_numpy=from_numpy) for k, v in samples.items() + } + + +def get_tensor( + array: np.ndarray | MemmapArray, + dtype: Optional[torch.dtype] = None, + clone: bool = False, + device: str | torch.dtype = "cpu", + from_numpy: bool = False, +) -> Tensor: + if isinstance(array, MemmapArray): + array = array.array + if clone: + array = array.copy() + if from_numpy: + torch_v = torch.from_numpy(array).to( + dtype=NUMPY_TO_TORCH_DTYPE_DICT[array.dtype] if dtype is None else dtype, + device=device, + ) + else: + torch_v = torch.as_tensor( + array, + dtype=NUMPY_TO_TORCH_DTYPE_DICT[array.dtype] if dtype is None else dtype, + device=device, + ) + return torch_v diff --git a/sheeprl/utils/callback.py b/sheeprl/utils/callback.py index db6859d4..d8712a08 100644 --- a/sheeprl/utils/callback.py +++ b/sheeprl/utils/callback.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, Optional, Union +from __future__ import annotations + +from typing import Any, Dict, Optional, Sequence, Union -import torch from lightning.fabric import Fabric from lightning.fabric.plugins.collectives import TorchCollective +from torch import Tensor -from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer, ReplayBuffer +from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, ReplayBuffer class CheckpointCallback: @@ -25,30 +27,17 @@ def on_checkpoint_coupled( fabric: Fabric, ckpt_path: str, state: Dict[str, Any], - replay_buffer: Optional[Union["AsyncReplayBuffer", "ReplayBuffer", "EpisodeBuffer"]] = None, + replay_buffer: Optional[Union["EnvIndependentReplayBuffer", "ReplayBuffer", "EpisodeBuffer"]] = None, ): if replay_buffer is not None: - if isinstance(replay_buffer, ReplayBuffer): - # clone the true done - true_done = replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :].clone() - # substitute the last done with all True values (all the environment are truncated) - replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :] = True - elif isinstance(replay_buffer, AsyncReplayBuffer): - true_dones = [] - for b in replay_buffer.buffer: - true_dones.append(b["dones"][(b._pos - 1) % b.buffer_size, :].clone()) - b["dones"][(b._pos - 1) % b.buffer_size, :] = True + rb_state = self._ckpt_rb(replay_buffer) state["rb"] = replay_buffer if fabric.world_size > 1: # We need to collect the buffers from all the ranks # The collective it is needed because the `gather_object` function is not implemented in Fabric checkpoint_collective = TorchCollective() # gloo is the torch.distributed backend that works on cpu - if replay_buffer.device == torch.device("cpu"): - backend = "gloo" - else: - backend = "nccl" - checkpoint_collective.create_group(backend=backend, ranks=list(range(fabric.world_size))) + checkpoint_collective.create_group(backend="gloo", ranks=list(range(fabric.world_size))) gathered_rb = [None for _ in range(fabric.world_size)] if fabric.global_rank == 0: checkpoint_collective.gather_object(replay_buffer, gathered_rb) @@ -56,12 +45,8 @@ def on_checkpoint_coupled( else: checkpoint_collective.gather_object(replay_buffer, None) fabric.save(ckpt_path, state) - if replay_buffer is not None and isinstance(replay_buffer, ReplayBuffer): - # reinsert the true dones in the buffer - replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :] = true_done - elif isinstance(replay_buffer, AsyncReplayBuffer): - for i, b in enumerate(replay_buffer.buffer): - b["dones"][(b._pos - 1) % b.buffer_size, :] = true_dones[i] + if replay_buffer is not None: + self._experiment_consistent_rb(replay_buffer, rb_state) def on_checkpoint_player( self, @@ -74,15 +59,11 @@ def on_checkpoint_player( player_trainer_collective.broadcast_object_list(state, src=1) state = state[0] if replay_buffer is not None: - # clone the true done - true_done = replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :].clone() - # substitute the last done with all True values (all the environment are truncated) - replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :] = True + rb_state = self._ckpt_rb(replay_buffer) state["rb"] = replay_buffer fabric.save(ckpt_path, state) if replay_buffer is not None: - # reinsert the true dones in the buffer - replay_buffer["dones"][(replay_buffer._pos - 1) % replay_buffer.buffer_size, :] = true_done + self._experiment_consistent_rb(replay_buffer, rb_state) def on_checkpoint_trainer( self, fabric: Fabric, player_trainer_collective: TorchCollective, state: Dict[str, Any], ckpt_path: str @@ -90,3 +71,60 @@ def on_checkpoint_trainer( if fabric.global_rank == 1: player_trainer_collective.broadcast_object_list([state], src=1) fabric.save(ckpt_path, state) + + def _ckpt_rb( + self, rb: ReplayBuffer | EnvIndependentReplayBuffer | EpisodeBuffer + ) -> Tensor | Sequence[Tensor] | Sequence[Sequence[Tensor]]: + """Modify the replay buffer in order to be consistent for the checkpoint. + There could be 3 cases, depending on the buffers: + + 1. The `ReplayBuffer` or `SequentialReplayBuffer`: a done is inserted in the last pos because the + state of the environment is not saved in the checkpoint. + 2. The `EnvIndependentReplayBuffer`: for each buffer, the done in the last position is set to True + (for the same reason of the point 1.). + 3. The `EpisodeBuffer`: the open episodes are discarded because the + state of the environment is not saved in the checkpoint. + + Args: + rb (ReplayBuffer | EnvIndependentReplayBuffer | EpisodeBuffer): the buffer. + + Returns: + The original state of the buffer. + """ + if isinstance(rb, ReplayBuffer): + # clone the true done + state = rb["dones"][(rb._pos - 1) % rb.buffer_size, :].copy() + # substitute the last done with all True values (all the environment are truncated) + rb["dones"][(rb._pos - 1) % rb.buffer_size, :] = True + elif isinstance(rb, EnvIndependentReplayBuffer): + state = [] + for b in rb.buffer: + state.append(b["dones"][(b._pos - 1) % b.buffer_size, :].copy()) + b["dones"][(b._pos - 1) % b.buffer_size, :] = True + elif isinstance(rb, EpisodeBuffer): + # remove open episodes from the buffer because the state of the environment is not saved + state = rb._open_episodes + rb._open_episodes = [[] for _ in range(rb.n_envs)] + return state + + def _experiment_consistent_rb( + self, + rb: ReplayBuffer | EnvIndependentReplayBuffer | EpisodeBuffer, + state: Tensor | Sequence[Tensor] | Sequence[Sequence[Tensor]], + ): + """Restore the state of the buffer consistent with the execution of the experiment. + I.e., it undoes the changes in the _ckpt_rb function. + + Args: + rb (ReplayBuffer | EnvIndependentReplayBuffer | EpisodeBuffer): the buffer. + state (Tensor | Sequence[Tensor] | Sequence[Sequence[Tensor]]): the original state of the buffer. + """ + if isinstance(rb, ReplayBuffer): + # reinsert the true dones in the buffer + rb["dones"][(rb._pos - 1) % rb.buffer_size, :] = state + elif isinstance(rb, EnvIndependentReplayBuffer): + for i, b in enumerate(rb.buffer): + b["dones"][(b._pos - 1) % b.buffer_size, :] = state[i] + elif isinstance(rb, EpisodeBuffer): + # reinsert the open episodes to continue the training + rb._open_episodes = state diff --git a/sheeprl/utils/logger.py b/sheeprl/utils/logger.py index 0e23cc58..1c14d297 100644 --- a/sheeprl/utils/logger.py +++ b/sheeprl/utils/logger.py @@ -32,7 +32,7 @@ def get_logger(fabric: Fabric, cfg: Dict[str, Any]) -> Optional[Logger]: ) cfg.metric.logger.root_dir = root_dir cfg.metric.logger.name = cfg.run_name - logger = hydra.utils.instantiate(cfg.metric.logger) + logger = hydra.utils.instantiate(cfg.metric.logger, _convert_="all") return logger diff --git a/sheeprl/utils/memmap.py b/sheeprl/utils/memmap.py new file mode 100644 index 00000000..98fe9726 --- /dev/null +++ b/sheeprl/utils/memmap.py @@ -0,0 +1,272 @@ +"""Inspired by: https://github.com/pytorch/tensordict/blob/main/tensordict/memmap.py""" + +from __future__ import annotations + +import os +import tempfile +import warnings +from io import TextIOWrapper +from pathlib import Path +from sys import getrefcount +from tempfile import _TemporaryFileWrapper +from typing import Any, Tuple + +import numpy as np +from numpy.typing import DTypeLike + +from sheeprl.utils.imports import _IS_WINDOWS + + +def is_shared(array: np.ndarray) -> bool: + return isinstance(array, np.ndarray) and hasattr(array, "_mmap") + + +class MemmapArray(np.lib.mixins.NDArrayOperatorsMixin): + def __init__( + self, + shape: None | int | Tuple[int, ...], + dtype: DTypeLike = None, + mode: str = "r+", + reset: bool = False, + filename: str | os.PathLike | None = None, + ): + """Create a memory-mapped array. The memory-mapped array is stored in a file on disk and is + lazily loaded on demand. The array can be modified in-place and is automatically flushed to + disk when the array is deleted. The ownership of the file can be transferred only when: + + * the array is created from an already mamory-mapped array (i.e., `MemmapArray.from_array`) + * the array is set from an already memory-mapped array (i.e., `MemmapArray.array = ...`) + + Args: + dtype (DTypeLike): the data type of the array. + shape (None | int | Tuple[int, ...]): the shape of the array. + mode (str, optional): the mode to open the file with. Defaults to "r+". + reset (bool, optional): whether to reset the opened array to 0s. Defaults to False. + filename (str | os.PathLike | None, optional): an optional filename. If the filename is None, + then a temporary file will be opened. + Defaults to None. + """ + if filename is None: + fd, path = tempfile.mkstemp(".memmap") + self._filename = Path(path).resolve() + self._file = _TemporaryFileWrapper(open(fd, mode="r+"), path, delete=False) + else: + path = Path(filename).resolve() + if os.path.exists(path): + warnings.warn( + "The specified filename already exists. " + "Please be aware that any modification will be possibly reflected.", + category=UserWarning, + ) + path.parent.mkdir(parents=True, exist_ok=True) + path.touch(exist_ok=True) + self._filename = path + self._file = open(path, mode="r+") + os.close(self._file.fileno()) + self._dtype = dtype + self._shape = shape + self._mode = mode + self._array = np.memmap( + filename=self._filename, + dtype=self._dtype, + shape=self._shape, + mode=self._mode, + ) + if reset: + self._array[:] = np.zeros_like(self._array) + self._has_ownership = True + self._array_dir = self._array.__dir__() + self.__array_interface__ = self._array.__array_interface__ + + @property + def filename(self) -> Path: + """Return the filename of the memory-mapped array.""" + return self._filename + + @property + def file(self) -> TextIOWrapper: + """Return the file object of the memory-mapped array.""" + return self._file + + @property + def dtype(self) -> DTypeLike: + """Return the data type of the memory-mapped array.""" + return self._dtype + + @property + def mode(self) -> str: + """Return the mode of the memory-mapped array that has been opened with.""" + return self._mode + + @property + def shape(self) -> None | int | Tuple[int, ...]: + """Return the shape of the memory-mapped array.""" + return self._shape + + @property + def has_ownership(self) -> bool: + """Return whether the memory-mapped array has ownership of the file.""" + return self._has_ownership + + @has_ownership.setter + def has_ownership(self, value: bool): + """Set whether the memory-mapped array has ownership of the file.""" + self._has_ownership = value + + @property + def array(self) -> np.memmap: + """Return the memory-mapped array.""" + if not os.path.isfile(self._filename): + self._array = None + if self._array is None: + self._array = np.memmap( + filename=self._filename, + dtype=self._dtype, + shape=self._shape, + mode=self._mode, + ) + return self._array + + @array.setter + def array(self, v: np.memmap | np.ndarray): + """Set the memory-mapped array. If the array to be set is already memory-mapped, then the ownership of the + file will not be transferred to this memory-mapped array; this instance will lose previous + ownership on its memory mapped file. Otherwise, the array will be copied into + the memory-mapped array. In this last case, the shape of the array to be set must be the same as the + shape of the memory-mapped array. + + Args: + v (np.memmap | np.ndarray): the array to be set. + + Raises: + ValueError: if the value to be set is not an instance of `np.memmap` or `np.ndarray`. + """ + if not isinstance(v, (np.memmap, np.ndarray)): + raise ValueError(f"The value to be set must be an instance of 'np.memmap' or 'np.ndarray', got '{type(v)}'") + if is_shared(v): + self.__del__() + tmpfile = _TemporaryFileWrapper(None, v.filename, delete=True) + tmpfile.name = v.filename + tmpfile._closer.name = v.filename + self._file = tmpfile + self._filename = v.filename + self._shape = v.shape + self._dtype = v.dtype + self._has_ownership = False + self.__array_interface__ = v.__array_interface__ + self._array = np.memmap( + filename=self._filename, + dtype=self._dtype, + shape=self._shape, + mode=self._mode, + ) + else: + if self._array.size != v.size: + raise ValueError( + "The shape of the value to be set must be the same as the shape of the memory-mapped array. " + f"Got {v.shape} and {self._shape}" + ) + reshaped_v = np.reshape(v, self._shape) + self._array[:] = reshaped_v + self._array.flush() + + @classmethod + def from_array( + cls, + array: np.ndarray | np.memmap | MemmapArray, + mode: str = "r+", + filename: str | os.PathLike | None = None, + ) -> MemmapArray: + """Create a memory-mapped array from an array. If the array is already memory-mapped, then the ownership of + the file will not be transferred to this memory-mapped array; this instance will lose previous ownership on + its memory mapped file. Otherwise, the array will be copied into the memory-mapped array. In this last case, + the shape of the array to be set must be the same as the shape of the memory-mapped array. + + Args: + array (np.ndarray | np.memmap | MemmapArray): the array to be set. + mode (str, optional): the mode to open the file with. Defaults to "r+". + filename (str | os.PathLike | None, optional): the filename. Defaults to None. + + Returns: + MemmapArray: the memory-mapped array. + """ + filename = Path(filename).resolve() if filename is not None else None + is_memmap_array = isinstance(array, MemmapArray) + is_shared_array = is_shared(array) + if isinstance(array, (np.ndarray, MemmapArray)): + out = cls(filename=filename, dtype=array.dtype, shape=array.shape, mode=mode, reset=False) + if is_memmap_array or is_shared_array: + if is_memmap_array: + array = array.array + if filename is not None and filename == Path(array.filename).resolve(): + out.array = array # Lose previous ownership + else: + out.array[:] = array[:] + else: + if filename is not None and os.path.exists(filename): + warnings.warn( + "The specified filename already exists. " + "Please be aware that any modification will be possibly reflected.", + category=UserWarning, + ) + out.array = array + return out + + def __del__(self) -> None: + """Delete the memory-mapped array. If the memory-mapped array has ownership of the file and no other + reference to the memory-mapped array exists or the OS is Windows-based, + then the memory-mapped array will be flushed to disk and both the memory-mapped array and + the file will be closed. If the memory-mapped array is mapped to a temporary file then the file is + removed. + """ + if (self._has_ownership and getrefcount(self._file) <= 2) or _IS_WINDOWS: + self._array.flush() + self._array._mmap.close() + del self._array._mmap + self._array = None + if isinstance(self._file, _TemporaryFileWrapper) and os.path.isfile(self._filename): + os.unlink(self._filename) + del self._file + + def __array__(self) -> np.memmap: + return self.array + + def __getattr__(self, attr: str) -> Any: + if attr in self.__dir__(): + return self.__getattribute__(attr) + if ("_array_dir" not in self.__dir__()) or (attr not in self.__getattribute__("_array_dir")): + raise AttributeError(f"'MemmapArray' object has no attribute '{attr}'") + array = self.__getattribute__("array") + return getattr(array, attr) + + def __getstate__(self): + # Copy the object's state from self.__dict__ which contains + # all our instance attributes. Always use the dict.copy() + # method to avoid modifying the original state. + state = self.__dict__.copy() + # Remove the unpicklable entries. + state["_file"] = None + state["_array"] = None + state["_has_ownership"] = False + return state + + def __setstate__(self, state): + filename = state["_filename"] + if state["_file"] is None: + tmpfile = _TemporaryFileWrapper(None, filename, delete=True) + tmpfile.name = filename + tmpfile._closer.name = filename + state["_file"] = tmpfile + self.__dict__.update(state) + + def __getitem__(self, idx: Any) -> np.ndarray: + return self.array[idx] + + def __setitem__(self, idx: Any, value: Any): + self.array[idx] = value + + def __repr__(self) -> str: + return f"MemmapArray(shape={self._shape}, dtype={self._dtype}, mode={self._mode}, filename={self._filename})" + + def __len__(self): + return len(self.array) diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index f705be9b..d85eb16a 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -4,6 +4,7 @@ import os from typing import Any, Dict, Optional, Sequence, Tuple, Union +import numpy as np import rich.syntax import rich.tree import torch @@ -13,6 +14,21 @@ from pytorch_lightning.utilities import rank_zero_only from torch import Tensor +NUMPY_TO_TORCH_DTYPE_DICT = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} +TORCH_TO_NUMPY_DTYPE_DICT = {value: key for key, value in NUMPY_TO_TORCH_DTYPE_DICT.items()} + class dotdict(dict): """ @@ -177,5 +193,5 @@ def unwrap_fabric(model: _FabricModule | nn.Module) -> nn.Module: return model -def save_configs(cfg: Dict[str, Any], log_dir: str): +def save_configs(cfg: dotdict, log_dir: str): OmegaConf.save(cfg.as_dict(), os.path.join(log_dir, "config.yaml"), resolve=True) diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 53a7607f..6948b93d 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -438,7 +438,6 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.world_model.recurrent_model.recurrent_state_size=8", "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", - "algo.cnn_keys.encoder=[rgb]", "algo.layer_norm=True", "algo.train_every=1", "algo.cnn_keys.encoder=[rgb]", diff --git a/tests/test_data/test_buffers.py b/tests/test_data/test_buffers.py index ab2b0016..c5c069cc 100644 --- a/tests/test_data/test_buffers.py +++ b/tests/test_data/test_buffers.py @@ -2,12 +2,12 @@ import shutil import time +import numpy as np import pytest -import torch from lightning import Fabric -from tensordict import TensorDict from sheeprl.data.buffers import ReplayBuffer +from sheeprl.utils.memmap import MemmapArray def test_replay_buffer_wrong_buffer_size(): @@ -20,93 +20,176 @@ def test_replay_buffer_wrong_n_envs(): ReplayBuffer(1, -1) +@pytest.mark.parametrize("memmap_mode", ["r", "x", "w", "z"]) +def test_replay_buffer_wrong_memmap_mode(memmap_mode): + with pytest.raises(ValueError, match="Accepted values for memmap_mode are"): + ReplayBuffer(10, 10, memmap_mode=memmap_mode, memmap=True) + + def test_replay_buffer_add_single_td_not_full(): buf_size = 5 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(2, 1, 1)}, batch_size=[2, n_envs]) + td1 = {"a": np.random.rand(2, 1, 1)} rb.add(td1) assert not rb.full assert rb._pos == 2 - torch.testing.assert_close(rb["t"][:2], td1["t"]) + np.testing.assert_allclose(rb["a"][:2], td1["a"]) def test_replay_buffer_add_tds(): buf_size = 5 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(2, 1, 1)}, batch_size=[2, n_envs]) - td2 = TensorDict({"t": torch.rand(2, 1, 1)}, batch_size=[2, n_envs]) - td3 = TensorDict({"t": torch.rand(3, 1, 1)}, batch_size=[3, n_envs]) + td1 = {"a": np.random.rand(2, 1, 1)} + td2 = {"a": np.random.rand(2, 1, 1)} + td3 = {"a": np.random.rand(3, 1, 1)} rb.add(td1) rb.add(td2) rb.add(td3) assert rb.full - assert rb["t"][0] == td3["t"][-2] - assert rb["t"][1] == td3["t"][-1] + assert rb["a"][0] == td3["a"][-2] + assert rb["a"][1] == td3["a"][-1] assert rb._pos == 2 - torch.testing.assert_close(rb["t"][2:4], td2["t"]) + np.testing.assert_allclose(rb["a"][2:4], td2["a"]) def test_replay_buffer_add_tds_exceeding_buf_size_multiple_times(): buf_size = 7 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(2, 1, 1)}, batch_size=[2, n_envs]) - td2 = TensorDict({"t": torch.rand(1, 1, 1)}, batch_size=[1, n_envs]) - td3 = TensorDict({"t": torch.rand(9, 1, 1)}, batch_size=[9, n_envs]) + td1 = {"a": np.random.rand(2, 1, 1)} + td2 = {"a": np.random.rand(1, 1, 1)} + td3 = {"a": np.random.rand(9, 1, 1)} rb.add(td1) rb.add(td2) assert not rb.full rb.add(td3) assert rb.full assert rb._pos == 5 - remainder = len(td3) % buf_size - torch.testing.assert_close(rb["t"][: rb._pos], td3["t"][rb.buffer_size - rb._pos + remainder :]) + remainder = len(td3["a"]) % buf_size + np.testing.assert_allclose(rb["a"][: rb._pos], td3["a"][rb.buffer_size - rb._pos + remainder :]) def test_replay_buffer_add_single_td_size_is_not_multiple(): buf_size = 5 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(17, 1, 1)}, batch_size=[17, n_envs]) + td1 = {"a": np.random.rand(17, 1, 1)} rb.add(td1) assert rb.full assert rb._pos == 2 - remainder = len(td1) % buf_size - torch.testing.assert_close(rb["t"][:remainder], td1["t"][-remainder:]) - torch.testing.assert_close(rb["t"][remainder:], td1["t"][-buf_size:-remainder]) + remainder = len(td1["a"]) % buf_size + np.testing.assert_allclose(rb["a"][:remainder], td1["a"][-remainder:]) + np.testing.assert_allclose(rb["a"][remainder:], td1["a"][-buf_size:-remainder]) def test_replay_buffer_add_single_td_size_is_multiple(): buf_size = 5 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(20, 1, 1)}, batch_size=[20, n_envs]) + td1 = {"a": np.random.rand(20, 1, 1)} rb.add(td1) assert rb.full assert rb._pos == 0 - torch.testing.assert_close(rb["t"], td1["t"][-buf_size:]) + np.testing.assert_allclose(rb["a"], td1["a"][-buf_size:]) -def test_replay_buffer_sample(): +def test_replay_buffer_add_replay_buffer(): buf_size = 5 n_envs = 1 + rb1 = ReplayBuffer(buf_size, n_envs) + rb1.add({"a": np.random.rand(6, 1, 1)}) + rb2 = ReplayBuffer(buf_size, n_envs) + rb2.add(rb1) + assert (rb1.buffer["a"] == rb2.buffer["a"]).all() + + +def test_replay_buffer_add_error(): + import torch + + buf_size = 5 + n_envs = 3 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(6, 1, 1)}, batch_size=[6, n_envs]) + with pytest.raises(ValueError, match="must be a dictionary containing Numpy arrays"): + rb.add([i for i in range(5)], validate_args=True) + with pytest.raises(ValueError, match=r"must be a dictionary containing Numpy arrays\. Found key"): + rb.add({"a": torch.rand(6, 1, 1)}, validate_args=True) + + with pytest.raises(RuntimeError, match="must have at least 2 dimensions"): + rb.add( + { + "a": np.random.rand( + 6, + ) + }, + validate_args=True, + ) + + with pytest.raises(RuntimeError, match="Every array in 'data' must be congruent in the first 2 dimensions"): + rb.add( + { + "a": np.random.rand(6, n_envs, 4), + "b": np.random.rand(6, n_envs, 4), + "c": np.random.rand(6, 1, 4), + }, + validate_args=True, + ) + + +def test_replay_buffer_sample(): + buf_size = 5 + n_envs = 1 + rb = ReplayBuffer(buf_size, n_envs, obs_keys=("a",)) + td1 = {"a": np.random.rand(6, 1, 1)} rb.add(td1) s = rb.sample(4) - assert s.shape == torch.Size([4, 1]) + assert s["a"].shape == tuple([1, 4, 1]) + s = rb.sample(4, n_samples=3) + assert s["a"].shape == tuple([3, 4, 1]) + s = rb.sample(4, n_samples=2, clone=True, sample_next_obs=True) + assert s["a"].shape == tuple([2, 4, 1]) + assert s["next_a"].shape == tuple([2, 4, 1]) + + +def test_replay_buffer_sample_one_sample_next_obs_error(): + buf_size = 5 + n_envs = 1 + rb = ReplayBuffer(buf_size, n_envs) + td1 = {"a": np.random.rand(1, 1, 1)} + rb.add(td1) + with pytest.raises(RuntimeError, match="You want to sample the next observations"): + rb.sample(1, sample_next_obs=True) + + +def test_replay_buffer_getitem_error(): + buf_size = 5 + n_envs = 1 + rb = ReplayBuffer(buf_size, n_envs) + with pytest.raises(RuntimeError, match="The buffer has not been initialized"): + rb["a"] + td = {"a": np.random.rand(1, 1, 1)} + rb.add(td) + with pytest.raises(TypeError, match="'key' must be a string"): + rb[0] + + +def test_replay_buffer_get_sample_empty_error(): + buf_size = 5 + n_envs = 1 + rb = ReplayBuffer(buf_size, n_envs) + with pytest.raises(RuntimeError, match="The buffer has not been initialized"): + rb._get_samples(np.zeros((1,)), sample_next_obs=True) def test_replay_buffer_sample_next_obs_not_full(): buf_size = 5 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"observations": torch.arange(4).view(-1, 1, 1)}, batch_size=[4, n_envs]) + td1 = {"observations": np.arange(4).reshape(-1, 1, 1)} rb.add(td1) s = rb.sample(10, sample_next_obs=True) - assert s.shape == torch.Size([10, 1]) + assert s["observations"].shape == tuple([1, 10, 1]) assert td1["observations"][-1] not in s["observations"] @@ -114,10 +197,10 @@ def test_replay_buffer_sample_next_obs_full(): buf_size = 5 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"observations": torch.arange(8).view(-1, 1, 1)}, batch_size=[8, n_envs]) + td1 = {"observations": np.arange(8).reshape(-1, 1, 1)} rb.add(td1) s = rb.sample(10, sample_next_obs=True) - assert s.shape == torch.Size([10, 1]) + assert s["observations"].shape == tuple([1, 10, 1]) assert td1["observations"][-1] not in s["observations"] @@ -125,22 +208,22 @@ def test_replay_buffer_sample_full(): buf_size = 5 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(6, 1, 1)}, batch_size=[6, n_envs]) + td1 = {"a": np.random.rand(6, 1, 1)} rb.add(td1) s = rb.sample(6) - assert s.shape == torch.Size([6, 1]) + assert s["a"].shape == tuple([1, 6, 1]) def test_replay_buffer_sample_one_element(): buf_size = 1 n_envs = 1 rb = ReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"observations": torch.rand(1, 1, 1)}, batch_size=[1, n_envs]) + td1 = {"observations": np.random.rand(1, 1, 1)} rb.add(td1) sample = rb.sample(1) assert rb.full assert sample["observations"] == td1["observations"] - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): rb.sample(1, sample_next_obs=True) @@ -150,24 +233,26 @@ def test_replay_buffer_sample_fail(): rb = ReplayBuffer(buf_size, n_envs) with pytest.raises(ValueError, match="No sample has been added to the buffer"): rb.sample(1) - with pytest.raises(ValueError, match="Batch size must be greater than 0"): + with pytest.raises(ValueError, match="must be both greater than 0"): rb.sample(-1) def test_memmap_replay_buffer(): buf_size = 10 n_envs = 4 - with pytest.warns( - UserWarning, - match="The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" - " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", + with pytest.raises( + ValueError, + match="The buffer is set to be memory-mapped but the 'memmap_dir'", ): rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=None) - td = TensorDict( - {"observations": torch.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=torch.uint8)}, batch_size=[10, n_envs] - ) + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir) + td = {"observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8)} rb.add(td) - assert rb.buffer.is_memmap() + assert rb.is_memmap + del rb + shutil.rmtree(root_dir) def test_memmap_to_file_replay_buffer(): @@ -176,13 +261,9 @@ def test_memmap_to_file_replay_buffer(): root_dir = os.path.join("pytest_" + str(int(time.time()))) memmap_dir = os.path.join(root_dir, "memmap_buffer") rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir) - td = TensorDict( - {"observations": torch.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=torch.uint8)}, batch_size=[10, n_envs] - ) + td = {"observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8)} rb.add(td) - assert rb.buffer.is_memmap() - assert os.path.exists(os.path.join(memmap_dir, "meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt")) + assert rb.is_memmap assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) fabric = Fabric(devices=1, accelerator="cpu") ckpt_file = os.path.join(root_dir, "checkpoint", "ckpt.ckpt") @@ -197,15 +278,14 @@ def test_memmap_to_file_replay_buffer(): def test_obs_keys_replay_buffer(): buf_size = 10 n_envs = 4 - rb = ReplayBuffer(buf_size, n_envs, memmap=True, obs_keys=("rgb", "state", "tmp")) - td = TensorDict( - { - "rgb": torch.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=torch.uint8), - "state": torch.randint(0, 256, (10, n_envs, 8), dtype=torch.uint8), - "tmp": torch.randint(0, 256, (10, n_envs, 5), dtype=torch.uint8), - }, - batch_size=[10, n_envs], - ) + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("rgb", "state", "tmp")) + td = { + "rgb": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + "state": np.random.randint(0, 256, (10, n_envs, 8), dtype=np.uint8), + "tmp": np.random.randint(0, 256, (10, n_envs, 5), dtype=np.uint8), + } rb.add(td) sample = rb.sample(10, True) sample_keys = sample.keys() @@ -215,23 +295,24 @@ def test_obs_keys_replay_buffer(): assert "next_rgb" in sample_keys assert "next_state" in sample_keys assert "next_tmp" in sample_keys + del rb + shutil.rmtree(root_dir) def test_obs_keys_replay_no_sample_next_obs_buffer(): buf_size = 10 n_envs = 4 - rb = ReplayBuffer(buf_size, n_envs, memmap=True, obs_keys=("rgb", "state", "tmp")) - td = TensorDict( - { - "rgb": torch.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=torch.uint8), - "state": torch.randint(0, 256, (10, n_envs, 8), dtype=torch.uint8), - "tmp": torch.randint(0, 256, (10, n_envs, 5), dtype=torch.uint8), - "next_rgb": torch.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=torch.uint8), - "next_state": torch.randint(0, 256, (10, n_envs, 8), dtype=torch.uint8), - "next_tmp": torch.randint(0, 256, (10, n_envs, 5), dtype=torch.uint8), - }, - batch_size=[10, n_envs], - ) + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("rgb", "state", "tmp")) + td = { + "rgb": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + "state": np.random.randint(0, 256, (10, n_envs, 8), dtype=np.uint8), + "tmp": np.random.randint(0, 256, (10, n_envs, 5), dtype=np.uint8), + "next_rgb": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + "next_state": np.random.randint(0, 256, (10, n_envs, 8), dtype=np.uint8), + "next_tmp": np.random.randint(0, 256, (10, n_envs, 5), dtype=np.uint8), + } rb.add(td) sample = rb.sample(10, False) sample_keys = sample.keys() @@ -241,3 +322,128 @@ def test_obs_keys_replay_no_sample_next_obs_buffer(): assert "next_rgb" in sample_keys assert "next_state" in sample_keys assert "next_tmp" in sample_keys + del rb + shutil.rmtree(root_dir) + + +def test_sample_tensors(): + import torch + + buf_size = 5 + n_envs = 1 + rb = ReplayBuffer(buf_size, n_envs) + td1 = {"observations": np.arange(8).reshape(-1, 1, 1)} + rb.add(td1) + s = rb.sample_tensors(10, sample_next_obs=True, n_samples=3) + assert isinstance(s["observations"], torch.Tensor) + assert s["observations"].shape == torch.Size([3, 10, 1]) + + +def test_sample_tensor_memmap(): + import torch + + buf_size = 10 + n_envs = 4 + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("observations")) + td = { + "observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + } + rb.add(td) + sample = rb.sample_tensors(10, False, n_samples=3) + assert isinstance(sample["observations"], torch.Tensor) + assert sample["observations"].shape == torch.Size([3, 10, 3, 64, 64]) + del rb + shutil.rmtree(root_dir) + + +def test_to_tensor(): + import torch + + buf_size = 5 + n_envs = 4 + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("observations")) + td = { + "observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + } + rb.add(td) + sample = rb.to_tensor() + assert isinstance(sample["observations"], torch.Tensor) + assert sample["observations"].shape == torch.Size([buf_size, n_envs, 3, 64, 64]) + assert (td["observations"][5:] == sample["observations"].cpu().numpy()).all() + del rb + shutil.rmtree(root_dir) + + +def test_setitem(): + buf_size = 5 + n_envs = 4 + rb = ReplayBuffer(buf_size, n_envs) + td1 = {"observations": np.arange(8).reshape(-1, 1, 1)} + rb.add(td1) + a = np.random.rand(buf_size, n_envs, 10) + rb["a"] = a + assert rb["a"].shape == tuple([buf_size, n_envs, 10]) + assert (rb["a"] == a).all() + + m = MemmapArray(filename="test.memmap", dtype=np.float32, shape=(buf_size, n_envs, 4)) + m.array = np.random.rand(buf_size, n_envs, 4) + rb["m"] = m + assert isinstance(rb["m"], np.ndarray) and not isinstance(rb["m"], (MemmapArray, np.memmap)) + assert rb["m"].shape == tuple([buf_size, n_envs, 4]) + assert (rb["m"] == m.array).all() + + del m + os.unlink("test.memmap") + + +def test_setitem_memmap(): + buf_size = 5 + n_envs = 4 + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = ReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("observations")) + td = { + "observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + } + rb.add(td) + a = np.random.rand(buf_size, n_envs, 10) + rb["a"] = a + assert isinstance(rb["a"], MemmapArray) + assert rb["a"].shape == tuple([buf_size, n_envs, 10]) + assert (rb["a"] == a).all() + + m = MemmapArray(filename=f"{root_dir}/test.memmap", dtype=np.float32, shape=(buf_size, n_envs, 4)) + m.array = np.random.rand(buf_size, n_envs, 4) + rb["m"] = m + assert isinstance(rb["m"], MemmapArray) + assert rb["m"].shape == tuple([buf_size, n_envs, 4]) + assert (rb["m"].array == m.array).all() + + del m + del rb + shutil.rmtree(root_dir) + + +def test_setitem_error(): + import torch + + buf_size = 5 + n_envs = 4 + rb = ReplayBuffer(buf_size, n_envs) + with pytest.raises(RuntimeError, match="The buffer has not been initialized"): + rb["no_init"] = np.zeros((buf_size, n_envs, 1)) + + td1 = {"observations": np.arange(8).reshape(-1, 1, 1)} + rb.add(td1) + + with pytest.raises(ValueError, match=r"The value to be set must be an instance of 'np\.ndarray', 'np\.memmap'"): + rb["torch"] = torch.zeros(buf_size, n_envs, 1) + + with pytest.raises(RuntimeError, match="must have at least two dimensions of dimension"): + rb["wrong_buffer_size"] = np.zeros((buf_size + 3, n_envs, 1)) + rb["wrong_n_envs"] = np.zeros((buf_size, n_envs - 1, 1)) + rb["wrong_dims"] = np.zeros((10,)) diff --git a/tests/test_data/test_env_independent_rb.py b/tests/test_data/test_env_independent_rb.py new file mode 100644 index 00000000..2364993d --- /dev/null +++ b/tests/test_data/test_env_independent_rb.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest + +from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer + + +def test_env_idependent_wrong_buffer_size(): + with pytest.raises(ValueError): + EnvIndependentReplayBuffer(-1) + + +def test_env_idependent_wrong_n_envs(): + with pytest.raises(ValueError): + EnvIndependentReplayBuffer(1, -1) + + +def test_env_independent_missing_memmap_dir(): + with pytest.raises(ValueError): + EnvIndependentReplayBuffer(10, 4, memmap=True, memmap_dir=None) + + +def test_env_independent_wrong_memmap_mode(): + with pytest.raises(ValueError): + EnvIndependentReplayBuffer(10, 4, memmap=True, memmap_mode="a+") + + +def test_env_independent_add(): + bs = 20 + n_envs = 4 + rb = EnvIndependentReplayBuffer(bs, n_envs) + stps1 = {"dones": np.zeros((10, 4, 1))} + rb.add(stps1) + for i in range(n_envs): + assert rb._buf[i]._pos == 10 + stps2 = {"dones": np.zeros((10, 2, 1))} + rb.add(stps2, [0, 3]) + assert rb._buf[0]._pos == 0 + assert rb._buf[1]._pos == 10 + assert rb._buf[2]._pos == 10 + assert rb._buf[0]._pos == 0 + + +def test_env_independent_add_error(): + bs = 10 + n_envs = 4 + rb = EnvIndependentReplayBuffer(bs, n_envs) + stps = {"dones": np.zeros((10, 3, 1))} + with pytest.raises(ValueError): + rb.add(stps) + + +def test_env_independent_sample_shape(): + bs = 20 + n_envs = 4 + rb = EnvIndependentReplayBuffer(bs, n_envs) + stps1 = {"dones": np.ones((10, 4, 1))} + rb.add(stps1) + stps2 = {"dones": np.ones((10, 2, 1))} + rb.add(stps2, [0, 3]) + sample = rb.sample(10, n_samples=10) + assert sample["dones"].shape == tuple([10, 10, 1]) + + +def test_env_independent_sample(): + bs = 20 + n_envs = 4 + rb = EnvIndependentReplayBuffer(bs, n_envs) + stps1 = {"dones": np.ones((10, 4, 1))} + for i in range(n_envs): + stps1["dones"][:, i] *= i + rb.add(stps1) + stps2 = {"dones": np.ones((10, 2, 1))} + for i, env in enumerate([0, 3]): + stps2["dones"][:, i] *= env + rb.add(stps2, [0, 3]) + sample = rb.sample(2000, n_samples=2) + for i in range(n_envs): + assert (sample["dones"] == i).any() + + +def test_env_independent_sample_error(): + bs = 20 + n_envs = 4 + rb = EnvIndependentReplayBuffer(bs, n_envs) + with pytest.raises(ValueError, match="No sample has been added to the buffer"): + rb.sample(10, n_samples=10) + stps1 = {"dones": np.zeros((10, 4, 1))} + rb.add(stps1) + stps2 = {"dones": np.zeros((10, 2, 1))} + rb.add(stps2, [0, 3]) + + with pytest.raises(ValueError, match="must be both greater than 0"): + rb.sample(0, n_samples=10) + rb.sample(10, n_samples=0) + rb.sample(-1, n_samples=10) + rb.sample(10, n_samples=-1) + + +def test_env_independent_sample_tensors(): + import torch + + bs = 20 + n_envs = 4 + rb = EnvIndependentReplayBuffer(bs, n_envs, buffer_cls=SequentialReplayBuffer) + with pytest.raises(ValueError, match="No sample has been added to the buffer"): + rb.sample(10, n_samples=10) + stps1 = {"dones": np.zeros((10, 4, 1))} + rb.add(stps1) + stps2 = {"dones": np.zeros((10, 2, 1))} + rb.add(stps2, [0, 3]) + + s = rb.sample_tensors(10, n_samples=3, sequence_length=5) + assert isinstance(s["dones"], torch.Tensor) + assert s["dones"].shape == torch.Size([3, 5, 10, 1]) diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index 07ef524c..967abfbf 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -2,215 +2,386 @@ import shutil import time +import numpy as np import pytest import torch -from tensordict import TensorDict -from sheeprl.data.buffers import EpisodeBuffer +from sheeprl.data.buffers import EpisodeBuffer, ReplayBuffer +from sheeprl.utils.memmap import MemmapArray def test_episode_buffer_wrong_buffer_size(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The buffer size must be greater than zero"): EpisodeBuffer(-1, 10) def test_episode_buffer_wrong_sequence_length(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The sequence length must be greater than zero"): EpisodeBuffer(1, -1) def test_episode_buffer_sequence_length_greater_than_batch_size(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The sequence length must be lower than the buffer size"): EpisodeBuffer(5, 10) +@pytest.mark.parametrize("memmap_mode", ["r", "x", "w", "z"]) +def test_replay_buffer_wrong_memmap_mode(memmap_mode): + with pytest.raises(ValueError, match="Accepted values for memmap_mode are"): + EpisodeBuffer(10, 10, memmap_mode=memmap_mode, memmap=True) + + def test_episode_buffer_add_episodes(): buf_size = 30 sl = 5 - rb = EpisodeBuffer(buf_size, sl) - td1 = TensorDict({"dones": torch.zeros(sl, 1)}, batch_size=[sl]) - td2 = TensorDict({"dones": torch.zeros(sl + 5, 1)}, batch_size=[sl + 5]) - td3 = TensorDict({"dones": torch.zeros(sl + 10, 1)}, batch_size=[sl + 10]) - td4 = TensorDict({"dones": torch.zeros(sl, 1)}, batch_size=[sl]) - td1["dones"][-1] = 1 - td2["dones"][-1] = 1 - td3["dones"][-1] = 1 - td4["dones"][-1] = 1 - rb.add(td1) - rb.add(td2) - rb.add(td3) - rb.add(td4) + n_envs = 1 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((sl, n_envs, 1))} + ep2 = {"dones": np.zeros((sl + 5, n_envs, 1))} + ep3 = {"dones": np.zeros((sl + 10, n_envs, 1))} + ep4 = {"dones": np.zeros((sl, n_envs, 1))} + ep1["dones"][-1] = 1 + ep2["dones"][-1] = 1 + ep3["dones"][-1] = 1 + ep4["dones"][-1] = 1 + rb.add(ep1) + rb.add(ep2) + rb.add(ep3) + rb.add(ep4) assert rb.full - assert (rb[-1]["dones"] == td4["dones"]).all() - assert (rb[0]["dones"] == td2["dones"]).all() + assert (rb._buf[-1]["dones"] == ep4["dones"][:, 0]).all() + assert (rb._buf[0]["dones"] == ep2["dones"][:, 0]).all() -def test_episode_buffer_add_single_td(): +def test_episode_buffer_add_single_dict(): buf_size = 5 sl = 5 - rb = EpisodeBuffer(buf_size, sl) - td1 = TensorDict({"dones": torch.zeros(sl, 1)}, batch_size=[sl]) - td1["dones"][-1] = 1 - rb.add(td1) + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((sl, n_envs, 1))} + ep1["dones"][-1] = 1 + rb.add(ep1) assert rb.full - assert (rb[0]["dones"] == td1["dones"]).all() + for env in range(n_envs): + assert (rb._buf[0]["dones"] == ep1["dones"][:, env]).all() def test_episode_buffer_error_add(): buf_size = 10 sl = 5 - rb = EpisodeBuffer(buf_size, sl) - td1 = TensorDict({"dones": torch.zeros(sl - 2, 1)}, batch_size=[sl - 2]) - with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): - rb.add(td1) + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + + ep1 = torch.zeros(sl, n_envs, 1) + with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays, but `data` is of type"): + rb.add(ep1, validate_args=True) + + ep2 = {"dones": torch.zeros((sl, n_envs, 1))} + with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays. Found key"): + rb.add(ep2, validate_args=True) + + ep3 = None + with pytest.raises(ValueError, match="The `data` replay buffer must be not None"): + rb.add(ep3, validate_args=True) + + ep4 = {"dones": np.zeros((1,))} + with pytest.raises(RuntimeError, match=r"`data` must have at least 2: \[sequence_length, n_envs"): + rb.add(ep4, validate_args=True) + + obs_keys = ("dones", "obs") + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep5 = {"dones": np.zeros((sl, n_envs, 1)), "obs": np.zeros((sl, 1, 6))} + with pytest.raises(RuntimeError, match="Every array in `data` must be congruent in the first 2 dimensions"): + rb.add(ep5, validate_args=True) + + ep6 = {"obs": np.zeros((sl, 1, 6))} + with pytest.raises(RuntimeError, match="The episode must contain the `dones` key"): + rb.add(ep6, validate_args=True) + + ep7 = {"dones": np.zeros((sl, 1, 1))} + ep7["dones"][-1] = 1 + with pytest.raises(ValueError, match="The indices of the environment must be integers in"): + rb.add(ep7, validate_args=True, env_idxes=[10]) + - td1["dones"][-3:] = 1 +def test_add_only_for_some_envs(): + buf_size = 10 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((sl, n_envs - 2, 1))} + rb.add(ep1, env_idxes=[0, 3]) + assert len(rb._open_episodes[0]) > 0 + assert len(rb._open_episodes[1]) == 0 + assert len(rb._open_episodes[2]) == 0 + assert len(rb._open_episodes[3]) > 0 + + +def test_save_episode(): + buf_size = 100 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] + ep_chunks[-1]["dones"][-1] = 1 + rb._save_episode(ep_chunks) + + assert len(rb._buf) == 1 + assert ( + np.concatenate([e["dones"] for e in rb.buffer], axis=0) + == np.concatenate([c["dones"] for c in ep_chunks], axis=0) + ).all() + + +def test_save_episode_errors(): + buf_size = 100 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + + with pytest.raises(RuntimeError, match="Invalid episode, an empty sequence is given"): + rb._save_episode([]) + + ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] + ep_chunks[-1]["dones"][-1] = 1 + ep_chunks[0]["dones"][-1] = 1 with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): - rb.add(td1) + rb._save_episode(ep_chunks) - td1["dones"][-2:] = 0 - with pytest.raises(RuntimeError, match="The last step must contain a done"): - rb.add(td1) + ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] + ep_chunks[0]["dones"][-1] = 1 + with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): + rb._save_episode(ep_chunks) - td1["dones"][-3] = 0 - td1["dones"][-1] = 1 + ep_chunks = [{"dones": np.ones((1, 1))}] with pytest.raises(RuntimeError, match="Episode too short"): - rb.add(td1) + rb._save_episode(ep_chunks) - td1 = TensorDict({"dones": torch.zeros(15, 1)}, batch_size=[15]) - td1["dones"][-1] = 1 + ep_chunks = [{"dones": np.zeros((110, 1))} for _ in range(8)] + ep_chunks[-1]["dones"][-1] = 1 with pytest.raises(RuntimeError, match="Episode too long"): - rb.add(td1) - - td1 = TensorDict({"t": torch.zeros(15, 1)}, batch_size=[15]) - td1["t"][-1] = 1 - with pytest.raises(KeyError, match='key "dones" not found'): - rb.add(td1) + rb._save_episode(ep_chunks) def test_episode_buffer_sample_one_element(): buf_size = 5 sl = 5 - rb = EpisodeBuffer(buf_size, sl) - td1 = TensorDict({"dones": torch.zeros(sl, 1), "t": torch.rand(sl, 1)}, batch_size=[sl]) - td1["dones"][-1] = 1 - rb.add(td1) - sample = rb.sample(1, 1) + n_envs = 1 + obs_keys = ("dones", "a") + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep = {"dones": np.zeros((sl, n_envs, 1)), "a": np.random.rand(sl, n_envs, 1)} + ep["dones"][-1] = 1 + rb.add(ep) + sample = rb.sample(1, n_samples=1, sequence_length=sl) assert rb.full - assert (sample["dones"][0, :, 0] == td1["dones"]).all() - assert (sample["t"][0, :, 0] == td1["t"]).all() + assert (sample["dones"][0, :, 0] == ep["dones"][:, 0]).all() + assert (sample["a"][0, :, 0] == ep["a"][:, 0]).all() def test_episode_buffer_sample_shapes(): buf_size = 30 sl = 2 - rb = EpisodeBuffer(buf_size, sl) - t = TensorDict({"dones": torch.zeros(sl, 1)}, batch_size=[sl]) - t["dones"][-1] = 1 - rb.add(t) - sample = rb.sample(3, n_samples=2) - assert sample.shape == torch.Size([2, sl, 3]) + n_envs = 1 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep = {"dones": np.zeros((sl, n_envs, 1))} + ep["dones"][-1] = 1 + rb.add(ep) + sample = rb.sample(3, n_samples=2, sequence_length=sl) + assert sample["dones"].shape[:-1] == tuple([2, sl, 3]) + sample = rb.sample(3, n_samples=2, sequence_length=sl, clone=True) + assert sample["dones"].shape[:-1] == tuple([2, sl, 3]) def test_episode_buffer_sample_more_episodes(): buf_size = 100 sl = 15 - rb = EpisodeBuffer(buf_size, sl) - td1 = TensorDict({"dones": torch.zeros(40, 1), "t": torch.ones(40, 1) * -1}, batch_size=[40]) - td2 = TensorDict({"dones": torch.zeros(45, 1), "t": torch.ones(45, 1) * -2}, batch_size=[45]) - td3 = TensorDict({"dones": torch.zeros(50, 1), "t": torch.ones(50, 1) * -3}, batch_size=[50]) - td1["dones"][-1] = 1 - td2["dones"][-1] = 1 - td3["dones"][-1] = 1 - rb.add(td1) - rb.add(td2) - rb.add(td3) - samples = rb.sample(50, n_samples=5) - assert samples.shape == torch.Size([5, sl, 50]) - for seq in samples.permute(0, -1, -2).reshape(-1, sl, 1): - assert torch.isin(seq["t"], -1).all() or torch.isin(seq["t"], -2).all() or torch.isin(seq["t"], -3).all() - assert len(torch.nonzero(seq["dones"])) == 0 or seq["dones"][-1] == 1 + n_envs = 1 + obs_keys = ("dones", "a") + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((40, n_envs, 1)), "a": np.ones((40, n_envs, 1)) * -1} + ep2 = {"dones": np.zeros((45, n_envs, 1)), "a": np.ones((45, n_envs, 1)) * -2} + ep3 = {"dones": np.zeros((50, n_envs, 1)), "a": np.ones((50, n_envs, 1)) * -3} + ep1["dones"][-1] = 1 + ep2["dones"][-1] = 1 + ep3["dones"][-1] = 1 + rb.add(ep1) + rb.add(ep2) + rb.add(ep3) + samples = rb.sample(50, n_samples=5, sequence_length=sl) + assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) + samples = {k: np.moveaxis(samples[k], 2, 1).reshape(-1, sl, 1) for k in obs_keys} + for i in range(len(samples["dones"])): + assert ( + np.isin(samples["a"][i], -1).all() + or np.isin(samples["a"][i], -2).all() + or np.isin(samples["a"][i], -3).all() + ) + assert len(samples["dones"][i].nonzero()[0]) == 0 or samples["dones"][i][-1] == 1 def test_episode_buffer_error_sample(): buf_size = 10 sl = 5 rb = EpisodeBuffer(buf_size, sl) - with pytest.raises(RuntimeError, match="No sample has been added"): - rb.sample(2, 2) + with pytest.raises(RuntimeError, match="No valid episodes has been added to the buffer"): + rb.sample(2, n_samples=2) with pytest.raises(ValueError, match="Batch size must be greater than 0"): rb.sample(-1, n_samples=2) with pytest.raises(ValueError, match="The number of samples must be greater than 0"): - rb.sample(2, -1) + rb.sample(2, n_samples=-1) + ep1 = {"dones": np.zeros((15, 1, 1))} + rb.add(ep1) + with pytest.raises(RuntimeError, match="No valid episodes has been added to the buffer"): + rb.sample(2, n_samples=2, sequence_length=20) + rb.sample(2, n_samples=2, sample_next_obs=True, sequence_length=sl) def test_episode_buffer_prioritize_ends(): buf_size = 100 sl = 15 - rb = EpisodeBuffer(buf_size, sl) - td1 = TensorDict({"dones": torch.zeros(15, 1)}, batch_size=[15]) - td2 = TensorDict({"dones": torch.zeros(25, 1)}, batch_size=[25]) - td3 = TensorDict({"dones": torch.zeros(30, 1)}, batch_size=[30]) - td1["dones"][-1] = 1 - td2["dones"][-1] = 1 - td3["dones"][-1] = 1 - rb.add(td1) - rb.add(td2) - rb.add(td3) - samples = rb.sample(50, n_samples=5, prioritize_ends=True) - assert samples.shape == torch.Size([5, sl, 50]) - assert torch.isin(samples["dones"], 1).any() > 0 + n_envs = 1 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, prioritize_ends=True) + ep1 = {"dones": np.zeros((15, n_envs, 1))} + ep2 = {"dones": np.zeros((25, n_envs, 1))} + ep3 = {"dones": np.zeros((30, n_envs, 1))} + ep1["dones"][-1] = 1 + ep2["dones"][-1] = 1 + ep3["dones"][-1] = 1 + rb.add(ep1) + rb.add(ep2) + rb.add(ep3) + samples = rb.sample(50, n_samples=5, sequence_length=sl) + assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) + assert np.isin(samples["dones"], 1).any() > 0 + + +def test_sample_next_obs(): + buf_size = 10 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((sl, n_envs, 1))} + ep1["dones"][-1] = 1 + rb.add(ep1) + sample = rb.sample(10, True, n_samples=5, sequence_length=sl - 1) + assert "next_dones" in sample + assert (sample["next_dones"][:, -1] == 1).all() def test_memmap_episode_buffer(): buf_size = 10 bs = 4 sl = 4 - with pytest.warns( - UserWarning, - match="The buffer will be memory-mapped into the `/tmp` folder, this means that there is the" - " possibility to lose the saved files. Set the `memmap_dir` to a known directory.", + n_envs = 1 + obs_keys = ("dones", "observations") + with pytest.raises( + ValueError, + match="The buffer is set to be memory-mapped but the `memmap_dir` attribute is None", ): - rb = EpisodeBuffer(buf_size, sl, memmap=True) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, memmap=True) + + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, memmap=True, memmap_dir="test_episode_buffer") for _ in range(buf_size // bs): - td = TensorDict( - {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, - batch_size=[bs], - ) - td["dones"][-1] = 1 - rb.add(td) - assert rb[-1].is_memmap() + ep = { + "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), + "dones": np.zeros((bs, n_envs, 1)), + } + ep["dones"][-1] = 1 + rb.add(ep) + assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["observations"], MemmapArray) assert rb.is_memmap + del rb + shutil.rmtree(os.path.abspath("test_episode_buffer")) def test_memmap_to_file_episode_buffer(): buf_size = 10 bs = 5 sl = 4 - root_dir = os.path.join("pytest_" + str(int(time.time()))) - memmap_dir = os.path.join(root_dir, "memmap_buffer") - rb = EpisodeBuffer(buf_size, sl, memmap=True, memmap_dir=memmap_dir) + n_envs = 1 + obs_keys = ("dones", "observations") + memmap_dir = "test_episode_buffer" + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, memmap=True, memmap_dir=memmap_dir) for i in range(4): if i >= 2: bs = 7 else: bs = 5 - td = TensorDict( - {"observations": torch.randint(0, 256, (bs, 3, 64, 64), dtype=torch.uint8), "dones": torch.zeros(bs)}, - batch_size=[bs], - ) - td["dones"][-1] = 1 - rb.add(td) - del td - assert rb[-1].is_memmap() - memmap_dir = os.path.dirname(rb.buffer[-1][rb.buffer[-1].sorted_keys[0]].filename) - assert os.path.exists(os.path.join(memmap_dir, "meta.pt")) - assert os.path.exists(os.path.join(memmap_dir, "dones.meta.pt")) + ep = { + "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), + "dones": np.zeros((bs, n_envs, 1)), + } + ep["dones"][-1] = 1 + rb.add(ep) + del ep + assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["observations"], MemmapArray) + memmap_dir = os.path.dirname(rb._buf[-1]["dones"].filename) assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) - assert os.path.exists(os.path.join(memmap_dir, "observations.meta.pt")) assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) assert rb.is_memmap for ep in rb.buffer: del ep del rb + shutil.rmtree(os.path.abspath("test_episode_buffer")) + + +def test_sample_tensors(): + import torch + + buf_size = 10 + n_envs = 1 + rb = EpisodeBuffer(buf_size, n_envs) + td = {"observations": np.arange(8).reshape(-1, 1, 1), "dones": np.zeros((8, 1, 1))} + td["dones"][-1] = 1 + rb.add(td) + s = rb.sample_tensors(10, sample_next_obs=True, n_samples=3, sequence_length=5) + assert isinstance(s["observations"], torch.Tensor) + assert s["observations"].shape == torch.Size([3, 5, 10, 1]) + s = rb.sample_tensors(10, sample_next_obs=True, n_samples=3, sequence_length=5, from_numpy=True, clone=True) + assert isinstance(s["observations"], torch.Tensor) + assert s["observations"].shape == torch.Size([3, 5, 10, 1]) + + +def test_sample_tensor_memmap(): + import torch + + buf_size = 10 + n_envs = 4 + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = EpisodeBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("observations")) + td = { + "observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + "dones": np.zeros((buf_size, n_envs, 1)), + } + td["dones"][-1] = 1 + rb.add(td) + sample = rb.sample_tensors(10, False, n_samples=3, sequence_length=5) + assert isinstance(sample["observations"], torch.Tensor) + assert sample["observations"].shape == torch.Size([3, 5, 10, 3, 64, 64]) + del rb shutil.rmtree(root_dir) + + +def test_add_rb(): + buf_size = 10 + n_envs = 3 + rb = ReplayBuffer(buf_size, n_envs) + rb.add({"dones": np.zeros((buf_size, n_envs, 1)), "a": np.random.rand(buf_size, n_envs, 5)}) + rb["dones"][-1] = 1 + epb = EpisodeBuffer(buf_size * n_envs, minimum_episode_length=2, n_envs=n_envs) + epb.add(rb) + assert (rb["a"][:, 0] == epb._buf[0]["a"]).all() + assert (rb["a"][:, 1] == epb._buf[1]["a"]).all() + assert (rb["a"][:, 2] == epb._buf[2]["a"]).all() diff --git a/tests/test_data/test_sequential_buffer.py b/tests/test_data/test_sequential_buffer.py index 294c8381..95bf8a8b 100644 --- a/tests/test_data/test_sequential_buffer.py +++ b/tests/test_data/test_sequential_buffer.py @@ -1,6 +1,9 @@ +import os +import shutil +import time + +import numpy as np import pytest -import torch -from tensordict import TensorDict from sheeprl.data.buffers import SequentialReplayBuffer @@ -19,47 +22,47 @@ def test_seq_replay_buffer_add_tds(): buf_size = 5 n_envs = 1 rb = SequentialReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(2, 1, 1)}, batch_size=[2, n_envs]) - td2 = TensorDict({"t": torch.rand(2, 1, 1)}, batch_size=[2, n_envs]) - td3 = TensorDict({"t": torch.rand(3, 1, 1)}, batch_size=[3, n_envs]) + td1 = {"a": np.random.rand(2, 1, 1)} + td2 = {"a": np.random.rand(2, 1, 1)} + td3 = {"a": np.random.rand(3, 1, 1)} rb.add(td1) rb.add(td2) rb.add(td3) assert rb.full - assert rb["t"][0] == td3["t"][-2] - assert rb["t"][1] == td3["t"][-1] - torch.testing.assert_close(rb["t"][2:4], td2["t"]) + assert rb["a"][0] == td3["a"][-2] + assert rb["a"][1] == td3["a"][-1] + np.testing.assert_allclose(rb["a"][2:4], td2["a"]) def test_seq_replay_buffer_add_single_td(): buf_size = 5 n_envs = 1 rb = SequentialReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(6, 1, 1)}, batch_size=[6, n_envs]) + td1 = {"a": np.random.rand(6, 1, 1)} rb.add(td1) assert rb.full - assert rb["t"][0] == td1["t"][-1] + assert rb["a"][0] == td1["a"][-1] def test_seq_replay_buffer_sample(): buf_size = 10 n_envs = 1 rb = SequentialReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(11, 1, 1)}, batch_size=[11, n_envs]) + td1 = {"a": np.random.rand(11, 1, 1)} rb.add(td1) s = rb.sample(4, sequence_length=2) - assert s.shape == torch.Size([1, 2, 4]) + assert s["a"].shape == tuple([1, 2, 4, 1]) def test_seq_replay_buffer_sample_one_element(): buf_size = 1 n_envs = 1 rb = SequentialReplayBuffer(buf_size, n_envs) - td1 = TensorDict({"t": torch.rand(1, 1, 1)}, batch_size=[1, n_envs]) + td1 = {"a": np.random.rand(1, 1, 1)} rb.add(td1) sample = rb.sample(1, sequence_length=1) assert rb.full - assert sample["t"] == td1["t"] + assert sample["a"] == td1["a"] with pytest.raises(ValueError): rb.sample(1, sequence_length=2) @@ -67,11 +70,14 @@ def test_seq_replay_buffer_sample_one_element(): def test_seq_replay_buffer_sample_shapes(): buf_size = 30 n_envs = 2 - rb = SequentialReplayBuffer(buf_size, n_envs) - t = TensorDict({"t": torch.arange(60).reshape(-1, 2, 1) % buf_size}, batch_size=[30, n_envs]) + rb = SequentialReplayBuffer(buf_size, n_envs, obs_keys=("a",)) + t = {"a": np.arange(60).reshape(-1, 2, 1) % buf_size} rb.add(t) sample = rb.sample(3, sequence_length=5, n_samples=2) - assert sample.shape == torch.Size([2, 5, 3]) + assert sample["a"].shape == tuple([2, 5, 3, 1]) + sample = rb.sample(3, sequence_length=5, n_samples=2, sample_next_obs=True, clone=True) + assert sample["a"].shape == tuple([2, 5, 3, 1]) + assert sample["next_a"].shape == tuple([2, 5, 3, 1]) def test_seq_replay_buffer_sample_full(): @@ -79,10 +85,10 @@ def test_seq_replay_buffer_sample_full(): n_envs = 1 seq_len = 50 rb = SequentialReplayBuffer(buf_size, n_envs) - t = TensorDict({"t": torch.arange(10500).reshape(-1, 1, 1) % buf_size}, batch_size=[10500, n_envs]) + t = {"a": np.arange(10500).reshape(-1, 1, 1) % buf_size} rb.add(t) samples = rb.sample(1000, sequence_length=seq_len, n_samples=5) - assert not torch.logical_and((samples["t"][:, 0, :] < rb._pos), (samples["t"][:, -1, :] >= rb._pos)).any() + assert not np.logical_and((samples["a"][:, 0, :] < rb._pos), (samples["a"][:, -1, :] >= rb._pos)).any() def test_seq_replay_buffer_sample_full_large_sl(): @@ -90,13 +96,13 @@ def test_seq_replay_buffer_sample_full_large_sl(): n_envs = 1 seq_len = 1000 rb = SequentialReplayBuffer(buf_size, n_envs) - t = TensorDict({"t": torch.arange(10500).reshape(-1, 1, 1) % buf_size}, batch_size=[10500, n_envs]) + t = {"a": np.arange(10500).reshape(-1, 1, 1) % buf_size} rb.add(t) samples = rb.sample(1000, sequence_length=seq_len, n_samples=5) - assert not torch.logical_and( - (samples["t"][:, 0, :] >= buf_size + rb._pos - seq_len + 1), (samples["t"][:, -1, :] < rb._pos) + assert not np.logical_and( + (samples["a"][:, 0, :] >= buf_size + rb._pos - seq_len + 1), (samples["a"][:, -1, :] < rb._pos) ).any() - assert not torch.logical_and((samples["t"][:, 0, :] < rb._pos), (samples["t"][:, -1, :] >= rb._pos)).any() + assert not np.logical_and((samples["a"][:, 0, :] < rb._pos), (samples["a"][:, -1, :] >= rb._pos)).any() def test_seq_replay_buffer_sample_fail_not_full(): @@ -104,9 +110,9 @@ def test_seq_replay_buffer_sample_fail_not_full(): n_envs = 1 seq_len = 8 rb = SequentialReplayBuffer(buf_size, n_envs) - t = TensorDict({"t": torch.arange(5).reshape(-1, 1, 1)}, batch_size=[5, n_envs]) + t = {"a": np.arange(5).reshape(-1, 1, 1)} rb.add(t) - with pytest.raises(ValueError, match="too long sequence length"): + with pytest.raises(ValueError, match="Cannot sample a sequence of length"): rb.sample(5, sequence_length=seq_len, n_samples=1) @@ -114,11 +120,11 @@ def test_seq_replay_buffer_sample_not_full(): buf_size = 10 n_envs = 1 rb = SequentialReplayBuffer(buf_size, n_envs) - rb._buf = TensorDict({"t": torch.ones(10, n_envs, 1) * 20}, batch_size=[10, n_envs]) - t = TensorDict({"t": torch.arange(7).reshape(-1, 1, 1) * 1.0}, batch_size=[7, n_envs]) + rb._buf = {"a": np.ones((10, n_envs, 1)) * 20} + t = {"a": np.arange(7).reshape(-1, 1, 1) * 1.0} rb.add(t) sample = rb.sample(2, sequence_length=5, n_samples=2) - assert (sample["t"] < 7).all() + assert (sample["a"] < 7).all() def test_seq_replay_buffer_sample_no_add(): @@ -127,3 +133,47 @@ def test_seq_replay_buffer_sample_no_add(): rb = SequentialReplayBuffer(buf_size, n_envs) with pytest.raises(ValueError, match="No sample has been added"): rb.sample(2, sequence_length=5, n_samples=2) + + +def test_seq_replay_buffer_sample_error(): + buf_size = 10 + n_envs = 1 + rb = SequentialReplayBuffer(buf_size, n_envs) + with pytest.raises(ValueError, match="must be both greater than "): + rb.sample(-1, sequence_length=5, n_samples=2) + rb.sample(2, sequence_length=-1, n_samples=2) + + +def test_sample_tensors(): + import torch + + buf_size = 5 + n_envs = 1 + rb = SequentialReplayBuffer(buf_size, n_envs) + td = {"observations": np.arange(8).reshape(-1, 1, 1), "dones": np.zeros((8, 1, 1))} + td["dones"][-1] = 1 + rb.add(td) + s = rb.sample_tensors(10, sample_next_obs=True, n_samples=3, sequence_length=5) + assert isinstance(s["observations"], torch.Tensor) + assert s["observations"].shape == torch.Size([3, 5, 10, 1]) + + +def test_sample_tensor_memmap(): + import torch + + buf_size = 10 + n_envs = 4 + root_dir = os.path.join("pytest_" + str(int(time.time()))) + memmap_dir = os.path.join(root_dir, "memmap_buffer") + rb = SequentialReplayBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("observations")) + td = { + "observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), + "dones": np.zeros((buf_size, n_envs, 1)), + } + td["dones"][-1] = 1 + rb.add(td) + sample = rb.sample_tensors(10, False, n_samples=3, sequence_length=5) + assert isinstance(sample["observations"], torch.Tensor) + assert sample["observations"].shape == torch.Size([3, 5, 10, 3, 64, 64]) + del rb + shutil.rmtree(root_dir) diff --git a/tests/test_utils/test_memmap.py b/tests/test_utils/test_memmap.py new file mode 100644 index 00000000..40aee2d8 --- /dev/null +++ b/tests/test_utils/test_memmap.py @@ -0,0 +1,207 @@ +import os +import pickle +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import pytest + +from sheeprl.utils.imports import _IS_WINDOWS +from sheeprl.utils.memmap import MemmapArray + + +@pytest.mark.parametrize( + "dtype", + [ + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.float32, + np.float64, + ], +) +@pytest.mark.parametrize("shape", [[2], [1, 2]]) +def test_memmap_data_type(dtype: np.dtype, shape): + """Test that MemmapArray can be created with a given data type and shape.""" + a = np.array([1, 0], dtype=dtype).reshape(shape) + m = MemmapArray.from_array(a) + assert m.dtype == a.dtype + assert (m == a).all() + assert m.shape == a.shape + + +def test_memmap_del(): + a = np.array([1]) + m = MemmapArray.from_array(a) + filename = m.filename + assert os.path.isfile(filename) + del m + assert not os.path.isfile(filename) + + +def test_memmap_pickling(): + a = np.array([1]) + m1 = MemmapArray.from_array(a) + filename = m1.filename + m1_pickle = pickle.dumps(m1) + assert m1._has_ownership + m2 = pickle.loads(m1_pickle) + assert m2.filename == m1.filename + assert not m2._has_ownership + del m1, m2 + assert not os.path.isfile(filename) + + +def test_memmap_array_get_not_none(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + assert m1.array is not None + + +def test_memmap_array_get_none(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + m1.__del__() + with pytest.raises(Exception): + m1.array + + +@pytest.mark.skipif( + _IS_WINDOWS, reason="'test_memmap_array_get_none_linux_only' should be run and succeed only on Linux" +) +def test_memmap_array_get_none_linux_only(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + m2 = MemmapArray.from_array(m1, filename=m1.filename) + del m1 + with pytest.raises(FileNotFoundError): + m2.array + + +def test_memmap_array_set_from_numpy(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + a = np.ones((10,)) * 3 + m1.array = a + assert (m1.array == a).all() + a = np.ones((10,)) * 4 + assert not (m1.array == a).all() + del m1 + + +def test_memmap_array_set_from_numpy_wrong_shape(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + a = np.ones((11,)) + with pytest.raises( + ValueError, match="The shape of the value to be set must be the same as the shape of the memory-mapped array. " + ): + m1.array = a + del m1 + + +def test_memmap_array_set_from_np_memmap(): + a = np.ones((10,)) * 2 + tmpfd, filename = tempfile.mkstemp(".memmap") + os.close(tmpfd) + memmap = np.memmap(filename, shape=a.shape, dtype=a.dtype) + memmap[:] = a[:] + m = MemmapArray(dtype=memmap.dtype, shape=memmap.shape) + assert m.has_ownership + m.array = memmap + m.array[:] = m.array * 2 + assert not m.has_ownership + del m + assert os.path.isfile(filename) + assert (memmap == 4).all() + memmap._mmap.close() + del memmap + Path.unlink(Path(filename), missing_ok=True) + + +def test_memmap_array_set_from_memmap_array(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + m2 = MemmapArray(dtype=m1.dtype, shape=m1.shape, mode=m1.mode) + filename = m1.filename + assert m2.has_ownership + with pytest.raises( + ValueError, + match="The value to be set must be an instance of 'np.memmap' or 'np.ndarray', " + "got ''", + ): + m2.array = m1 + m2.array = m1.array + m2.array[:] = m2 * 2 + assert not m2.has_ownership + del m2 + assert os.path.isfile(filename) + assert (m1.array == 4).all() + del m1 + assert not os.path.isfile(filename) + + +def test_memmap_from_array_memmap_array_different_filename(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + m2 = MemmapArray.from_array(m1) + m1_filename = m1.filename + m2_filename = m2.filename + assert m1.has_ownership + assert m2.has_ownership + assert m1_filename != m2_filename + assert (m1.array == m2.array).all() + del m1 + del m2 + assert not os.path.isfile(m1_filename) + assert not os.path.isfile(m2_filename) + + +def test_memmap_from_array_memmap_array(): + a = np.ones((10,)) * 2 + m1 = MemmapArray.from_array(a) + m2 = MemmapArray.from_array(m1, filename=m1.filename) + filename = m1.filename + assert m1.has_ownership + assert not m2.has_ownership + del m2 + assert os.path.isfile(filename) + del m1 + assert not os.path.isfile(filename) + + +@pytest.mark.parametrize("mode", ["r", "r+", "w+", "c", "readonly", "readwrite", "write", "copyonwrite"]) +def test_memmap_mode(mode): + ma = MemmapArray(shape=10, dtype=np.float32, filename="./test_memmap_mode/test.memmap") + ma[:] = np.ones(10) * 1.5 + del ma + + ma = MemmapArray(shape=10, dtype=np.float32, filename="./test_memmap_mode/test.memmap", mode=mode) + if mode in ("r", "readonly", "r+", "readwrite", "c", "copyonwrite"): + # data in memmap persists + assert (ma.array == 1.5).all() + elif mode in ("w+", "write"): + # memmap is initialized to zero + assert (ma.array == 0).all() + + if mode in ("r", "readonly"): + with pytest.raises(ValueError): + ma[:] = np.ones(10) * 2.5 + del ma + else: + ma[:] = np.ones(10) * 2.5 + assert (ma.array == 2.5).all() + del ma + + mt2 = MemmapArray(shape=10, dtype=np.float32, filename="./test_memmap_mode/test.memmap") + if mode in ("c", "copyonwrite"): + # tensor was only mutated in memory, not on disk + assert (mt2.array == 1.5).all() + else: + assert (mt2.array == 2.5).all() + del mt2 + shutil.rmtree("./test_memmap_mode")