From b862fe2edd9a22c72732b71290b0a1cea0fab5ae Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 Nov 2023 09:40:52 +0000 Subject: [PATCH] [Feature] Better constructors for MemoryMappedTensors (#557) --- tensordict/memmap.py | 76 +++++++++++++++++++++++++++++++--- test/test_memmap.py | 97 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 156 insertions(+), 17 deletions(-) diff --git a/tensordict/memmap.py b/tensordict/memmap.py index 3b33c45ed..b9f23e52b 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -13,7 +13,7 @@ from multiprocessing import util from multiprocessing.context import reduction from pathlib import Path -from typing import Any +from typing import Any, overload import numpy as np import torch @@ -247,7 +247,17 @@ def ones_like(cls, input, *, filename=None): ) @classmethod - def ones(cls, *shape, dtype=None, device=None, filename=None): + @overload + def ones(cls, *size, dtype=None, device=None, filename=None): + ... + + @classmethod + @overload + def ones(cls, shape, *, dtype=None, device=None, filename=None): + ... + + @classmethod + def ones(cls, *args, **kwargs): # noqa: D417 """Creates a tensor with a 1-filled content, specific shape, dtype and filename. @@ -261,6 +271,7 @@ def ones(cls, *shape, dtype=None, device=None, filename=None): filename (path or equivalent): the path to the file, if any. If none is provided, a handler is used. """ + shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs) if device is not None: device = torch.device(device) if device.type != "cpu": @@ -278,7 +289,17 @@ def ones(cls, *shape, dtype=None, device=None, filename=None): ) @classmethod - def zeros(cls, *shape, dtype=None, device=None, filename=None): + @overload + def zeros(cls, *size, dtype=None, device=None, filename=None): + ... + + @classmethod + @overload + def zeros(cls, shape, *, dtype=None, device=None, filename=None): + ... + + @classmethod + def zeros(cls, *args, **kwargs): # noqa: D417 """Creates a tensor with a 0-filled content, specific shape, dtype and filename. @@ -292,6 +313,7 @@ def zeros(cls, *shape, dtype=None, device=None, filename=None): filename (path or equivalent): the path to the file, if any. If none is provided, a handler is used. """ + shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs) if device is not None: device = torch.device(device) if device.type != "cpu": @@ -310,7 +332,17 @@ def zeros(cls, *shape, dtype=None, device=None, filename=None): return result @classmethod - def empty(cls, *shape, dtype=None, device=None, filename=None): + @overload + def empty(cls, *size, dtype=None, device=None, filename=None): + ... + + @classmethod + @overload + def empty(cls, shape, *, dtype=None, device=None, filename=None): + ... + + @classmethod + def empty(cls, *args, **kwargs): # noqa: D417 """Creates a tensor with empty content, specific shape, dtype and filename. @@ -324,6 +356,7 @@ def empty(cls, *shape, dtype=None, device=None, filename=None): filename (path or equivalent): the path to the file, if any. If none is provided, a handler is used. """ + shape, device, dtype, _, filename = _proc_args_const(*args, **kwargs) if device is not None: device = torch.device(device) if device.type != "cpu": @@ -339,7 +372,17 @@ def empty(cls, *shape, dtype=None, device=None, filename=None): return result @classmethod - def full(cls, *shape, fill_value, dtype=None, device=None, filename=None): + @overload + def full(cls, *size, fill_value, dtype=None, device=None, filename=None): + ... + + @classmethod + @overload + def full(cls, shape, *, fill_value, dtype=None, device=None, filename=None): + ... + + @classmethod + def full(cls, *args, **kwargs): # noqa: D417 """Creates a tensor with a single content specified by `fill_value`, specific shape, dtype and filename. @@ -354,6 +397,7 @@ def full(cls, *shape, fill_value, dtype=None, device=None, filename=None): filename (path or equivalent): the path to the file, if any. If none is provided, a handler is used. """ + shape, device, dtype, fill_value, filename = _proc_args_const(*args, **kwargs) if device is not None: device = torch.device(device) if device.type != "cpu": @@ -546,3 +590,25 @@ def _reduce_memmap(memmap_tensor): # For backward compatibility in imports from tensordict.memmap_deprec import MemmapTensor # noqa: F401 + + +def _proc_args_const(*args, **kwargs): + if len(args) > 0: + # then the first (or the N first) args are the shape + if len(args) == 1 and not isinstance(args[0], int): + shape = torch.Size(args[0]) + else: + shape = torch.Size(args) + else: + # we should have a "shape" keyword arg + shape = kwargs.pop("shape", None) + if shape is None: + raise TypeError("Could not find the shape argument in the arguments.") + shape = torch.Size(shape) + return ( + shape, + kwargs.pop("device", None), + kwargs.pop("dtype", None), + kwargs.pop("fill_value", None), + kwargs.pop("filename", None), + ) diff --git a/test/test_memmap.py b/test/test_memmap.py index 4cef90d71..841921b0e 100644 --- a/test/test_memmap.py +++ b/test/test_memmap.py @@ -118,7 +118,8 @@ def test_existing(tmp_path): @pytest.mark.parametrize("device", [None] + get_available_devices()) @pytest.mark.parametrize("from_path", [True, False]) class TestConstructors: - def test_zeros(self, shape, dtype, device, tmp_path, from_path): + @pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"]) + def test_zeros(self, shape, dtype, device, tmp_path, from_path, shape_arg): if from_path: filename = tmp_path / "file.memmap" else: @@ -129,9 +130,22 @@ def test_zeros(self, shape, dtype, device, tmp_path, from_path): shape, dtype=dtype, device=device, filename=filename ) return - t = MemoryMappedTensor.zeros( - shape, dtype=dtype, device=device, filename=filename - ) + if shape_arg == "expand": + with pytest.raises(TypeError) if shape == () else nullcontext(): + t = MemoryMappedTensor.zeros( + *shape, dtype=dtype, device=device, filename=filename + ) + if shape == (): + return + elif shape_arg == "arg": + t = MemoryMappedTensor.zeros( + shape, dtype=dtype, device=device, filename=filename + ) + elif shape_arg == "kwarg": + t = MemoryMappedTensor.zeros( + shape=shape, dtype=dtype, device=device, filename=filename + ) + assert t.shape == shape if dtype is not None: assert t.dtype is dtype @@ -139,7 +153,8 @@ def test_zeros(self, shape, dtype, device, tmp_path, from_path): assert t.filename == filename assert (t == 0).all() - def test_ones(self, shape, dtype, device, tmp_path, from_path): + @pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"]) + def test_ones(self, shape, dtype, device, tmp_path, from_path, shape_arg): if from_path: filename = tmp_path / "file.memmap" else: @@ -150,9 +165,21 @@ def test_ones(self, shape, dtype, device, tmp_path, from_path): shape, dtype=dtype, device=device, filename=filename ) return - t = MemoryMappedTensor.ones( - shape, dtype=dtype, device=device, filename=filename - ) + if shape_arg == "expand": + with pytest.raises(TypeError) if shape == () else nullcontext(): + t = MemoryMappedTensor.ones( + *shape, dtype=dtype, device=device, filename=filename + ) + if shape == (): + return + elif shape_arg == "arg": + t = MemoryMappedTensor.ones( + shape, dtype=dtype, device=device, filename=filename + ) + elif shape_arg == "kwarg": + t = MemoryMappedTensor.ones( + shape=shape, dtype=dtype, device=device, filename=filename + ) assert t.shape == shape if dtype is not None: assert t.dtype is dtype @@ -160,7 +187,41 @@ def test_ones(self, shape, dtype, device, tmp_path, from_path): assert t.filename == filename assert (t == 1).all() - def test_full(self, shape, dtype, device, tmp_path, from_path): + @pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"]) + def test_empty(self, shape, dtype, device, tmp_path, from_path, shape_arg): + if from_path: + filename = tmp_path / "file.memmap" + else: + filename = None + if device is not None and device.type != "cpu": + with pytest.raises(RuntimeError): + MemoryMappedTensor.empty( + shape, dtype=dtype, device=device, filename=filename + ) + return + if shape_arg == "expand": + with pytest.raises(TypeError) if shape == () else nullcontext(): + t = MemoryMappedTensor.empty( + *shape, dtype=dtype, device=device, filename=filename + ) + if shape == (): + return + elif shape_arg == "arg": + t = MemoryMappedTensor.empty( + shape, dtype=dtype, device=device, filename=filename + ) + elif shape_arg == "kwarg": + t = MemoryMappedTensor.empty( + shape=shape, dtype=dtype, device=device, filename=filename + ) + assert t.shape == shape + if dtype is not None: + assert t.dtype is dtype + if filename is not None: + assert t.filename == filename + + @pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"]) + def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg): if from_path: filename = tmp_path / "file.memmap" else: @@ -171,9 +232,21 @@ def test_full(self, shape, dtype, device, tmp_path, from_path): shape, fill_value=2, dtype=dtype, device=device, filename=filename ) return - t = MemoryMappedTensor.full( - shape, fill_value=2, dtype=dtype, device=device, filename=filename - ) + if shape_arg == "expand": + with pytest.raises(TypeError) if shape == () else nullcontext(): + t = MemoryMappedTensor.full( + *shape, fill_value=2, dtype=dtype, device=device, filename=filename + ) + if shape == (): + return + elif shape_arg == "arg": + t = MemoryMappedTensor.full( + shape, fill_value=2, dtype=dtype, device=device, filename=filename + ) + elif shape_arg == "kwarg": + t = MemoryMappedTensor.full( + shape=shape, fill_value=2, dtype=dtype, device=device, filename=filename + ) assert t.shape == shape if dtype is not None: assert t.dtype is dtype