Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner Cache dtype and device extraction for CUDA graph generation for quantizers compatibility #29079

Merged
merged 6 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,9 +815,13 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

for layer in self.model.layers:
weights = layer.self_attn.o_proj.weight
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
)

def _reset_cache(self):
Expand Down
68 changes: 61 additions & 7 deletions tests/quantization/aqlm_integration/test_aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# limitations under the License.

import gc
import importlib
import tempfile
import unittest

from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM
from packaging import version

from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, StaticCache
from transformers.testing_utils import (
require_accelerate,
require_aqlm,
Expand All @@ -26,7 +29,7 @@
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
from transformers.utils import is_accelerate_available, is_aqlm_available, is_torch_available


if is_torch_available():
Expand Down Expand Up @@ -71,11 +74,12 @@ def test_from_dict(self):
@require_aqlm
@require_accelerate
class AqlmTest(unittest.TestCase):
model_name = "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch"
model_name = "BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf"

input_text = "Hello my name is"
max_new_tokens = 40

EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am currently a sophomore and am majoring in Psychology. I am"
EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I am very easy going and love to make"

device_map = "cuda"

Expand Down Expand Up @@ -144,7 +148,7 @@ def test_quantized_model(self):
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_raise_if_non_quantized(self):
Expand All @@ -164,7 +168,7 @@ def test_save_pretrained(self):

input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = model.generate(**input_ids, max_new_tokens=40)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

@require_torch_multi_gpu
Expand All @@ -178,6 +182,56 @@ def test_quantized_model_multi_gpu(self):

self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})

output = quantized_model.generate(**input_ids, max_new_tokens=40)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)

self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

@unittest.skipUnless(
is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"),
"test requires `aqlm>=1.0.3`",
)
def test_quantized_model_compile(self):
Comment on lines +189 to +193
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loving this test ❤️

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker @BlackSamorez The problem with it that it's failing :) . See this. So, advice needed on what to do here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is super important that the outputs match for quantized models no? Distributions are the same, but kernels / ops are not run in the same order. It's small but could explain this?
Would just add a long generation and make sure it still makes sense!

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know how to automatically check if text makes sense.
Alternatively, I've shortened the generation length from 40 tokens to 32 and it matches perfectly on RTX 3090, RTX 2080ti and a6000. Maybe we could just leave it as is since the tests above are exact match anyway.
(Current iteration tests pass)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with me 😉

"""
Simple test that checks if the quantized model is working properly
"""

# Sample tokens greedily
def decode_one_tokens(model, cur_token, input_pos, cache_position):
logits = model(
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
)[0]
new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)

return new_token

# Tokenize the test input
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)["input_ids"]
seq_length = input_ids.shape[1]

# Setup static KV cache for generation
self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens)

# Allocate token ids to be generated and copy prefix ids
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(1, seq_length + self.max_new_tokens, dtype=torch.int, device=torch_device)
generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)

# Do a forward pass to fill the prefix cache and compile the kernels if necessary
logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0]
next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
generated_ids[:, [seq_length]] = next_token

with torch.no_grad():
# Compile the CUDA graph
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)

# Generate tokens one by one
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, self.max_new_tokens - 1):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position)
generated_ids.index_copy_(1, cache_position, next_token)
cache_position += 1

# Check generated text
self.assertEqual(self.tokenizer.decode(generated_ids[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
Loading