Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 9, 2023
1 parent 6fa85b3 commit 13b9ca3
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4682,23 +4682,31 @@ def load_memmap(cls, prefix: str) -> T:
key = key[:-1] # drop "meta.pt" from key
metadata = torch.load(path)
if key in out.keys(include_nested=True):
out[key].batch_size = metadata["batch_size"]
out.get(key).batch_size = metadata["batch_size"]
device = metadata["device"]
if device is not None:
out[key] = out[key].to(device)
else:
out[key] = cls(
{}, batch_size=metadata["batch_size"], device=metadata["device"]
out.set(
key,
cls(
{},
batch_size=metadata["batch_size"],
device=metadata["device"],
),
)
else:
leaf, *_ = key[-1].rsplit(".", 2) # remove .meta.pt suffix
key = (*key[:-1], leaf)
metadata = torch.load(path)
out[key] = MemmapTensor(
*metadata["shape"],
device=metadata["device"],
dtype=metadata["dtype"],
filename=str(path.parent / f"{leaf}.memmap"),
out.set(
key,
MemmapTensor(
*metadata["shape"],
device=metadata["device"],
dtype=metadata["dtype"],
filename=str(path.parent / f"{leaf}.memmap"),
),
)

return out
Expand Down

0 comments on commit 13b9ca3

Please sign in to comment.