-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] TensorClass drop _tensordict argument in constructor #175
Changes from 2 commits
b5704d8
bcace64
7228d9b
78dc118
7f8d197
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -125,7 +125,7 @@ def __torch_function__( | |||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
cls.__init__ = _init_wrapper(cls.__init__, expected_keys) | ||||||||||||||||||||||||||||||
cls._build_from_tensordict = classmethod(_build_from_tensordict) | ||||||||||||||||||||||||||||||
cls.from_tensordict = classmethod(_from_tensordict) | ||||||||||||||||||||||||||||||
cls.__torch_function__ = classmethod(__torch_function__) | ||||||||||||||||||||||||||||||
cls.__getstate__ = _getstate | ||||||||||||||||||||||||||||||
cls.__setstate__ = _setstate | ||||||||||||||||||||||||||||||
|
@@ -165,65 +165,45 @@ def _init_wrapper(init, expected_keys): | |||||||||||||||||||||||||||||
required_params = [p.name for p in params[1:] if p.default is inspect._empty] | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
@functools.wraps(init) | ||||||||||||||||||||||||||||||
def wrapper(self, *args, batch_size=None, device=None, _tensordict=None, **kwargs): | ||||||||||||||||||||||||||||||
if (args or kwargs) and _tensordict is not None: | ||||||||||||||||||||||||||||||
raise ValueError("Cannot pass both args/kwargs and _tensordict.") | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if _tensordict is not None: | ||||||||||||||||||||||||||||||
if not all(key in expected_keys for key in _tensordict.keys()): | ||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||
f"Keys from the tensordict ({set(_tensordict.keys())}) must " | ||||||||||||||||||||||||||||||
f"correspond to the class attributes ({expected_keys})." | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
input_dict = {key: None for key in _tensordict.keys()} | ||||||||||||||||||||||||||||||
init(self, **input_dict) | ||||||||||||||||||||||||||||||
self.tensordict = _tensordict | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
for value, key in zip(args, self.__dataclass_fields__): | ||||||||||||||||||||||||||||||
if key in kwargs: | ||||||||||||||||||||||||||||||
raise ValueError(f"The key {key} is already set in kwargs") | ||||||||||||||||||||||||||||||
kwargs[key] = value | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
for key, field in self.__dataclass_fields__.items(): | ||||||||||||||||||||||||||||||
if field.default_factory is not dataclasses.MISSING: | ||||||||||||||||||||||||||||||
default = field.default_factory() | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
default = field.default | ||||||||||||||||||||||||||||||
if default not in (None, dataclasses.MISSING): | ||||||||||||||||||||||||||||||
kwargs.setdefault(key, default) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
missing_params = [p for p in required_params if p not in kwargs] | ||||||||||||||||||||||||||||||
if missing_params: | ||||||||||||||||||||||||||||||
n_missing = len(missing_params) | ||||||||||||||||||||||||||||||
raise TypeError( | ||||||||||||||||||||||||||||||
f"{self.__class__.__name__}.__init__() missing {n_missing} " | ||||||||||||||||||||||||||||||
f"required positional argument{'' if n_missing == 1 else 's'}: " | ||||||||||||||||||||||||||||||
f"""{", ".join(f"'{name}'" for name in missing_params)}""" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
def wrapper(self, *args, batch_size, device=None, **kwargs): | ||||||||||||||||||||||||||||||
for value, key in zip(args, self.__dataclass_fields__): | ||||||||||||||||||||||||||||||
if key in kwargs: | ||||||||||||||||||||||||||||||
raise ValueError(f"The key {key} is already set in kwargs") | ||||||||||||||||||||||||||||||
kwargs[key] = value | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
for key, field in self.__dataclass_fields__.items(): | ||||||||||||||||||||||||||||||
if field.default_factory is not dataclasses.MISSING: | ||||||||||||||||||||||||||||||
default = field.default_factory() | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
default = field.default | ||||||||||||||||||||||||||||||
if default not in (None, dataclasses.MISSING): | ||||||||||||||||||||||||||||||
kwargs.setdefault(key, default) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
self.tensordict = TensorDict({}, batch_size=batch_size, device=device) | ||||||||||||||||||||||||||||||
init( | ||||||||||||||||||||||||||||||
self, **{key: _get_typed_value(value) for key, value in kwargs.items()} | ||||||||||||||||||||||||||||||
missing_params = [p for p in required_params if p not in kwargs] | ||||||||||||||||||||||||||||||
if missing_params: | ||||||||||||||||||||||||||||||
n_missing = len(missing_params) | ||||||||||||||||||||||||||||||
raise TypeError( | ||||||||||||||||||||||||||||||
f"{self.__class__.__name__}.__init__() missing {n_missing} " | ||||||||||||||||||||||||||||||
f"required positional argument{'' if n_missing == 1 else 's'}: " | ||||||||||||||||||||||||||||||
f"""{", ".join(f"'{name}'" for name in missing_params)}""" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
self.tensordict = TensorDict({}, batch_size=batch_size, device=device) | ||||||||||||||||||||||||||||||
init(self, **{key: _get_typed_value(value) for key, value in kwargs.items()}) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
new_params = [ | ||||||||||||||||||||||||||||||
inspect.Parameter( | ||||||||||||||||||||||||||||||
"batch_size", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None | ||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||
inspect.Parameter( | ||||||||||||||||||||||||||||||
"device", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None | ||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||
inspect.Parameter( | ||||||||||||||||||||||||||||||
"_tensordict", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None | ||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||
inspect.Parameter("batch_size", inspect.Parameter.KEYWORD_ONLY), | ||||||||||||||||||||||||||||||
inspect.Parameter("device", inspect.Parameter.KEYWORD_ONLY, default=None), | ||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||
wrapper.__signature__ = init_sig.replace(parameters=params + new_params) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return wrapper | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _build_from_tensordict(cls, tensordict): | ||||||||||||||||||||||||||||||
return cls(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
def _from_tensordict(cls, tensordict): | ||||||||||||||||||||||||||||||
tc = cls(**tensordict, batch_size=tensordict.batch_size) | ||||||||||||||||||||||||||||||
vmoens marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||
tc.__dict__["tensordict"] = tensordict | ||||||||||||||||||||||||||||||
return tc | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I see where you're doing with this, just a bit awkward because currently the internals of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But that's the tensorclass' init method right? So that's effectively what I'm currently doing anyway with tc = cls(**tensordict, batch_size=tensordict.batch_size) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, my bad...in this PR we are still overwriting |
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _getstate(self): | ||||||||||||||||||||||||||||||
|
@@ -274,7 +254,7 @@ def _getattr(self, attr): | |||||||||||||||||||||||||||||
def wrapped_func(*args, **kwargs): | ||||||||||||||||||||||||||||||
res = func(*args, **kwargs) | ||||||||||||||||||||||||||||||
if isinstance(res, TensorDictBase): | ||||||||||||||||||||||||||||||
new = self.__class__(_tensordict=res) | ||||||||||||||||||||||||||||||
new = self.from_tensordict(res) | ||||||||||||||||||||||||||||||
return new | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
return res | ||||||||||||||||||||||||||||||
|
@@ -288,7 +268,7 @@ def _getitem(self, item): | |||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
raise ValueError("Invalid indexing arguments.") | ||||||||||||||||||||||||||||||
res = self.tensordict[item] | ||||||||||||||||||||||||||||||
return self.__class__(_tensordict=res) # device=res.device) | ||||||||||||||||||||||||||||||
return self.from_tensordict(res) # device=res.device) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _setitem(self, item, value): | ||||||||||||||||||||||||||||||
|
@@ -351,13 +331,13 @@ def _batch_size_setter(self, new_size: torch.Size) -> None: | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _unbind(tdc, dim): | ||||||||||||||||||||||||||||||
tensordicts = torch.unbind(tdc.tensordict, dim) | ||||||||||||||||||||||||||||||
out = [tdc.__class__(_tensordict=td) for td in tensordicts] | ||||||||||||||||||||||||||||||
out = [tdc.from_tensordict(td) for td in tensordicts] | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _full_like(tdc, fill_value): | ||||||||||||||||||||||||||||||
tensordict = torch.full_like(tdc.tensordict, fill_value) | ||||||||||||||||||||||||||||||
out = tdc.__class__(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
out = tdc.from_tensordict(tensordict) | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -371,43 +351,43 @@ def _ones_like(tdc): | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _clone(tdc): | ||||||||||||||||||||||||||||||
tensordict = torch.clone(tdc.tensordict) | ||||||||||||||||||||||||||||||
out = tdc.__class__(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
out = tdc.from_tensordict(tensordict) | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _squeeze(tdc): | ||||||||||||||||||||||||||||||
tensordict = torch.squeeze(tdc.tensordict) | ||||||||||||||||||||||||||||||
out = tdc.__class__(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
out = tdc.from_tensordict(tensordict) | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _unsqueeze(tdc, dim=0): | ||||||||||||||||||||||||||||||
tensordict = torch.unsqueeze(tdc.tensordict, dim) | ||||||||||||||||||||||||||||||
out = tdc.__class__(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
out = tdc.from_tensordict(tensordict) | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _permute(tdc, dims): | ||||||||||||||||||||||||||||||
tensordict = torch.permute(tdc.tensordict, dims) | ||||||||||||||||||||||||||||||
out = tdc.__class__(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
out = tdc.from_tensordict(tensordict) | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _split(tdc, split_size_or_sections, dim=0): | ||||||||||||||||||||||||||||||
tensordicts = torch.split(tdc.tensordict, split_size_or_sections, dim) | ||||||||||||||||||||||||||||||
out = [tdc.__class__(_tensordict=td) for td in tensordicts] | ||||||||||||||||||||||||||||||
out = [tdc.from_tensordict(td) for td in tensordicts] | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _stack(list_of_tdc, dim): | ||||||||||||||||||||||||||||||
tensordict = torch.stack([tdc.tensordict for tdc in list_of_tdc], dim) | ||||||||||||||||||||||||||||||
out = list_of_tdc[0].__class__(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
out = list_of_tdc[0].from_tensordict(tensordict) | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _cat(list_of_tdc, dim): | ||||||||||||||||||||||||||||||
tensordict = torch.cat([tdc.tensordict for tdc in list_of_tdc], dim) | ||||||||||||||||||||||||||||||
out = list_of_tdc[0].__class__(_tensordict=tensordict) | ||||||||||||||||||||||||||||||
out = list_of_tdc[0].from_tensordict(tensordict) | ||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -425,18 +405,18 @@ def _get_typed_output(out, expected_type): | |||||||||||||||||||||||||||||
# Otherwise, if the output is some TensorDictBase subclass, we check the type and if it | ||||||||||||||||||||||||||||||
# does not match, we map it. In all other cases, just return what has been gathered. | ||||||||||||||||||||||||||||||
if isinstance(expected_type, str) and expected_type in CLASSES_DICT: | ||||||||||||||||||||||||||||||
out = CLASSES_DICT[expected_type](_tensordict=out) | ||||||||||||||||||||||||||||||
out = CLASSES_DICT[expected_type].from_tensordict(out) | ||||||||||||||||||||||||||||||
elif ( | ||||||||||||||||||||||||||||||
isinstance(expected_type, type) | ||||||||||||||||||||||||||||||
and not isinstance(out, expected_type) | ||||||||||||||||||||||||||||||
and isinstance(out, TensorDictBase) | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
out = expected_type(_tensordict=out) | ||||||||||||||||||||||||||||||
out = expected_type.from_tensordict(out) | ||||||||||||||||||||||||||||||
elif isinstance(out, TensorDictBase): | ||||||||||||||||||||||||||||||
dest_dtype = _check_td_out_type(expected_type) | ||||||||||||||||||||||||||||||
if dest_dtype is not None: | ||||||||||||||||||||||||||||||
print(dest_dtype) | ||||||||||||||||||||||||||||||
out = dest_dtype(_tensordict=out) | ||||||||||||||||||||||||||||||
out = dest_dtype.from_tensordict(out) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adding a
copy
parameter to allow just storing the reference or deep copyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say a bit more about how you imagine this working?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see here.