From 15e6950cf11cd4390100f7c0b254923b4ab2c170 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 9 Oct 2023 14:43:20 +0100 Subject: [PATCH] amend --- tensordict/tensordict.py | 67 +++++++++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index b3618a854..75b705acd 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1570,6 +1570,28 @@ def apply( >>> assert (td_2["a"] == -2).all() >>> assert (td_2["b", "c"] == 2).all() """ + return self._apply( + fn, + *others, + batch_size=batch_size, + device=device, + names=names, + inplace=inplace, + checked=False, + **constructor_kwargs, + ) + + def _apply( + self, + fn: Callable, + *others: T, + batch_size: Sequence[int] | None = None, + device: torch.device | None = None, + names: Sequence[str] | None = None, + inplace: bool = False, + checked: bool = False, + **constructor_kwargs, + ) -> T: if inplace: out = self elif batch_size is not None: @@ -1596,14 +1618,15 @@ def apply( out.unlock_() for key, item in self.items(): - _others = [_other.get(key) for _other in others] + _others = [_other._get_str(key, default=NO_DEFAULT) for _other in others] if _is_tensor_collection(item.__class__): - item_trsf = item.apply( + item_trsf = item._apply( fn, *_others, inplace=inplace, batch_size=batch_size, device=device, + checked=checked, **constructor_kwargs, ) else: @@ -1618,13 +1641,41 @@ def apply( key, item_trsf, inplace=BEST_ATTEMPT_INPLACE if inplace else False, - validated=False, + validated=checked, ) if not inplace and is_locked: out.lock_() return out + def _fast_apply( + self, + fn: Callable, + *others: T, + batch_size: Sequence[int] | None = None, + device: torch.device | None = None, + names: Sequence[str] | None = None, + inplace: bool = False, + **constructor_kwargs, + ) -> T: + """A faster apply method. + + This method does not run any check after performing the func. This + means that one to make sure that the metadata of the resulting tensors + (device, shape etc.) match the :meth:`~.apply` ones. + + """ + return self._apply( + fn, + *others, + batch_size=batch_size, + device=device, + names=names, + inplace=inplace, + checked=True, + **constructor_kwargs, + ) + def map( self, fn: Callable, @@ -4736,7 +4787,7 @@ def to(tensor): if device is not None or dtype is not None: apply_kwargs["device"] = device apply_kwargs["batch_size"] = batch_size - result = result.apply(to, **apply_kwargs) + result = result._fast_apply(to, **apply_kwargs) elif batch_size is not None: result.batch_size = batch_size return result @@ -7192,7 +7243,7 @@ def apply_(self, fn: Callable, *others): td.apply_(fn, *[other[idx] for other in others]) return self - def apply( + def _apply( self, fn: Callable, *others: T, @@ -7200,6 +7251,7 @@ def apply( device: torch.device | None = None, names: Sequence[str] | None = None, inplace: bool = False, + checked: bool = False, **constructor_kwargs, ) -> T: if inplace: @@ -7210,18 +7262,19 @@ def apply( return self.apply_(fn, *others) else: if batch_size is not None: - return super().apply( + return super()._apply( fn, *others, batch_size=batch_size, device=device, names=names, + checked=checked, **constructor_kwargs, ) others = (other.unbind(self.stack_dim) for other in others) out = LazyStackedTensorDict( *( - td.apply(fn, *oth, device=device) + td._apply(fn, *oth, checked=checked, device=device) for td, *oth in zip(self.tensordicts, *others) ), stack_dim=self.stack_dim,