From bb71fa8ef1de643f8ab8ace07384b7ec4280370a Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 9 Dec 2023 06:47:34 +0000 Subject: [PATCH] amend --- tensordict/_td.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 93209fc2d..40d56f3c4 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1477,7 +1477,7 @@ def memmap_( self, prefix: str | None = None, copy_existing: bool = False, - num_threads: int=0, + num_threads: int = 0, executor=None, futures=None, ) -> T: @@ -1488,7 +1488,12 @@ def memmap_( with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] - result = self.memmap_(prefix=prefix, copy_existing=copy_existing, executor=executor, futures=futures) + result = self.memmap_( + prefix=prefix, + copy_existing=copy_existing, + executor=executor, + futures=futures, + ) for future in futures: future.result() return result @@ -1525,22 +1530,28 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): # ensure subdirectory exists os.makedirs(prefix / key, exist_ok=True) self._tensordict[key] = value.memmap_( - prefix=prefix / key, copy_existing=copy_existing, executor=executor, futures=futures, + prefix=prefix / key, + copy_existing=copy_existing, + executor=executor, + futures=futures, ) else: - self._tensordict[key] = value.memmap_(executor=executor, futures=futures) + self._tensordict[key] = value.memmap_( + executor=executor, futures=futures + ) continue else: # user did specify location and memmap is in wrong place, so we copy - def _populate(): + def _populate(value=value, key=key, copy_existing=copy_existing): + filename = None if prefix is None else str(prefix / f"{key}.memmap") self._tensordict[key] = MemoryMappedTensor.from_tensor( value, - filename=str(prefix / f"{key}.memmap") - if prefix is not None - else None, + filename=filename, copy_existing=copy_existing, + existsok=True, # copy_existing=copy_existing, ) + if executor is None: _populate() else: @@ -1560,7 +1571,9 @@ def _populate(): metadata=metadata, ) else: - futures.append(executor.submit(save_metadata, args=(self, prefix / "meta.json", metadata))) + futures.append( + executor.submit(save_metadata, self, prefix / "meta.json", metadata) + ) self._is_memmap = True self._is_shared = False # since they are mutually exclusive self._device = torch.device("cpu")