diff --git a/tensordict/base.py b/tensordict/base.py index 8ae783905..e4abac65f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -220,6 +220,30 @@ class TensorDictBase(MutableMapping): def __bool__(self) -> bool: raise RuntimeError("Converting a tensordict to boolean value is not permitted") + def __abs__(self) -> T: + """Returns a new TensorDict instance with absolute values of all tensors. + + Returns: + A new TensorDict instance with the same key set as the original, + but with all tensors having their absolute values computed. + + .. seealso:: :meth:`~.abs` + + """ + return self.abs() + + def __neg__(self) -> T: + """Returns a new TensorDict instance with negated values of all tensors. + + Returns: + A new TensorDict instance with the same key set as the original, + but with all tensors having their values negated. + + .. seealso:: :meth:`~.neg` + + """ + return self.neg() + @abc.abstractmethod def __ne__(self, other: object) -> T: """NOT operation over two tensordicts, for evey key. @@ -237,7 +261,7 @@ def __ne__(self, other: object) -> T: ... @abc.abstractmethod - def __xor__(self, other: TensorDictBase | float): + def __xor__(self, other: TensorDictBase | torch.Tensor | float): """XOR operation over two tensordicts, for evey key. The two tensordicts must have the same key set. @@ -252,6 +276,13 @@ def __xor__(self, other: TensorDictBase | float): """ ... + def __rxor__(self, other: TensorDictBase | torch.Tensor | float): + """XOR operation over two tensordicts, for evey key. + + Wraps `__xor__` as it is assumed to be commutative. + """ + return self.__xor__(other) + @abc.abstractmethod def __or__(self, other: TensorDictBase | torch.Tensor) -> T: """OR operation over two tensordicts, for evey key. @@ -268,6 +299,71 @@ def __or__(self, other: TensorDictBase | torch.Tensor) -> T: """ ... + def __ror__(self, other: TensorDictBase | torch.Tensor) -> T: + """Right-side OR operation over two tensordicts, for evey key. + + This is a wrapper around `__or__` since it is assumed to be commutative. + """ + return self | other + + def __invert__(self) -> T: + """Returns a new TensorDict instance with all tensors inverted (i.e., bitwise NOT operation). + + Returns: + A new TensorDict instance with the same key set as the original, + but with all tensors having their bits inverted. + """ + keys, vals = self._items_list(True, True) + vals = [~v for v in vals] + items = dict(zip(keys, vals)) + + def get(name, val): + return items.get(name, val) + + return self._fast_apply( + get, + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + ) + + def __and__(self, other: TensorDictBase | torch.Tensor | float) -> T: + """Returns a new TensorDict instance with all tensors performing a logical or bitwise AND operation with the given value. + + Args: + other: The value to perform the AND operation with. + + Returns: + A new TensorDict instance with the same key set as the original, + but with all tensors having performed a AND operation with the given value. + """ + keys, vals = self._items_list(True, True) + if _is_tensor_collection(type(other)): + new_keys, other_val = other._items_list(True, True, sorting_keys=keys) + vals = [(v1 & v2) for v1, v2 in zip(vals, other_val)] + else: + vals = [(v & other) for v in vals] + items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + + result = self._fast_apply( + pop, + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + filter_empty=True, + default=None, + ) + if items: + result.update(items) + return result + + __rand__ = __and__ + @abc.abstractmethod def __eq__(self, other: object) -> T: """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. @@ -8903,21 +8999,27 @@ def record(tensor): def __add__(self, other: TensorDictBase | torch.Tensor) -> T: return self.add(other) + def __radd__(self, other: TensorDictBase | torch.Tensor) -> T: + return self.add(other) + def __iadd__(self, other: TensorDictBase | torch.Tensor) -> T: return self.add_(other) - def __abs__(self): - return self.abs() - def __truediv__(self, other: TensorDictBase | torch.Tensor) -> T: return self.div(other) def __itruediv__(self, other: TensorDictBase | torch.Tensor) -> T: return self.div_(other) + def __rtruediv__(self, other: TensorDictBase | torch.Tensor) -> T: + return other * self.reciprocal() + def __mul__(self, other: TensorDictBase | torch.Tensor) -> T: return self.mul(other) + def __rmul__(self, other: TensorDictBase | torch.Tensor) -> T: + return self.mul(other) + def __imul__(self, other: TensorDictBase | torch.Tensor) -> T: return self.mul_(other) @@ -8927,9 +9029,18 @@ def __sub__(self, other: TensorDictBase | torch.Tensor) -> T: def __isub__(self, other: TensorDictBase | torch.Tensor) -> T: return self.sub_(other) + def __rsub__(self, other: TensorDictBase | torch.Tensor) -> T: + return self.sub(other) + def __pow__(self, other: TensorDictBase | torch.Tensor) -> T: return self.pow(other) + def __rpow__(self, other: TensorDictBase | torch.Tensor) -> T: + raise NotImplementedError( + "rpow isn't implemented for tensordict yet. Make sure both elements are wrapped " + "in a tensordict for this to work." + ) + def __ipow__(self, other: TensorDictBase | torch.Tensor) -> T: return self.pow_(other) @@ -9661,6 +9772,110 @@ def pop(name, val): result.update(items) return result + def bitwise_and( + self, + other: TensorDictBase | torch.Tensor, + *, + default: str | CompatibleType | None = None, + ) -> TensorDictBase: # noqa: D417 + r"""Performs a bitwise AND operation between ``self`` and :attr:`other`. + + .. math:: + \text{{out}}_i = \text{{input}}_i \land \text{{other}}_i + + Args: + other (TensorDictBase or torch.Tensor): the tensor or TensorDict to perform the bitwise AND with. + + Keyword Args: + default (torch.Tensor or str, optional): the default value to use for exclusive entries. + If none is provided, the two tensordicts key list must match exactly. + If ``default="intersection"`` is passed, only the intersecting key sets will be considered + and other keys will be ignored. + In all other cases, ``default`` will be used for all missing entries on both sides of the + operation. + """ + keys, vals = self._items_list(True, True) + if _is_tensor_collection(type(other)): + new_keys, other_val = other._items_list( + True, True, sorting_keys=keys, default=default + ) + if default is not None: + as_dict = dict(zip(keys, vals)) + vals = [as_dict.get(key, default) for key in new_keys] + keys = new_keys + vals = [(v1.bitwise_and(v2)) for v1, v2 in zip(vals, other_val)] + else: + vals = [v.bitwise_and(other) for v in vals] + items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + + result = self._fast_apply( + pop, + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + filter_empty=True, + default=None, + ) + if items: + result.update(items) + return result + + def logical_and( + self, + other: TensorDictBase | torch.Tensor, + *, + default: str | CompatibleType | None = None, + ) -> TensorDictBase: # noqa: D417 + r"""Performs a logical AND operation between ``self`` and :attr:`other`. + + .. math:: + \text{{out}}_i = \text{{input}}_i \land \text{{other}}_i + + Args: + other (TensorDictBase or torch.Tensor): the tensor or TensorDict to perform the logical AND with. + + Keyword Args: + default (torch.Tensor or str, optional): the default value to use for exclusive entries. + If none is provided, the two tensordicts key list must match exactly. + If ``default="intersection"`` is passed, only the intersecting key sets will be considered + and other keys will be ignored. + In all other cases, ``default`` will be used for all missing entries on both sides of the + operation. + """ + keys, vals = self._items_list(True, True) + if _is_tensor_collection(type(other)): + new_keys, other_val = other._items_list( + True, True, sorting_keys=keys, default=default + ) + if default is not None: + as_dict = dict(zip(keys, vals)) + vals = [as_dict.get(key, default) for key in new_keys] + keys = new_keys + vals = [(v1.logical_and(v2)) for v1, v2 in zip(vals, other_val)] + else: + vals = [v.logical_and(other) for v in vals] + items = dict(zip(keys, vals)) + + def pop(name, val): + return items.pop(name, None) + + result = self._fast_apply( + pop, + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + filter_empty=True, + default=None, + ) + if items: + result.update(items) + return result + def add( self, other: TensorDictBase | torch.Tensor, @@ -9719,7 +9934,12 @@ def pop(name, val): result.update(items) return result - def add_(self, other: TensorDictBase | float, *, alpha: float | None = None): + def add_( + self, + other: TensorDictBase | torch.Tensor | float, + *, + alpha: float | None = None, + ): """In-place version of :meth:`~.add`. .. note:: @@ -9781,7 +10001,11 @@ def get(name, val): propagate_lock=True, ) - def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float): + def lerp_( + self, + end: TensorDictBase | torch.Tensor | float, + weight: TensorDictBase | torch.Tensor | float, + ): """In-place version of :meth:`~.lerp`.""" if _is_tensor_collection(type(end)): end_val = end._values_list(True, True) @@ -9917,7 +10141,7 @@ def addcmul_(self, other1, other2, *, value: float | None = 1): def sub( self, - other: TensorDictBase | float, + other: TensorDictBase | torch.Tensor | float, *, alpha: float | None = None, default: str | CompatibleType | None = None, @@ -9976,7 +10200,9 @@ def pop(name, val): result.update(items) return result - def sub_(self, other: TensorDictBase | float, alpha: float | None = None): + def sub_( + self, other: TensorDictBase | torch.Tensor | float, alpha: float | None = None + ): """In-place version of :meth:`~.sub`. .. note:: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 62eeb83bb..50af28215 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -166,19 +166,30 @@ def __subclasscheck__(self, subclass): _FALLBACK_METHOD_FROM_TD = [ "__abs__", "__add__", + "__and__", "__bool__", "__eq__", "__ge__", "__gt__", "__iadd__", "__imul__", + "__invert__", "__ipow__", "__isub__", "__itruediv__", "__mul__", "__ne__", + "__neg__", "__or__", "__pow__", + "__radd__", + "__rand__", + "__rmul__", + "__ror__", + "__rpow__", + "__rsub__", + "__rtruediv__", + "__rxor__", "__sub__", "__truediv__", "__xor__", @@ -228,6 +239,7 @@ def __subclasscheck__(self, subclass): "atan_", "auto_batch_size_", "auto_device_", + "bitwise_and", "ceil", "ceil_", "chunk", @@ -291,10 +303,8 @@ def __subclasscheck__(self, subclass): "log2", "log2_", "log_", - "map", + "logical_and" "map", "map_iter", - "to_namedtuple", - "to_pytree", "masked_fill", "masked_fill_", "max", @@ -356,6 +366,8 @@ def __subclasscheck__(self, subclass): "tanh_", "to", "to_module", + "to_namedtuple", + "to_pytree", "transpose", "trunc", "trunc_", diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 3133f1c2a..e29778bf9 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3080,6 +3080,72 @@ def test_zero_grad_module(self): class TestPointwiseOps: + def test_r_ops(self): + td = TensorDict(a=1) + # mul + assert isinstance(0 * td, TensorDict) + assert isinstance(torch.zeros(()) * td, TensorDict) + # + + assert isinstance(0 + td, TensorDict) + assert isinstance(torch.zeros(()) + td, TensorDict) + # - + assert isinstance(0 - td, TensorDict) + assert isinstance(torch.zeros(()) - td, TensorDict) + # / + assert isinstance(0 / td, TensorDict) + assert isinstance(torch.zeros(()) / td, TensorDict) + # ** + # assert isinstance(1 ** td, TensorDict) + # assert isinstance(torch.ones(()) ** td, TensorDict) + + td = TensorDict(a=True) + # | + assert isinstance(False | td, TensorDict) + assert isinstance(torch.zeros((), dtype=torch.bool) | td, TensorDict) + # ^ + assert isinstance(False ^ td, TensorDict) + assert isinstance(torch.zeros((), dtype=torch.bool) ^ td, TensorDict) + + def test_builtins(self): + td_float = TensorDict(a=1.0) + td_bool = TensorDict(a=True) + ones = torch.ones(()) + bool_ones = torch.ones(()).to(torch.bool) + assert ((-td_float) == (-ones)).all() + # assert ((-td_bool) == (-bool_ones)).all() # Not defined for bool + assert (abs(td_float) == abs(ones)).all() + # assert (abs(td_bool) == abs(bool_ones)).all() # Not defined for bool + # assert ((~td_float) == (~ones)).all() # Not defined for float + assert ((~td_bool) == (~bool_ones)).all() + assert ((td_float != td_float) == (ones != ones)).all() + assert ((td_bool != td_bool) == (bool_ones != bool_ones)).all() + assert ((td_float == td_float) == (ones == ones)).all() + assert ((td_bool == td_bool) == (bool_ones == bool_ones)).all() + assert ((td_float < td_float) == (ones < ones)).all() + assert ((td_bool < td_bool) == (bool_ones < bool_ones)).all() + assert ((td_float <= td_float) == (ones <= ones)).all() + assert ((td_bool <= td_bool) == (bool_ones <= bool_ones)).all() + assert ((td_float > td_float) == (ones > ones)).all() + assert ((td_bool > td_bool) == (bool_ones > bool_ones)).all() + assert ((td_float >= td_float) == (ones >= ones)).all() + assert ((td_bool >= td_bool) == (bool_ones >= bool_ones)).all() + assert ((td_float + td_float) == (ones + ones)).all() + # assert ((td_bool + td_bool) == (bool_ones + bool_ones)).all() # Not defined for bool + assert ((td_float - td_float) == (ones - ones)).all() + # assert ((td_bool - td_bool) == (bool_ones - bool_ones)).all() # Not defined for bool + assert ((td_float * td_float) == (ones * ones)).all() + # assert ((td_bool * td_bool) == (bool_ones * bool_ones)).all() # Not defined for bool + assert ((td_float / td_float) == (ones / ones)).all() + # assert ((td_bool / td_bool) == (bool_ones / bool_ones)).all() # Not defined for bool + assert ((td_float**td_float) == (ones**ones)).all() + # assert ((td_bool**td_bool) == (bool_ones**bool_ones)).all() # Not defined for bool + # assert ((td_float & td_float) == (ones & ones)).all() # Not defined for float + assert ((td_bool & td_bool) == (bool_ones & bool_ones)).all() + # assert ((td_float ^ td_float) == (ones ^ ones)).all() # Not defined for float + assert ((td_bool ^ td_bool) == (bool_ones ^ bool_ones)).all() + # assert ((td_float | td_float) == (ones | ones)).all() # Not defined for float + assert ((td_bool | td_bool) == (bool_ones | bool_ones)).all() + @property def dummy_td_0(self): return TensorDict(