Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: PegasusXForConditionalGeneration does not support device_map='auto'. To implement support, the modelclass needs to implement the _no_split_modules attribute. #1900

Closed
andreeahedes opened this issue Aug 28, 2023 · 5 comments

Comments

@andreeahedes
Copy link

andreeahedes commented Aug 28, 2023

Hi, I am using Accelerate version: 0.21.0. I am trying to implement disk/CPU offload with accelerate for a pretrained PegasusX-large model. Below is my implementation, using an "auto" device map to see if the accelerate would load the model with a map for which each layer has a device associated.

def get_device_map(model_checkpoint: str):
    with accelerate.init_empty_weights():
        config = PegasusXConfig.from_pretrained(model_checkpoint)
        model = PegasusXForConditionalGeneration(config=config)
        model.tie_weights()
        model = model.from_pretrained(model_checkpoint, device_map="auto")
    model = accelerate.load_checkpoint_and_dispatch(model, model_checkpoint, no_split_module_classes=["PegasusXDecoderLayer"])

On the second to last line, I get the following error: ValueError: PegasusXForConditionalGeneration does not support device_map='auto'. To implement support, the modelclass needs to implement the _no_split_modules attribute. This issue seems related to the implementation of the PegasusX model. How would one go about fixing it?

@muellerzr
Copy link
Collaborator

cc @SunMarc

@SunMarc
Copy link
Member

SunMarc commented Aug 28, 2023

Hi @andreeahedes , there are two ways to load a model with device_map:

  • accelerate :
with accelerate.init_empty_weights():
    config = PegasusXConfig.from_pretrained(model_checkpoint)
    model = PegasusXForConditionalGeneration.from_config(config)
model.tie_weights()
model = accelerate.load_checkpoint_and_dispatch(model, model_checkpoint, no_split_module_classes=["PegasusXDecoderLayer"])

However, you should better use the transformers integration of this feature because you might encounter some issues with the first method. This is because when we call PegasusXForConditionalGeneration or AutoModelForCasualLM, we create additional layers (e.g. lm_head) and this will lead to errors when loading the weights.

  • transformers
model = PegasusXForConditionalGeneration.from_pretrained(model_checkpoint, device_map="auto")

If the device_map is not supported with your model, you can make it work by implementing the _no_split_modules attribute in the modelclass. Feel free to submit a PR to add support ! You can also check this issue if you want more context.
Let me know if it is clearer !

@andreeahedes
Copy link
Author

Thank you for your quick response. Which accelerate tests should I run to check whether the _no_split_modules are functional?

@SunMarc
Copy link
Member

SunMarc commented Aug 29, 2023

That's awesome that you are interested on implementing this =) . You need to make sure that 3 tests are functional. The tests are test_disk_offload, test_cpu_offload and test_model_parallelism. You can find them here. To run them, you can use the following command on the root of the transformers folder : pytest tests/models/pegasus_x/test_modeling_pegasus_x.py::PegasusXModelTest::test_cpu_offload -s for example for the test_cpu_offload test. Make sure that you have access to at least 2 GPUs to try if it works. For debugging, you can check the device_map that was obtained by checking model.hf_device_map ! If it is not working, it probably means that some modules that should not be split were not specified in _no_split_modules. Let me know how it goes. If everything works, you can then submit your PR !

@andreeahedes
Copy link
Author

Hi, I found a combination that works and passes all 3 tests on a machine with 2 GPUs. I created a PR, it should be linked to this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants