diff --git a/tensordict/_td.py b/tensordict/_td.py index bff0e5f81..2905994f4 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2188,12 +2188,28 @@ def _populate( dest=dest, value=value, key=key, copy_existing=copy_existing ): filename = None if prefix is None else str(prefix / f"{key}.memmap") + if value.is_nested: + shape = value._nested_tensor_size() + # Make the shape a memmap tensor too + if prefix is not None: + shape_filename = Path(filename) + shape_filename = shape_filename.with_suffix(".shape.memmap") + MemoryMappedTensor.from_tensor( + shape, + filename=shape_filename, + copy_existing=copy_existing, + existsok=True, + copy_data=not like, + ) + else: + shape = None dest._tensordict[key] = MemoryMappedTensor.from_tensor( value.data if value.requires_grad else value, filename=filename, copy_existing=copy_existing, existsok=True, copy_data=not like, + shape=shape, ) if executor is None: @@ -2203,8 +2219,11 @@ def _populate( if prefix is not None: metadata[key] = { "device": str(value.device), - "shape": list(value.shape), + "shape": list(value.shape) + if not value.is_nested + else value._nested_tensor_size().shape, "dtype": str(value.dtype), + "is_nested": value.is_nested, } if prefix is not None: @@ -2258,16 +2277,25 @@ def _load_memmap( if ( device is None or device != torch.device("meta") ) and not torch._guards.active_fake_mode(): + if entry_metadata.get("is_nested", False): + # The shape is the shape of the shape, get the shape from it + shape = MemoryMappedTensor.from_filename( + (prefix / f"{key}.memmap").with_suffix(".shape.memmap"), + shape=shape, + dtype=torch.long, + ) + else: + shape = torch.Size(shape) tensor = MemoryMappedTensor.from_filename( dtype=_STRDTYPE2DTYPE[dtype], - shape=torch.Size(entry_metadata["shape"]), + shape=shape, filename=str(prefix / f"{key}.memmap"), ) if device is not None: tensor = tensor.to(device, non_blocking=True) else: tensor = torch.zeros( - torch.Size(entry_metadata["shape"]), + torch.Size(shape), device=device, dtype=_STRDTYPE2DTYPE[dtype], ) diff --git a/tensordict/base.py b/tensordict/base.py index bae07b5c6..abd7719b3 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6143,8 +6143,10 @@ def to_tensordict(self) -> T: { key: value.clone() if not _is_tensor_collection(value.__class__) + else value + if is_non_tensor(value) else value.to_tensordict() - for key, value in self.items() + for key, value in self.items(is_leaf=_is_leaf_nontensor) }, device=self.device, batch_size=self.batch_size, diff --git a/tensordict/memmap.py b/tensordict/memmap.py index 1e8f4e070..67311951b 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -20,10 +20,16 @@ import numpy as np import torch -from tensordict.utils import implement_for +from tensordict.utils import _shape, implement_for from torch.multiprocessing.reductions import ForkingPickler +NESTED_TENSOR_ERR = ( + "The PyTorch version isn't compatible with memmap " + "nested tensors. Please upgrade to a more recent " + "version." +) + class MemoryMappedTensor(torch.Tensor): """A Memory-mapped Tensor. @@ -60,31 +66,47 @@ class MemoryMappedTensor(torch.Tensor): ... memmap_tensor = MemoryMappedTensor.ones_like(tensor, filename=file.name) """ - _filename: str | Path - _handler: _FileHandler + _filename: str | Path = None + _handler: _FileHandler = None _clear: bool index: Any parent_shape: torch.Size def __new__( cls, - tensor_or_file, + source, *, dtype=None, shape=None, index=None, device=None, handler=None, + filename=None, ): if device is not None and torch.device(device).type != "cpu": raise ValueError(f"{cls} device must be cpu!") - if isinstance(tensor_or_file, str): + if isinstance(source, str): + if filename is not None: + raise TypeError("Duplicated filename argument.") + filename = source + source = None + if filename is not None: return cls.from_filename( - tensor_or_file, + filename, dtype, shape, index, ) + elif isinstance(source, torch.StorageBase): + return cls.from_storage( + source, + dtype=dtype, + shape=shape, + index=index, + device=device, + handler=handler, + filename=filename, + ) elif handler is not None: return cls.from_handler( handler, @@ -92,11 +114,9 @@ def __new__( shape, index, ) - return super().__new__(cls, tensor_or_file) + return super().__new__(cls, source) - def __init__( - self, tensor_or_file, handler=None, dtype=None, shape=None, device=None - ): + def __init__(self, source, handler=None, dtype=None, shape=None, device=None): ... __torch_function__ = torch._C._disabled_torch_function_impl @@ -110,6 +130,7 @@ def from_tensor( existsok=False, copy_existing=False, copy_data=True, + shape=None, ): """Creates a MemoryMappedTensor with the same content as another tensor. @@ -127,12 +148,14 @@ def from_tensor( an existing file. Defaults to ``False``. copy_existing (bool, optional): if ``True`` and the provided input is a MemoryMappedTensor with an associated filename, copying - the content to the new location is permitted. Otherwise an - exception is thown. This behaviour exists to prevent - unadvertedly duplicating data on disk. + the content to the new location is permitted. Otherwise, an + exception is thrown. This behaviour exists to prevent + inadvertently duplicating data on disk. copy_data (bool, optional): if ``True``, the content of the tensor will be copied on the storage. Defaults to ``True``. - + shape (torch.Size or torch.Tensor): a shape to override the tensor + shape. If a tensor is passed, it must represent the nested shapes of a + nested tensor. """ if isinstance(input, MemoryMappedTensor): if (filename is None and input._filename is None) or ( @@ -161,41 +184,126 @@ def from_tensor( raise RuntimeError( "MemoryMappedTensor.from_tensor is incompatible with tensor.requires_grad." ) - shape = input.shape + if shape is None: + shape = _shape(input, nested_shape=True) + if isinstance(shape, torch.Tensor): + shape_numel = shape.prod(-1).sum() + elif isinstance(shape, torch.Size): + shape_numel = shape.numel() + else: + shape_numel = torch.Size(shape).numel() if filename is None: if input.dtype.is_floating_point: - size = torch.finfo(input.dtype).bits // 8 * shape.numel() + size = torch.finfo(input.dtype).bits // 8 * shape_numel elif input.dtype.is_complex: raise ValueError( "Complex-valued tensors are not supported by MemoryMappedTensor." ) elif input.dtype == torch.bool: - size = shape.numel() + size = shape_numel else: # assume integer - size = torch.iinfo(input.dtype).bits // 8 * shape.numel() + size = torch.iinfo(input.dtype).bits // 8 * shape_numel handler = _FileHandler(size) - out = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype) - out = out.view(shape) - out = cls(out) + if isinstance(shape, torch.Tensor): + func_offset_stride = getattr( + torch, "_nested_compute_contiguous_strides_offsets", None + ) + if func_offset_stride is not None: + offsets_strides = func_offset_stride(shape) + else: + raise RuntimeError(NESTED_TENSOR_ERR) + result = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype) + if copy_data: + result.untyped_storage().copy_(input.untyped_storage()) + result = torch._nested_view_from_buffer( + result, + shape, + *offsets_strides, + ) + else: + result = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype) + result = result.view(shape) + result = cls(result) else: handler = None if not existsok and os.path.exists(str(filename)): raise RuntimeError(f"The file {filename} already exists.") - out = cls( - torch.from_file( - str(filename), shared=True, dtype=input.dtype, size=shape.numel() - ).view(input.shape) + result = torch.from_file( + str(filename), shared=True, dtype=input.dtype, size=shape_numel ) - out._handler = handler - out._filename = filename - out.index = None - out.parent_shape = input.shape + if isinstance(shape, torch.Tensor): + func_offset_stride = getattr( + torch, "_nested_compute_contiguous_strides_offsets", None + ) + if func_offset_stride is not None: + offsets_strides = func_offset_stride(shape) + else: + raise RuntimeError(NESTED_TENSOR_ERR) + if copy_data: + result.untyped_storage().copy_(input.untyped_storage()) + result = torch._nested_view_from_buffer( + result, + shape, + *offsets_strides, + ) + else: + result = result.view(shape) + result = cls(result) + result._handler = handler + result._filename = filename + result.index = None + result.parent_shape = shape if copy_data: if hasattr(input, "full_tensor"): + # for DTensors, cheaper than importing DTensor every time input = input.full_tensor() - out.copy_(input) - return out + if not result.is_nested: + result.copy_(input) + return result + + @classmethod + def from_storage( + cls, + storage, + *, + shape=None, + dtype=None, + device=None, + index=None, + filename=None, + handler=None, + ): + tensor = torch.tensor(storage, dtype=dtype, device=device) + if shape is not None: + if isinstance(shape, torch.Tensor): + func_offset_stride = getattr( + torch, "_nested_compute_contiguous_strides_offsets", None + ) + if func_offset_stride is not None: + offsets_strides = func_offset_stride(shape) + else: + raise RuntimeError( + "The PyTorch version isn't compatible with memmap " + "nested tensors. Please upgrade to a more recent " + "version." + ) + tensor = torch._nested_view_from_buffer( + tensor, + shape, + *offsets_strides, + ) + else: + tensor = tensor.view(shape) + + tensor = cls(tensor) + if filename is not None: + tensor._filename = filename + elif handler is not None: + tensor._handler = handler + if index is not None: + return tensor[index] + return tensor @property def filename(self): @@ -312,6 +420,10 @@ def ones(cls, *args, **kwargs): if device.type != "cpu": raise RuntimeError("Only CPU tensors are supported.") result = torch.ones((), dtype=dtype, device=device) + if isinstance(shape, torch.Tensor): + return cls.empty( + shape, device=device, dtype=dtype, filename=filename + ).fill_(1) if shape: if isinstance(shape[0], (list, tuple)) and len(shape) == 1: shape = torch.Size(shape[0]) @@ -353,6 +465,10 @@ def zeros(cls, *args, **kwargs): device = torch.device(device) if device.type != "cpu": raise RuntimeError("Only CPU tensors are supported.") + if isinstance(shape, torch.Tensor): + return cls.empty( + shape, device=device, dtype=dtype, filename=filename + ).fill_(0) result = torch.zeros((), dtype=dtype, device=device) if shape: if isinstance(shape[0], (list, tuple)) and len(shape) == 1: @@ -397,6 +513,94 @@ def empty(cls, *args, **kwargs): if device.type != "cpu": raise RuntimeError("Only CPU tensors are supported.") result = torch.zeros((), dtype=dtype, device=device) + if isinstance(shape, torch.Tensor): + # nested tensor + shape_numel = shape.prod(-1).sum() + + if filename is None: + if dtype.is_floating_point: + size = torch.finfo(dtype).bits // 8 * shape_numel + elif dtype.is_complex: + raise ValueError( + "Complex-valued tensors are not supported by MemoryMappedTensor." + ) + elif dtype == torch.bool: + size = shape_numel + else: + # assume integer + size = torch.iinfo(dtype).bits // 8 * shape_numel + handler = _FileHandler(size) + + # buffer + func_offset_stride = getattr( + torch, "_nested_compute_contiguous_strides_offsets", None + ) + if func_offset_stride is not None: + offsets_strides = func_offset_stride(shape) + else: + raise RuntimeError(NESTED_TENSOR_ERR) + result = torch.frombuffer(memoryview(handler.buffer), dtype=dtype) + result = torch._nested_view_from_buffer( + result, + shape, + *offsets_strides, + ) + result = cls(result) + result._handler = handler + return result + else: + result = torch.from_file( + str(filename), shared=True, dtype=dtype, size=shape_numel + ) + func_offset_stride = getattr( + torch, "_nested_compute_contiguous_strides_offsets", None + ) + if func_offset_stride is not None: + offsets_strides = func_offset_stride(shape) + else: + raise RuntimeError(NESTED_TENSOR_ERR) + result = torch._nested_view_from_buffer( + result, + shape, + *offsets_strides, + ) + result = cls(result) + result._filename = filename + return result + return result + + if shape: + if isinstance(shape[0], (list, tuple)) and len(shape) == 1: + shape = torch.Size(shape[0]) + else: + shape = torch.Size(shape) + result = result.expand(shape) + result = cls.from_tensor(result, filename=filename) + return result + + @classmethod + def empty_nested(cls, *args, **kwargs): + # noqa: D417 + """Creates a tensor with empty content, specific shape, dtype and filename. + + Args: + shape (nested_shape): the shapes of the tensors. + + Keyword Args: + dtype (torch.dtype): the dtype of the tensor. + device (torch.device): the device of the tensor. Only `None` and `"cpu"` + are accepted, any other device will raise an exception. + filename (path or equivalent): the path to the file, if any. If none + is provided, a handler is used. + """ + shape = kwargs.pop("shape", args[0]) + args = (torch.Size([]), *args) + _, device, dtype, _, filename = _proc_args_const(*args, **kwargs) + if device is not None: + device = torch.device(device) + if device.type != "cpu": + raise RuntimeError("Only CPU tensors are supported.") + result = torch.zeros((), dtype=dtype, device=device) if shape: if isinstance(shape[0], (list, tuple)) and len(shape) == 1: shape = torch.Size(shape[0]) @@ -454,15 +658,39 @@ def from_filename(cls, filename, dtype, shape, index=None): Args: filename (path or equivalent): the path to the file. dtype (torch.dtype): the dtype of the tensor. - shape (integers or torch.Size): the shape of the tensor. + shape (torch.Size or torch.Tensor): the shape of the tensor. If + a tensor is provided, it is assumed that the tensor is a nested_tensor + instance. index (torch-compatible index type): an index to use to build the tensor. """ - shape = torch.Size(shape) - tensor = torch.from_file( - str(filename), shared=True, dtype=dtype, size=shape.numel() - ).view(shape) + if isinstance(shape, torch.Tensor): + func_offset_stride = getattr( + torch, "_nested_compute_contiguous_strides_offsets", None + ) + if func_offset_stride is not None: + offsets_strides = func_offset_stride(shape) + else: + raise RuntimeError( + "The PyTorch version isn't compatible with memmap " + "nested tensors. Please upgrade to a more recent " + "version." + ) + tensor = torch.from_file( + str(filename), shared=True, dtype=dtype, size=shape.prod(-1).sum().int() + ) + tensor = torch._nested_view_from_buffer( + tensor, + shape, + *offsets_strides, + ) + else: + shape = torch.Size(shape) + tensor = torch.from_file( + str(filename), shared=True, dtype=dtype, size=shape.numel() + ).view(shape) + if index is not None: tensor = tensor[index] out = cls(tensor) @@ -473,21 +701,42 @@ def from_filename(cls, filename, dtype, shape, index=None): return out @classmethod - def from_handler(cls, handler, dtype, shape, index): + def from_handler(cls, handler, dtype, shape, index=None): # noqa: D417 """Loads a MemoryMappedTensor from a given handler. Args: handler (compatible file handler): the handler for the tensor. dtype (torch.dtype): the dtype of the tensor. - shape (integers or torch.Size): the shape of the tensor. - index (torch-compatible index type): an index to use to build the + shape (torch.Size or torch.Tensor): the shape of the tensor. If + a tensor is provided, it is assumed that the tensor is a nested_tensor + instance. + index (torch-compatible index type, optional): an index to use to build the tensor. """ - shape = torch.Size(shape) out = torch.frombuffer(memoryview(handler.buffer), dtype=dtype) - out = torch.reshape(out, shape) + if isinstance(shape, torch.Tensor): + func_offset_stride = getattr( + torch, "_nested_compute_contiguous_strides_offsets", None + ) + if func_offset_stride is not None: + offsets_strides = func_offset_stride(shape) + else: + raise RuntimeError( + "The PyTorch version isn't compatible with memmap " + "nested tensors. Please upgrade to a more recent " + "version." + ) + out = torch._nested_view_from_buffer( + out, + shape, + *offsets_strides, + ) + else: + shape = torch.Size(shape) + out = torch.reshape(out, shape) + if index is not None: out = out[index] out = cls(out) @@ -583,10 +832,10 @@ def _index_wrap(self, tensor, item, check=False): return self._index_wrap(tensor, item) return tensor tensor = MemoryMappedTensor(tensor) - tensor._handler = self._handler - tensor._filename = self._filename + tensor._handler = getattr(self, "_handler", None) + tensor._filename = getattr(self, "_filename", None) tensor.index = item - tensor.parent_shape = self.parent_shape + tensor.parent_shape = getattr(self, "parent_shape", None) return tensor def unbind(self, dim): @@ -696,7 +945,9 @@ def _reduce_memmap(memmap_tensor): def _proc_args_const(*args, **kwargs): if len(args) > 0: # then the first (or the N first) args are the shape - if len(args) == 1 and not isinstance(args[0], int): + if len(args) == 1 and isinstance(args[0], torch.Tensor): + shape = args[0] + elif len(args) == 1 and not isinstance(args[0], int): shape = torch.Size(args[0]) else: shape = torch.Size(args) @@ -705,7 +956,8 @@ def _proc_args_const(*args, **kwargs): shape = kwargs.pop("shape", None) if shape is None: raise TypeError("Could not find the shape argument in the arguments.") - shape = torch.Size(shape) + if not isinstance(shape, torch.Tensor): + shape = torch.Size(shape) return ( shape, kwargs.pop("device", None), diff --git a/tensordict/utils.py b/tensordict/utils.py index 11e065b4a..cdef54925 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -586,7 +586,7 @@ def _ndimension(tensor: Tensor) -> int: return tensor.ndimension() -def _shape(tensor: Tensor) -> torch.Size: +def _shape(tensor: Tensor, nested_shape=False) -> torch.Size: if isinstance(tensor, UninitializedTensorMixin): return torch.Size([*getattr(tensor, "batch_size", ()), -1]) elif not isinstance(tensor, Tensor): @@ -594,6 +594,8 @@ def _shape(tensor: Tensor) -> torch.Size: return torch.Size([len(tensor.lengths()) // len(tensor.keys())]) return tensor.shape if tensor.is_nested: + if nested_shape: + return tensor._nested_tensor_size() shape = [] for i in range(tensor.ndim): try: @@ -1553,7 +1555,7 @@ def _expand_to_match_shape( def _set_max_batch_size(source: T, batch_dims=None): - """Updates a tensordict with its maximium batch size.""" + """Updates a tensordict with its maximum batch size.""" from tensordict.base import _is_tensor_collection tensor_data = [val for val in source.values() if not is_non_tensor(val)] @@ -1571,17 +1573,19 @@ def _set_max_batch_size(source: T, batch_dims=None): return source curr_dim = 0 + tensor_shapes = [_shape(_tensor_data) for _tensor_data in tensor_data] + while True: - if tensor_data[0].dim() > curr_dim: - curr_dim_size = tensor_data[0].size(curr_dim) + if len(tensor_shapes[0]) > curr_dim: + curr_dim_size = tensor_shapes[0][curr_dim] else: source.batch_size = batch_size return - for leaf in tensor_data[1:]: + for leaf, shape in zip(tensor_data[1:], tensor_shapes[1:]): # if we have a nested empty tensordict we can modify its batch size at will if _is_tensor_collection(type(leaf)) and leaf.is_empty(): continue - if (leaf.dim() <= curr_dim) or (leaf.size(curr_dim) != curr_dim_size): + if (len(shape) <= curr_dim) or (shape[curr_dim] != curr_dim_size): source.batch_size = batch_size return if batch_dims is None or len(batch_size) < batch_dims: diff --git a/test/test_memmap.py b/test/test_memmap.py index 4c9ce6555..ae62895b8 100644 --- a/test/test_memmap.py +++ b/test/test_memmap.py @@ -3,11 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import gc from contextlib import nullcontext +from pathlib import Path import pytest import torch from _utils_internal import get_available_devices +from tensordict import TensorDict from tensordict.memmap import MemoryMappedTensor from torch import multiprocessing as mp @@ -579,6 +582,131 @@ def test_ne(self): assert (memmap != ~memmap).all() +class TestNestedTensor: + shape = torch.tensor([[2, 3], [2, 4], [3, 2]]) + + def test_with_filename(self, tmpdir): + filename = tmpdir + "/test_file2.memmap" + tensor = MemoryMappedTensor.empty( + self.shape, filename=filename, dtype=torch.int + ) + assert isinstance(tensor, MemoryMappedTensor) + assert tensor.dtype == torch.int + tensor.fill_(2) + assert (tensor[0] == 2).all() + assert tensor.filename is not None + + filename = tmpdir + "/test_file0.memmap" + tensor = MemoryMappedTensor.zeros( + self.shape, filename=filename, dtype=torch.bool + ) + assert isinstance(tensor, MemoryMappedTensor) + assert tensor.dtype == torch.bool + assert tensor.filename is not None + + filename = tmpdir + "/test_file1.memmap" + tensor = MemoryMappedTensor.ones(self.shape, filename=filename, dtype=torch.int) + assert type(tensor) is MemoryMappedTensor + assert tensor.dtype == torch.int + assert (tensor[0] == 1).all() + assert tensor.filename is not None + + filename = tmpdir + "/test_file3.memmap" + tensor = torch.nested.nested_tensor( + [torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)] + ) + memmap_tensor = MemoryMappedTensor.from_tensor(tensor, filename=filename) + assert type(memmap_tensor) is MemoryMappedTensor + for t1, t2 in zip(tensor, memmap_tensor): + assert t1.dtype == t2.dtype + assert (t1 == t2).all() + + memmap_tensor2 = MemoryMappedTensor.from_filename( + filename, dtype=memmap_tensor.dtype, shape=self.shape + ) + assert type(memmap_tensor2) is MemoryMappedTensor + for t1, t2 in zip(memmap_tensor2, memmap_tensor): + assert t1.dtype == t2.dtype + assert (t1 == t2).all() + + def test_with_handler(self): + tensor = MemoryMappedTensor.empty(self.shape, dtype=torch.int) + assert isinstance(tensor, MemoryMappedTensor) + assert tensor.dtype == torch.int + tensor.fill_(2) + assert (tensor[0] == 2).all() + assert tensor._handler is not None + + tensor = MemoryMappedTensor.zeros(self.shape, dtype=torch.bool) + assert isinstance(tensor, MemoryMappedTensor) + assert tensor.dtype == torch.bool + assert tensor._handler is not None + + tensor = MemoryMappedTensor.ones(self.shape, dtype=torch.int) + assert type(tensor) is MemoryMappedTensor + assert tensor.dtype == torch.int + assert (tensor[0] == 1).all() + assert tensor._handler is not None + + tensor = torch.nested.nested_tensor( + [torch.zeros(shape.tolist()) + i for i, shape in enumerate(self.shape)] + ) + memmap_tensor = MemoryMappedTensor.from_tensor(tensor) + assert type(memmap_tensor) is MemoryMappedTensor + for t1, t2 in zip(tensor, memmap_tensor): + assert t1.dtype == t2.dtype + assert (t1 == t2).all() + + memmap_tensor2 = MemoryMappedTensor.from_handler( + memmap_tensor._handler, dtype=memmap_tensor.dtype, shape=self.shape + ) + assert type(memmap_tensor2) is MemoryMappedTensor + for t1, t2 in zip(memmap_tensor2, memmap_tensor): + assert t1.dtype == t2.dtype + assert (t1 == t2).all() + + @pytest.mark.parametrize("with_filename", [False, True]) + def test_from_storage(self, with_filename, tmpdir): + if with_filename: + filename = Path(tmpdir) / "file.memmap" + filename = str(filename) + else: + filename = None + a = MemoryMappedTensor.from_tensor( + torch.arange(10, dtype=torch.float64), filename=filename + ) + assert type(a) is MemoryMappedTensor + shape = torch.tensor([[2, 2], [2, 3]]) + b = MemoryMappedTensor.from_storage( + a.untyped_storage(), filename=filename, shape=shape, dtype=a.dtype + ) + assert type(b) is MemoryMappedTensor + assert (b._nested_tensor_size() == shape).all() + assert (b[0] == torch.arange(4).view(2, 2)).all() + assert (b[1] == torch.arange(4, 10).view(2, 3)).all() + + def test_save_td_with_nested(self, tmpdir): + td = TensorDict( + { + "a": torch.nested.nested_tensor( + [ + torch.arange(12, dtype=torch.float64).view(3, 4), + torch.arange(15, dtype=torch.float64).view(3, 5), + ] + ) + }, + batch_size=[2, 3], + ) + tdsave = td.clone() + td.memmap(tmpdir) + del td + gc.collect() + td = TensorDict.load(tmpdir) + for i in range(2): + for j in range(3): + assert (td[i, j] == tdsave[i, j]).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)