Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TensorClass __post_init__ support #172

Merged
merged 18 commits into from
Jan 23, 2023
Merged

Conversation

tcbegley
Copy link
Contributor

Description

This PR makes it possible to write a __post_init__ method on a tensorclass, as one can with dataclasses.

Note however that we do not currently support setting derived values, atm this is since the tensorclass permits only setting attributes in __dataclass_fields__ which is determined from the class definition. The main challenge with lifting this restriction is that we would not have type information for the derived attributes, which is required elsewhere in the code to determine the types of return values. I don't personally see a way around this atm...

Example

import torch
from tensordict.prototype import tensorclass


@tensorclass
class MyDataPostInit:
    X: torch.Tensor
    y: torch.Tensor

    def __post_init__(self):
        assert (self.X > 0).all()
        assert self.y.abs().max() <= 10
        # modifying existing fields is fine
        self.y = self.y.abs()


y = torch.clamp(torch.randn(3, 4), min=-10, max=10)
data = MyDataPostInit(X=torch.rand(3, 4), y=y, batch_size=[3, 4])
assert (data.y == y.abs()).all()

# this results in an assertion error
MyDataPostInit(X=-torch.ones(2), y=torch.rand(2), batch_size=[2])

cc @apbard

tcbegley and others added 3 commits January 19, 2023 13:35
Co-authored-by: Alessandro Pietro Bardelli <[email protected]>
Co-authored-by: Alessandro Pietro Bardelli <[email protected]>
Co-authored-by: Alessandro Pietro Bardelli <[email protected]>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 19, 2023
def test_default():
@tensorclass
class MyData:
X: torch.Tensor = None # TODO: do we want to allow any default, say an integer?
y: torch.Tensor = torch.ones(3, 4, 5)

data = MyData(batch_size=[3, 4])
assert data.__dict__["y"] is None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was ultimately an implementation detail that changed and caused the tests to fail. "y" no longer exists in data.__dict__.

@@ -238,7 +230,6 @@ def wrapper(self, key, value):
if type(value) in CLASSES_DICT.values():
value = value.__dict__["tensordict"]
self.__dict__["tensordict"][key] = value
assert self.__dict__["tensordict"][key] is value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was failing both for None values, but also tensors that were moved to a new device when assigned to the tensordict, so I decided to just delete it. lmk though if you think we need to retain some kind of input validation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. we can drop that check.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • we don't want assert in the code base, these should be kept for tests only

@vmoens
Copy link
Contributor

vmoens commented Jan 19, 2023

@tcbegley sorry, but I did not get when things would break in your comment. Can you give an example?

@tcbegley
Copy link
Contributor Author

Can you give an example?

Sure

import torch
from tensordict.prototype import tensorclass


@tensorclass
class Data:
    x: torch.Tensor
    y: torch.Tensor

    def __post_init__(self):
        self.z = self.x + self.y

d = Data(x=torch.rand(10), y=torch.rand(10), batch_size=[10])

That will fail because "z" is not in expected_keys, but even if we were to solve that problem, I wonder if there's going to be problems around not having a field definition / type annotation which we use in various places, e.g. here

@vmoens
Copy link
Contributor

vmoens commented Jan 19, 2023

That will fail because "z" is not in expected_keys

But this will work right?

import torch
from tensordict.prototype import tensorclass


@tensorclass
class Data:
    x: torch.Tensor
    y: torch. Tensor
    z: Any

    def __post_init__(self):
        self.z = self.x + self.y

d = Data(x=torch.rand(10), y=torch.rand(10), batch_size=[10])

@tcbegley
Copy link
Contributor Author

I gives me an error, but I think it could be made to work. Currently it's failing because we try to put dataclasses._MISSING_TYPE into the tensordict and it doesn't know what to do with it. But we could handle that.

@vmoens
Copy link
Contributor

vmoens commented Jan 19, 2023

I gives me an error, but I think it could be made to work. Currently it's failing because we try to put dataclasses._MISSING_TYPE into the tensordict and it doesn't know what to do with it. But we could handle that.

My opinion (do you guys agree?) is that the purpose of tensorclass is to have an explicit list of content before instantiation. If someone violates that, we just need to handle the error in an informative way (essentially telling people that assigning undefined values is not permitted).
@apbard what do you think? Is this something that you'd find intuitive enough?

@apbard
Copy link
Contributor

apbard commented Jan 19, 2023

My opinion (do you guys agree?) is that the purpose of tensorclass is to have an explicit list of content before instantiation. If someone violates that, we just need to handle the error in an informative way (essentially telling people that assigning undefined values is not permitted).

Agree, your example with z: Any should work. Declaration in _post_init not.

@tcbegley we should add both use-cases to unit-tests

@tcbegley
Copy link
Contributor Author

tcbegley commented Jan 19, 2023

