Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 9, 2023
1 parent d66fcdb commit bb71fa8
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ def memmap_(
self,
prefix: str | None = None,
copy_existing: bool = False,
num_threads: int=0,
num_threads: int = 0,
executor=None,
futures=None,
) -> T:
Expand All @@ -1488,7 +1488,12 @@ def memmap_(

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
result = self.memmap_(prefix=prefix, copy_existing=copy_existing, executor=executor, futures=futures)
result = self.memmap_(
prefix=prefix,
copy_existing=copy_existing,
executor=executor,
futures=futures,
)
for future in futures:
future.result()
return result
Expand Down Expand Up @@ -1525,22 +1530,28 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
# ensure subdirectory exists
os.makedirs(prefix / key, exist_ok=True)
self._tensordict[key] = value.memmap_(
prefix=prefix / key, copy_existing=copy_existing, executor=executor, futures=futures,
prefix=prefix / key,
copy_existing=copy_existing,
executor=executor,
futures=futures,
)
else:
self._tensordict[key] = value.memmap_(executor=executor, futures=futures)
self._tensordict[key] = value.memmap_(
executor=executor, futures=futures
)
continue
else:
# user did specify location and memmap is in wrong place, so we copy
def _populate():
def _populate(value=value, key=key, copy_existing=copy_existing):
filename = None if prefix is None else str(prefix / f"{key}.memmap")
self._tensordict[key] = MemoryMappedTensor.from_tensor(
value,
filename=str(prefix / f"{key}.memmap")
if prefix is not None
else None,
filename=filename,
copy_existing=copy_existing,
existsok=True,
# copy_existing=copy_existing,
)

if executor is None:
_populate()
else:
Expand All @@ -1560,7 +1571,9 @@ def _populate():
metadata=metadata,
)
else:
futures.append(executor.submit(save_metadata, args=(self, prefix / "meta.json", metadata)))
futures.append(
executor.submit(save_metadata, self, prefix / "meta.json", metadata)
)
self._is_memmap = True
self._is_shared = False # since they are mutually exclusive
self._device = torch.device("cpu")
Expand Down

0 comments on commit bb71fa8

Please sign in to comment.