Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 22, 2024
2 parents b0e9b81 + df3e231 commit 7c999e5
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 59 deletions.
23 changes: 12 additions & 11 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,17 +338,17 @@ def from_dict(
stack_dim_name=None,
stack_dim=0,
):
if batch_size is not None:
batch_size = list(batch_size)
if stack_dim is None:
stack_dim = 0
n = batch_size.pop(stack_dim)
if n != len(input_dict):
raise ValueError(
"The number of dicts and the corresponding batch-size must match, "
f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}."
)
batch_size = torch.Size(batch_size)
# if batch_size is not None:
# batch_size = list(batch_size)
# if stack_dim is None:
# stack_dim = 0
# n = batch_size.pop(stack_dim)
# if n != len(input_dict):
# raise ValueError(
# "The number of dicts and the corresponding batch-size must match, "
# f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}."
# )
# batch_size = torch.Size(batch_size)
return LazyStackedTensorDict(
*(
TensorDict.from_dict(
Expand All @@ -357,6 +357,7 @@ def from_dict(
auto_batch_size=auto_batch_size,
device=device,
batch_dims=batch_dims,
batch_size=batch_size,
)
for i in range(len(input_dict))
),
Expand Down
1 change: 1 addition & 0 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def from_metadata(metadata=metadata, prefix=None):
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
)
print('cls_metadata', cls_metadata)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
Expand Down
48 changes: 29 additions & 19 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def __ne__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other != self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand All @@ -639,7 +639,7 @@ def __xor__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other ^ self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand All @@ -663,7 +663,7 @@ def __or__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other | self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand All @@ -687,7 +687,7 @@ def __eq__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other == self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand All @@ -709,7 +709,7 @@ def __ge__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other <= self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand All @@ -731,7 +731,7 @@ def __gt__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other < self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand All @@ -753,7 +753,7 @@ def __le__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other >= self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand All @@ -775,7 +775,7 @@ def __lt__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other > self
if isinstance(other, (dict,)):
other = self.from_dict_instance(other)
other = self.from_dict_instance(other, auto_batch_size=False)
if _is_tensor_collection(type(other)):
keys1 = set(self.keys())
keys2 = set(other.keys())
Expand Down Expand Up @@ -2019,7 +2019,7 @@ def from_dict(
names=names,
)
if batch_size is None:
if auto_batch_size is None:
if auto_batch_size is None and batch_dims is None:
warn(
"The batch-size was not provided and auto_batch_size isn't set either. "
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
Expand All @@ -2028,6 +2028,8 @@ def from_dict(
category=DeprecationWarning,
)
auto_batch_size = True
elif auto_batch_size is None:
auto_batch_size = True
if auto_batch_size:
_set_max_batch_size(out, batch_dims)
else:
Expand Down Expand Up @@ -2099,23 +2101,26 @@ def from_dict_instance(
# TODO: v0.7: remove the None
cur_value = self.get(key, None)
if cur_value is not None:
print(type(cur_value))
input_dict[key] = cur_value.from_dict_instance(
value,
device=device,
auto_batch_size=auto_batch_size,
auto_batch_size=False,
)
print(type(cur_value), type(input_dict[key]))
continue
# we don't know if another tensor of smaller size is coming
# so we can't be sure that the batch-size will still be valid later
input_dict[key] = TensorDict.from_dict(
value,
device=device,
auto_batch_size=auto_batch_size,
)
else:
# we don't know if another tensor of smaller size is coming
# so we can't be sure that the batch-size will still be valid later
input_dict[key] = TensorDict.from_dict(
value,
device=device,
auto_batch_size=False,
)
else:
input_dict[key] = TensorDict.from_any(
value,
auto_batch_size=auto_batch_size,
auto_batch_size=False,
)

out = TensorDict.from_dict(
Expand All @@ -2125,7 +2130,7 @@ def from_dict_instance(
names=names,
)
if batch_size is None:
if auto_batch_size is None:
if auto_batch_size is None and batch_dims is None:
warn(
"The batch-size was not provided and auto_batch_size isn't set either. "
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
Expand All @@ -2134,8 +2139,13 @@ def from_dict_instance(
category=DeprecationWarning,
)
auto_batch_size = True
elif auto_batch_size is None:
auto_batch_size = True
if auto_batch_size:
print('self', self)
print('out', out)
_set_max_batch_size(out, batch_dims)
print('out', out)
else:
out.batch_size = batch_size
return out
Expand Down
5 changes: 5 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,13 +1216,16 @@ def _from_dict_validated(cls, *args, **kwargs):
By default, falls back on :meth:`~.from_dict`.
"""
kwargs.setdefault("auto_batch_size", True)
print('kwargs', kwargs)
return cls.from_dict(*args, **kwargs)

@abc.abstractmethod
def from_dict_instance(
self,
input_dict,
*others,
auto_batch_size: bool | None=None,
batch_size=None,
device=None,
batch_dims=None,
Expand Down Expand Up @@ -9866,6 +9869,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False):
- h5 objects through :meth:`~.from_h5`
"""
if is_tensor_collection(obj):
return obj
if isinstance(obj, dict):
return cls.from_dict(obj, auto_batch_size=auto_batch_size)
if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"):
Expand Down
5 changes: 4 additions & 1 deletion tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def make_tensordict(
input_dict: dict[str, CompatibleType] | None = None,
batch_size: Sequence[int] | torch.Size | int | None = None,
device: DeviceType | None = None,
auto_batch_size:bool|None=None,
**kwargs: CompatibleType, # source
) -> TensorDict:
"""Returns a TensorDict created from the keyword arguments or an input dictionary.
Expand All @@ -453,6 +454,8 @@ def make_tensordict(
(incompatible with nested keys).
batch_size (iterable of int, optional): a batch size for the tensordict.
device (torch.device or compatible type, optional): a device for the TensorDict.
auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically.
Defaults to ``False``.
Examples:
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
Expand Down Expand Up @@ -500,4 +503,4 @@ def make_tensordict(
"""
if input_dict is not None:
kwargs.update(input_dict)
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device, auto_batch_size=auto_batch_size)
30 changes: 22 additions & 8 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from textwrap import indent

from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar
from warnings import warn

import numpy as np
import orjson as json
Expand Down Expand Up @@ -632,8 +633,9 @@ def __torch_function__(

_is_non_tensor = getattr(cls, "_is_non_tensor", False)

if not dataclasses.is_dataclass(cls):
cls = dataclass(cls, frozen=frozen)
# Breaks some tests, don't do that:
# if not dataclasses.is_dataclass(cls):
cls = dataclass(cls, frozen=frozen)
_TENSORCLASS_MEMO[cls] = True

expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__)
Expand Down Expand Up @@ -1367,7 +1369,7 @@ def _update(
non_blocking: bool = False,
):
if isinstance(input_dict_or_td, dict):
input_dict_or_td = self.from_dict(input_dict_or_td)
input_dict_or_td = self.from_dict(input_dict_or_td, auto_batch_size=False)

if is_tensorclass(input_dict_or_td):
non_tensordict = {
Expand Down Expand Up @@ -1579,7 +1581,7 @@ def _to_dict(self, *, retain_none: bool = True) -> dict:
return td_dict


def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
def _from_dict(cls, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None):
# we pass through a tensordict because keys could be passed as NestedKeys
# We can't assume all keys are strings, otherwise calling cls(**kwargs)
# would work ok
Expand All @@ -1593,15 +1595,15 @@ def _from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None):
non_tensordict=input_dict,
)
td = TensorDict.from_dict(
input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims
input_dict, batch_size=batch_size, device=device, batch_dims=batch_dims, auto_batch_size=auto_batch_size
)
non_tensordict = {}

return cls.from_tensordict(tensordict=td, non_tensordict=non_tensordict)


def _from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None
self, input_dict, *, auto_batch_size:bool|None=None, batch_size=None, device=None, batch_dims=None
):
if batch_dims is not None and batch_size is not None:
raise ValueError("Cannot pass both batch_size and batch_dims to `from_dict`.")
Expand All @@ -1611,7 +1613,7 @@ def _from_dict_instance(
# TODO: this is a bit slow and will be a bottleneck every time td[idx] = dict(subtd)
# is called when there are non tensor data in it
if not _is_tensor_collection(type(input_dict)):
input_tdict = TensorDict.from_dict(input_dict)
input_tdict = TensorDict.from_dict(input_dict, auto_batch_size=auto_batch_size)
else:
input_tdict = input_dict
trsf_dict = {}
Expand Down Expand Up @@ -1639,7 +1641,19 @@ def _from_dict_instance(
)
# check that
if batch_size is None:
out._tensordict.auto_batch_size_()
if auto_batch_size is None and batch_dims is None:
warn(
"The batch-size was not provided and auto_batch_size isn't set either. "
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
"will be changed in v0.8 and auto_batch_size will be False onward. "
"To silence this warning, pass auto_batch_size directly.",
category=DeprecationWarning,
)
auto_batch_size = True
elif auto_batch_size is None:
auto_batch_size = True
if auto_batch_size:
out.auto_batch_size_()
return out


Expand Down
Loading

0 comments on commit 7c999e5

Please sign in to comment.