My opinion (do you guys agree?) is that the purpose of tensorclass is to have an explicit list of content before instantiation. If someone violates that, we just need to handle the error in an informative way (essentially telling people that assigning undefined values is not permitted).

I also agree with this. Though in that case I'm not sure we should support this pattern

@tensorclass
class Data:
    x: torch.Tensor
    y: torch. Tensor
    z: Any

    def __post_init__(self):
        self.z = self.x + self.y

d = Data(x=torch.rand(10), y=torch.rand(10), batch_size=[10])

EDIT - to clarify, I don't think there's a problem necessarily with values being modified in the __post_init__, rather that listed variables with no defaults really ought to be required arguments to the constructor. I think it's messy as a way to make that attribute exist and be writable inside __post_init__.

An analogous dataclass would throw a TypeError on instantiation because the user hasn't supplied an argument for z which has no default. I think if we require all features to be listed up-front, then they should all be passed to the constructor also.

If the user really doesn't want to pass z to the constructor because it's supposed to be derived from x and y then there is always the option to do the following

@tensorclass
class Data:
    x: torch.Tensor
    y: torch.Tensor

    @property
    def z(self):
        return self.x + self.y

Which I think is much clearer and has the added benefit of being up to date if x and y change. You could even use functools.cached_property if they won't change and computation is something more expensive than addition

@apbard
Copy link
Contributor

apbard commented Jan 19, 2023

you mean that the supported use case should be the following?

@tensorclass
class Data:
    x: torch.Tensor
    y: torch. Tensor
    z: Any

    def __post_init__(self):
        self.z = self.x + self.y

d = Data(x=torch.rand(10), y=torch.rand(10), z=anyobject, batch_size=[10])

if yes, agree. I would try to stick as close as possible to dataclass semantics as you suggest

@tcbegley
Copy link
Contributor Author

Yeah, sort of. That example would run under my proposed changes, but it's not a pattern I would necessarily encourage. I think passing in an arbitrary object only to have it get overwritten no matter the value is not ideal.

But yes, my main argument is that we should try to stick to dataclass-like behaviour where possible.

@apbard
Copy link
Contributor

apbard commented Jan 20, 2023

thanks to #175 I noticed that we miss one test: a class with custom _post_init that gets initialised with a tensordict. Currently this test would fail as self.tensordict is filled after the init

@vmoens vmoens added the enhancement New feature or request label Jan 23, 2023
@vmoens
Copy link
Contributor

vmoens commented Jan 23, 2023

thanks to #175 I noticed that we miss one test: a class with custom _post_init that gets initialised with a tensordict. Currently this test would fail as self.tensordict is filled after the init

Is this something we need to address before landing this?

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
@apbard let me know if you want your last comment to be addressed before merging

@@ -238,7 +230,6 @@ def wrapper(self, key, value):
if type(value) in CLASSES_DICT.values():
value = value.__dict__["tensordict"]
self.__dict__["tensordict"][key] = value
assert self.__dict__["tensordict"][key] is value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • we don't want assert in the code base, these should be kept for tests only

@@ -67,6 +68,34 @@ def test_type():
assert type(data) is MyDataUndecorated


def test_signature():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@apbard
Copy link
Contributor

apbard commented Jan 23, 2023

LGTM @apbard let me know if you want your last comment to be addressed before merging

yes, I think we should. In fact, we lack coverage for that use-case and it would also fail.

@tcbegley
Copy link
Contributor Author

tcbegley commented Jan 23, 2023

We should definitely add a test.

It seems to me @apbard that #175 resolves that issue right? I think the bigger problem then is that merging #175 is going to cause havoc with the non-tensor data PR that is close to being finished.

Perhaps @vmoens if you're happy with #175 in principle, we can merge into this branch, add the test, and then work on adapting to the pending non-tensor data changes.

EDIT - I added the test to both branches. As expected it fails here, but is working on #175.

Comment on lines +219 to +228
# bypass initialisation. this means we don't incur any overhead creating an
# empty tensordict and writing values to it. we can skip this because we already
# have a tensordict to use as the underlying tensordict
tc = cls.__new__(cls)
tc.__dict__["tensordict"] = tensordict
# since we aren't calling the dataclass init method, we need to manually check
# whether a __post_init__ method has been defined and invoke it if so
if hasattr(tc, "__post_init__"):
tc.__post_init__()
return tc
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this to reduce overhead when constructing from TensorDict. We bypass the regular constructor which means we don't create a tensordict or have to set the attributes. The only thing we need to make sure we do is run the __post_init__ method if it exists because we're no longer invoking the dataclass' init method.

I think this is ok, but if you can think of any edge cases I might have missed please let me know!

See relevant discussion here: #175 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with that solution

@vmoens vmoens merged commit 95bf524 into main Jan 23, 2023
@vmoens vmoens deleted the tensorclass-post-init branch February 11, 2023 10:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants