From ade3f599bb174dc05cc27fe32e40746152a57b5a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 11:16:17 +0100 Subject: [PATCH] init --- tensordict/base.py | 159 +++++++++++++++++++++++++++----------- tensordict/tensorclass.py | 3 + test/test_tensordict.py | 20 +++++ 3 files changed, 139 insertions(+), 43 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index ec7163f21..8978f2b7d 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -420,6 +420,79 @@ def any(self, dim: int = None) -> bool | TensorDictBase: """ ... + def isfinite(self) -> T: + """Returns a new tensordict with boolean elements representing if each element is finite or not. + + Real values are finite when they are not NaN, negative infinity, or infinity. Complex values are finite when both their real and imaginary parts are finite. + + """ + keys, vals = self._items_list(True, True) + vals = [val.isfinite() for val in vals] + items = dict(zip(keys, vals)) + return self._fast_apply( + lambda name, val: items.get(name, val), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + ) + + def isnan(self) -> T: + """Returns a new tensordict with boolean elements representing if each element of input is NaN or not. + + Complex values are considered NaN when either their real and/or imaginary part is NaN. + + """ + keys, vals = self._items_list(True, True) + vals = [val.isnan() for val in vals] + items = dict(zip(keys, vals)) + return self._fast_apply( + lambda name, val: items.get(name, val), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + ) + + def isneginf(self) -> T: + """Tests if each element of input is negative infinity or not.""" + keys, vals = self._items_list(True, True) + vals = [val.isneginf() for val in vals] + items = dict(zip(keys, vals)) + return self._fast_apply( + lambda name, val: items.get(name, val), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + ) + + def isposinf(self) -> T: + """Tests if each element of input is negative infinity or not.""" + keys, vals = self._items_list(True, True) + vals = [val.isposinf() for val in vals] + items = dict(zip(keys, vals)) + return self._fast_apply( + lambda name, val: items.get(name, val), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + ) + + def isreal(self) -> T: + """Returns a new tensordict with boolean elements representing if each element of input is real-valued or not.""" + keys, vals = self._items_list(True, True) + vals = [val.isreal() for val in vals] + items = dict(zip(keys, vals)) + return self._fast_apply( + lambda name, val: items.get(name, val), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + ) + def mean( self, dim: int | Tuple[int] = NO_DEFAULT, @@ -5338,7 +5411,7 @@ def abs(self) -> T: vals = torch._foreach_abs(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5354,7 +5427,7 @@ def acos(self) -> T: vals = torch._foreach_acos(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5370,7 +5443,7 @@ def exp(self) -> T: vals = torch._foreach_exp(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5386,7 +5459,7 @@ def neg(self) -> T: vals = torch._foreach_neg(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5402,7 +5475,7 @@ def reciprocal(self) -> T: vals = torch._foreach_reciprocal(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5418,7 +5491,7 @@ def sigmoid(self) -> T: vals = torch._foreach_sigmoid(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5434,7 +5507,7 @@ def sign(self) -> T: vals = torch._foreach_sign(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5450,7 +5523,7 @@ def sin(self) -> T: vals = torch._foreach_sin(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5466,7 +5539,7 @@ def sinh(self) -> T: vals = torch._foreach_sinh(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5482,7 +5555,7 @@ def tan(self) -> T: vals = torch._foreach_tan(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5498,7 +5571,7 @@ def tanh(self) -> T: vals = torch._foreach_tanh(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5514,7 +5587,7 @@ def trunc(self) -> T: vals = torch._foreach_trunc(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5537,7 +5610,7 @@ def norm( vals = torch._foreach_norm(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, batch_size=[], @@ -5554,7 +5627,7 @@ def norm( # noqa: F811 vals = torch._foreach_norm(vals, dtype=dtype) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, batch_size=[], @@ -5566,7 +5639,7 @@ def lgamma(self) -> T: vals = torch._foreach_lgamma(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5582,7 +5655,7 @@ def frac(self) -> T: vals = torch._foreach_frac(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5598,7 +5671,7 @@ def expm1(self) -> T: vals = torch._foreach_expm1(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5614,7 +5687,7 @@ def log(self) -> T: vals = torch._foreach_log(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5630,7 +5703,7 @@ def log10(self) -> T: vals = torch._foreach_log10(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5646,7 +5719,7 @@ def log1p(self) -> T: vals = torch._foreach_log1p(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5662,7 +5735,7 @@ def log2(self) -> T: vals = torch._foreach_log2(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5678,7 +5751,7 @@ def ceil(self) -> T: vals = torch._foreach_ceil(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5694,7 +5767,7 @@ def floor(self) -> T: vals = torch._foreach_floor(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5710,7 +5783,7 @@ def round(self) -> T: vals = torch._foreach_round(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5726,7 +5799,7 @@ def erf(self) -> T: vals = torch._foreach_erf(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5742,7 +5815,7 @@ def erfc(self) -> T: vals = torch._foreach_erfc(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5758,7 +5831,7 @@ def asin(self) -> T: vals = torch._foreach_asin(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5774,7 +5847,7 @@ def atan(self) -> T: vals = torch._foreach_atan(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5790,7 +5863,7 @@ def cos(self) -> T: vals = torch._foreach_cos(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5806,7 +5879,7 @@ def cosh(self) -> T: vals = torch._foreach_cosh(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5829,7 +5902,7 @@ def add(self, other: TensorDictBase | float, alpha: float | None = None): vals = torch._foreach_add(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5860,7 +5933,7 @@ def lerp(self, end: TensorDictBase | float, weight: TensorDictBase | float): vals = torch._foreach_lerp(vals, end_val, weight_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5892,7 +5965,7 @@ def addcdiv(self, other1, other2, value: float | None = 1): vals = torch._foreach_addcdiv(vals, other1_val, other2_val, value=value) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5926,7 +5999,7 @@ def addcmul(self, other1, other2, value: float | None = 1): vals = torch._foreach_addcmul(vals, other1_val, other2_val, value=value) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5959,7 +6032,7 @@ def sub(self, other: TensorDictBase | float, alpha: float | None = None): vals = torch._foreach_sub(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -5994,7 +6067,7 @@ def mul(self, other: TensorDictBase | float) -> T: vals = torch._foreach_mul(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -6018,7 +6091,7 @@ def maximum(self, other: TensorDictBase | float) -> T: vals = torch._foreach_maximum(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -6042,7 +6115,7 @@ def minimum(self, other: TensorDictBase | float) -> T: vals = torch._foreach_minimum(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -6066,7 +6139,7 @@ def clamp_max(self, other: TensorDictBase | float) -> T: vals = torch._foreach_clamp_max(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -6090,7 +6163,7 @@ def clamp_min(self, other: TensorDictBase | float) -> T: vals = torch._foreach_clamp_min(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -6114,7 +6187,7 @@ def pow(self, other: TensorDictBase | float) -> T: vals = torch._foreach_pow(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -6138,7 +6211,7 @@ def div(self, other: TensorDictBase | float) -> T: vals = torch._foreach_div(vals, other_val) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, @@ -6154,7 +6227,7 @@ def sqrt(self): vals = torch._foreach_sqrt(vals) items = dict(zip(keys, vals)) return self._fast_apply( - lambda name, val: items[name], + lambda name, val: items.get(name, val), named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 44296f347..bb3c67081 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -189,6 +189,9 @@ def __subclasscheck__(self, subclass): "is_memmap", "is_shared", "is_shared", + "isfinite", + "isnan", + "isreal", "items", "keys", "lerp", diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e7efbb742..af5e3035d 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3428,6 +3428,26 @@ def test_inferred_view_size(self, td_name, device): assert td.view(*new_shape) is td assert td.view(-1).view(*new_shape) is td + def test_isfinite(self, td_name, device): + td = getattr(self, td_name)(device) + assert td.isfinite().all() + + def test_isnan(self, td_name, device): + td = getattr(self, td_name)(device) + assert not td.isnan().any() + + def test_isreal(self, td_name, device): + td = getattr(self, td_name)(device) + assert td.isreal().all() + + def test_isposinf(self, td_name, device): + td = getattr(self, td_name)(device) + assert not td.isposinf().any() + + def test_isneginf(self, td_name, device): + td = getattr(self, td_name)(device) + assert not td.isneginf().any() + def test_items_values_keys(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device)