Skip to content

Commit

Permalink
[Feature] repeat and repeat_interleave
Browse files Browse the repository at this point in the history
ghstack-source-id: d90a1a7bd87115c5f7af1a413788a30cbc2096ee
Pull Request resolved: #1115
  • Loading branch information
vmoens committed Nov 27, 2024
1 parent a45c7e3 commit 004f979
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 100 deletions.
73 changes: 65 additions & 8 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3041,6 +3041,62 @@ def _transpose(self, dim0, dim1):
)
return result

def _repeat(self, *repeats: int) -> TensorDictBase:
repeats = list(repeats)
r_dim = repeats.pop(self.stack_dim)
tds = [td.repeat(*repeats) for td in self.tensordicts]
tds = [td for _ in range(r_dim) for td in tds]
return type(self)(
*tds,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
hook_in=self.hook_in,
hook_out=self.hook_out,
)

def repeat_interleave(
self, repeats: torch.Tensor | int, dim: int = None, *, output_size: int = None
) -> TensorDictBase:
if self.ndim == 0:
return self.unsqueeze(0).repeat_interleave(
repeats=repeats, dim=dim, output_size=output_size
)
if dim is None:
if self.ndim > 1:
return self.reshape(-1).repeat_interleave(repeats, dim=0)
return self.repeat_interleave(repeats, dim=0)
dim_corrected = dim if dim >= 0 else self.ndim + dim
if not (dim_corrected >= 0):
raise ValueError(
f"dim {dim} is out of range for tensordict with shape {self.shape}."
)
if dim_corrected == self.stack_dim:
new_list_of_tds = [t for t in self.tensordicts for _ in range(repeats)]
result = type(self)(
*new_list_of_tds,
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
hook_out=self.hook_out,
hook_in=self.hook_in,
)
else:
dim_corrected = (
dim_corrected if dim_corrected < self.stack_dim else dim_corrected - 1
)
result = type(self)(
*(
td.repeat_interleave(
repeats=repeats, dim=dim_corrected, output_size=output_size
)
for td in self.tensordicts
),
stack_dim=self.stack_dim,
stack_dim_name=self._td_dim_name,
hook_in=self.hook_in,
hook_out=self.hook_out,
)
return result

