Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner Cache dtype and device extraction for CUDA graph generation for quantizers compatibility #29079

Merged
merged 6 commits into from
Feb 27, 2024

Conversation

BlackSamorez
Copy link
Contributor

What does this PR do?

As of now, this PR fixes a small problem preventing one from using CUDA graph generation from #28937 with quantized models.

In the long run, It would be great to have compiled generation actually working for GPTQ/AQLM/other quantization methods.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts
Copy link
Collaborator

cc @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing !
For retrieving the correct device, the fix sounds correct.
However for the dtype, I am afraid this might lead to some bugs / unexpected behaviours 😭 As many users call perform text generation after calling some utility methods such as prepare_model_for_kbit_training (using PEFT), we do sometimes cast the layer norms in FP32. This is quite a niche usecase though. I propose to be on the safe zone and retrieve the dtype similarly as what we do here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L451-L457 - can you let me know if applying that logic here would fix CUDA graph generation for quantized models?
Also, can you elaborate a bit on the original issue, i.e. what you are trying to achieve and what bug do you get
Thanks !

@BlackSamorez
Copy link
Contributor Author

BlackSamorez commented Feb 22, 2024

@younesbelkada The error I get on main is quite simple:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[8], [line 22](vscode-notebook-cell:?execution_count=8&line=22)
     [19](vscode-notebook-cell:?execution_count=8&line=19)     return new_token
     [21](vscode-notebook-cell:?execution_count=8&line=21) with torch.no_grad():
---> [22](vscode-notebook-cell:?execution_count=8&line=22)     model._setup_cache(StaticCache, BS, max_cache_len=max_cache_length)
     [24](vscode-notebook-cell:?execution_count=8&line=24)     ### PREFILL
     [25](vscode-notebook-cell:?execution_count=8&line=25)     # input_pos = torch.arange(seq_length, device=device)
     [26](vscode-notebook-cell:?execution_count=8&line=26)     cache_position = torch.arange(seq_length , device=device)

File [~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:817](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:817), in LlamaPreTrainedModel._setup_cache(self, cache_cls, max_batch_size, max_cache_len)
    [814](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:814)     self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
    [816](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:816) for layer in self.model.layers:
--> [817](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:817)     weights = layer.self_attn.o_proj.weight
    [818](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:818)     layer.self_attn.past_key_value = cache_cls(
    [819](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:819)         self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
    [820](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:820)     )

File [~/AQLM/.conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1706](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1706), in Module.__getattr__(self, name)
   [1704](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1704)     if name in modules:
   [1705](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1705)         return modules[name]
-> [1706](https://vscode-remote+ssh-002dremote-002bdas8gpu4.vscode-resource.vscode-cdn.net/nfs/scistore14/alistgrp/apanfero/AQLM/~/AQLM/.conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1706) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'QuantizedLinear' object has no attribute 'weight'

That's why, I believe, it would make sense to source device and dtype from outside the linear layer to be quantization-agnostic.

@BlackSamorez
Copy link
Contributor Author

And I don't really understand what you're proposing with the code you referenced.

@younesbelkada
Copy link
Contributor

@BlackSamorez thanks!
Sorry I sent the wrong link, it should be:

target_dtype = self.config._pre_quantization_dtype

you can do something like:

if hasattr(self.config, "_pre_quantization_dtype")
    target_dtype = self.config._pre_quantization_dtype
else:
    target_dtype = layer.self_attn.o_proj.weight

Does that fixes the issue?

@BlackSamorez BlackSamorez changed the title Quantization support for CUDA graph generation. Cleaner Cache dtype and device extraction for CUDA graph generation for quantizers compatibility Feb 22, 2024
@BlackSamorez
Copy link
Contributor Author

What you proposed seems to work fine with both FP16 and AQLM with a notebook test based off @ArthurZucker's test script.
With what other models shall I try and run it?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work !
Could you add a simple test in the aqlm testing file to test that usecase 🙏

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BlackSamorez
Copy link
Contributor Author

BlackSamorez commented Feb 22, 2024

@younesbelkada CUDA graph generation diverges at some point:
Prefix: Hello my name is
Normal: Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I am very easy going and love to make
CUDA graph: Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I am a very hard worker and I

A stupid solution would be to generate shorter texts but I'm not sure if it's a good idea to have unstable tests.

P.S. As you might have guessed, I added a CUDA graph generation test for AQLM.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good to me!
Do you have any numbers to share regarding benchmakr?

Comment on lines +189 to +193
@unittest.skipUnless(
is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"),
"test requires `aqlm>=1.0.3`",
)
def test_quantized_model_compile(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loving this test ❤️

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker @BlackSamorez The problem with it that it's failing :) . See this. So, advice needed on what to do here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is super important that the outputs match for quantized models no? Distributions are the same, but kernels / ops are not run in the same order. It's small but could explain this?
Would just add a long generation and make sure it still makes sense!

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know how to automatically check if text makes sense.
Alternatively, I've shortened the generation length from 40 tokens to 32 and it matches perfectly on RTX 3090, RTX 2080ti and a6000. Maybe we could just leave it as is since the tests above are exact match anyway.
(Current iteration tests pass)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with me 😉

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work @BlackSamorez !

Comment on lines +189 to +193
@unittest.skipUnless(
is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"),
"test requires `aqlm>=1.0.3`",
)
def test_quantized_model_compile(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with me 😉

@younesbelkada younesbelkada merged commit e3fc90a into huggingface:main Feb 27, 2024
19 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
…on for quantizers compatibility (#29079)

* input_layernorm as the beacon of hope

* cleaner dtype extraction

* AQLM + CUDA graph test

* is available check

* shorter text test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants