Skip to content

Commit

Permalink
Feature/episode buffer np (#121)
Browse files Browse the repository at this point in the history
* feat: added episode buffer numpy

* fix: memmap episode buffer numpy

* fix: checkpoint when memmap=True EpisodeBufferNumpy

* fix: memmap episode buffer np

* tests: added tests for episode buffer np

* feat: update episode buffer, added MemmapArray
  • Loading branch information
michele-milesi authored Oct 12, 2023
1 parent 7b81238 commit b134b4f
Show file tree
Hide file tree
Showing 3 changed files with 608 additions and 1 deletion.
322 changes: 322 additions & 0 deletions sheeprl/data/buffers_np.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import logging
import os
import shutil
import typing
import uuid
from pathlib import Path
from typing import Dict, Optional, Sequence, Type

Expand Down Expand Up @@ -687,3 +690,322 @@ def sample_tensors(
)
samples[k] = torch_v
return samples


class EpisodeBuffer:
"""A replay buffer that stores separately the episodes.
Args:
buffer_size (int): The capacity of the buffer.
sequence_length (int): The length of the sequences of the samples
(an episode cannot be shorter than the episode length).
n_envs (int): The number of environments.
Default to 1.
obs_keys (Sequence[str]): The observations keys to store in the buffer.
Default to ("observations",).
prioritize_ends (bool): Whether to prioritize the ends of the episodes when sampling.
Default to False.
memmap (bool): Whether to memory-mapping the buffer.
Default to False.
memmap_dir (str | os.PathLike, optional): The directory for the memmap.
Default to None.
memmap_mode (str, optional): memory-map mode.
Possible values are: "r+", "w+", "c", "copyonwrite", "readwrite", "write".
Defaults to "r+".
"""

def __init__(
self,
buffer_size: int,
sequence_length: int,
n_envs: int = 1,
obs_keys: Sequence[str] = ("observations",),
prioritize_ends: bool = False,
memmap: bool = False,
memmap_dir: str | os.PathLike | None = None,
memmap_mode: str = "r+",
) -> None:
if buffer_size <= 0:
raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}")
if sequence_length <= 0:
raise ValueError(f"The sequence length must be greater than zero, got: {sequence_length}")
if buffer_size < sequence_length:
raise ValueError(
"The sequence length must be lower than the buffer size, "
f"got: bs = {buffer_size} and sl = {sequence_length}"
)
self._n_envs = n_envs
self._obs_keys = obs_keys
self._buffer_size = buffer_size
self._sequence_length = sequence_length
self._prioritize_ends = prioritize_ends

# One list for each environment that contains open episodes:
# one open episode per environment
self._open_episodes = [[] for _ in range(n_envs)]
# Contain the cumulative length of the episodes in the buffer
self._cum_lengths: Sequence[int] = []
# List of stored episodes
self._buf: Sequence[Dict[str, np.ndarray | MemmapArray]] = []

self._memmap = memmap
self._memmap_dir = memmap_dir
self._memmap_mode = memmap_mode
if self._memmap:
if self._memmap_mode not in ("r+", "w+", "c", "copyonwrite", "readwrite", "write"):
raise ValueError(
'Accepted values for memmap_mode are "r+", "readwrite", "w+", "write", "c" or '
'"copyonwrite". PyTorch does not support tensors backed by read-only '
'NumPy arrays, so "r" and "readonly" are not supported.'
)
if self._memmap_dir is None:
raise ValueError(
"The buffer is set to be memory-mapped but the `memmap_dir` attribute is None. "
"Set the `memmap_dir` to a known directory.",
)
else:
self._memmap_dir = Path(self._memmap_dir)
self._memmap_dir.mkdir(parents=True, exist_ok=True)
self._chunk_length = np.arange(sequence_length, dtype=np.intp).reshape(1, -1)

@property
def prioritize_ends(self) -> bool:
return self._prioritize_ends

@prioritize_ends.setter
def prioritize_ends(self, prioritize_ends: bool) -> None:
self._prioritize_ends = prioritize_ends

@property
def buffer(self) -> Optional[Dict[str, np.ndarray | MemmapArray]]:
if len(self._buf) > 0:
return {k: np.concatenate([v[k] for v in self._buf]) for k in self._obs_keys}
else:
return {}

