Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent 98ad4e2 commit 6d56dc7
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 11 deletions.
242 changes: 234 additions & 8 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ class TensorDictBase(MutableMapping):
def __bool__(self) -> bool:
raise RuntimeError("Converting a tensordict to boolean value is not permitted")

def __abs__(self) -> T:
"""Returns a new TensorDict instance with absolute values of all tensors.

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having their absolute values computed.

.. seealso:: :meth:`~.abs`

"""
return self.abs()

def __neg__(self) -> T:
"""Returns a new TensorDict instance with negated values of all tensors.

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having their values negated.

.. seealso:: :meth:`~.neg`

"""
return self.neg()

@abc.abstractmethod
def __ne__(self, other: object) -> T:
"""NOT operation over two tensordicts, for evey key.
Expand All @@ -237,7 +261,7 @@ def __ne__(self, other: object) -> T:
...

@abc.abstractmethod
def __xor__(self, other: TensorDictBase | float):
def __xor__(self, other: TensorDictBase | torch.Tensor | float):
"""XOR operation over two tensordicts, for evey key.

The two tensordicts must have the same key set.
Expand All @@ -252,6 +276,13 @@ def __xor__(self, other: TensorDictBase | float):
"""
...

def __rxor__(self, other: TensorDictBase | torch.Tensor | float):
"""XOR operation over two tensordicts, for evey key.

Wraps `__xor__` as it is assumed to be commutative.
"""
return self.__xor__(other)

@abc.abstractmethod
def __or__(self, other: TensorDictBase | torch.Tensor) -> T:
"""OR operation over two tensordicts, for evey key.
Expand All @@ -268,6 +299,71 @@ def __or__(self, other: TensorDictBase | torch.Tensor) -> T:
"""
...

def __ror__(self, other: TensorDictBase | torch.Tensor) -> T:
"""Right-side OR operation over two tensordicts, for evey key.

This is a wrapper around `__or__` since it is assumed to be commutative.
"""
return self | other

def __invert__(self) -> T:
"""Returns a new TensorDict instance with all tensors inverted (i.e., bitwise NOT operation).

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having their bits inverted.
"""
keys, vals = self._items_list(True, True)
vals = [~v for v in vals]
items = dict(zip(keys, vals))

def get(name, val):
return items.get(name, val)

return self._fast_apply(
get,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
)

def __and__(self, other: TensorDictBase | torch.Tensor | float) -> T:
"""Returns a new TensorDict instance with all tensors performing a logical or bitwise AND operation with the given value.

Args:
other: The value to perform the AND operation with.

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having performed a AND operation with the given value.
"""
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
new_keys, other_val = other._items_list(True, True, sorting_keys=keys)
vals = [(v1 & v2) for v1, v2 in zip(vals, other_val)]
else:
vals = [(v & other) for v in vals]
items = dict(zip(keys, vals))

def pop(name, val):
return items.pop(name, None)

result = self._fast_apply(
pop,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

__rand__ = __and__

@abc.abstractmethod
def __eq__(self, other: object) -> T:
"""Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set.
Expand Down Expand Up @@ -8903,21 +8999,27 @@ def record(tensor):
def __add__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.add(other)

def __radd__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.add(other)

def __iadd__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.add_(other)

def __abs__(self):
return self.abs()

def __truediv__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.div(other)

def __itruediv__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.div_(other)

def __rtruediv__(self, other: TensorDictBase | torch.Tensor) -> T:
return other * self.reciprocal()

def __mul__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.mul(other)

def __rmul__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.mul(other)

def __imul__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.mul_(other)

Expand All @@ -8927,9 +9029,18 @@ def __sub__(self, other: TensorDictBase | torch.Tensor) -> T:
def __isub__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.sub_(other)

def __rsub__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.sub(other)

def __pow__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.pow(other)

def __rpow__(self, other: TensorDictBase | torch.Tensor) -> T:
raise NotImplementedError(
"rpow isn't implemented for tensordict yet. Make sure both elements are wrapped "
"in a tensordict for this to work."
)

def __ipow__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.pow_(other)

Expand Down Expand Up @@ -9661,6 +9772,110 @@ def pop(name, val):
result.update(items)
return result

def bitwise_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: # noqa: D417
r"""Performs a bitwise AND operation between ``self`` and :attr:`other`.

.. math::
\text{{out}}_i = \text{{input}}_i \land \text{{other}}_i

Args:
other (TensorDictBase or torch.Tensor): the tensor or TensorDict to perform the bitwise AND with.

Keyword Args:
default (torch.Tensor or str, optional): the default value to use for exclusive entries.
If none is provided, the two tensordicts key list must match exactly.
If ``default="intersection"`` is passed, only the intersecting key sets will be considered
and other keys will be ignored.
In all other cases, ``default`` will be used for all missing entries on both sides of the
operation.
"""
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
new_keys, other_val = other._items_list(
True, True, sorting_keys=keys, default=default
)
if default is not None:
as_dict = dict(zip(keys, vals))
vals = [as_dict.get(key, default) for key in new_keys]
keys = new_keys
vals = [(v1.bitwise_and(v2)) for v1, v2 in zip(vals, other_val)]
else:
vals = [v.bitwise_and(other) for v in vals]
items = dict(zip(keys, vals))

