Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 9, 2023
1 parent 0edb0be commit 847de8e
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,12 +1479,19 @@ def memmap_(
copy_existing: bool = False,
num_threads: int=0,
executor=None,
futures=None,
) -> T:
if num_threads > 0 and executor is None:
if num_threads > 1:
if executor is not None:
raise TypeError("num_threads and executor are exclusive arguments.")
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(max_workers=num_threads) as executor:
return self.memmap_(prefix=prefix, copy_existing=copy_existing, num_threads=num_threads, executor=executor)
futures = []
result = self.memmap_(prefix=prefix, copy_existing=copy_existing, executor=executor, futures=futures)
for future in futures:
future.result()
return result

def save_metadata(data: TensorDictBase, filepath, metadata=None):
if metadata is None:
Expand Down Expand Up @@ -1518,10 +1525,10 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
# ensure subdirectory exists
os.makedirs(prefix / key, exist_ok=True)
self._tensordict[key] = value.memmap_(
prefix=prefix / key, copy_existing=copy_existing, executor=executor, num_threads=num_threads,
prefix=prefix / key, copy_existing=copy_existing, executor=executor, futures=futures,
)
else:
self._tensordict[key] = value.memmap_(executor=executor, num_threads=num_threads)
self._tensordict[key] = value.memmap_(executor=executor, futures=futures)
continue
else:
# user did specify location and memmap is in wrong place, so we copy
Expand All @@ -1537,7 +1544,7 @@ def _populate():
if executor is None:
_populate()
else:
executor.submit(_populate)
futures.append(executor.submit(_populate))
if prefix is not None:
metadata[key] = {
"device": str(value.device),
Expand All @@ -1553,7 +1560,7 @@ def _populate():
metadata=metadata,
)
else:
executor.submit(save_metadata, args=(self, prefix / "meta.json", metadata))
futures.append(executor.submit(save_metadata, args=(self, prefix / "meta.json", metadata)))
self._is_memmap = True
self._is_shared = False # since they are mutually exclusive
self._device = torch.device("cpu")
Expand Down

0 comments on commit 847de8e

Please sign in to comment.