@property
def obs_keys(self) -> Sequence[str]:
return self._obs_keys

@property
def n_envs(self) -> int:
return self._n_envs

@property
def buffer_size(self) -> int:
return self._buffer_size

@property
def sequence_length(self) -> int:
return self._sequence_length

@property
def is_memmap(self) -> bool:
return self._memmap

@property
def full(self) -> bool:
return self._cum_lengths[-1] + self._sequence_length > self._buffer_size if len(self._buf) > 0 else False

def __len__(self) -> int:
return self._cum_lengths[-1] if len(self._buf) > 0 else 0

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...

@typing.overload
def add(self, data: Dict[str, np.ndarray | MemmapArray], validate_args: bool = False) -> None:
...

def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray | MemmapArray], validate_args: bool = False) -> None:
"""_summary_
Args:
data (&quot;ReplayBuffer&quot; | Dict[str, np.ndarray | MemmapArray]]): data to add.
"""
if isinstance(data, ReplayBuffer):
data = data.buffer
if validate_args:
if data is None:
raise ValueError("The `data` replay buffer must be not None")
if not isinstance(data, dict):
raise ValueError(
f"`data` must be a dictionary containing Numpy arrays, but `data` is of type `{type(data)}`"
)
elif isinstance(data, dict):
for k, v in data.items():
if not isinstance(v, np.ndarray):
raise ValueError(
f"`data` must be a dictionary containing Numpy arrays. Found key `{k}` "
f"containing a value of type `{type(v)}`"
)
last_key = next(iter(data.keys()))
last_batch_shape = next(iter(data.values())).shape[:2]
for i, (k, v) in enumerate(data.items()):
if len(v.shape) < 2:
raise RuntimeError(
"`data` must have at least 2: [sequence_length, n_envs, ...]. " f"Shape of `{k}` is {v.shape}"
)
if i > 0:
current_key = k
current_batch_shape = v.shape[:2]
if current_batch_shape != last_batch_shape:
raise RuntimeError(
"Every array in `data` must be congruent in the first 2 dimensions: "
f"found key `{last_key}` with shape `{last_batch_shape}` "
f"and `{current_key}` with shape `{current_batch_shape}`"
)
last_key = current_key
last_batch_shape = current_batch_shape

if "dones" not in data:
raise RuntimeError(f"The episode must contain the `dones` key, got: {data.keys()}")

# For each environment
for env in range(self._n_envs):
# Take the data from a single environment
env_data = {k: v[:, env] for k, v in data.items()}
done = env_data["dones"]
# Take episode ends
episode_ends = done.nonzero()[0].tolist()
# If there is not any done, then add the data to the respective open episode
if len(episode_ends) == 0:
self._open_episodes[env].append(env_data)
else:
# In case there is at leas one done, then split the environment data into episodes
episode_ends.append(len(done))
start = 0
# For each episode in the received data
for ep_end_idx in episode_ends:
stop = ep_end_idx
# Take the episode from the data
episode = {k: env_data[k][start : stop + 1] for k in self._obs_keys}
# If the episode length is greater than zero, then add it to the open episode
# of the corresponding environment.
if len(episode["dones"]) > 0:
self._open_episodes[env].append(episode)
start = stop + 1
# If the open episode is not empty and the last element is a done, then save the episode
# in the buffer and clear the open episode
if len(self._open_episodes[env]) > 0 and self._open_episodes[env][-1]["dones"][-1] == 1:
self._save_episode(self._open_episodes[env])
self._open_episodes[env] = []

def _save_episode(self, episode_chunks: Sequence[Dict[str, np.ndarray | MemmapArray]]) -> None:
if len(episode_chunks) == 0:
raise RuntimeError("Invalid episode, an empty sequence is given. You must pass a non-empty sequence.")
# Concatenate all the chunks of the episode
episode = {k: [] for k in self._obs_keys}
for chunk in episode_chunks:
for k in self._obs_keys:
episode[k].append(chunk[k])
episode = {k: np.concatenate(episode[k], axis=0) for k in self._obs_keys}

# Control the validity of the episode
ep_len = episode["dones"].shape[0]
if len(episode["dones"].nonzero()[0]) != 1 or episode["dones"][-1] != 1:
raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(episode['dones']))}")
if ep_len < self._sequence_length:
raise RuntimeError(f"Episode too short (at least {self._sequence_length} steps), got: {ep_len} steps")
if ep_len > self._buffer_size:
raise RuntimeError(f"Episode too long (at most {self._buffer_size} steps), got: {ep_len} steps")

