diff --git a/tensordict/_td.py b/tensordict/_td.py index c2f493274..6453ba139 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1477,7 +1477,15 @@ def memmap_( self, prefix: str | None = None, copy_existing: bool = False, + num_threads: int=0, + executor=None, ) -> T: + if num_threads > 0 and executor is None: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + return self.memmap_(prefix=prefix, copy_existing=copy_existing, num_threads=num_threads, executor=executor) + def save_metadata(data: TensorDictBase, filepath, metadata=None): if metadata is None: metadata = {} @@ -1510,21 +1518,26 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): # ensure subdirectory exists os.makedirs(prefix / key, exist_ok=True) self._tensordict[key] = value.memmap_( - prefix=prefix / key, copy_existing=copy_existing + prefix=prefix / key, copy_existing=copy_existing, executor=executor, num_threads=num_threads, ) else: - self._tensordict[key] = value.memmap_() + self._tensordict[key] = value.memmap_(executor=executor, num_threads=num_threads) continue else: # user did specify location and memmap is in wrong place, so we copy - self._tensordict[key] = MemoryMappedTensor.from_tensor( - value, - filename=str(prefix / f"{key}.memmap") - if prefix is not None - else None, - copy_existing=copy_existing, - # copy_existing=copy_existing, - ) + def _populate(): + self._tensordict[key] = MemoryMappedTensor.from_tensor( + value, + filename=str(prefix / f"{key}.memmap") + if prefix is not None + else None, + copy_existing=copy_existing, + # copy_existing=copy_existing, + ) + if executor is None: + _populate() + else: + executor.submit(_populate) if prefix is not None: metadata[key] = { "device": str(value.device), @@ -1533,11 +1546,14 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): } if prefix is not None: - save_metadata( - self, - prefix / "meta.json", - metadata=metadata, - ) + if executor is None: + save_metadata( + self, + prefix / "meta.json", + metadata=metadata, + ) + else: + executor.submit(save_metadata, args=(self, prefix / "meta.json", metadata)) self._is_memmap = True self._is_shared = False # since they are mutually exclusive self._device = torch.device("cpu")