From b134b4fa6e307fe3427edfd133099994a80ebef1 Mon Sep 17 00:00:00 2001 From: michele-milesi <74559684+michele-milesi@users.noreply.github.com> Date: Thu, 12 Oct 2023 10:02:58 +0200 Subject: [PATCH] Feature/episode buffer np (#121) * feat: added episode buffer numpy * fix: memmap episode buffer numpy * fix: checkpoint when memmap=True EpisodeBufferNumpy * fix: memmap episode buffer np * tests: added tests for episode buffer np * feat: update episode buffer, added MemmapArray --- sheeprl/data/buffers_np.py | 322 ++++++++++++++++++++++ sheeprl/utils/memmap.py | 2 +- tests/test_data/test_episode_buffer_np.py | 285 +++++++++++++++++++ 3 files changed, 608 insertions(+), 1 deletion(-) create mode 100644 tests/test_data/test_episode_buffer_np.py diff --git a/sheeprl/data/buffers_np.py b/sheeprl/data/buffers_np.py index 35209489..b86300e5 100644 --- a/sheeprl/data/buffers_np.py +++ b/sheeprl/data/buffers_np.py @@ -1,7 +1,10 @@ from __future__ import annotations +import logging import os +import shutil import typing +import uuid from pathlib import Path from typing import Dict, Optional, Sequence, Type @@ -687,3 +690,322 @@ def sample_tensors( ) samples[k] = torch_v return samples + + +class EpisodeBuffer: + """A replay buffer that stores separately the episodes. + + Args: + buffer_size (int): The capacity of the buffer. + sequence_length (int): The length of the sequences of the samples + (an episode cannot be shorter than the episode length). + n_envs (int): The number of environments. + Default to 1. + obs_keys (Sequence[str]): The observations keys to store in the buffer. + Default to ("observations",). + prioritize_ends (bool): Whether to prioritize the ends of the episodes when sampling. + Default to False. + memmap (bool): Whether to memory-mapping the buffer. + Default to False. + memmap_dir (str | os.PathLike, optional): The directory for the memmap. + Default to None. + memmap_mode (str, optional): memory-map mode. + Possible values are: "r+", "w+", "c", "copyonwrite", "readwrite", "write". + Defaults to "r+". + """ + + def __init__( + self, + buffer_size: int, + sequence_length: int, + n_envs: int = 1, + obs_keys: Sequence[str] = ("observations",), + prioritize_ends: bool = False, + memmap: bool = False, + memmap_dir: str | os.PathLike | None = None, + memmap_mode: str = "r+", + ) -> None: + if buffer_size <= 0: + raise ValueError(f"The buffer size must be greater than zero, got: {buffer_size}") + if sequence_length <= 0: + raise ValueError(f"The sequence length must be greater than zero, got: {sequence_length}") + if buffer_size < sequence_length: + raise ValueError( + "The sequence length must be lower than the buffer size, " + f"got: bs = {buffer_size} and sl = {sequence_length}" + ) + self._n_envs = n_envs + self._obs_keys = obs_keys + self._buffer_size = buffer_size + self._sequence_length = sequence_length + self._prioritize_ends = prioritize_ends + + # One list for each environment that contains open episodes: + # one open episode per environment + self._open_episodes = [[] for _ in range(n_envs)] + # Contain the cumulative length of the episodes in the buffer + self._cum_lengths: Sequence[int] = [] + # List of stored episodes + self._buf: Sequence[Dict[str, np.ndarray | MemmapArray]] = [] + + self._memmap = memmap + self._memmap_dir = memmap_dir + self._memmap_mode = memmap_mode + if self._memmap: + if self._memmap_mode not in ("r+", "w+", "c", "copyonwrite", "readwrite", "write"): + raise ValueError( + 'Accepted values for memmap_mode are "r+", "readwrite", "w+", "write", "c" or ' + '"copyonwrite". PyTorch does not support tensors backed by read-only ' + 'NumPy arrays, so "r" and "readonly" are not supported.' + ) + if self._memmap_dir is None: + raise ValueError( + "The buffer is set to be memory-mapped but the `memmap_dir` attribute is None. " + "Set the `memmap_dir` to a known directory.", + ) + else: + self._memmap_dir = Path(self._memmap_dir) + self._memmap_dir.mkdir(parents=True, exist_ok=True) + self._chunk_length = np.arange(sequence_length, dtype=np.intp).reshape(1, -1) + + @property + def prioritize_ends(self) -> bool: + return self._prioritize_ends + + @prioritize_ends.setter + def prioritize_ends(self, prioritize_ends: bool) -> None: + self._prioritize_ends = prioritize_ends + + @property + def buffer(self) -> Optional[Dict[str, np.ndarray | MemmapArray]]: + if len(self._buf) > 0: + return {k: np.concatenate([v[k] for v in self._buf]) for k in self._obs_keys} + else: + return {} + + @property + def obs_keys(self) -> Sequence[str]: + return self._obs_keys + + @property + def n_envs(self) -> int: + return self._n_envs + + @property + def buffer_size(self) -> int: + return self._buffer_size + + @property + def sequence_length(self) -> int: + return self._sequence_length + + @property + def is_memmap(self) -> bool: + return self._memmap + + @property + def full(self) -> bool: + return self._cum_lengths[-1] + self._sequence_length > self._buffer_size if len(self._buf) > 0 else False + + def __len__(self) -> int: + return self._cum_lengths[-1] if len(self._buf) > 0 else 0 + + @typing.overload + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: + ... + + @typing.overload + def add(self, data: Dict[str, np.ndarray | MemmapArray], validate_args: bool = False) -> None: + ... + + def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray | MemmapArray], validate_args: bool = False) -> None: + """_summary_ + + Args: + data ("ReplayBuffer" | Dict[str, np.ndarray | MemmapArray]]): data to add. + """ + if isinstance(data, ReplayBuffer): + data = data.buffer + if validate_args: + if data is None: + raise ValueError("The `data` replay buffer must be not None") + if not isinstance(data, dict): + raise ValueError( + f"`data` must be a dictionary containing Numpy arrays, but `data` is of type `{type(data)}`" + ) + elif isinstance(data, dict): + for k, v in data.items(): + if not isinstance(v, np.ndarray): + raise ValueError( + f"`data` must be a dictionary containing Numpy arrays. Found key `{k}` " + f"containing a value of type `{type(v)}`" + ) + last_key = next(iter(data.keys())) + last_batch_shape = next(iter(data.values())).shape[:2] + for i, (k, v) in enumerate(data.items()): + if len(v.shape) < 2: + raise RuntimeError( + "`data` must have at least 2: [sequence_length, n_envs, ...]. " f"Shape of `{k}` is {v.shape}" + ) + if i > 0: + current_key = k + current_batch_shape = v.shape[:2] + if current_batch_shape != last_batch_shape: + raise RuntimeError( + "Every array in `data` must be congruent in the first 2 dimensions: " + f"found key `{last_key}` with shape `{last_batch_shape}` " + f"and `{current_key}` with shape `{current_batch_shape}`" + ) + last_key = current_key + last_batch_shape = current_batch_shape + + if "dones" not in data: + raise RuntimeError(f"The episode must contain the `dones` key, got: {data.keys()}") + + # For each environment + for env in range(self._n_envs): + # Take the data from a single environment + env_data = {k: v[:, env] for k, v in data.items()} + done = env_data["dones"] + # Take episode ends + episode_ends = done.nonzero()[0].tolist() + # If there is not any done, then add the data to the respective open episode + if len(episode_ends) == 0: + self._open_episodes[env].append(env_data) + else: + # In case there is at leas one done, then split the environment data into episodes + episode_ends.append(len(done)) + start = 0 + # For each episode in the received data + for ep_end_idx in episode_ends: + stop = ep_end_idx + # Take the episode from the data + episode = {k: env_data[k][start : stop + 1] for k in self._obs_keys} + # If the episode length is greater than zero, then add it to the open episode + # of the corresponding environment. + if len(episode["dones"]) > 0: + self._open_episodes[env].append(episode) + start = stop + 1 + # If the open episode is not empty and the last element is a done, then save the episode + # in the buffer and clear the open episode + if len(self._open_episodes[env]) > 0 and self._open_episodes[env][-1]["dones"][-1] == 1: + self._save_episode(self._open_episodes[env]) + self._open_episodes[env] = [] + + def _save_episode(self, episode_chunks: Sequence[Dict[str, np.ndarray | MemmapArray]]) -> None: + if len(episode_chunks) == 0: + raise RuntimeError("Invalid episode, an empty sequence is given. You must pass a non-empty sequence.") + # Concatenate all the chunks of the episode + episode = {k: [] for k in self._obs_keys} + for chunk in episode_chunks: + for k in self._obs_keys: + episode[k].append(chunk[k]) + episode = {k: np.concatenate(episode[k], axis=0) for k in self._obs_keys} + + # Control the validity of the episode + ep_len = episode["dones"].shape[0] + if len(episode["dones"].nonzero()[0]) != 1 or episode["dones"][-1] != 1: + raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(episode['dones']))}") + if ep_len < self._sequence_length: + raise RuntimeError(f"Episode too short (at least {self._sequence_length} steps), got: {ep_len} steps") + if ep_len > self._buffer_size: + raise RuntimeError(f"Episode too long (at most {self._buffer_size} steps), got: {ep_len} steps") + + # If the buffer is full, then remove the oldest episodes + if self.full or len(self) + ep_len > self._buffer_size: + # Compute the index of the last episode to remove + cum_lengths = np.array(self._cum_lengths) + mask = (len(self) - cum_lengths + ep_len) <= self._buffer_size + last_to_remove = mask.argmax() + # Remove all memmaped episodes + if self._memmap and self._memmap_dir is not None: + for _ in range(last_to_remove + 1): + try: + shutil.rmtree(os.path.dirname(self._buf[0][self._obs_keys[0]].filename)) + except Exception as e: + logging.error(e) + del self._buf[0] + else: + self._buf = self._buf[last_to_remove + 1 :] + # Update the cum_lengths lists + cum_lengths = cum_lengths[last_to_remove + 1 :] - cum_lengths[last_to_remove] + self._cum_lengths = cum_lengths.tolist() + self._cum_lengths.append(len(self) + ep_len) + episode_to_store = episode + if self._memmap: + episode_dir = self._memmap_dir / f"episode_{str(uuid.uuid4())}" + episode_dir.mkdir(parents=True, exist_ok=True) + episode_to_store = {} + for k, v in episode.items(): + path = Path(episode_dir / f"{k}.memmap") + filename = str(path) + episode_to_store[k] = MemmapArray( + filename=str(filename), + dtype=v.dtype, + shape=v.shape, + mode=self._memmap_mode, + ) + episode_to_store[k][:] = episode[k] + self._buf.append(episode_to_store) + + def sample( + self, + batch_size: int, + n_samples: int = 1, + clone: bool = False, + ) -> Dict[str, np.ndarray]: + """Sample trajectories from the replay buffer. + + Args: + batch_size (int): Number of element in the batch. + n_samples (bool): The number of samples to be retrieved. + Defaults to 1. + clone (bool): Whether to clone the samples. + Default to False. + + Returns: + TensorDictBase: the sampled TensorDictBase with a `batch_size` of [batch_size, 1] + """ + if batch_size <= 0: + raise ValueError(f"Batch size must be greater than 0, got: {batch_size}") + if n_samples <= 0: + raise ValueError(f"The number of samples must be greater than 0, got: {n_samples}") + if len(self) == 0: + raise RuntimeError( + "No sample has been added to the buffer. Please add at least one sample calling `self.add()`" + ) + + nsample_per_eps = np.bincount(np.random.randint(0, len(self._buf), (batch_size * n_samples,))).astype(np.intp) + samples = {k: [] for k in self._obs_keys} + for i, n in enumerate(nsample_per_eps): + ep_len = self._buf[i]["dones"].shape[0] + # Define the maximum index that can be sampled in the episodes + upper = ep_len - self._sequence_length + 1 + # If you want to prioritize ends, then all the indices of the episode + # can be sampled as starting index + if self._prioritize_ends: + upper += self._sequence_length + # Sample the starting indices and upper bound with `ep_len - self._sequence_length` + start_idxes = np.minimum( + np.random.randint(0, upper, size=(n,)).reshape(-1, 1), ep_len - self._sequence_length, dtype=np.intp + ) + # Compute the indices of the sequences + indices = start_idxes + self._chunk_length + # Retrieve the data + for k in self._obs_keys: + samples[k].append(self._buf[i][k][indices]) + # Concatenate all the trajectories on the batch dimension and properly reshape them + samples = { + k: np.moveaxis( + np.concatenate(samples[k], axis=0).reshape( + n_samples, batch_size, self._sequence_length, *samples[k][0].shape[2:] + ), + 2, + 1, + ) + for k in self._obs_keys + if len(samples[k]) > 0 + } + if clone: + return {k: v.clone() for k, v in samples.items()} + return samples diff --git a/sheeprl/utils/memmap.py b/sheeprl/utils/memmap.py index 79eb4290..a37bc4b4 100644 --- a/sheeprl/utils/memmap.py +++ b/sheeprl/utils/memmap.py @@ -65,7 +65,7 @@ def shape(self) -> None | int | Tuple[int, ...]: def has_ownership(self) -> bool: return self._has_ownership - @has_ownership + @has_ownership.setter def has_ownership(self, value: bool): self._has_ownership = value diff --git a/tests/test_data/test_episode_buffer_np.py b/tests/test_data/test_episode_buffer_np.py new file mode 100644 index 00000000..b470a093 --- /dev/null +++ b/tests/test_data/test_episode_buffer_np.py @@ -0,0 +1,285 @@ +import os +import shutil + +import numpy as np +import pytest +import torch +from tensordict import TensorDict + +from sheeprl.data.buffers_np import EpisodeBuffer +from sheeprl.utils.memmap import MemmapArray + + +def test_episode_buffer_wrong_buffer_size(): + with pytest.raises(ValueError, match="The buffer size must be greater than zero"): + EpisodeBuffer(-1, 10) + + +def test_episode_buffer_wrong_sequence_length(): + with pytest.raises(ValueError, match="The sequence length must be greater than zero"): + EpisodeBuffer(1, -1) + + +def test_episode_buffer_sequence_length_greater_than_batch_size(): + with pytest.raises(ValueError, match="The sequence length must be lower than the buffer size"): + EpisodeBuffer(5, 10) + + +def test_episode_buffer_add_episodes(): + buf_size = 30 + sl = 5 + n_envs = 1 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((sl, n_envs, 1))} + ep2 = {"dones": np.zeros((sl + 5, n_envs, 1))} + ep3 = {"dones": np.zeros((sl + 10, n_envs, 1))} + ep4 = {"dones": np.zeros((sl, n_envs, 1))} + ep1["dones"][-1] = 1 + ep2["dones"][-1] = 1 + ep3["dones"][-1] = 1 + ep4["dones"][-1] = 1 + rb.add(ep1) + rb.add(ep2) + rb.add(ep3) + rb.add(ep4) + assert rb.full + assert (rb._buf[-1]["dones"] == ep4["dones"][:, 0]).all() + assert (rb._buf[0]["dones"] == ep2["dones"][:, 0]).all() + + +def test_episode_buffer_add_single_dict(): + buf_size = 5 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((sl, n_envs, 1))} + ep1["dones"][-1] = 1 + rb.add(ep1) + assert rb.full + for env in range(n_envs): + assert (rb._buf[0]["dones"] == ep1["dones"][:, env]).all() + + +def test_episode_buffer_error_add(): + buf_size = 10 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + + ep1 = TensorDict({"dones": torch.zeros(sl, n_envs, 1)}, batch_size=[sl, n_envs]) + with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays, but `data` is of type"): + rb.add(ep1, validate_args=True) + + ep2 = {"dones": torch.zeros((sl, n_envs, 1))} + with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays. Found key"): + rb.add(ep2, validate_args=True) + + ep3 = None + with pytest.raises(ValueError, match="The `data` replay buffer must be not None"): + rb.add(ep3, validate_args=True) + + ep4 = {"dones": np.zeros((1,))} + with pytest.raises(RuntimeError, match=r"`data` must have at least 2: \[sequence_length, n_envs"): + rb.add(ep4, validate_args=True) + + obs_keys = ("dones", "obs") + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep5 = {"dones": np.zeros((sl, n_envs, 1)), "obs": np.zeros((sl, 1, 6))} + with pytest.raises(RuntimeError, match="Every array in `data` must be congruent in the first 2 dimensions"): + rb.add(ep5, validate_args=True) + + ep6 = {"obs": np.zeros((sl, 1, 6))} + with pytest.raises(RuntimeError, match="The episode must contain the `dones` key"): + rb.add(ep6, validate_args=True) + + +def test_save_episode(): + buf_size = 100 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] + ep_chunks[-1]["dones"][-1] = 1 + rb._save_episode(ep_chunks) + + assert len(rb._buf) == 1 + assert (rb.buffer["dones"] == np.concatenate([c["dones"] for c in ep_chunks], axis=0)).all() + + +def test_save_episode_errors(): + buf_size = 100 + sl = 5 + n_envs = 4 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + + with pytest.raises(RuntimeError, match="Invalid episode, an empty sequence is given"): + rb._save_episode([]) + + ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] + ep_chunks[-1]["dones"][-1] = 1 + ep_chunks[0]["dones"][-1] = 1 + with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): + rb._save_episode(ep_chunks) + + ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] + ep_chunks[0]["dones"][-1] = 1 + with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): + rb._save_episode(ep_chunks) + + ep_chunks = [{"dones": np.ones((1, 1))}] + with pytest.raises(RuntimeError, match="Episode too short"): + rb._save_episode(ep_chunks) + + ep_chunks = [{"dones": np.zeros((110, 1))} for _ in range(8)] + ep_chunks[-1]["dones"][-1] = 1 + with pytest.raises(RuntimeError, match="Episode too long"): + rb._save_episode(ep_chunks) + + +def test_episode_buffer_sample_one_element(): + buf_size = 5 + sl = 5 + n_envs = 1 + obs_keys = ("dones", "a") + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep = {"dones": np.zeros((sl, n_envs, 1)), "a": np.random.rand(sl, n_envs, 1)} + ep["dones"][-1] = 1 + rb.add(ep) + sample = rb.sample(1, 1) + assert rb.full + assert (sample["dones"][0, :, 0] == ep["dones"][:, 0]).all() + assert (sample["a"][0, :, 0] == ep["a"][:, 0]).all() + + +def test_episode_buffer_sample_shapes(): + buf_size = 30 + sl = 2 + n_envs = 1 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep = {"dones": np.zeros((sl, n_envs, 1))} + ep["dones"][-1] = 1 + rb.add(ep) + sample = rb.sample(3, n_samples=2) + assert sample["dones"].shape[:-1] == tuple([2, sl, 3]) + + +def test_episode_buffer_sample_more_episodes(): + buf_size = 100 + sl = 15 + n_envs = 1 + obs_keys = ("dones", "a") + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) + ep1 = {"dones": np.zeros((40, n_envs, 1)), "a": np.ones((40, n_envs, 1)) * -1} + ep2 = {"dones": np.zeros((45, n_envs, 1)), "a": np.ones((45, n_envs, 1)) * -2} + ep3 = {"dones": np.zeros((50, n_envs, 1)), "a": np.ones((50, n_envs, 1)) * -3} + ep1["dones"][-1] = 1 + ep2["dones"][-1] = 1 + ep3["dones"][-1] = 1 + rb.add(ep1) + rb.add(ep2) + rb.add(ep3) + samples = rb.sample(50, n_samples=5) + assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) + samples = {k: np.moveaxis(samples[k], 2, 1).reshape(-1, sl, 1) for k in obs_keys} + for i in range(len(samples["dones"])): + assert ( + np.isin(samples["a"][i], -1).all() + or np.isin(samples["a"][i], -2).all() + or np.isin(samples["a"][i], -3).all() + ) + assert len(samples["dones"][i].nonzero()[0]) == 0 or samples["dones"][i][-1] == 1 + + +def test_episode_buffer_error_sample(): + buf_size = 10 + sl = 5 + rb = EpisodeBuffer(buf_size, sl) + with pytest.raises(RuntimeError, match="No sample has been added"): + rb.sample(2, 2) + with pytest.raises(ValueError, match="Batch size must be greater than 0"): + rb.sample(-1, n_samples=2) + with pytest.raises(ValueError, match="The number of samples must be greater than 0"): + rb.sample(2, -1) + + +def test_episode_buffer_prioritize_ends(): + buf_size = 100 + sl = 15 + n_envs = 1 + obs_keys = ("dones",) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, prioritize_ends=True) + ep1 = {"dones": np.zeros((15, n_envs, 1))} + ep2 = {"dones": np.zeros((25, n_envs, 1))} + ep3 = {"dones": np.zeros((30, n_envs, 1))} + ep1["dones"][-1] = 1 + ep2["dones"][-1] = 1 + ep3["dones"][-1] = 1 + rb.add(ep1) + rb.add(ep2) + rb.add(ep3) + samples = rb.sample(50, n_samples=5) + assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) + assert np.isin(samples["dones"], 1).any() > 0 + + +def test_memmap_episode_buffer(): + buf_size = 10 + bs = 4 + sl = 4 + n_envs = 1 + obs_keys = ("dones", "observations") + with pytest.raises( + ValueError, + match="The buffer is set to be memory-mapped but the `memmap_dir` attribute is None", + ): + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, memmap=True) + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, memmap=True, memmap_dir="test_episode_buffer") + for _ in range(buf_size // bs): + ep = { + "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), + "dones": np.zeros((bs, n_envs, 1)), + } + ep["dones"][-1] = 1 + rb.add(ep) + assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["observations"], MemmapArray) + assert rb.is_memmap + shutil.rmtree(os.path.abspath("test_episode_buffer")) + + +def test_memmap_to_file_episode_buffer(): + buf_size = 10 + bs = 5 + sl = 4 + n_envs = 1 + obs_keys = ("dones", "observations") + memmap_dir = "test_episode_buffer" + rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, memmap=True, memmap_dir=memmap_dir) + for i in range(4): + if i >= 2: + bs = 7 + else: + bs = 5 + ep = { + "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), + "dones": np.zeros((bs, n_envs, 1)), + } + ep["dones"][-1] = 1 + rb.add(ep) + del ep + assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["observations"], MemmapArray) + memmap_dir = os.path.dirname(rb._buf[-1]["dones"].filename) + assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) + assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) + assert rb.is_memmap + for ep in rb.buffer: + del ep + del rb + shutil.rmtree(os.path.abspath("test_episode_buffer"))