diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index c3024e0f0..74c011fce 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1345,7 +1345,7 @@ def _apply_nest( call_on_nested: bool = False, default: Any = NO_DEFAULT, named: bool = False, - complete_names: bool = False, + nested_keys: bool = False, prefix: tuple = (), **constructor_kwargs, ) -> T: @@ -1369,7 +1369,7 @@ def _apply_nest( call_on_nested=call_on_nested, default=default, named=named, - complete_names=complete_names, + nested_keys=nested_keys, prefix=prefix, **constructor_kwargs, ) @@ -1384,7 +1384,7 @@ def _apply_nest( call_on_nested=call_on_nested, default=default, named=named, - complete_names=complete_names, + nested_keys=nested_keys, prefix=prefix + (i,), ) for i, (td, *oth) in enumerate(zip(self.tensordicts, *others)) diff --git a/tensordict/_td.py b/tensordict/_td.py index 830f29117..7e975be00 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -632,7 +632,7 @@ def _apply_nest( call_on_nested: bool = False, default: Any = NO_DEFAULT, named: bool = False, - complete_names: bool = False, + nested_keys: bool = False, prefix: tuple = (), **constructor_kwargs, ) -> T: @@ -681,7 +681,7 @@ def _apply_nest( device=device, checked=checked, named=named, - complete_names=complete_names, + nested_keys=nested_keys, default=default, prefix=prefix + (key,), **constructor_kwargs, @@ -689,7 +689,7 @@ def _apply_nest( else: _others = [_other._get_str(key, default=default) for _other in others] if named: - if complete_names: + if nested_keys: item_trsf = fn(unravel_key(prefix + (key,)), item, *_others) else: item_trsf = fn(key, item, *_others) diff --git a/tensordict/base.py b/tensordict/base.py index 4d76de5ec..18b597a5c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3618,7 +3618,7 @@ def named_apply( self, fn: Callable, *others: T, - complete_names: bool = False, + nested_keys: bool = False, batch_size: Sequence[int] | None = None, device: torch.device | None = None, names: Sequence[str] | None = None, @@ -3640,8 +3640,9 @@ def named_apply( unnamed inputs as the number of tensordicts, including self. If other tensordicts have missing entries, a default value can be passed through the ``default`` keyword argument. - complete_names (bool, optional): if ``True``, the complete path - to the leaf will be used. Defaults to ``False``. + nested_keys (bool, optional): if ``True``, the complete path + to the leaf will be used. Defaults to ``False``, i.e. only the last + string is passed to the function. batch_size (sequence of int, optional): if provided, the resulting TensorDict will have the desired batch_size. The :obj:`batch_size` argument should match the batch_size after @@ -3737,7 +3738,7 @@ def named_apply( checked=False, default=default, named=True, - complete_names=complete_names, + nested_keys=nested_keys, **constructor_kwargs, ) @@ -3754,7 +3755,7 @@ def _apply_nest( call_on_nested: bool = False, default: Any = NO_DEFAULT, named: bool = False, - complete_names: bool = False, + nested_keys: bool = False, prefix: tuple = (), **constructor_kwargs, ) -> T: @@ -3771,7 +3772,7 @@ def _fast_apply( call_on_nested: bool = False, default: Any = NO_DEFAULT, named: bool = False, - complete_names: bool = False, + nested_keys: bool = False, **constructor_kwargs, ) -> T: """A faster apply method. @@ -3792,7 +3793,7 @@ def _fast_apply( call_on_nested=call_on_nested, named=named, default=default, - complete_names=complete_names, + nested_keys=nested_keys, **constructor_kwargs, ) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index d2241a9fe..8ae08a69f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3027,10 +3027,10 @@ def count(name, value, keys): keys.add(name) td.named_apply( - functools.partial(count, keys=keys_complete), complete_names=True + functools.partial(count, keys=keys_complete), nested_keys=True ) td.named_apply( - functools.partial(count, keys=keys_not_complete), complete_names=False + functools.partial(count, keys=keys_not_complete), nested_keys=False ) assert len(keys_complete) == len(list(td.keys(True, True))) assert len(keys_complete) > len(keys_not_complete)