Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 27, 2024
1 parent f400102 commit e193148
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
6 changes: 3 additions & 3 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def _apply_nest(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
complete_names: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
**constructor_kwargs,
) -> T:
Expand All @@ -1369,7 +1369,7 @@ def _apply_nest(
call_on_nested=call_on_nested,
default=default,
named=named,
complete_names=complete_names,
nested_keys=nested_keys,
prefix=prefix,
**constructor_kwargs,
)
Expand All @@ -1384,7 +1384,7 @@ def _apply_nest(
call_on_nested=call_on_nested,
default=default,
named=named,
complete_names=complete_names,
nested_keys=nested_keys,
prefix=prefix + (i,),
)
for i, (td, *oth) in enumerate(zip(self.tensordicts, *others))
Expand Down
6 changes: 3 additions & 3 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def _apply_nest(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
complete_names: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
**constructor_kwargs,
) -> T:
Expand Down Expand Up @@ -681,15 +681,15 @@ def _apply_nest(
device=device,
checked=checked,
named=named,
complete_names=complete_names,
nested_keys=nested_keys,
default=default,
prefix=prefix + (key,),
**constructor_kwargs,
)
else:
_others = [_other._get_str(key, default=default) for _other in others]
if named:
if complete_names:
if nested_keys:
item_trsf = fn(unravel_key(prefix + (key,)), item, *_others)
else:
item_trsf = fn(key, item, *_others)
Expand Down
15 changes: 8 additions & 7 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3618,7 +3618,7 @@ def named_apply(
self,
fn: Callable,
*others: T,
complete_names: bool = False,
nested_keys: bool = False,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
Expand All @@ -3640,8 +3640,9 @@ def named_apply(
unnamed inputs as the number of tensordicts, including self.
If other tensordicts have missing entries, a default value
can be passed through the ``default`` keyword argument.
complete_names (bool, optional): if ``True``, the complete path
to the leaf will be used. Defaults to ``False``.
nested_keys (bool, optional): if ``True``, the complete path
to the leaf will be used. Defaults to ``False``, i.e. only the last
string is passed to the function.
batch_size (sequence of int, optional): if provided,
the resulting TensorDict will have the desired batch_size.
The :obj:`batch_size` argument should match the batch_size after
Expand Down Expand Up @@ -3737,7 +3738,7 @@ def named_apply(
checked=False,
default=default,
named=True,
complete_names=complete_names,
nested_keys=nested_keys,
**constructor_kwargs,
)

Expand All @@ -3754,7 +3755,7 @@ def _apply_nest(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
complete_names: bool = False,
nested_keys: bool = False,
prefix: tuple = (),
**constructor_kwargs,
) -> T:
Expand All @@ -3771,7 +3772,7 @@ def _fast_apply(
call_on_nested: bool = False,
default: Any = NO_DEFAULT,
named: bool = False,
complete_names: bool = False,
nested_keys: bool = False,
**constructor_kwargs,
) -> T:
"""A faster apply method.
Expand All @@ -3792,7 +3793,7 @@ def _fast_apply(
call_on_nested=call_on_nested,
named=named,
default=default,
complete_names=complete_names,
nested_keys=nested_keys,
**constructor_kwargs,
)

Expand Down
4 changes: 2 additions & 2 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,10 +3027,10 @@ def count(name, value, keys):
keys.add(name)

td.named_apply(
functools.partial(count, keys=keys_complete), complete_names=True
functools.partial(count, keys=keys_complete), nested_keys=True
)
td.named_apply(
functools.partial(count, keys=keys_not_complete), complete_names=False
functools.partial(count, keys=keys_not_complete), nested_keys=False
)
assert len(keys_complete) == len(list(td.keys(True, True)))
assert len(keys_complete) > len(keys_not_complete)
Expand Down

0 comments on commit e193148

Please sign in to comment.