diff --git a/tensordict/_td.py b/tensordict/_td.py index 6453ba139..93209fc2d 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1479,12 +1479,19 @@ def memmap_( copy_existing: bool = False, num_threads: int=0, executor=None, + futures=None, ) -> T: - if num_threads > 0 and executor is None: + if num_threads > 1: + if executor is not None: + raise TypeError("num_threads and executor are exclusive arguments.") from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=num_threads) as executor: - return self.memmap_(prefix=prefix, copy_existing=copy_existing, num_threads=num_threads, executor=executor) + futures = [] + result = self.memmap_(prefix=prefix, copy_existing=copy_existing, executor=executor, futures=futures) + for future in futures: + future.result() + return result def save_metadata(data: TensorDictBase, filepath, metadata=None): if metadata is None: @@ -1518,10 +1525,10 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): # ensure subdirectory exists os.makedirs(prefix / key, exist_ok=True) self._tensordict[key] = value.memmap_( - prefix=prefix / key, copy_existing=copy_existing, executor=executor, num_threads=num_threads, + prefix=prefix / key, copy_existing=copy_existing, executor=executor, futures=futures, ) else: - self._tensordict[key] = value.memmap_(executor=executor, num_threads=num_threads) + self._tensordict[key] = value.memmap_(executor=executor, futures=futures) continue else: # user did specify location and memmap is in wrong place, so we copy @@ -1537,7 +1544,7 @@ def _populate(): if executor is None: _populate() else: - executor.submit(_populate) + futures.append(executor.submit(_populate)) if prefix is not None: metadata[key] = { "device": str(value.device), @@ -1553,7 +1560,7 @@ def _populate(): metadata=metadata, ) else: - executor.submit(save_metadata, args=(self, prefix / "meta.json", metadata)) + futures.append(executor.submit(save_metadata, args=(self, prefix / "meta.json", metadata))) self._is_memmap = True self._is_shared = False # since they are mutually exclusive self._device = torch.device("cpu")