Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KeyError raised in Module.update. #1630

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hehua2008
Copy link

e.g., load safetensors from black-forest-labs/FLUX.1-Depth-dev-lora.

Traceback (most recent call last):
  File "/Volumes/Develop/MLX/mlx-examples/flux/txt2image.py", line 141, in <module>
    load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
  File "/Volumes/Develop/MLX/mlx-examples/flux/txt2image.py", line 106, in load_adapter
    flux.flow.load_weights(list(weights.items()), strict=False)
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 195, in load_weights
    self.update(tree_unflatten(weights))
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 335, in update
    apply(self, parameters)
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 323, in apply
    apply(current_value, new_value)
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 331, in apply
    current_value.update(new_value)
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 335, in update
    apply(self, parameters)
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 321, in apply
    current_value.update(new_value)
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 335, in update
    apply(self, parameters)
  File "/opt/anaconda3/envs/MLX/lib/python3.12/site-packages/mlx/nn/layers/base.py", line 326, in apply
    current_value = dst[i]
                    ~~~^^^
KeyError: 0

@hehua2008 hehua2008 force-pushed the fix-module-update branch 3 times, most recently from 551101d to e42ac51 Compare November 27, 2024 14:13
e.g., load safetensors from black-forest-labs/FLUX.1-Depth-dev-lora.
@awni
Copy link
Member

awni commented Nov 27, 2024

I'm not sure this fix makes sense. This might be an issue in the Flux code.. could you share steps to reproduce it?

@hehua2008
Copy link
Author

hehua2008 commented Nov 28, 2024

I'm not sure this fix makes sense. This might be an issue in the Flux code.. could you share steps to reproduce it?

Can be reproduced in mlx-examples/flux project:

python txt2image.py --model dev --save-raw --steps 20 --n-images 1 \
    --adapter black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors \
    --fuse-adapter \
    "Any prompt..."

I found

"final_layer.adaLN_modulation.1.bias": {"dtype":"BF16","shape":[6144],"data_offsets":[12913589248,12913601536]},
"final_layer.adaLN_modulation.1.weight": {"dtype":"BF16","shape":[6144,3072],"data_offsets":[12913601536,12951350272]},

in black-forest-labs/FLUX.1-dev/flux1-dev.safetensors,

and

"final_layer.adaLN_modulation.1.lora_A.weight": {"dtype":"BF16","shape":[128,3072],"data_offsets":[690863104,691649536]},
"final_layer.adaLN_modulation.1.lora_B.bias": {"dtype":"BF16","shape":[6144],"data_offsets":[691649536,691661824]},
"final_layer.adaLN_modulation.1.lora_B.weight": {"dtype":"BF16","shape":[6144,128],"data_offsets":[691661824,693234688]},

in black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors.

When mlx loads the above models, 'mlx.nn.layers.containers.Sequential' will be created for adaLN_modulation, which will result in "KeyError: 0" exception when "flux.flow.load_weights(lora_weights), strict=False)".

Overall, Modules like Sequential are not considered in the internal implementation of the Module.update(parameters: dict) method in the current version of mlx.

@hehua2008
Copy link
Author

hehua2008 commented Nov 28, 2024

Actually, I don't know if this commit makes sense either, although it works smoothly. Maybe this should be changed to an issue.
Looking forward to your bug fixes!

@hehua2008
Copy link
Author

>>> import mlx.nn
>>> sequential = mlx.nn.Sequential()
>>> print(sequential.items())
dict_items([('layers', [])])

@hehua2008
Copy link
Author

hehua2008 commented Nov 28, 2024

The author of mlx-examples/flux had to use a method named "sanitize" to eliminate the effect of Sequential on the addition of ".layers" to the related weights' full name:

mlx-examples/flux/flux/model.py

    def sanitize(self, weights):
        new_weights = {}
        for k, w in weights.items():
            if k.endswith(".scale"):
                k = k[:-6] + ".weight"
            for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
                if f".{seq}." in k:
                    k = k.replace(f".{seq}.", f".{seq}.layers.")
                    break
            new_weights[k] = w
        return new_weights

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants