Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 16, 2024
1 parent 88691d0 commit d99cf3c
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,8 @@ def _memmap_(
inplace,
like,
) -> T:

if inplace:
raise RuntimeError("Cannot call memmap inplace in a persistent tensordict.")
# re-implements this to make it faster using the meta-data
def save_metadata(data: TensorDictBase, filepath, metadata=None):
if metadata is None:
Expand All @@ -591,16 +592,13 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
"memmap_like() must be called when the TensorDict is (partially) "
"populated. Set a tensor first."
)
dest = (
self
if inplace
else TensorDict(
dest = TensorDict(
{},
batch_size=self.batch_size,
names=self.names if self._has_names() else None,
device=torch.device("cpu"),
)
)
dest._is_memmap = True
for key, value in self._items_metadata():
if not value["array"]:
value = self._get_str(key)
Expand Down Expand Up @@ -658,8 +656,6 @@ def _populate(
futures.append(
executor.submit(save_metadata, dest, prefix / "meta.json", metadata)
)
dest._is_memmap = True
dest.lock_()
return dest

_load_memmap = TensorDict._load_memmap
Expand Down

0 comments on commit d99cf3c

Please sign in to comment.