From f93e7308fd117341d479c4f2bc7d57cfc2700a36 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Sat, 17 Feb 2024 20:55:36 +0100 Subject: [PATCH 1/5] input_layernorm as the beacon of hope --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c30be2a2da4f63..a9fb777e40119f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -796,7 +796,7 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) for layer in self.model.layers: - weights = layer.self_attn.o_proj.weight + weights = layer.input_layernorm.weight layer.self_attn.past_key_value = cache_cls( self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype ) From a956ec8a98aaf998453c41bf94527f0222d8d2aa Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 22 Feb 2024 11:14:12 +0100 Subject: [PATCH 2/5] cleaner dtype extraction --- src/transformers/models/llama/modeling_llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a9fb777e40119f..92865f667b7edc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -796,9 +796,13 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) for layer in self.model.layers: - weights = layer.input_layernorm.weight + device = layer.input_layernorm.weight.device + if hasattr(self.config, "_pre_quantization_dtype"): + dtype = self.config._pre_quantization_dtype + else: + dtype = layer.self_attn.o_proj.weight.dtype layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype + self.config, max_batch_size, max_cache_len, device=device, dtype=dtype ) def _reset_cache(self): From edebc721bc6d66779eb28780e06628bfbf0609c7 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 22 Feb 2024 15:25:27 +0100 Subject: [PATCH 3/5] AQLM + CUDA graph test --- .../aqlm_integration/test_aqlm.py | 66 +++++++++++++++++-- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 6a5cefea2fb177..96f8200aba1ef8 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -14,10 +14,13 @@ # limitations under the License. import gc +import importlib import tempfile import unittest -from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM +from packaging import version + +from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, StaticCache from transformers.testing_utils import ( require_accelerate, require_aqlm, @@ -71,11 +74,12 @@ def test_from_dict(self): @require_aqlm @require_accelerate class AqlmTest(unittest.TestCase): - model_name = "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch" + model_name = "BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf" input_text = "Hello my name is" + max_new_tokens = 40 - EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am currently a sophomore and am majoring in Psychology. I am" + EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I am very easy going and love to make" device_map = "cuda" @@ -144,7 +148,7 @@ def test_quantized_model(self): """ input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) - output = self.quantized_model.generate(**input_ids, max_new_tokens=40) + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) def test_raise_if_non_quantized(self): @@ -164,7 +168,7 @@ def test_save_pretrained(self): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) - output = model.generate(**input_ids, max_new_tokens=40) + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) @require_torch_multi_gpu @@ -178,6 +182,56 @@ def test_quantized_model_multi_gpu(self): self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) - output = quantized_model.generate(**input_ids, max_new_tokens=40) + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @unittest.skipUnless( + version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"), + "test requires `aqlm>=1.0.3`", + ) + def test_quantized_model_compile(self): + """ + Simple test that checks if the quantized model is working properly + """ + + # Sample tokens greedily + def decode_one_tokens(model, cur_token, input_pos, cache_position): + logits = model( + cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + )[0] + new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) + + return new_token + + # Tokenize the test input + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)["input_ids"] + seq_length = input_ids.shape[1] + + # Setup static KV cache for generation + self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens) + + # Allocate token ids to be generated and copy prefix ids + cache_position = torch.arange(seq_length, device=torch_device) + generated_ids = torch.zeros(1, seq_length + self.max_new_tokens, dtype=torch.int, device=torch_device) + generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) + + # Do a forward pass to fill the prefix cache and compile the kernels if necessary + logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] + next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) + generated_ids[:, [seq_length]] = next_token + + with torch.no_grad(): + # Compile the CUDA graph + decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) + + # Generate tokens one by one + cache_position = torch.tensor([seq_length + 1], device=torch_device) + for _ in range(1, self.max_new_tokens - 1): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position) + generated_ids.index_copy_(1, cache_position, next_token) + cache_position += 1 + + # Check generated text + self.assertEqual(self.tokenizer.decode(generated_ids[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) From 39d7603de929e07554d8684d5ea57199eefe943b Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 22 Feb 2024 16:00:11 +0100 Subject: [PATCH 4/5] is available check --- tests/quantization/aqlm_integration/test_aqlm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 96f8200aba1ef8..dfeb3d31643901 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -29,7 +29,7 @@ slow, torch_device, ) -from transformers.utils import is_accelerate_available, is_torch_available +from transformers.utils import is_accelerate_available, is_aqlm_available, is_torch_available if is_torch_available(): @@ -187,7 +187,7 @@ def test_quantized_model_multi_gpu(self): self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) @unittest.skipUnless( - version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"), + is_aqlm_available() and version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.3"), "test requires `aqlm>=1.0.3`", ) def test_quantized_model_compile(self): From 1c0adb8e9f341f456108d8fb2f4282d4944024bc Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 23 Feb 2024 10:59:28 +0100 Subject: [PATCH 5/5] shorter text test --- tests/quantization/aqlm_integration/test_aqlm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index dfeb3d31643901..46b64573b93802 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -77,9 +77,9 @@ class AqlmTest(unittest.TestCase): model_name = "BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf" input_text = "Hello my name is" - max_new_tokens = 40 + max_new_tokens = 32 - EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I am very easy going and love to make" + EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I" device_map = "cuda" @@ -209,7 +209,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): seq_length = input_ids.shape[1] # Setup static KV cache for generation - self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens) + self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1) # Allocate token ids to be generated and copy prefix ids cache_position = torch.arange(seq_length, device=torch_device) @@ -227,7 +227,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): # Generate tokens one by one cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, self.max_new_tokens - 1): + for _ in range(1, self.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position) generated_ids.index_copy_(1, cache_position, next_token)