def _permute(
self,
*args,
Expand Down Expand Up @@ -3815,23 +3871,24 @@ def _cast_reduction(
_check_is_shared = TensorDict._check_is_shared
_convert_to_tensordict = TensorDict._convert_to_tensordict
_index_tensordict = TensorDict._index_tensordict
masked_select = TensorDict.masked_select
reshape = TensorDict.reshape
split = TensorDict.split
_to_module = TensorDict._to_module

_apply_nest = TensorDict._apply_nest
_get_names_idx = TensorDict._get_names_idx
_maybe_remove_batch_dim = TensorDict._maybe_remove_batch_dim
_multithread_apply_flat = TensorDict._multithread_apply_flat
_multithread_rebuild = TensorDict._multithread_rebuild

_remove_batch_dim = TensorDict._remove_batch_dim
_maybe_remove_batch_dim = TensorDict._maybe_remove_batch_dim
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind
all = TensorDict.all
any = TensorDict.any
expand = TensorDict.expand
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx
from_dict_instance = TensorDict.from_dict_instance
masked_select = TensorDict.masked_select
_repeat = TensorDict._repeat
repeat_interleave = TensorDict.repeat_interleave
reshape = TensorDict.reshape
split = TensorDict.split


class _UnsqueezedTensorDict(_CustomOpTensorDict):
Expand Down
46 changes: 44 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,6 +1790,46 @@ def _reshape(tensor):
propagate_lock=True,
)

def repeat_interleave(
self, repeats: torch.Tensor | int, dim: int = None, *, output_size: int = None
) -> T:
if self.ndim == 0:
return self.unsqueeze(0).repeat_interleave(
repeats=repeats, dim=dim, output_size=output_size
)
if dim is None:
if self.ndim > 1:
return self.reshape(-1).repeat_interleave(repeats, dim=0)
return self.repeat_interleave(repeats, dim=0)
dim_corrected = dim if dim >= 0 else self.ndim + dim
if not (dim_corrected >= 0):
raise ValueError(
f"dim {dim} is out of range for tensordict with shape {self.shape}."
)
new_batch_size = torch.Size(
[
s if i != dim_corrected else s * repeats
for i, s in enumerate(self.batch_size)
]
)
return self._fast_apply(
lambda leaf: leaf.repeat_interleave(
repeats=repeats, dim=dim_corrected, output_size=output_size
),
batch_size=new_batch_size,
call_on_nested=True,
propagate_lock=True,
)

def _repeat(self, *repeats: int) -> TensorDictBase:
new_batch_size = torch.Size([i * r for i, r in zip(self.batch_size, repeats)])
return self._fast_apply(
lambda leaf: leaf.repeat(*repeats, *((1,) * (leaf.ndim - self.ndim))),
batch_size=new_batch_size,
call_on_nested=True,
propagate_lock=True,
)

def _transpose(self, dim0, dim1):
def _transpose(tensor):
return tensor.transpose(dim0, dim1)
Expand Down Expand Up @@ -4208,14 +4248,16 @@ def _cast_reduction(
__or__ = TensorDict.__or__
_check_device = TensorDict._check_device
_check_is_shared = TensorDict._check_is_shared
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind
all = TensorDict.all
any = TensorDict.any
masked_select = TensorDict.masked_select
memmap_like = TensorDict.memmap_like
repeat_interleave = TensorDict.repeat_interleave
_repeat = TensorDict._repeat
reshape = TensorDict.reshape
split = TensorDict.split
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind

def _view(self, *args, **kwargs):
raise RuntimeError(
Expand Down
117 changes: 117 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2574,6 +2574,123 @@ def reshape(
"""
...

@abc.abstractmethod
def repeat_interleave(
self, repeats: torch.Tensor | int, dim: int = None, *, output_size: int = None
) -> TensorDictBase:
"""Repeat elements of a TensorDict.
.. warning:: This is different from :meth:`~torch.Tensor.repeat` but similar to :func:`numpy.repeat`.
Args:
repeats (torch.Tensor or int): The number of repetitions for each element. `repeats` is broadcast to fit
the shape of the given axis.
dim (int, optional): The dimension along which to repeat values. By default, use the flattened input
array, and return a flat output array.
Keyword Args:
output_size (int, optional): Total output size for the given axis (e.g. sum of repeats). If given, it
will avoid stream synchronization needed to calculate output shape of the tensordict.
Returns:
Repeated TensorDict which has the same shape as input, except along the given axis.
Examples:
>>> import torch
>>>
>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
... {
... "a": torch.randn(3, 4, 5),
... "b": TensorDict({
... "c": torch.randn(3, 4, 10, 1),
... "a string": "a string!",
... }, batch_size=[3, 4, 10])
... }, batch_size=[3, 4],
... )
>>> print(td.repeat_interleave(2, dim=0))
TensorDict(
fields={
a: Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
a string: NonTensorData(data=a string!, batch_size=torch.Size([6, 4, 10]), device=None),
c: Tensor(shape=torch.Size([6, 4, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([6, 4, 10]),
device=None,
is_shared=False)},
batch_size=torch.Size([6, 4]),
device=None,
is_shared=False)
"""
...

@overload
def repeat(self, repeats: torch.Size): ...

def repeat(self, *repeats: int) -> TensorDictBase:
"""Repeats this tensor along the specified dimensions.
Unlike :meth:`~.expand()`, this function copies the tensor’s data.
.. warning:: :meth:`~.repeat` behaves differently from :func:`~numpy.repeat`, but is more similar to
:func:`numpy.tile`. For the operator similar to :func:`numpy.repeat`, see :meth:`~tensordict.TensorDictBase.repeat_interleave`.
Args:
repeat (torch.Size, int..., tuple of int or list of int): The number of times to repeat this tensor along
each dimension.
Examples:
>>> import torch
>>>
>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
... {
... "a": torch.randn(3, 4, 5),
... "b": TensorDict({
... "c": torch.randn(3, 4, 10, 1),
... "a string": "a string!",
... }, batch_size=[3, 4, 10])
... }, batch_size=[3, 4],
... )
>>> print(td.repeat(1, 2))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 8, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
a string: NonTensorData(data=a string!, batch_size=torch.Size([3, 8, 10]), device=None),
c: Tensor(shape=torch.Size([3, 8, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 8, 10]),
device=None,
is_shared=False)},
batch_size=torch.Size([3, 8]),
device=None,
is_shared=False)
"""
if len(repeats) == 1 and not isinstance(repeats[0], int):
repeats = repeats[0]
if isinstance(repeats, torch.Size):
return self.repeat(*repeats[0])
if isinstance(repeats, torch.Tensor):
# This will cause cuda to sync, which may not be desirable
return self.repeat(*repeats.tolist())
raise ValueError(
f"repeats must be a sequence of integers, a tensor or a torch.Size object. Got {type(repeats)} instead."
)
if len(repeats) != self.ndimension():
raise ValueError(
f"The number of repeat elements must match the number of dimensions of the tensordict. Got {len(repeats)} but ndim={self.ndimension()}."
)
return self._repeat(*repeats)

@abc.abstractmethod
def _repeat(self, *repeats: int) -> TensorDictBase: ...

def cat_tensors(
self,
*keys: NestedKey,
Expand Down
6 changes: 6 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,12 @@ def memmap_like(
@_fallback
def reshape(self, *shape: int): ...

@_fallback
def repeat_interleave(self, *shape: int): ...

@_fallback
def _repeat(self, *repeats: int): ...

@_fallback
def split(
self, split_size: int | list[int], dim: int = 0
Expand Down
18 changes: 10 additions & 8 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,25 +1412,27 @@ def _unsqueeze(self, dim):
__le__ = TensorDict.__le__
__lt__ = TensorDict.__lt__

_cast_reduction = TensorDict._cast_reduction
_apply_nest = TensorDict._apply_nest
_multithread_apply_flat = TensorDict._multithread_apply_flat
_multithread_rebuild = TensorDict._multithread_rebuild

_cast_reduction = TensorDict._cast_reduction
_check_device = TensorDict._check_device
_check_is_shared = TensorDict._check_is_shared
_convert_to_tensordict = TensorDict._convert_to_tensordict
_get_names_idx = TensorDict._get_names_idx
_index_tensordict = TensorDict._index_tensordict
_multithread_apply_flat = TensorDict._multithread_apply_flat
_multithread_rebuild = TensorDict._multithread_rebuild
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind
all = TensorDict.all
any = TensorDict.any
expand = TensorDict.expand
from_dict_instance = TensorDict.from_dict_instance
masked_select = TensorDict.masked_select
_repeat = TensorDict._repeat
_repeat = TensorDict._repeat
repeat_interleave = TensorDict.repeat_interleave
reshape = TensorDict.reshape
split = TensorDict.split
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx
from_dict_instance = TensorDict.from_dict_instance


def _set_max_batch_size(source: PersistentTensorDict):
Expand Down
23 changes: 19 additions & 4 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __subclasscheck__(self, subclass):
"_maybe_remove_batch_dim",
"_multithread_apply_flat",
"_remove_batch_dim",
"_repeat",
"_select", # TODO: must be specialized
"_set_at_tuple",
"_set_tuple",
Expand Down Expand Up @@ -297,6 +298,8 @@ def __subclasscheck__(self, subclass):
"reciprocal_",
"refine_names",
"rename_", # TODO: must be specialized
"repeat",
"repeat_interleave",
"replace",
"requires_grad_",
"reshape",
Expand Down Expand Up @@ -395,6 +398,7 @@ def from_dataclass(
frozen: bool = False,
autocast: bool = False,
nocast: bool = False,
inplace: bool = False,
) -> Any:
"""Converts a dataclass instance or a type into a tensorclass instance or type, respectively.
Expand All @@ -409,6 +413,8 @@ def from_dataclass(
frozen (bool, optional): If ``True``, the resulting class or instance will be immutable. Defaults to ``False``.
autocast (bool, optional): If ``True``, enables automatic type casting for the resulting class or instance. Defaults to ``False``.
nocast (bool, optional): If ``True``, disables any type casting for the resulting class or instance. Defaults to ``False``.
inplace (bool, optional): If ``True``, the dataclass type passed will be modified in-place. Defaults to ``False``.
Without effect if an instance is provided.
Returns:
A tensor-compatible class or instance derived from the provided dataclass.
Expand Down Expand Up @@ -457,9 +463,14 @@ def from_dataclass(
if isinstance(obj, type):
if is_tensorclass(obj):
return obj
cls = make_dataclass(
obj.__name__ + "_tc", fields=obj.__dataclass_fields__, bases=obj.__bases__
)
if not inplace:
cls = make_dataclass(
obj.__name__ + "_tc",
fields=obj.__dataclass_fields__,
bases=obj.__bases__,
)
else:
cls = obj
clz = _tensorclass(cls, frozen=frozen)
clz._type_hints = get_type_hints(obj)
clz._autocast = autocast
Expand Down Expand Up @@ -768,7 +779,11 @@ def __torch_function__(
cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}"

_register_tensor_class(cls)
_register_td_node(cls)
try:
_register_td_node(cls)
except ValueError:
# The class may already be registered as a pytree node
pass

# faster than doing instance checks
cls._is_non_tensor = _is_non_tensor
Expand Down
Loading

0 comments on commit 004f979

Please sign in to comment.