Skip to content

Commit

Permalink
[Feature] Memory-mapped nested tensors (#618)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 8, 2024
1 parent 1b84e0c commit 04e52a1
Show file tree
Hide file tree
Showing 5 changed files with 470 additions and 56 deletions.
34 changes: 31 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2188,12 +2188,28 @@ def _populate(
dest=dest, value=value, key=key, copy_existing=copy_existing
):
filename = None if prefix is None else str(prefix / f"{key}.memmap")
if value.is_nested:
shape = value._nested_tensor_size()
# Make the shape a memmap tensor too
if prefix is not None:
shape_filename = Path(filename)
shape_filename = shape_filename.with_suffix(".shape.memmap")
MemoryMappedTensor.from_tensor(
shape,
filename=shape_filename,
copy_existing=copy_existing,
existsok=True,
copy_data=not like,
)
else:
shape = None
dest._tensordict[key] = MemoryMappedTensor.from_tensor(
value.data if value.requires_grad else value,
filename=filename,
copy_existing=copy_existing,
existsok=True,
copy_data=not like,
shape=shape,
)

if executor is None:
Expand All @@ -2203,8 +2219,11 @@ def _populate(
if prefix is not None:
metadata[key] = {
"device": str(value.device),
"shape": list(value.shape),
"shape": list(value.shape)
if not value.is_nested
else value._nested_tensor_size().shape,
"dtype": str(value.dtype),
"is_nested": value.is_nested,
}

if prefix is not None:
Expand Down Expand Up @@ -2258,16 +2277,25 @@ def _load_memmap(
if (
device is None or device != torch.device("meta")
) and not torch._guards.active_fake_mode():
if entry_metadata.get("is_nested", False):
# The shape is the shape of the shape, get the shape from it
shape = MemoryMappedTensor.from_filename(
(prefix / f"{key}.memmap").with_suffix(".shape.memmap"),
shape=shape,
dtype=torch.long,
)
else:
shape = torch.Size(shape)
tensor = MemoryMappedTensor.from_filename(
dtype=_STRDTYPE2DTYPE[dtype],
shape=torch.Size(entry_metadata["shape"]),
shape=shape,
filename=str(prefix / f"{key}.memmap"),
)
if device is not None:
tensor = tensor.to(device, non_blocking=True)
else:
tensor = torch.zeros(
torch.Size(entry_metadata["shape"]),
torch.Size(shape),
device=device,
dtype=_STRDTYPE2DTYPE[dtype],
)
Expand Down
4 changes: 3 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6143,8 +6143,10 @@ def to_tensordict(self) -> T:
{
key: value.clone()
if not _is_tensor_collection(value.__class__)
else value
if is_non_tensor(value)
else value.to_tensordict()
for key, value in self.items()
for key, value in self.items(is_leaf=_is_leaf_nontensor)
},
device=self.device,
batch_size=self.batch_size,
Expand Down
Loading

0 comments on commit 04e52a1

Please sign in to comment.