Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 12, 2023
1 parent d79f4ee commit 11f7f32
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4723,7 +4723,7 @@ def memmap_(
"copy_existing=True"
)
else:
self._tensordict[key] = MemoryMappedTensor.from_tensor(
self._tensordict[key] = MemmapTensor.from_tensor(
value,
filename=str(prefix / f"{key}.memmap")
if prefix is not None
Expand Down Expand Up @@ -4776,7 +4776,7 @@ def load_memmap(cls, prefix: str) -> T:
metadata = torch.load(path)
out.set(
key,
MemoryMappedTensor(
MemmapTensor(
*metadata["shape"],
device=metadata["device"],
dtype=metadata["dtype"],
Expand Down Expand Up @@ -5185,10 +5185,10 @@ def assert_allclose_td(

default_msg = f"key {key} does not match, got mse = {mse:4.4f}"
msg = "\t".join([default_msg, msg]) if len(msg) else default_msg
if isinstance(input1, MemmapTensor):
input1 = input1._tensor
if isinstance(input2, MemmapTensor):
input2 = input2._tensor
# if isinstance(input1, MemmapTensor):
# input1 = input1._tensor
# if isinstance(input2, MemmapTensor):
# input2 = input2._tensor
torch.testing.assert_close(
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg
)
Expand Down

0 comments on commit 11f7f32

Please sign in to comment.