def pop(name, val):
return items.pop(name, None)

result = self._fast_apply(
pop,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def logical_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: # noqa: D417
r"""Performs a logical AND operation between ``self`` and :attr:`other`.

.. math::
\text{{out}}_i = \text{{input}}_i \land \text{{other}}_i

Args:
other (TensorDictBase or torch.Tensor): the tensor or TensorDict to perform the logical AND with.

Keyword Args:
default (torch.Tensor or str, optional): the default value to use for exclusive entries.
If none is provided, the two tensordicts key list must match exactly.
If ``default="intersection"`` is passed, only the intersecting key sets will be considered
and other keys will be ignored.
In all other cases, ``default`` will be used for all missing entries on both sides of the
operation.
"""
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
new_keys, other_val = other._items_list(
True, True, sorting_keys=keys, default=default
)
if default is not None:
as_dict = dict(zip(keys, vals))
vals = [as_dict.get(key, default) for key in new_keys]
keys = new_keys
vals = [(v1.logical_and(v2)) for v1, v2 in zip(vals, other_val)]
else:
vals = [v.logical_and(other) for v in vals]
items = dict(zip(keys, vals))

def pop(name, val):
return items.pop(name, None)

result = self._fast_apply(
pop,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def add(
self,
other: TensorDictBase | torch.Tensor,
Expand Down Expand Up @@ -9719,7 +9934,12 @@ def pop(name, val):
result.update(items)
return result

def add_(self, other: TensorDictBase | float, *, alpha: float | None = None):
def add_(
self,
other: TensorDictBase | torch.Tensor | float,
*,
alpha: float | None = None,
):
"""In-place version of :meth:`~.add`.

.. note::
Expand Down Expand Up @@ -9781,7 +10001,11 @@ def get(name, val):
propagate_lock=True,
)

def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float):
def lerp_(
self,
end: TensorDictBase | torch.Tensor | float,
weight: TensorDictBase | torch.Tensor | float,
):
"""In-place version of :meth:`~.lerp`."""
if _is_tensor_collection(type(end)):
end_val = end._values_list(True, True)
Expand Down Expand Up @@ -9917,7 +10141,7 @@ def addcmul_(self, other1, other2, *, value: float | None = 1):

def sub(
self,
other: TensorDictBase | float,
other: TensorDictBase | torch.Tensor | float,
*,
alpha: float | None = None,
default: str | CompatibleType | None = None,
Expand Down Expand Up @@ -9976,7 +10200,9 @@ def pop(name, val):
result.update(items)
return result

def sub_(self, other: TensorDictBase | float, alpha: float | None = None):
def sub_(
self, other: TensorDictBase | torch.Tensor | float, alpha: float | None = None
):
"""In-place version of :meth:`~.sub`.

.. note::
Expand Down
18 changes: 15 additions & 3 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,30 @@ def __subclasscheck__(self, subclass):
_FALLBACK_METHOD_FROM_TD = [
"__abs__",
"__add__",
"__and__",
"__bool__",
"__eq__",
"__ge__",
"__gt__",
"__iadd__",
"__imul__",
"__invert__",
"__ipow__",
"__isub__",
"__itruediv__",
"__mul__",
"__ne__",
"__neg__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rmul__",
"__ror__",
"__rpow__",
"__rsub__",
"__rtruediv__",
"__rxor__",
"__sub__",
"__truediv__",
"__xor__",
Expand Down Expand Up @@ -228,6 +239,7 @@ def __subclasscheck__(self, subclass):
"atan_",
"auto_batch_size_",
"auto_device_",
"bitwise_and",
"ceil",
"ceil_",
"chunk",
Expand Down Expand Up @@ -291,10 +303,8 @@ def __subclasscheck__(self, subclass):
"log2",
"log2_",
"log_",
"map",
"logical_and" "map",
"map_iter",
"to_namedtuple",
"to_pytree",
"masked_fill",
"masked_fill_",
"max",
Expand Down Expand Up @@ -356,6 +366,8 @@ def __subclasscheck__(self, subclass):
"tanh_",
"to",
"to_module",
"to_namedtuple",
"to_pytree",
"transpose",
"trunc",
"trunc_",
Expand Down
Loading

0 comments on commit 6d56dc7

Please sign in to comment.