Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cloudpickle fails only when jitting after de-serialisation #231

Closed
simon-bachhuber opened this issue Nov 16, 2022 · 2 comments
Closed

cloudpickle fails only when jitting after de-serialisation #231

simon-bachhuber opened this issue Nov 16, 2022 · 2 comments
Labels
question User queries

Comments

@simon-bachhuber
Copy link

Not an issue but a question. Why does cloudpickle work partially in this example?

import equinox as eqx
import jax.numpy as jnp 
import cloudpickle as p 

class Model(eqx.Module):
    x: jnp.ndarray
    def __call__(self):
        return self.x

m = Model(jnp.ones((3,3)))
m = p.loads(p.dumps(m))

# this works
m()
# this does not
eqx.filter_jit(m)()
>> AttributeError: 'Model' object has no attribute 'x'
@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 17, 2022

Looks like some kind of bug in cloudpickle, mutating what counts as the fields of a dataclass. Here's a reproducer that doesn't use Equinox or JAX:

import cloudpickle as p
import dataclasses

@dataclasses.dataclass
class Model:
    x: int

print(dataclasses.fields(Model))
p.loads(p.dumps(Model(1)))
print(dataclasses.fields(Model))

@patrick-kidger patrick-kidger added the question User queries label Nov 17, 2022
@simon-bachhuber
Copy link
Author

Ha. I will move this then over to cloudpickle directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants