Skip to content

Commit

Permalink
[Minor] NestedKey typing issues (#640)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 26, 2024
1 parent a5a7ab5 commit c72d500
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 49 deletions.
24 changes: 11 additions & 13 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def _stack_onto_(
def _get_str(
self,
key: NestedKey,
default: str | CompatibleType = NO_DEFAULT,
default: Any = NO_DEFAULT,
) -> CompatibleType:
# we can handle the case where the key is a tuple of length 1
tensors = []
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim):
def get_nestedtensor(
self,
key: NestedKey,
default: str | CompatibleType = NO_DEFAULT,
default: Any = NO_DEFAULT,
) -> CompatibleType:
"""Returns a nested tensor when stacking cannot be achieved.
Expand Down Expand Up @@ -1386,7 +1386,7 @@ def _apply_nest(

def _select(
self,
*keys: str,
*keys: NestedKey,
inplace: bool = False,
strict: bool = False,
set_shared: bool = True,
Expand All @@ -1402,7 +1402,7 @@ def _select(
return result

def _exclude(
self, *keys: str, inplace: bool = False, set_shared: bool = True
self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
) -> LazyStackedTensorDict:
tensordicts = [
tensordict._exclude(*keys, inplace=inplace, set_shared=set_shared)
Expand Down Expand Up @@ -1840,9 +1840,7 @@ def del_(self, key: NestedKey, **kwargs: Any) -> T:
raise error
return self

def pop(
self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType:
# using try/except for get/del is suboptimal, but
# this is faster that checkink if key in self keys
key = _unravel_key_to_tuple(key)
Expand Down Expand Up @@ -2071,7 +2069,9 @@ def update_at_(
)
return self

def rename_key_(self, old_key: str, new_key: str, safe: bool = False) -> T:
def rename_key_(
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
) -> T:
for td in self.tensordicts:
td.rename_key_(old_key, new_key, safe=safe)
return self
Expand Down Expand Up @@ -2627,7 +2627,7 @@ def keys(

def _select(
self,
*keys: str,
*keys: NestedKey,
inplace: bool = False,
strict: bool = True,
set_shared: bool = True,
Expand All @@ -2639,7 +2639,7 @@ def _select(
)

def _exclude(
self, *keys: str, inplace: bool = False, set_shared: bool = True
self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
) -> _CustomOpTensorDict:
if inplace:
raise RuntimeError("Cannot call exclude inplace on a lazy tensordict.")
Expand Down Expand Up @@ -2676,7 +2676,7 @@ def contiguous(self) -> T:
return self._fast_apply(lambda x: x.contiguous())

def rename_key_(
self, old_key: str, new_key: str, safe: bool = False
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
) -> _CustomOpTensorDict:
self._source.rename_key_(old_key, new_key, safe=safe)
return self
Expand Down Expand Up @@ -2998,7 +2998,6 @@ def _legacy_unsqueeze(self, dim: int) -> T:

def _stack_onto_(
self,
# key: str,
list_item: list[CompatibleType],
dim: int,
) -> T:
Expand Down Expand Up @@ -3224,7 +3223,6 @@ def _update_inv_op_kwargs(self, tensor: Tensor) -> dict[str, Any]:

def _stack_onto_(
self,
# key: str,
list_item: list[CompatibleType],
dim: int,
) -> T:
Expand Down
20 changes: 13 additions & 7 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,9 @@ def del_(self, key: NestedKey) -> T:
return self

@lock_blocked
def rename_key_(self, old_key: str, new_key: str, safe: bool = False) -> T:
def rename_key_(
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
) -> T:
# these checks are not perfect, tuples that are not tuples of strings or empty
# tuples could go through but (1) it will raise an error anyway and (2)
# those checks are expensive when repeated often.
Expand Down Expand Up @@ -1951,7 +1953,9 @@ def _select(
# self._maybe_set_shared_attributes(result)
return result

def _exclude(self, *keys: str, inplace: bool = False, set_shared: bool = True) -> T:
def _exclude(
self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
) -> T:
# faster than Base.exclude
if not len(keys):
return self.copy() if not inplace else self
Expand Down Expand Up @@ -2184,7 +2188,7 @@ def device(self) -> None | torch.device:
def device(self, value: DeviceType) -> None:
self._source.device = value

def _preallocate(self, key: str, value: CompatibleType) -> T:
def _preallocate(self, key: NestedKey, value: CompatibleType) -> T:
return self._source.set(key, value)

def _convert_inplace(self, inplace, key):
Expand Down Expand Up @@ -2570,7 +2574,7 @@ def contiguous(self) -> T:

def _select(
self,
*keys: str,
*keys: NestedKey,
inplace: bool = False,
strict: bool = True,
set_shared: bool = True,
Expand All @@ -2581,7 +2585,9 @@ def _select(
*keys, inplace=False, strict=strict, set_shared=set_shared
)

def _exclude(self, *keys: str, inplace: bool = False, set_shared: bool = True) -> T:
def _exclude(
self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
) -> T:
if inplace:
raise RuntimeError("Cannot call exclude inplace on a lazy tensordict.")
return self.to_tensordict()._exclude(
Expand All @@ -2604,7 +2610,7 @@ def is_memmap(self) -> bool:
return self._source.is_memmap()

def rename_key_(
self, old_key: str, new_key: str, safe: bool = False
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
) -> _SubTensorDict:
self._source.rename_key_(old_key, new_key, safe=safe)
return self
Expand Down Expand Up @@ -2880,7 +2886,7 @@ def _iter_helper(
if not self.leaves_only or is_leaf:
yield full_key

def _combine_keys(self, prefix: tuple | None, key: str) -> tuple:
def _combine_keys(self, prefix: tuple | None, key: NestedKey) -> tuple:
if prefix is not None:
return prefix + (key,)
return (key,)
Expand Down
22 changes: 10 additions & 12 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2213,7 +2213,7 @@ def _stack_onto_(

def _stack_onto_at_(
self,
key: str,
key: NestedKey,
list_item: list[CompatibleType],
dim: int,
idx: IndexType,
Expand All @@ -2225,9 +2225,7 @@ def _stack_onto_at_(
"before calling __getindex__ and stack."
)

def _default_get(
self, key: str, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
def _default_get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType:
if default is not NO_DEFAULT:
return default
else:
Expand All @@ -2236,9 +2234,7 @@ def _default_get(
_KEY_ERROR.format(key, self.__class__.__name__, sorted(self.keys()))
)

def get(
self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
def get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType:
"""Gets the value stored with the input key.
Args:
Expand Down Expand Up @@ -2952,7 +2948,9 @@ def unflatten(tensor):
return out

@abc.abstractmethod
def rename_key_(self, old_key: str, new_key: str, safe: bool = False) -> T:
def rename_key_(
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
) -> T:
"""Renames a key with a new string and returns the same tensordict with the updated key name.
Args:
Expand Down Expand Up @@ -4125,7 +4123,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self

# Clone, select, exclude, empty
def select(self, *keys: str, inplace: bool = False, strict: bool = True) -> T:
def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) -> T:
"""Selects the keys of the tensordict and returns a new tensordict with only the selected keys.
The values are not copied: in-place modifications a tensor of either
Expand Down Expand Up @@ -4188,14 +4186,14 @@ def select(self, *keys: str, inplace: bool = False, strict: bool = True) -> T:
@abc.abstractmethod
def _select(
self,
*keys: str,
*keys: NestedKey,
inplace: bool = False,
strict: bool = True,
set_shared: bool = True,
) -> T:
...

def exclude(self, *keys: str, inplace: bool = False) -> T:
def exclude(self, *keys: NestedKey, inplace: bool = False) -> T:
"""Excludes the keys of the tensordict and returns a new tensordict without these entries.
The values are not copied: in-place modifications a tensor of either
Expand Down Expand Up @@ -4242,7 +4240,7 @@ def exclude(self, *keys: str, inplace: bool = False) -> T:
@abc.abstractmethod
def _exclude(
self,
*keys: str,
*keys: NestedKey,
inplace: bool = False,
set_shared: bool = True,
) -> T:
Expand Down
14 changes: 5 additions & 9 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,15 +427,13 @@ def update(

@lock_blocked
@_unlock_and_set
def pop(
self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType:
...

@lock_blocked
@_unlock_and_set
def rename_key_(
self, old_key: str, new_key: str, safe: bool = False
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
) -> TensorDictBase:
...

Expand Down Expand Up @@ -488,9 +486,7 @@ def _apply_nest(*args, **kwargs):

@_get_post_hook
@_fallback
def get(
self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
def get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType:
...

@_get_post_hook
Expand Down Expand Up @@ -815,7 +811,7 @@ def unflatten_keys(

@_unlock_and_set(inplace=True)
def _exclude(
self, *keys: str, inplace: bool = False, set_shared: bool = True
self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
) -> TensorDictBase:
...

Expand Down Expand Up @@ -1054,7 +1050,7 @@ def _stack_onto_(
@_apply_on_data
def _stack_onto_at_(
self,
key: str,
key: NestedKey,
list_item: list[CompatibleType],
dim: int,
idx: IndexType,
Expand Down
14 changes: 7 additions & 7 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def get(self, key, default=NO_DEFAULT):
_get_tuple = get

def get_at(
self, key: str, idx: IndexType, default: CompatibleType = NO_DEFAULT
self, key: NestedKey, idx: IndexType, default: CompatibleType = NO_DEFAULT
) -> CompatibleType:
array = self._get_array(key, default)
if isinstance(array, (h5py.Dataset,)):
Expand Down Expand Up @@ -705,7 +705,7 @@ def map(
return out

def rename_key_(
self, old_key: str, new_key: str, safe: bool = False
self, old_key: NestedKey, new_key: NestedKey, safe: bool = False
) -> PersistentTensorDict:
old_key = self._process_key(old_key)
new_key = self._process_key(new_key)
Expand All @@ -715,7 +715,7 @@ def rename_key_(
raise KeyError(f"key {new_key} already present in TensorDict.") from err
return self

def fill_(self, key: str, value: float | bool) -> TensorDictBase:
def fill_(self, key: NestedKey, value: float | bool) -> TensorDictBase:
"""Fills a tensor pointed by the key with the a given value.
Args:
Expand All @@ -742,15 +742,15 @@ def _create_nested_str(self, key):
return target_td

def _select(
self, *keys: str, inplace: bool = False, strict: bool = True
self, *keys: NestedKey, inplace: bool = False, strict: bool = True
) -> PersistentTensorDict:
raise NotImplementedError(
"Cannot call select on a PersistentTensorDict. "
"Create a regular tensordict first using the `to_tensordict` method."
)

def _exclude(
self, *keys: str, inplace: bool = False, set_shared: bool = True
self, *keys: NestedKey, inplace: bool = False, set_shared: bool = True
) -> PersistentTensorDict:
raise NotImplementedError(
"Cannot call exclude on a PersistentTensorDict. "
Expand Down Expand Up @@ -817,8 +817,8 @@ def _to_numpy(self, value):

def _set(
self,
key: str,
value,
key: NestedKey,
value: Any,
inplace: bool = False,
idx=None,
validated=False,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ def _get_repr_custom(cls, shape, device, dtype, is_shared) -> str:
return f"{cls.__name__}({s})"


def _make_repr(key: str, item, tensordict: T) -> str:
def _make_repr(key: NestedKey, item, tensordict: T) -> str:
from tensordict.base import _is_tensor_collection

if _is_tensor_collection(type(item)):
Expand Down

0 comments on commit c72d500

Please sign in to comment.