Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/episode buffer np #121

Merged
merged 11 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 322 additions & 0 deletions sheeprl/data/buffers_np.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import logging
import os
import shutil
import typing
import uuid
from pathlib import Path
from typing import Dict, Optional, Sequence, Type

Expand Down Expand Up @@ -687,3 +690,322 @@ def sample_tensors(
)
samples[k] = torch_v
return samples


class EpisodeBuffer:
"""A replay buffer that stores separately the episodes.

Args:
buffer_size (int): The capacity of the buffer.
sequence_length (int): The length of the sequences of the samples
(an episode cannot be shorter than the episode length).
n_envs (int): The number of environments.
Default to 1.
obs_keys (Sequence[str]): The observations keys to store in the buffer.
Default to ("observations",).
prioritize_ends (bool): Whether to prioritize the ends of the episodes when sampling.
Default to False.
memmap (bool): Whether to memory-mapping the buffer.
Default to False.
memmap_dir (str | os.PathLike, optional): The directory for the memmap.
Default to None.
memmap_mode (str, optional): memory-map mode.
Possible values are: "r+", "w+", "c", "copyonwrite", "readwrite", "write".
Defaults to "r+".
"""

def __init__(
self,
buffer_size: int,
sequence_length: int,
n_envs: int = 1,
obs_keys: Sequence[str] = ("observations",),
prioritize_ends: bool = False,
memmap: bool = False,
memmap_dir: str | os.PathLike | None = None,
memmap_mode: str = "r+",
) -> None:
if buffer_size <= 0:
raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}")
if sequence_length <= 0:
raise ValueError(f"The sequence length must be greater than zero, got: {sequence_length}")
if buffer_size < sequence_length:
raise ValueError(
"The sequence length must be lower than the buffer size, "
f"got: bs = {buffer_size} and sl = {sequence_length}"
)
self._n_envs = n_envs
self._obs_keys = obs_keys
self._buffer_size = buffer_size
self._sequence_length = sequence_length
self._prioritize_ends = prioritize_ends

# One list for each environment that contains open episodes:
# one open episode per environment
self._open_episodes = [[] for _ in range(n_envs)]
# Contain the cumulative length of the episodes in the buffer
self._cum_lengths: Sequence[int] = []
# List of stored episodes
self._buf: Sequence[Dict[str, np.ndarray | MemmapArray]] = []

self._memmap = memmap
self._memmap_dir = memmap_dir
self._memmap_mode = memmap_mode
if self._memmap:
if self._memmap_mode not in ("r+", "w+", "c", "copyonwrite", "readwrite", "write"):
raise ValueError(
'Accepted values for memmap_mode are "r+", "readwrite", "w+", "write", "c" or '
'"copyonwrite". PyTorch does not support tensors backed by read-only '
'NumPy arrays, so "r" and "readonly" are not supported.'
)
if self._memmap_dir is None:
raise ValueError(
"The buffer is set to be memory-mapped but the `memmap_dir` attribute is None. "
"Set the `memmap_dir` to a known directory.",
)
else:
self._memmap_dir = Path(self._memmap_dir)
self._memmap_dir.mkdir(parents=True, exist_ok=True)
self._chunk_length = np.arange(sequence_length, dtype=np.intp).reshape(1, -1)

@property
def prioritize_ends(self) -> bool:
return self._prioritize_ends

@prioritize_ends.setter
def prioritize_ends(self, prioritize_ends: bool) -> None:
self._prioritize_ends = prioritize_ends

@property
def buffer(self) -> Optional[Dict[str, np.ndarray | MemmapArray]]:
if len(self._buf) > 0:
return {k: np.concatenate([v[k] for v in self._buf]) for k in self._obs_keys}
else:
return {}

@property
def obs_keys(self) -> Sequence[str]:
return self._obs_keys

@property
def n_envs(self) -> int:
return self._n_envs

@property
def buffer_size(self) -> int:
return self._buffer_size

@property
def sequence_length(self) -> int:
return self._sequence_length

@property
def is_memmap(self) -> bool:
return self._memmap

