Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure module prefixes only match that module #1319

Merged
merged 1 commit into from
Apr 17, 2023

Conversation

xloem
Copy link
Contributor

@xloem xloem commented Apr 17, 2023

Hi,

I found laboriously today that the prefix matching for identifying submodules in modeling.py will also match sibling modules with longer names, such as model.layers.1 matching model.layers.18. This was producing broken structures preventing execution.

This patch mutates all uses of .startswith() in that file to ensure the model name is suffixed with a .. This resolved the immediate issue I encountered for me.

@muellerzr muellerzr requested a review from sgugger April 17, 2023 16:04
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 17, 2023

The documentation is not available anymore as the PR was closed or merged.

@xloem
Copy link
Contributor Author

xloem commented Apr 17, 2023

I see there are test failures. Maybe one of the changes I made was incorrect. The change that fixed things for me was the one in load_state_dict. I'm trying the test locally.

@xloem xloem force-pushed the prefix-conflation branch from 6e91e5d to 344391c Compare April 17, 2023 17:16
@xloem
Copy link
Contributor Author

xloem commented Apr 17, 2023

Tests appear to be passing now and I've rebased to a single commit against main.

@sgugger
Copy link
Collaborator

sgugger commented Apr 17, 2023

Can you please post a sample of code that is failing before this PR so we can understand better the problem you are fixing?

@xloem
Copy link
Contributor Author

xloem commented Apr 17, 2023

The below code exhausts memory without this patch, because layer10 shares a prefix with layer1, is improperly determined to be a submodule of it, and then uploaded needlessly to both devices.

It seemed to me that without this change, models with > 10 layers were duplicating vram usage, uploading their weights multiply times.

import accelerate
import safetensors.torch

import torch
import torch.nn as nn

DEV = 0
dtype = torch.float32

class Test(nn.Module):
    def __init__(self, layer_count, linear_size_1, linear_size_2, dtype):
        super().__init__()
        self.layer1 = nn.Linear(linear_size_1, linear_size_2, dtype=dtype)
        self.layer10 = nn.Linear(linear_size_1, linear_size_2, dtype=dtype)

max_dev_memory = accelerate.utils.get_max_memory()[DEV]
max_dim_fits = int((max_dev_memory / accelerate.utils.dtype_byte_size(dtype)) ** (1/2))
module = Test(
    layer_count = 2,
    linear_size_1 = max_dim_fits * 2 // 3, # use 2/3rds available vram for each layer
    linear_size_2 = max_dim_fits,
    dtype = dtype
)
safetensors.torch.save_file(module.state_dict(), 'test-prefix.safetensors', metadata=dict(format='pt'))
device_map = accelerate.infer_auto_device_map(module)
del module
accelerate.utils.load_state_dict('test-prefix.safetensors', device_map=device_map)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR on this and the clear explanation on what was the problem!

@sgugger sgugger merged commit ee0c587 into huggingface:main Apr 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants