Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docstring for State is wrong about leaf types (or else the FSDP example is wrong) #4342

Closed
gabbard opened this issue Oct 28, 2024 · 1 comment · Fixed by #4346
Closed

Docstring for State is wrong about leaf types (or else the FSDP example is wrong) #4342

gabbard opened this issue Oct 28, 2024 · 1 comment · Fixed by #4346
Assignees

Comments

@gabbard
Copy link

gabbard commented Oct 28, 2024

System information

  • Flax: 0.4.34

Problem you have encountered:

In the example for how to do Fully Sharded Data Parallelism (FSDP), we do:

    state = nnx.state(optimizer)
    ...
    def get_named_shardings(path: tuple, value: nnx.VariableState):
    if path[0] == 'params':
        ret = value.replace(NamedSharding(mesh, P(*value.sharding)))
        return ret
    elif path[0] == 'momentum':
        # currently the same as above but in general it could be different
        return value.replace(NamedSharding(mesh, P(*value.sharding)))
    else:
        raise ValueError(f'Unknown path: {path}')

    named_shardings = state.map(get_named_shardings)
    sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
    nnx.update(optimizer, sharded_state)

The code implies that state has has a key type str and value type of VariableState (which the debugger confirms). But the docstring of State says:

    A pytree-like structure that contains a ``Mapping`` from strings or
    integers to leaves. A valid leaf type is either :class:`Variable`,
    ``jax.Array``, ``numpy.ndarray`` or nested ``State``'s....

So having a VariableState as a leaf value seems at odds with the docstring.

To avoid confusion, either the docstring on State should be updated, or the FSDP example should be updated.

Also, a small nit: the sharded_state defined on line 107 of the FSDP example is unused. This isn't a big deal in itself, but it creates a doubt for the reader about the correctness of the example.

@cgarciae
Copy link
Collaborator

Hey @gabbard, thanks for reporting this! The docstrings are outdated as we no longer treat jax.Array and np.ndarray as State leaves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants