Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using pytreeclass with jax and pytorch without specifying backend as environment variable #90

Open
mfinzi opened this issue Oct 6, 2023 · 0 comments

Comments

@mfinzi
Copy link

mfinzi commented Oct 6, 2023

Hi @ASEM000 ,

I really like your library compared to some other automated pytree alternatives and would love to see more people using it.
I was interested in using pytreeclass in CoLA, a numerical linear algebra library that I have been involved in developing. One of the design constraints is that we need to be able to support usage in both jax and pytorch, whether jax is installed, pytorch is installed, or both. This decision depends on the LinearOperator objects that the user creates , and there can be scenarios even where both jax and pytorch objects exist simultaneously.

We were hoping to use pytreeclass as the base pytree for the LinearOperator objects, but have run into some issues with this cross-platform support. We know that pytreeclass was designed with support for both jax and pytorch in mind, but I couldn't find details on this topic in the docs.

Having a look in pytreeclass/_src/backend/init.py is this specified using the environment variable?
Is there any way that pytree class can function whether or not jax or pytorch is installed based on whether the imports succeed or fail? Also do you have any thoughts for whether it would be possible to have jax and pytorch pytrees existing at the same time?

Cheers,
Marc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant