-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] TensorClass __post_init__
support
#172
Conversation
Co-authored-by: Alessandro Pietro Bardelli <[email protected]>
Co-authored-by: Alessandro Pietro Bardelli <[email protected]>
Co-authored-by: Alessandro Pietro Bardelli <[email protected]>
def test_default(): | ||
@tensorclass | ||
class MyData: | ||
X: torch.Tensor = None # TODO: do we want to allow any default, say an integer? | ||
y: torch.Tensor = torch.ones(3, 4, 5) | ||
|
||
data = MyData(batch_size=[3, 4]) | ||
assert data.__dict__["y"] is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was ultimately an implementation detail that changed and caused the tests to fail. "y"
no longer exists in data.__dict__
.
tensordict/prototype/tensorclass.py
Outdated
@@ -238,7 +230,6 @@ def wrapper(self, key, value): | |||
if type(value) in CLASSES_DICT.values(): | |||
value = value.__dict__["tensordict"] | |||
self.__dict__["tensordict"][key] = value | |||
assert self.__dict__["tensordict"][key] is value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check was failing both for None
values, but also tensors that were moved to a new device when assigned to the tensordict, so I decided to just delete it. lmk though if you think we need to retain some kind of input validation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you. we can drop that check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- we don't want
assert
in the code base, these should be kept for tests only
@tcbegley sorry, but I did not get when things would break in your comment. Can you give an example? |
Sure import torch
from tensordict.prototype import tensorclass
@tensorclass
class Data:
x: torch.Tensor
y: torch.Tensor
def __post_init__(self):
self.z = self.x + self.y
d = Data(x=torch.rand(10), y=torch.rand(10), batch_size=[10]) That will fail because |
But this will work right? import torch
from tensordict.prototype import tensorclass
@tensorclass
class Data:
x: torch.Tensor
y: torch. Tensor
z: Any
def __post_init__(self):
self.z = self.x + self.y
d = Data(x=torch.rand(10), y=torch.rand(10), batch_size=[10]) |
I gives me an error, but I think it could be made to work. Currently it's failing because we try to put |
My opinion (do you guys agree?) is that the purpose of tensorclass is to have an explicit list of content before instantiation. If someone violates that, we just need to handle the error in an informative way (essentially telling people that assigning undefined values is not permitted). |
Agree, your example with @tcbegley we should add both use-cases to unit-tests |
I also agree with this. Though in that case I'm not sure we should support this pattern @tensorclass
class Data:
x: torch.Tensor
y: torch. Tensor
z: Any
def __post_init__(self):
self.z = self.x + self.y
d = Data(x=torch.rand(10), y=torch.rand(10), batch_size=[10]) EDIT - to clarify, I don't think there's a problem necessarily with values being modified in the An analogous dataclass would throw a If the user really doesn't want to pass @tensorclass
class Data:
x: torch.Tensor
y: torch.Tensor
@property
def z(self):
return self.x + self.y Which I think is much clearer and has the added benefit of being up to date if |
you mean that the supported use case should be the following?
if yes, agree. I would try to stick as close as possible to dataclass semantics as you suggest |
Yeah, sort of. That example would run under my proposed changes, but it's not a pattern I would necessarily encourage. I think passing in an arbitrary object only to have it get overwritten no matter the value is not ideal. But yes, my main argument is that we should try to stick to dataclass-like behaviour where possible. |
thanks to #175 I noticed that we miss one test: a class with custom _post_init that gets initialised with a tensordict. Currently this test would fail as self.tensordict is filled after the init |
Is this something we need to address before landing this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@apbard let me know if you want your last comment to be addressed before merging
tensordict/prototype/tensorclass.py
Outdated
@@ -238,7 +230,6 @@ def wrapper(self, key, value): | |||
if type(value) in CLASSES_DICT.values(): | |||
value = value.__dict__["tensordict"] | |||
self.__dict__["tensordict"][key] = value | |||
assert self.__dict__["tensordict"][key] is value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- we don't want
assert
in the code base, these should be kept for tests only
@@ -67,6 +68,34 @@ def test_type(): | |||
assert type(data) is MyDataUndecorated | |||
|
|||
|
|||
def test_signature(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
yes, I think we should. In fact, we lack coverage for that use-case and it would also fail. |
We should definitely add a test. It seems to me @apbard that #175 resolves that issue right? I think the bigger problem then is that merging #175 is going to cause havoc with the non-tensor data PR that is close to being finished. Perhaps @vmoens if you're happy with #175 in principle, we can merge into this branch, add the test, and then work on adapting to the pending non-tensor data changes. EDIT - I added the test to both branches. As expected it fails here, but is working on #175. |
Didn't previously test compatiblility of a custom __post_init__ and building the tensorclass from a tensordict instance.
# Conflicts: # tensordict/prototype/tensorclass.py
# bypass initialisation. this means we don't incur any overhead creating an | ||
# empty tensordict and writing values to it. we can skip this because we already | ||
# have a tensordict to use as the underlying tensordict | ||
tc = cls.__new__(cls) | ||
tc.__dict__["tensordict"] = tensordict | ||
# since we aren't calling the dataclass init method, we need to manually check | ||
# whether a __post_init__ method has been defined and invoke it if so | ||
if hasattr(tc, "__post_init__"): | ||
tc.__post_init__() | ||
return tc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this to reduce overhead when constructing from TensorDict
. We bypass the regular constructor which means we don't create a tensordict or have to set the attributes. The only thing we need to make sure we do is run the __post_init__
method if it exists because we're no longer invoking the dataclass' init method.
I think this is ok, but if you can think of any edge cases I might have missed please let me know!
See relevant discussion here: #175 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with that solution
Description
This PR makes it possible to write a
__post_init__
method on a tensorclass, as one can with dataclasses.Note however that we do not currently support setting derived values, atm this is since the tensorclass permits only setting attributes in
__dataclass_fields__
which is determined from the class definition. The main challenge with lifting this restriction is that we would not have type information for the derived attributes, which is required elsewhere in the code to determine the types of return values. I don't personally see a way around this atm...Example
cc @apbard