Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue on restoring NNX state from Orbax checkpoint as pure dict #4308

Closed
mmorinag127 opened this issue Oct 18, 2024 · 2 comments · Fixed by #4317
Closed

Issue on restoring NNX state from Orbax checkpoint as pure dict #4308

mmorinag127 opened this issue Oct 18, 2024 · 2 comments · Fixed by #4317
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@mmorinag127
Copy link

Dear team,

I'm having an issue restoring pure dict from nnx.state,
Orbax will change the key has int type into string type.
Is there any way to fix it?
I put a sample code that can reproduce the issue.

from flax import nnx
import orbax.checkpoint as ocp
import jax


class MLPs(nnx.Module):
    def __init__(self, dim, rngs: nnx.Rngs):
        self.layers = []
        for _ in range(4):
            self.layers.append(nnx.Linear(dim, dim, rngs=rngs, use_bias=False))

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def test1():
    model = MLPs(4, rngs=nnx.Rngs(0))
    x = jax.random.normal(jax.random.key(42), (3, 4))
    assert model(x).shape == (3, 4)
    
    _, state = nnx.split(model)
    pure_dict_state = state.to_pure_dict()
    nnx.display(pure_dict_state)

    ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
    checkpointer = ocp.StandardCheckpointer()
    # checkpointer.save(ckpt_dir / 'state', state)
    checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
    
    # Restore as a pure dictionary.
    restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
    nnx.display(restored_pure_dict)
    
    abstract_model = nnx.eval_shape(lambda: MLPs(4, rngs=nnx.Rngs(0)))
    graphdef, abstract_state = nnx.split(abstract_model)
    abstract_state.replace_by_pure_dict(restored_pure_dict)
    model = nnx.merge(graphdef, abstract_state)
    assert model(x).shape == (3, 4)  # The model still works!


if __name__ == '__main__' :
    test1()

This code throws the error following
ValueError: key in pure_dict not available in state: ('layers', '0', 'bias')

@OmerRochman
Copy link

OmerRochman commented Oct 20, 2024

I'm having the same issue. I think it happens without orbarx too. int vs str?

    class MLPs(nnx.Module):
        def __init__(self, dim, rngs: nnx.Rngs):
            self.layers = []
            for _ in range(4):
                self.layers.append(nnx.Linear(dim, dim, rngs=rngs, use_bias=False))

        def __call__(self, x):
            for layer in self.layers:
                x = layer(x)
            return x
        
    model = MLPs(10, nnx.Rngs(42))
    _, model_state = nnx.split(model)

    print(model_state.to_pure_dict().keys())
    print(model_state.to_pure_dict()['layers'].keys())
    print(model_state.to_pure_dict()['layers'][0].keys())
    print(model_state.to_pure_dict()['layers']['0'].keys())

output:

dict_keys(['layers'])
dict_keys([0, 1, 2, 3])
dict_keys(['bias', 'kernel'])
Traceback (most recent call last):
  File "<removed>", line 294, in <module>
    print(model_state.to_pure_dict()['layers']['0'].keys())
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: '0'

As a temporal fix, I'm using the function below to convert all the str(int) keys into int

        def convert_keys(d):
            new_d = {}
            for key, value in d.items():
                # Convert key if it is a string representation of an integer
                if isinstance(key, str):
                    try:
                        new_key = int(key)
                    except ValueError:
                        new_key = key
                else:
                    new_key = key

                # Recursively process nested dictionaries
                if isinstance(value, dict):
                    new_value = convert_keys(value)
                else:
                    new_value = value

                new_d[new_key] = new_value

            return new_d

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 21, 2024

Hey @mmorinag127, thanks for posting this. I believe we knew about this edge case but forgot to upstream it.
I've created #4317 with a fix that passes your test. The solution is similar to what @OmerRochman is showing above (try casting to int) but I've integrated it into State.replace_by_pure_dict so it works out of the box.

@cgarciae cgarciae self-assigned this Oct 21, 2024
@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Oct 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants