Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.6.3 introduced a circular dependency with orbax #2707

Closed
mattsoulanille opened this issue Dec 10, 2022 · 16 comments
Closed

0.6.3 introduced a circular dependency with orbax #2707

mattsoulanille opened this issue Dec 10, 2022 · 16 comments
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@mattsoulanille
Copy link

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
  • Python version:
  • GPU/TPU model and memory:
  • CUDA version (if applicable):

Problem you have encountered:

Flax 0.6.3 added a dependency on orbax, which has a dependency on flax. This is causing tensorflow/tfjs#7159 in the TensorFlow.js repository. TFJS resolves pypi packages using Bazel, which does not support circular dependencies.

Was this change intentional? If so, I can file a bug with rules_python instead, although last time this kind of circular dependency issue arose, it was determined to be a bug in the downstream package. I'm not sure if that true in this case, though.

What you expected to happen:

No circular dependency.

Logs, error messages, etc:

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

@marcvanzee marcvanzee added the Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required) label Dec 12, 2022
@marcvanzee
Copy link
Collaborator

Thanks @mattsoulanille for filing this issue. We will look into it as soon as possible, I have assigned it to @IvyZX since she is leading the Orbax work so she has the most context here.

@mattsoulanille
Copy link
Author

Thanks! TFJS has pinned flax to 0.6.2 for now, so we are unblocked for the moment.

@IvyZX
Copy link
Collaborator

IvyZX commented Dec 12, 2022

In internal equivalents of Bazel, the BUILD rules are specified so that there is no actual code-level circular dependency between Flax and Orbax. Is it possible to specify BUILD rules in the open-source codebase so that Bazel knows how to run it?

@IvyZX IvyZX added Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) and removed Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required) labels Dec 12, 2022
@IvyZX
Copy link
Collaborator

IvyZX commented Dec 13, 2022

Discussed with @mattsoulanille and @levskaya offline. TFJS seems to be okay with pinning flax==0.6.2 for a while, since it seems a bit tricky to figure out a Bazel-based solution for TFJS that recognizes Flax and Orbax dependency graph at code-level. TFJS also doesn't use Flax in its code except for one test.

But in the long term this circular dependency of Flax and Orbax may be a blocker for other installations as well. Generally the right direction is probably to cut Orbax's dependency on Flax, since Orbax targets for general (JAX/TF) tooling.

Right now Orbax is dependent on Flax for serialization.py and traverse_util.py packages.
Most code can be copied to Orbax as a temporary solution, except for the _STATE_DICT_REGISTRY in serialization which allows user to register serialization solution for customized leaf types. If Orbax copies serialization.py and creates its own registry, custom types registered through Flax will fail to propagate. I wonder if Flax can simply frame its register_serialization_state() as a wrapper of Orbax's state registration.

Either way, gonna have more discussion with Orbax and figure out a long-term solution.

@cgarciae
Copy link
Collaborator

For the future, here are some ideas:

  • Orbax becomes the owner of seralization.py, this makes sense for a check-pointing library. On the Flax side we add some wrappers.
  • Create a new flax-serialization library which both libraries can depend on to avoid the circular dependency.

@levskaya
Copy link
Collaborator

We have to maintain control of traversals / flattening nested-dict, struct dataclass, etc. As we use those for more things than just checkpoints. One hope is to enhance JAX pytree calls with the notion of "paths" such that we could remove our hacky state-dict registry altogether - but this is still in discussion.

@aschleck
Copy link

aschleck commented Feb 7, 2023

We hit this today using a rules_python Bazel repository :(. Has there been any further thought into resolving this?

@samuela
Copy link

samuela commented Feb 9, 2023

Just came across this today trying to upgrade flax in nixpkgs...

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 9, 2023

We just released 0.6.5 can you try again and see if its fixed for you?

@samuela
Copy link

samuela commented Feb 9, 2023

Looking at 0.6.5 it appears that there's still a dependency on orbax:

flax/setup.py

Line 33 in 524629e

"orbax",

@wookayin
Copy link
Contributor

wookayin commented Feb 9, 2023

I guess this is a problem of orbax depending on flax?

@samuela
Copy link

samuela commented Feb 9, 2023

I guess this is a problem of orbax depending on flax?

Or the other way around, depending on how you want to think about it. All comes down to how flax/orbax maintainers choose to remedy the situation. Seems like #2707 (comment) is the latest on that.

dotlambda added a commit to dotlambda/nixpkgs that referenced this issue Feb 10, 2023
This reverts commit fe0048c which broke
flax due to google/flax#2707.
gador pushed a commit to gador/nixpkgs that referenced this issue Feb 13, 2023
This reverts commit fe0048c which broke
flax due to google/flax#2707.
@levskaya
Copy link
Collaborator

Hi all - sorry for the delay on this issue! The underlying issue is that orbax has been using the flax serialization routines, partly for some backwards-compatibility reasons, but mainly because the simple flax "state dict" machinery was a common way to handle deriving the "key paths" to each leaf in a pytree. The circular dependency is occurring since we're trying to transition to being able to use orbax for checkpoints.

We're trying to resolve this issue in a fundamental way by adopting a mechanism in jax itself to define the key-paths to pytree leaves so that we needn't use our relatively simple state-dict abstraction in other libraries (and ultimately to delete it ourselves).

Our sincere apologies for the build troubles with this circular dependency - we and the orbax maintainers are working to try to resolve it in the next week or so.

@wookayin
Copy link
Contributor

wookayin commented Apr 9, 2023

Hi! When will this be fixed? This makes pip dependency resolution broken when orbax/flax are listed dependencies.

@daskol
Copy link

daskol commented Apr 10, 2023

@wookayin Fix was promised soon in #2882. 😃

As far as I remember pip actually installs flax and orbax but report broken dependency resolution.

@IvyZX
Copy link
Collaborator

IvyZX commented Jul 27, 2023

As Flax depends on orbax-checkpoint now, there is no longer a circular dependency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

No branches or pull requests

9 participants