# If the buffer is full, then remove the oldest episodes
if self.full or len(self) + ep_len > self._buffer_size:
# Compute the index of the last episode to remove
cum_lengths = np.array(self._cum_lengths)
mask = (len(self) - cum_lengths + ep_len) <= self._buffer_size
last_to_remove = mask.argmax()
# Remove all memmaped episodes
if self._memmap and self._memmap_dir is not None:
for _ in range(last_to_remove + 1):
try:
shutil.rmtree(os.path.dirname(self._buf[0][self._obs_keys[0]].filename))
except Exception as e:
logging.error(e)
del self._buf[0]
else:
self._buf = self._buf[last_to_remove + 1 :]
# Update the cum_lengths lists
cum_lengths = cum_lengths[last_to_remove + 1 :] - cum_lengths[last_to_remove]
self._cum_lengths = cum_lengths.tolist()
self._cum_lengths.append(len(self) + ep_len)
episode_to_store = episode
if self._memmap:
episode_dir = self._memmap_dir / f"episode_{str(uuid.uuid4())}"
episode_dir.mkdir(parents=True, exist_ok=True)
episode_to_store = {}
for k, v in episode.items():
path = Path(episode_dir / f"{k}.memmap")
filename = str(path)
episode_to_store[k] = MemmapArray(
filename=str(filename),
dtype=v.dtype,
shape=v.shape,
mode=self._memmap_mode,
)
episode_to_store[k][:] = episode[k]
self._buf.append(episode_to_store)

def sample(
self,
batch_size: int,
n_samples: int = 1,
clone: bool = False,
) -> Dict[str, np.ndarray]:
"""Sample trajectories from the replay buffer.
Args:
batch_size (int): Number of element in the batch.
n_samples (bool): The number of samples to be retrieved.
Defaults to 1.
clone (bool): Whether to clone the samples.
Default to False.
Returns:
TensorDictBase: the sampled TensorDictBase with a `batch_size` of [batch_size, 1]
"""
if batch_size <= 0:
raise ValueError(f"Batch size must be greater than 0, got: {batch_size}")
if n_samples <= 0:
raise ValueError(f"The number of samples must be greater than 0, got: {n_samples}")
if len(self) == 0:
raise RuntimeError(
"No sample has been added to the buffer. Please add at least one sample calling `self.add()`"
)

nsample_per_eps = np.bincount(np.random.randint(0, len(self._buf), (batch_size * n_samples,))).astype(np.intp)
samples = {k: [] for k in self._obs_keys}
for i, n in enumerate(nsample_per_eps):
ep_len = self._buf[i]["dones"].shape[0]
# Define the maximum index that can be sampled in the episodes
upper = ep_len - self._sequence_length + 1
# If you want to prioritize ends, then all the indices of the episode
# can be sampled as starting index
if self._prioritize_ends:
upper += self._sequence_length
# Sample the starting indices and upper bound with `ep_len - self._sequence_length`
start_idxes = np.minimum(
np.random.randint(0, upper, size=(n,)).reshape(-1, 1), ep_len - self._sequence_length, dtype=np.intp
)
# Compute the indices of the sequences
indices = start_idxes + self._chunk_length
# Retrieve the data
for k in self._obs_keys:
samples[k].append(self._buf[i][k][indices])
# Concatenate all the trajectories on the batch dimension and properly reshape them
samples = {
k: np.moveaxis(
np.concatenate(samples[k], axis=0).reshape(
n_samples, batch_size, self._sequence_length, *samples[k][0].shape[2:]
),
2,
1,
)
for k in self._obs_keys
if len(samples[k]) > 0
}
if clone:
return {k: v.clone() for k, v in samples.items()}
return samples
2 changes: 1 addition & 1 deletion sheeprl/utils/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def shape(self) -> None | int | Tuple[int, ...]:
def has_ownership(self) -> bool:
return self._has_ownership

@has_ownership
@has_ownership.setter
def has_ownership(self, value: bool):
self._has_ownership = value

Expand Down
Loading

0 comments on commit b134b4f

Please sign in to comment.