Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 9, 2023
1 parent 49c1f9f commit 15e6950
Showing 1 changed file with 60 additions and 7 deletions.
67 changes: 60 additions & 7 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,28 @@ def apply(
>>> assert (td_2["a"] == -2).all()
>>> assert (td_2["b", "c"] == 2).all()
"""
return self._apply(
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
inplace=inplace,
checked=False,
**constructor_kwargs,
)

def _apply(
self,
fn: Callable,
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
checked: bool = False,
**constructor_kwargs,
) -> T:
if inplace:
out = self
elif batch_size is not None:
Expand All @@ -1596,14 +1618,15 @@ def apply(
out.unlock_()

for key, item in self.items():
_others = [_other.get(key) for _other in others]
_others = [_other._get_str(key, default=NO_DEFAULT) for _other in others]
if _is_tensor_collection(item.__class__):
item_trsf = item.apply(
item_trsf = item._apply(
fn,
*_others,
inplace=inplace,
batch_size=batch_size,
device=device,
checked=checked,
**constructor_kwargs,
)
else:
Expand All @@ -1618,13 +1641,41 @@ def apply(
key,
item_trsf,
inplace=BEST_ATTEMPT_INPLACE if inplace else False,
validated=False,
validated=checked,
)

if not inplace and is_locked:
out.lock_()
return out

def _fast_apply(
self,
fn: Callable,
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
**constructor_kwargs,
) -> T:
"""A faster apply method.
This method does not run any check after performing the func. This
means that one to make sure that the metadata of the resulting tensors
(device, shape etc.) match the :meth:`~.apply` ones.
"""
return self._apply(
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
inplace=inplace,
checked=True,
**constructor_kwargs,
)

def map(
self,
fn: Callable,
Expand Down Expand Up @@ -4736,7 +4787,7 @@ def to(tensor):
if device is not None or dtype is not None:
apply_kwargs["device"] = device
apply_kwargs["batch_size"] = batch_size
result = result.apply(to, **apply_kwargs)
result = result._fast_apply(to, **apply_kwargs)
elif batch_size is not None:
result.batch_size = batch_size
return result
Expand Down Expand Up @@ -7192,14 +7243,15 @@ def apply_(self, fn: Callable, *others):
td.apply_(fn, *[other[idx] for other in others])
return self

def apply(
def _apply(
self,
fn: Callable,
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
checked: bool = False,
**constructor_kwargs,
) -> T:
if inplace:
Expand All @@ -7210,18 +7262,19 @@ def apply(
return self.apply_(fn, *others)
else:
if batch_size is not None:
return super().apply(
return super()._apply(
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
checked=checked,
**constructor_kwargs,
)
others = (other.unbind(self.stack_dim) for other in others)
out = LazyStackedTensorDict(
*(
td.apply(fn, *oth, device=device)
td._apply(fn, *oth, checked=checked, device=device)
for td, *oth in zip(self.tensordicts, *others)
),
stack_dim=self.stack_dim,
Expand Down

0 comments on commit 15e6950

Please sign in to comment.