@property
def full(self) -> bool:
return self._cum_lengths[-1] + self._sequence_length > self._buffer_size if len(self._buf) > 0 else False

def __len__(self) -> int:
return self._cum_lengths[-1] if len(self._buf) > 0 else 0

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...

@typing.overload
def add(self, data: Dict[str, np.ndarray | MemmapArray], validate_args: bool = False) -> None:
...

def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray | MemmapArray], validate_args: bool = False) -> None:
"""_summary_

Args:
data (&quot;ReplayBuffer&quot; | Dict[str, np.ndarray | MemmapArray]]): data to add.
"""
if isinstance(data, ReplayBuffer):
data = data.buffer
if validate_args:
if data is None:
raise ValueError("The `data` replay buffer must be not None")
if not isinstance(data, dict):
raise ValueError(
f"`data` must be a dictionary containing Numpy arrays, but `data` is of type `{type(data)}`"
)
elif isinstance(data, dict):
for k, v in data.items():
if not isinstance(v, np.ndarray):
raise ValueError(
f"`data` must be a dictionary containing Numpy arrays. Found key `{k}` "
f"containing a value of type `{type(v)}`"
)
last_key = next(iter(data.keys()))
last_batch_shape = next(iter(data.values())).shape[:2]
for i, (k, v) in enumerate(data.items()):
if len(v.shape) < 2:
raise RuntimeError(
"`data` must have at least 2: [sequence_length, n_envs, ...]. " f"Shape of `{k}` is {v.shape}"
)
if i > 0:
current_key = k
current_batch_shape = v.shape[:2]
if current_batch_shape != last_batch_shape:
raise RuntimeError(
"Every array in `data` must be congruent in the first 2 dimensions: "
f"found key `{last_key}` with shape `{last_batch_shape}` "
f"and `{current_key}` with shape `{current_batch_shape}`"
)
last_key = current_key
last_batch_shape = current_batch_shape

if "dones" not in data:
raise RuntimeError(f"The episode must contain the `dones` key, got: {data.keys()}")

# For each environment
for env in range(self._n_envs):
# Take the data from a single environment
env_data = {k: v[:, env] for k, v in data.items()}
done = env_data["dones"]
# Take episode ends
episode_ends = done.nonzero()[0].tolist()
# If there is not any done, then add the data to the respective open episode
if len(episode_ends) == 0:
self._open_episodes[env].append(env_data)
else:
# In case there is at leas one done, then split the environment data into episodes
episode_ends.append(len(done))
start = 0
# For each episode in the received data
for ep_end_idx in episode_ends:
stop = ep_end_idx
# Take the episode from the data
episode = {k: env_data[k][start : stop + 1] for k in self._obs_keys}
# If the episode length is greater than zero, then add it to the open episode
# of the corresponding environment.
if len(episode["dones"]) > 0:
self._open_episodes[env].append(episode)
start = stop + 1
# If the open episode is not empty and the last element is a done, then save the episode
# in the buffer and clear the open episode
if len(self._open_episodes[env]) > 0 and self._open_episodes[env][-1]["dones"][-1] == 1:
self._save_episode(self._open_episodes[env])
self._open_episodes[env] = []

def _save_episode(self, episode_chunks: Sequence[Dict[str, np.ndarray | MemmapArray]]) -> None:
if len(episode_chunks) == 0:
raise RuntimeError("Invalid episode, an empty sequence is given. You must pass a non-empty sequence.")
# Concatenate all the chunks of the episode
episode = {k: [] for k in self._obs_keys}
for chunk in episode_chunks:
for k in self._obs_keys:
episode[k].append(chunk[k])
episode = {k: np.concatenate(episode[k], axis=0) for k in self._obs_keys}

# Control the validity of the episode
ep_len = episode["dones"].shape[0]
if len(episode["dones"].nonzero()[0]) != 1 or episode["dones"][-1] != 1:
raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(episode['dones']))}")
if ep_len < self._sequence_length:
raise RuntimeError(f"Episode too short (at least {self._sequence_length} steps), got: {ep_len} steps")
if ep_len > self._buffer_size:
raise RuntimeError(f"Episode too long (at most {self._buffer_size} steps), got: {ep_len} steps")

