Skip to content

Commit

Permalink
[Feature] Better constructors for MemoryMappedTensors (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 15, 2023
1 parent f601dfa commit b862fe2
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 17 deletions.
76 changes: 71 additions & 5 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from multiprocessing import util
from multiprocessing.context import reduction
from pathlib import Path
from typing import Any
from typing import Any, overload

import numpy as np
import torch
Expand Down Expand Up @@ -247,7 +247,17 @@ def ones_like(cls, input, *, filename=None):
)

@classmethod
def ones(cls, *shape, dtype=None, device=None, filename=None):
@overload
def ones(cls, *size, dtype=None, device=None, filename=None):
...

@classmethod
@overload
def ones(cls, shape, *, dtype=None, device=None, filename=None):
...

@classmethod
def ones(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with a 1-filled content, specific shape, dtype and filename.
Expand All @@ -261,6 +271,7 @@ def ones(cls, *shape, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
Expand All @@ -278,7 +289,17 @@ def ones(cls, *shape, dtype=None, device=None, filename=None):
)

@classmethod
def zeros(cls, *shape, dtype=None, device=None, filename=None):
@overload
def zeros(cls, *size, dtype=None, device=None, filename=None):
...

@classmethod
@overload
def zeros(cls, shape, *, dtype=None, device=None, filename=None):
...

@classmethod
def zeros(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with a 0-filled content, specific shape, dtype and filename.
Expand All @@ -292,6 +313,7 @@ def zeros(cls, *shape, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
Expand All @@ -310,7 +332,17 @@ def zeros(cls, *shape, dtype=None, device=None, filename=None):
return result

@classmethod
def empty(cls, *shape, dtype=None, device=None, filename=None):
@overload
def empty(cls, *size, dtype=None, device=None, filename=None):
...

@classmethod
@overload
def empty(cls, shape, *, dtype=None, device=None, filename=None):
...

@classmethod
def empty(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with empty content, specific shape, dtype and filename.
Expand All @@ -324,6 +356,7 @@ def empty(cls, *shape, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
Expand All @@ -339,7 +372,17 @@ def empty(cls, *shape, dtype=None, device=None, filename=None):
return result

@classmethod
def full(cls, *shape, fill_value, dtype=None, device=None, filename=None):
@overload
def full(cls, *size, fill_value, dtype=None, device=None, filename=None):
...

@classmethod
@overload
def full(cls, shape, *, fill_value, dtype=None, device=None, filename=None):
...

@classmethod
def full(cls, *args, **kwargs):
# noqa: D417
"""Creates a tensor with a single content specified by `fill_value`, specific shape, dtype and filename.
Expand All @@ -354,6 +397,7 @@ def full(cls, *shape, fill_value, dtype=None, device=None, filename=None):
filename (path or equivalent): the path to the file, if any. If none
is provided, a handler is used.
"""
shape, device, dtype, fill_value, filename = _proc_args_const(*args, **kwargs)
if device is not None:
device = torch.device(device)
if device.type != "cpu":
Expand Down Expand Up @@ -546,3 +590,25 @@ def _reduce_memmap(memmap_tensor):

# For backward compatibility in imports
from tensordict.memmap_deprec import MemmapTensor # noqa: F401


def _proc_args_const(*args, **kwargs):
if len(args) > 0:
# then the first (or the N first) args are the shape
if len(args) == 1 and not isinstance(args[0], int):
shape = torch.Size(args[0])
else:
shape = torch.Size(args)
else:
# we should have a "shape" keyword arg
shape = kwargs.pop("shape", None)
if shape is None:
raise TypeError("Could not find the shape argument in the arguments.")
shape = torch.Size(shape)
return (
shape,
kwargs.pop("device", None),
kwargs.pop("dtype", None),
kwargs.pop("fill_value", None),
kwargs.pop("filename", None),
)
97 changes: 85 additions & 12 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_existing(tmp_path):
@pytest.mark.parametrize("device", [None] + get_available_devices())
@pytest.mark.parametrize("from_path", [True, False])
class TestConstructors:
def test_zeros(self, shape, dtype, device, tmp_path, from_path):
@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
def test_zeros(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if from_path:
filename = tmp_path / "file.memmap"
else:
Expand All @@ -129,17 +130,31 @@ def test_zeros(self, shape, dtype, device, tmp_path, from_path):
shape, dtype=dtype, device=device, filename=filename
)
return
t = MemoryMappedTensor.zeros(
shape, dtype=dtype, device=device, filename=filename
)
if shape_arg == "expand":
with pytest.raises(TypeError) if shape == () else nullcontext():
t = MemoryMappedTensor.zeros(
*shape, dtype=dtype, device=device, filename=filename
)
if shape == ():
return
elif shape_arg == "arg":
t = MemoryMappedTensor.zeros(
shape, dtype=dtype, device=device, filename=filename
)
elif shape_arg == "kwarg":
t = MemoryMappedTensor.zeros(
shape=shape, dtype=dtype, device=device, filename=filename
)

assert t.shape == shape
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert (t == 0).all()

def test_ones(self, shape, dtype, device, tmp_path, from_path):
@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
def test_ones(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if from_path:
filename = tmp_path / "file.memmap"
else:
Expand All @@ -150,17 +165,63 @@ def test_ones(self, shape, dtype, device, tmp_path, from_path):
shape, dtype=dtype, device=device, filename=filename
)
return
t = MemoryMappedTensor.ones(
shape, dtype=dtype, device=device, filename=filename
)
if shape_arg == "expand":
with pytest.raises(TypeError) if shape == () else nullcontext():
t = MemoryMappedTensor.ones(
*shape, dtype=dtype, device=device, filename=filename
)
if shape == ():
return
elif shape_arg == "arg":
t = MemoryMappedTensor.ones(
shape, dtype=dtype, device=device, filename=filename
)
elif shape_arg == "kwarg":
t = MemoryMappedTensor.ones(
shape=shape, dtype=dtype, device=device, filename=filename
)
assert t.shape == shape
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert (t == 1).all()

def test_full(self, shape, dtype, device, tmp_path, from_path):
@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
def test_empty(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if from_path:
filename = tmp_path / "file.memmap"
else:
filename = None
if device is not None and device.type != "cpu":
with pytest.raises(RuntimeError):
MemoryMappedTensor.empty(
shape, dtype=dtype, device=device, filename=filename
)
return
if shape_arg == "expand":
with pytest.raises(TypeError) if shape == () else nullcontext():
t = MemoryMappedTensor.empty(
*shape, dtype=dtype, device=device, filename=filename
)
if shape == ():
return
elif shape_arg == "arg":
t = MemoryMappedTensor.empty(
shape, dtype=dtype, device=device, filename=filename
)
elif shape_arg == "kwarg":
t = MemoryMappedTensor.empty(
shape=shape, dtype=dtype, device=device, filename=filename
)
assert t.shape == shape
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename

@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if from_path:
filename = tmp_path / "file.memmap"
else:
Expand All @@ -171,9 +232,21 @@ def test_full(self, shape, dtype, device, tmp_path, from_path):
shape, fill_value=2, dtype=dtype, device=device, filename=filename
)
return
t = MemoryMappedTensor.full(
shape, fill_value=2, dtype=dtype, device=device, filename=filename
)
if shape_arg == "expand":
with pytest.raises(TypeError) if shape == () else nullcontext():
t = MemoryMappedTensor.full(
*shape, fill_value=2, dtype=dtype, device=device, filename=filename
)
if shape == ():
return
elif shape_arg == "arg":
t = MemoryMappedTensor.full(
shape, fill_value=2, dtype=dtype, device=device, filename=filename
)
elif shape_arg == "kwarg":
t = MemoryMappedTensor.full(
shape=shape, fill_value=2, dtype=dtype, device=device, filename=filename
)
assert t.shape == shape
if dtype is not None:
assert t.dtype is dtype
Expand Down

0 comments on commit b862fe2

Please sign in to comment.