diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 73c316981..46b636b0e 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -3041,6 +3041,62 @@ def _transpose(self, dim0, dim1): ) return result + def _repeat(self, *repeats: int) -> TensorDictBase: + repeats = list(repeats) + r_dim = repeats.pop(self.stack_dim) + tds = [td.repeat(*repeats) for td in self.tensordicts] + tds = [td for _ in range(r_dim) for td in tds] + return type(self)( + *tds, + stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, + hook_in=self.hook_in, + hook_out=self.hook_out, + ) + + def repeat_interleave( + self, repeats: torch.Tensor | int, dim: int = None, *, output_size: int = None + ) -> TensorDictBase: + if self.ndim == 0: + return self.unsqueeze(0).repeat_interleave( + repeats=repeats, dim=dim, output_size=output_size + ) + if dim is None: + if self.ndim > 1: + return self.reshape(-1).repeat_interleave(repeats, dim=0) + return self.repeat_interleave(repeats, dim=0) + dim_corrected = dim if dim >= 0 else self.ndim + dim + if not (dim_corrected >= 0): + raise ValueError( + f"dim {dim} is out of range for tensordict with shape {self.shape}." + ) + if dim_corrected == self.stack_dim: + new_list_of_tds = [t for t in self.tensordicts for _ in range(repeats)] + result = type(self)( + *new_list_of_tds, + stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, + hook_out=self.hook_out, + hook_in=self.hook_in, + ) + else: + dim_corrected = ( + dim_corrected if dim_corrected < self.stack_dim else dim_corrected - 1 + ) + result = type(self)( + *( + td.repeat_interleave( + repeats=repeats, dim=dim_corrected, output_size=output_size + ) + for td in self.tensordicts + ), + stack_dim=self.stack_dim, + stack_dim_name=self._td_dim_name, + hook_in=self.hook_in, + hook_out=self.hook_out, + ) + return result + def _permute( self, *args, @@ -3815,23 +3871,24 @@ def _cast_reduction( _check_is_shared = TensorDict._check_is_shared _convert_to_tensordict = TensorDict._convert_to_tensordict _index_tensordict = TensorDict._index_tensordict - masked_select = TensorDict.masked_select - reshape = TensorDict.reshape - split = TensorDict.split - _to_module = TensorDict._to_module _apply_nest = TensorDict._apply_nest + _get_names_idx = TensorDict._get_names_idx + _maybe_remove_batch_dim = TensorDict._maybe_remove_batch_dim _multithread_apply_flat = TensorDict._multithread_apply_flat _multithread_rebuild = TensorDict._multithread_rebuild - _remove_batch_dim = TensorDict._remove_batch_dim - _maybe_remove_batch_dim = TensorDict._maybe_remove_batch_dim + _to_module = TensorDict._to_module + _unbind = TensorDict._unbind all = TensorDict.all any = TensorDict.any expand = TensorDict.expand - _unbind = TensorDict._unbind - _get_names_idx = TensorDict._get_names_idx from_dict_instance = TensorDict.from_dict_instance + masked_select = TensorDict.masked_select + _repeat = TensorDict._repeat + repeat_interleave = TensorDict.repeat_interleave + reshape = TensorDict.reshape + split = TensorDict.split class _UnsqueezedTensorDict(_CustomOpTensorDict): diff --git a/tensordict/_td.py b/tensordict/_td.py index c6b03e1ae..5672fc070 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1790,6 +1790,46 @@ def _reshape(tensor): propagate_lock=True, ) + def repeat_interleave( + self, repeats: torch.Tensor | int, dim: int = None, *, output_size: int = None + ) -> T: + if self.ndim == 0: + return self.unsqueeze(0).repeat_interleave( + repeats=repeats, dim=dim, output_size=output_size + ) + if dim is None: + if self.ndim > 1: + return self.reshape(-1).repeat_interleave(repeats, dim=0) + return self.repeat_interleave(repeats, dim=0) + dim_corrected = dim if dim >= 0 else self.ndim + dim + if not (dim_corrected >= 0): + raise ValueError( + f"dim {dim} is out of range for tensordict with shape {self.shape}." + ) + new_batch_size = torch.Size( + [ + s if i != dim_corrected else s * repeats + for i, s in enumerate(self.batch_size) + ] + ) + return self._fast_apply( + lambda leaf: leaf.repeat_interleave( + repeats=repeats, dim=dim_corrected, output_size=output_size + ), + batch_size=new_batch_size, + call_on_nested=True, + propagate_lock=True, + ) + + def _repeat(self, *repeats: int) -> TensorDictBase: + new_batch_size = torch.Size([i * r for i, r in zip(self.batch_size, repeats)]) + return self._fast_apply( + lambda leaf: leaf.repeat(*repeats, *((1,) * (leaf.ndim - self.ndim))), + batch_size=new_batch_size, + call_on_nested=True, + propagate_lock=True, + ) + def _transpose(self, dim0, dim1): def _transpose(tensor): return tensor.transpose(dim0, dim1) @@ -4208,14 +4248,16 @@ def _cast_reduction( __or__ = TensorDict.__or__ _check_device = TensorDict._check_device _check_is_shared = TensorDict._check_is_shared + _to_module = TensorDict._to_module + _unbind = TensorDict._unbind all = TensorDict.all any = TensorDict.any masked_select = TensorDict.masked_select memmap_like = TensorDict.memmap_like + repeat_interleave = TensorDict.repeat_interleave + _repeat = TensorDict._repeat reshape = TensorDict.reshape split = TensorDict.split - _to_module = TensorDict._to_module - _unbind = TensorDict._unbind def _view(self, *args, **kwargs): raise RuntimeError( diff --git a/tensordict/base.py b/tensordict/base.py index 91b94e714..8b86cb250 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2574,6 +2574,123 @@ def reshape( """ ... + @abc.abstractmethod + def repeat_interleave( + self, repeats: torch.Tensor | int, dim: int = None, *, output_size: int = None + ) -> TensorDictBase: + """Repeat elements of a TensorDict. + + .. warning:: This is different from :meth:`~torch.Tensor.repeat` but similar to :func:`numpy.repeat`. + + Args: + repeats (torch.Tensor or int): The number of repetitions for each element. `repeats` is broadcast to fit + the shape of the given axis. + dim (int, optional): The dimension along which to repeat values. By default, use the flattened input + array, and return a flat output array. + + Keyword Args: + output_size (int, optional): Total output size for the given axis (e.g. sum of repeats). If given, it + will avoid stream synchronization needed to calculate output shape of the tensordict. + + Returns: + Repeated TensorDict which has the same shape as input, except along the given axis. + + Examples: + >>> import torch + >>> + >>> from tensordict import TensorDict + >>> + >>> td = TensorDict( + ... { + ... "a": torch.randn(3, 4, 5), + ... "b": TensorDict({ + ... "c": torch.randn(3, 4, 10, 1), + ... "a string": "a string!", + ... }, batch_size=[3, 4, 10]) + ... }, batch_size=[3, 4], + ... ) + >>> print(td.repeat_interleave(2, dim=0)) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + a string: NonTensorData(data=a string!, batch_size=torch.Size([6, 4, 10]), device=None), + c: Tensor(shape=torch.Size([6, 4, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([6, 4, 10]), + device=None, + is_shared=False)}, + batch_size=torch.Size([6, 4]), + device=None, + is_shared=False) + + """ + ... + + @overload + def repeat(self, repeats: torch.Size): ... + + def repeat(self, *repeats: int) -> TensorDictBase: + """Repeats this tensor along the specified dimensions. + + Unlike :meth:`~.expand()`, this function copies the tensor’s data. + + .. warning:: :meth:`~.repeat` behaves differently from :func:`~numpy.repeat`, but is more similar to + :func:`numpy.tile`. For the operator similar to :func:`numpy.repeat`, see :meth:`~tensordict.TensorDictBase.repeat_interleave`. + + Args: + repeat (torch.Size, int..., tuple of int or list of int): The number of times to repeat this tensor along + each dimension. + + Examples: + >>> import torch + >>> + >>> from tensordict import TensorDict + >>> + >>> td = TensorDict( + ... { + ... "a": torch.randn(3, 4, 5), + ... "b": TensorDict({ + ... "c": torch.randn(3, 4, 10, 1), + ... "a string": "a string!", + ... }, batch_size=[3, 4, 10]) + ... }, batch_size=[3, 4], + ... ) + >>> print(td.repeat(1, 2)) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 8, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + a string: NonTensorData(data=a string!, batch_size=torch.Size([3, 8, 10]), device=None), + c: Tensor(shape=torch.Size([3, 8, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 8, 10]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 8]), + device=None, + is_shared=False) + + """ + if len(repeats) == 1 and not isinstance(repeats[0], int): + repeats = repeats[0] + if isinstance(repeats, torch.Size): + return self.repeat(*repeats[0]) + if isinstance(repeats, torch.Tensor): + # This will cause cuda to sync, which may not be desirable + return self.repeat(*repeats.tolist()) + raise ValueError( + f"repeats must be a sequence of integers, a tensor or a torch.Size object. Got {type(repeats)} instead." + ) + if len(repeats) != self.ndimension(): + raise ValueError( + f"The number of repeat elements must match the number of dimensions of the tensordict. Got {len(repeats)} but ndim={self.ndimension()}." + ) + return self._repeat(*repeats) + + @abc.abstractmethod + def _repeat(self, *repeats: int) -> TensorDictBase: ... + def cat_tensors( self, *keys: NestedKey, diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index f305fcddd..125e0e9c5 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -1095,6 +1095,12 @@ def memmap_like( @_fallback def reshape(self, *shape: int): ... + @_fallback + def repeat_interleave(self, *shape: int): ... + + @_fallback + def _repeat(self, *repeats: int): ... + @_fallback def split( self, split_size: int | list[int], dim: int = 0 diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 332023587..70d06aa8c 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -1412,25 +1412,27 @@ def _unsqueeze(self, dim): __le__ = TensorDict.__le__ __lt__ = TensorDict.__lt__ - _cast_reduction = TensorDict._cast_reduction _apply_nest = TensorDict._apply_nest - _multithread_apply_flat = TensorDict._multithread_apply_flat - _multithread_rebuild = TensorDict._multithread_rebuild - + _cast_reduction = TensorDict._cast_reduction _check_device = TensorDict._check_device _check_is_shared = TensorDict._check_is_shared _convert_to_tensordict = TensorDict._convert_to_tensordict + _get_names_idx = TensorDict._get_names_idx _index_tensordict = TensorDict._index_tensordict + _multithread_apply_flat = TensorDict._multithread_apply_flat + _multithread_rebuild = TensorDict._multithread_rebuild + _to_module = TensorDict._to_module + _unbind = TensorDict._unbind all = TensorDict.all any = TensorDict.any expand = TensorDict.expand + from_dict_instance = TensorDict.from_dict_instance masked_select = TensorDict.masked_select + _repeat = TensorDict._repeat + _repeat = TensorDict._repeat + repeat_interleave = TensorDict.repeat_interleave reshape = TensorDict.reshape split = TensorDict.split - _to_module = TensorDict._to_module - _unbind = TensorDict._unbind - _get_names_idx = TensorDict._get_names_idx - from_dict_instance = TensorDict.from_dict_instance def _set_max_batch_size(source: PersistentTensorDict): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 728fc4055..0ed505279 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -177,6 +177,7 @@ def __subclasscheck__(self, subclass): "_maybe_remove_batch_dim", "_multithread_apply_flat", "_remove_batch_dim", + "_repeat", "_select", # TODO: must be specialized "_set_at_tuple", "_set_tuple", @@ -297,6 +298,8 @@ def __subclasscheck__(self, subclass): "reciprocal_", "refine_names", "rename_", # TODO: must be specialized + "repeat", + "repeat_interleave", "replace", "requires_grad_", "reshape", @@ -395,6 +398,7 @@ def from_dataclass( frozen: bool = False, autocast: bool = False, nocast: bool = False, + inplace: bool = False, ) -> Any: """Converts a dataclass instance or a type into a tensorclass instance or type, respectively. @@ -409,6 +413,8 @@ def from_dataclass( frozen (bool, optional): If ``True``, the resulting class or instance will be immutable. Defaults to ``False``. autocast (bool, optional): If ``True``, enables automatic type casting for the resulting class or instance. Defaults to ``False``. nocast (bool, optional): If ``True``, disables any type casting for the resulting class or instance. Defaults to ``False``. + inplace (bool, optional): If ``True``, the dataclass type passed will be modified in-place. Defaults to ``False``. + Without effect if an instance is provided. Returns: A tensor-compatible class or instance derived from the provided dataclass. @@ -457,9 +463,14 @@ def from_dataclass( if isinstance(obj, type): if is_tensorclass(obj): return obj - cls = make_dataclass( - obj.__name__ + "_tc", fields=obj.__dataclass_fields__, bases=obj.__bases__ - ) + if not inplace: + cls = make_dataclass( + obj.__name__ + "_tc", + fields=obj.__dataclass_fields__, + bases=obj.__bases__, + ) + else: + cls = obj clz = _tensorclass(cls, frozen=frozen) clz._type_hints = get_type_hints(obj) clz._autocast = autocast @@ -768,7 +779,11 @@ def __torch_function__( cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}" _register_tensor_class(cls) - _register_td_node(cls) + try: + _register_td_node(cls) + except ValueError: + # The class may already be registered as a pytree node + pass # faster than doing instance checks cls._is_non_tensor = _is_non_tensor diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index a77ef185a..da3dc7e07 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -361,6 +361,10 @@ class TensorClass: def reshape(self, *shape: int): ... @overload def reshape(self, shape: list | tuple): ... + def repeat_interleave( + self, repeats: torch.Tensor | int, dim: int = None, *, output_size: int = None + ) -> TensorDictBase: ... + def repeat(self, *repeats: int) -> TensorDictBase: ... def cat_tensors( self, *keys: NestedKey, diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 0f71bd743..806a282a1 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -691,9 +691,7 @@ class MyClass: assert (c_gather2 == c_gather).all() - def test_get( - self, - ): + def test_get(self): @tensorclass class MyDataNest: X: torch.Tensor @@ -776,9 +774,7 @@ class MyDataParent: assert data.y.v == "test_nested" assert data.y.batch_size == torch.Size(batch_size) - def test_indexing( - self, - ): + def test_indexing(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -836,9 +832,7 @@ class MyClass: assert a.grad.x is not None assert a.grad.z is None - def test_kjt( - self, - ): + def test_kjt(self): try: from torchrec import KeyedJaggedTensor except ImportError: @@ -887,9 +881,7 @@ class MyData: ).all() assert subdata.z == data.z == z - def test_len( - self, - ): + def test_len(self): myc = MyData( X=torch.rand(2, 3, 4), y=torch.rand(2, 3, 4, 5), @@ -906,18 +898,14 @@ def test_len( ) assert len(myc2) == 0 - def test_multiprocessing( - self, - ): + def test_multiprocessing(self): with Pool(os.cpu_count()) as p: catted = torch.cat(p.map(_make_data, [(i, 2) for i in range(1, 9)]), dim=0) assert catted.batch_size == torch.Size([36]) assert catted.z == "test_tensorclass" - def test_nested( - self, - ): + def test_nested(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -932,9 +920,7 @@ class MyDataNested: assert isinstance(data.y, MyDataNested), type(data.y) assert data.z == data_nest.z == data.y.z == z - def test_nested_eq( - self, - ): + def test_nested_eq(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -995,9 +981,7 @@ class MyDataParent: assert isinstance(data[0].y, type(data.y)) assert data[0].y.X.shape == torch.Size([4, 5]) - def test_nested_ne( - self, - ): + def test_nested_ne(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -1018,9 +1002,7 @@ class MyDataNested: assert not (data != data2).y.X.any() assert not (data != data2).y.z.any() - def test_permute( - self, - ): + def test_permute(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -1041,9 +1023,7 @@ class MyDataNested: assert isinstance(permuted_data._tensordict, _PermutedTensorDict) assert permuted_data.z == permuted_data.y.z == z - def test_pickle( - self, - ): + def test_pickle(self): data = MyData( X=torch.ones(3, 4, 5), y=torch.zeros(3, 4, 5, dtype=torch.bool), @@ -1067,9 +1047,7 @@ def test_pickle( assert isinstance(data2, MyData) assert data2.z == data.z - def test_post_init( - self, - ): + def test_post_init(self): @tensorclass class MyDataPostInit: X: torch.Tensor @@ -1097,9 +1075,7 @@ def __post_init__(self): TensorDict({"X": -torch.ones(2), "y": torch.rand(2)}, batch_size=[2]) ) - def test_pre_allocate( - self, - ): + def test_pre_allocate(self): @tensorclass class M1: X: Any @@ -1118,9 +1094,35 @@ class M3: m1[0] = m2 assert (m1[0].X.X.X == m2.X.X.X).all() - def test_reshape( - self, - ): + def test_repeat(self): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + assert (data.repeat(2, 3) == torch.cat([torch.cat([data] * 2, 0)] * 3, 1)).all() + + def test_repeat_interleave(self): + @tensorclass + class MyDataNested: + X: torch.Tensor + z: str + y: "MyDataNested" = None + + X = torch.ones(3, 4, 5) + z = "test_tensorclass" + batch_size = [3, 4] + data_nest = MyDataNested(X=X, z=z, batch_size=batch_size) + data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size) + assert data.repeat_interleave(2, dim=1).shape == torch.Size((3, 8)) + + def test_reshape(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -1140,9 +1142,7 @@ class MyDataNested: assert isinstance(stacked_tc._tensordict, TensorDict) assert stacked_tc.z == stacked_tc.y.z == z - def test_set( - self, - ): + def test_set(self): @tensorclass class MyDataNest: X: torch.Tensor @@ -1316,9 +1316,7 @@ class MyDataParent: # ensure optional fields are writable data.k = torch.zeros(3, 4, 5) - def test_setitem( - self, - ): + def test_setitem(self): data = MyData( X=torch.ones(3, 4, 5), y=torch.zeros(3, 4, 5), @@ -1423,9 +1421,7 @@ class MyDataNested: assert (data.X[:2] == 0).all() assert (data.y.X[:2] == 0).all() - def test_setitem_memmap( - self, - ): + def test_setitem_memmap(self): # regression test PR #203 # We should be able to set tensors items with MemoryMappedTensors and viceversa @tensorclass @@ -1454,9 +1450,7 @@ class MyDataMemMap1: assert (data2.x[2:] == 0).all() assert (data2.y[2:] == 0).all() - def test_setitem_other_cls( - self, - ): + def test_setitem_other_cls(self): @tensorclass class MyData1: x: torch.Tensor @@ -1504,9 +1498,7 @@ class MyData3: ): data_wrong_cls[2:] = data1[2:] - def test_signature( - self, - ): + def test_signature(self): sig = inspect.signature(MyData) assert list(sig.parameters) == ["X", "y", "z", "batch_size", "device", "names"] @@ -1689,9 +1681,7 @@ class MyDataNested: ): torch.stack([data1, data3], dim=0) - def test_statedict_errors( - self, - ): + def test_statedict_errors(self): @tensorclass class MyClass: x: torch.Tensor @@ -1723,9 +1713,7 @@ class MyClass: with pytest.raises(KeyError, match="Key 'a' wasn't expected in the state-dict"): tc.load_state_dict(sd) - def test_tensorclass_get_at( - self, - ): + def test_tensorclass_get_at(self): @tensorclass class MyDataNest: X: torch.Tensor @@ -1753,9 +1741,7 @@ class MyDataParent: assert data.get_at(("y", "foo"), slice(2, 3), "working") == "working" assert data.get_at("foo", slice(2, 3), "working") == "working" - def test_tensorclass_set_at_( - self, - ): + def test_tensorclass_set_at_(self): @tensorclass class MyDataNest: X: torch.Tensor @@ -1786,9 +1772,7 @@ class MyDataParent: assert (data.get_at("X", slice(3, 5)) == 1).all() assert (data.get_at(("y", "X"), slice(3, 5)) == 1).all() - def test_to_tensordict( - self, - ): + def test_to_tensordict(self): @tensorclass class MyClass: x: torch.Tensor @@ -1871,9 +1855,7 @@ class MyClass: # load_state_dict outperforms snapshot in this case assert tc_dest.z == z - def test_type( - self, - ): + def test_type(self): data = MyData( X=torch.ones(3, 4, 5), y=torch.zeros(3, 4, 5, dtype=torch.bool), @@ -1886,9 +1868,7 @@ def test_type( # we get an instance of the user defined class, not a dynamically defined subclass assert type(data) is MyDataUndecorated - def test_unbind( - self, - ): + def test_unbind(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -1909,9 +1889,7 @@ class MyDataNested: assert unbind_tcs[0].batch_size == torch.Size([4]) assert unbind_tcs[0].z == unbind_tcs[1].z == unbind_tcs[2].z == z - def test_unsqueeze( - self, - ): + def test_unsqueeze(self): @tensorclass class MyDataNested: X: torch.Tensor @@ -1929,9 +1907,7 @@ class MyDataNested: assert unsqueeze_tc.y.X.shape == torch.Size([3, 1, 4, 5]) assert unsqueeze_tc.z == unsqueeze_tc.y.z == z - def test_view( - self, - ): + def test_view(self): @tensorclass class MyDataNested: X: torch.Tensor diff --git a/test/test_tensordict.py b/test/test_tensordict.py index bc9e45963..edf9eb580 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -5564,6 +5564,32 @@ def test_rename_key_nested(self, td_name, device) -> None: assert (td[("nested", "back")] == 0).all() assert "second" not in td.keys() + def test_repeat(self, td_name, device): + td = getattr(self, td_name)(device) + assert (td.repeat(1, 1, 1, 1) == td).all() + assert (td.repeat(2, 1, 1, 1) == torch.cat([td] * 2, 0)).all() + assert (td.repeat(1, 2, 1, 1) == torch.cat([td] * 2, 1)).all() + assert (td.repeat(1, 1, 2, 1) == torch.cat([td] * 2, 2)).all() + assert (td.repeat(1, 1, 1, 2) == torch.cat([td] * 2, 3)).all() + + def test_repeat_interleave(self, td_name, device): + td = getattr(self, td_name)(device) + for d in [0, 1, 2, 3, -1, -2, -3, -4]: + t = torch.empty(td.shape) + t_shape = t.repeat_interleave(3, dim=d).shape + td_repeat = td.repeat_interleave(3, dim=d) + assert td_repeat.shape == t_shape + assert td_repeat.device == td.device + if d < 0: + d = td.ndim + d + a = td["a"] + a_repeat = td_repeat["a"] + torch.testing.assert_close(a.repeat_interleave(3, dim=d), a_repeat) + + t = torch.empty(td.shape) + t_shape = t.repeat_interleave(3).shape + assert t_shape == td.repeat_interleave(3).shape + def test_replace(self, td_name, device): td = getattr(self, td_name)(device) td_dict = td.to_dict() @@ -10703,6 +10729,30 @@ def test_nontensor_tensor(self): ) # this triggers an exception assert all(isinstance(t, torch.Tensor) for t in stack.tolist()) + def test_repeat(self): + stack = NonTensorStack( + NonTensorData("a", batch_size=(3,)), + NonTensorData("b", batch_size=(3,)), + stack_dim=1, + ) + assert stack.shape == (3, 2) + r = stack.repeat(1, 2) + assert r.shape == (3, 4) + assert r[0].tolist() == ["a", "b", "a", "b"] + assert r[:, 0].tolist() == ["a", "a", "a"] + assert r[:, -1].tolist() == ["b", "b", "b"] + + def test_repeat_interleave(self): + stack = NonTensorStack( + NonTensorData("a", batch_size=(3,)), + NonTensorData("b", batch_size=(3,)), + stack_dim=1, + ) + assert stack.shape == (3, 2) + r = stack.repeat_interleave(3, dim=1) + assert isinstance(r, NonTensorStack) + assert r[0].tolist() == ["a", "a", "a", "b", "b", "b"] + def test_set(self, non_tensor_data): non_tensor_data.set(("nested", "another_string"), "another string!") assert (