# If the buffer is full, then remove the oldest episodes
if self.full or len(self) + ep_len > self._buffer_size:
# Compute the index of the last episode to remove
cum_lengths = np.array(self._cum_lengths)
mask = (len(self) - cum_lengths + ep_len) <= self._buffer_size
last_to_remove = mask.argmax()
# Remove all memmaped episodes
if self._memmap and self._memmap_dir is not None:
for _ in range(last_to_remove + 1):
try:
shutil.rmtree(os.path.dirname(self._buf[0][self._obs_keys[0]].filename))
except Exception as e:
logging.error(e)
del self._buf[0]
else:
self._buf = self._buf[last_to_remove + 1 :]
# Update the cum_lengths lists
cum_lengths = cum_lengths[last_to_remove + 1 :] - cum_lengths[last_to_remove]
self._cum_lengths = cum_lengths.tolist()
self._cum_lengths.append(len(self) + ep_len)
episode_to_store = episode
if self._memmap:
episode_dir = self._memmap_dir / f"episode_{str(uuid.uuid4())}"
episode_dir.mkdir(parents=True, exist_ok=True)
episode_to_store = {}
for k, v in episode.items():
path = Path(episode_dir / f"{k}.memmap")
filename = str(path)
episode_to_store[k] = MemmapArray(
filename=str(filename),
dtype=v.dtype,
shape=v.shape,
mode=self._memmap_mode,
)
episode_to_store[k][:] = episode[k]
self._buf.append(episode_to_store)

def sample(
self,
batch_size: int,
n_samples: int = 1,
clone: bool = False,
) -> Dict[str, np.ndarray]:
"""Sample trajectories from the replay buffer.

Args:
batch_size (int): Number of element in the batch.
n_samples (bool): The number of samples to be retrieved.
Defaults to 1.
clone (bool): Whether to clone the samples.
Default to False.

Returns:
TensorDictBase: the sampled TensorDictBase with a `batch_size` of [batch_size, 1]
"""
if batch_size <= 0:
raise ValueError(f"Batch size must be greater than 0, got: {batch_size}")
if n_samples <= 0:
raise ValueError(f"The number of samples must be greater than 0, got: {n_samples}")
if len(self) == 0:
raise RuntimeError(
"No sample has been added to the buffer. Please add at least one sample calling `self.add()`"
)

nsample_per_eps = np.bincount(np.random.randint(0, len(self._buf), (batch_size * n_samples,))).astype(np.intp)
samples = {k: [] for k in self._obs_keys}
for i, n in enumerate(nsample_per_eps):
ep_len = self._buf[i]["dones"].shape[0]
# Define the maximum index that can be sampled in the episodes
upper = ep_len - self._sequence_length + 1
# If you want to prioritize ends, then all the indices of the episode
# can be sampled as starting index
if self._prioritize_ends:
upper += self._sequence_length
# Sample the starting indices and upper bound with `ep_len - self._sequence_length`
start_idxes = np.minimum(
np.random.randint(0, upper, size=(n,)).reshape(-1, 1), ep_len - self._sequence_length, dtype=np.intp
)
# Compute the indices of the sequences
indices = start_idxes + self._chunk_length
# Retrieve the data
for k in self._obs_keys:
samples[k].append(self._buf[i][k][indices])
# Concatenate all the trajectories on the batch dimension and properly reshape them
samples = {
k: np.moveaxis(
np.concatenate(samples[k], axis=0).reshape(
n_samples, batch_size, self._sequence_length, *samples[k][0].shape[2:]
),
2,
1,
)
for k in self._obs_keys
if len(samples[k]) > 0
}
if clone:
return {k: v.clone() for k, v in samples.items()}
return samples
2 changes: 1 addition & 1 deletion sheeprl/utils/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def shape(self) -> None | int | Tuple[int, ...]:
def has_ownership(self) -> bool:
return self._has_ownership

@has_ownership
@has_ownership.setter
def has_ownership(self, value: bool):
self._has_ownership = value

Expand Down
Loading