diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 3ee774270ba40f..e96eb9539c8bd2 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -55,6 +55,9 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu # Add einops for additional model testing RUN python3 -m pip install --no-cache-dir einops +# Add aqlm for quantization testing +RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.1 + # Add autoawq for quantization testing RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index c28d2e23fbb2ac..297dd1a49531bd 100644 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide. +## AqlmConfig + +[[autodoc]] AqlmConfig + ## AwqConfig [[autodoc]] AwqConfig diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index d33acf94c9ae6a..29ee188852feca 100644 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -26,6 +26,34 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan +## AQLM + + + +Try AQLM on [Google Colab](https://colab.research.google.com/drive/1-xZmBRXT5Fm3Ghn4Mwa2KRypORXb855X?usp=sharing)! + +Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and take advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. + +Inference support for AQLM is realised in the `aqlm` library. Make sure to install it to run the models (note aqlm works only with python>=3.10): +```bash +pip install aqlm[gpu,cpu] +``` + +The library provides efficient kernels for both GPU and CPU inference. + +The instructions on how to quantize models yourself, as well as all the relevant code can be found in the corresponding GitHub [repository](https://github.com/Vahe1994/AQLM). + +### AQLM configurations + +AQLM quantization setpus vary mainly on the number of codebooks used as well as codebook sizes in bits. The most popular setups, as well as inference kernels they support are: + +| Kernel | Number of codebooks | Codebook size, bits | Notation | Accuracy | Speedup | Fast GPU inference | Fast CPU inference | +|---|---------------------|---------------------|----------|-------------|-------------|--------------------|--------------------| +| Triton | K | N | KxN | - | Up to ~0.7x | ✅ | ❌ | +| CUDA | 1 | 16 | 1x16 | Best | Up to ~1.3x | ✅ | ❌ | +| CUDA | 2 | 8 | 2x8 | OK | Up to ~3.0x | ✅ | ❌ | +| Numba | K | 8 | Kx8 | Good | Up to ~4.0x | ❌ | ✅ | + ## AWQ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 44e36f662fdb67..84a66458022730 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1087,7 +1087,7 @@ "is_vision_available", "logging", ], - "utils.quantization_config": ["AwqConfig", "BitsAndBytesConfig", "GPTQConfig"], + "utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"], } # sentencepiece-backed objects @@ -5845,7 +5845,7 @@ ) # bitsandbytes config - from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig + from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig try: if not is_sentencepiece_available(): diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 3d1e41263eef70..bded6b3984a59c 100644 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -17,6 +17,7 @@ _import_structure = { + "aqlm": ["replace_with_aqlm_linear"], "awq": ["fuse_awq_modules", "replace_with_awq_linear"], "bitsandbytes": [ "get_keys_to_not_convert", @@ -80,6 +81,7 @@ } if TYPE_CHECKING: + from .aqlm import replace_with_aqlm_linear from .awq import fuse_awq_modules, replace_with_awq_linear from .bitsandbytes import ( get_keys_to_not_convert, diff --git a/src/transformers/integrations/aqlm.py b/src/transformers/integrations/aqlm.py new file mode 100644 index 00000000000000..903d0ecdaebc05 --- /dev/null +++ b/src/transformers/integrations/aqlm.py @@ -0,0 +1,99 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"AQLM (Additive Quantization of Language Model) integration file" + + +from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available + + +if is_torch_available(): + import torch.nn as nn + + +def replace_with_aqlm_linear( + model, + quantization_config=None, + linear_weights_not_to_quantize=None, + current_key_name=None, + has_been_replaced=False, +): + """ + Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers. + `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the + conversion has been successfull or not. + + Args: + model (`torch.nn.Module`): + The model to convert, can be any `torch.nn.Module` instance. + quantization_config (`AqlmConfig`): + The quantization config object that contains the quantization parameters. + linear_weights_not_to_quantize (`list[str]`, *optional*): + A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be + converted. + current_key_name (`list`, *optional*): + A list that contains the current key name. This is used for recursion and should not be passed by the user. + has_been_replaced (`bool`, *optional*): + A boolean that indicates if the conversion has been successful or not. This is used for recursion and + should not be passed by the user. + """ + if not is_aqlm_available(): + raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`") + + if not is_accelerate_available(): + raise ValueError("AQLM requires Accelerate to be installed: `pip install accelerate`") + + if linear_weights_not_to_quantize is None: + linear_weights_not_to_quantize = [] + + from accelerate import init_empty_weights + from aqlm import QuantizedLinear + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear): + # Check if the current key is not in the `linear_weights_not_to_quantize` + if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize: + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = QuantizedLinear( + in_features, + out_features, + bias=module.bias is not None, + in_group_size=quantization_config.in_group_size, + out_group_size=quantization_config.out_group_size, + num_codebooks=quantization_config.num_codebooks, + nbits_per_codebook=quantization_config.nbits_per_codebook, + ) + has_been_replaced = True + + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = replace_with_aqlm_linear( + module, + quantization_config=quantization_config, + linear_weights_not_to_quantize=linear_weights_not_to_quantize, + current_key_name=current_key_name, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 6b8d71b7c73090..a78b07fdb3a331 100644 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -16,12 +16,14 @@ from ..models.auto.configuration_auto import AutoConfig from ..utils.quantization_config import ( + AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantizationConfigMixin, QuantizationMethod, ) +from .quantizer_aqlm import AqlmHfQuantizer from .quantizer_awq import AwqQuantizer from .quantizer_bnb_4bit import Bnb4BitHfQuantizer from .quantizer_bnb_8bit import Bnb8BitHfQuantizer @@ -33,6 +35,7 @@ "bitsandbytes_4bit": Bnb4BitHfQuantizer, "bitsandbytes_8bit": Bnb8BitHfQuantizer, "gptq": GptqHfQuantizer, + "aqlm": AqlmHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -40,6 +43,7 @@ "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, "gptq": GPTQConfig, + "aqlm": AqlmConfig, } diff --git a/src/transformers/quantizers/quantizer_aqlm.py b/src/transformers/quantizers/quantizer_aqlm.py new file mode 100644 index 00000000000000..6e17fe77186e20 --- /dev/null +++ b/src/transformers/quantizers/quantizer_aqlm.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Optional + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..integrations import replace_with_aqlm_linear +from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available, logging +from ..utils.quantization_config import QuantizationConfigMixin + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class AqlmHfQuantizer(HfQuantizer): + """ + Quantizer of the AQLM method. Enables the loading of prequantized models. + """ + + requires_calibration = True + required_packages = ["aqlm"] + optimum_quantizer = None + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not is_accelerate_available(): + raise ImportError("Using `aqlm` quantization requires Accelerate: `pip install accelerate`") + + if not is_aqlm_available(): + raise ImportError("Using `aqlm` quantization requires AQLM: `pip install aqlm[gpu,cpu]`") + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + if torch.cuda.is_available(): + torch_dtype = torch.float16 + logger.info( + "CUDA available. Assuming AQLM inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually." + ) + else: + torch_dtype = torch.float32 + logger.info( + "CUDA is unavailable. Assuming AQLM inference on CPU and loading the model in `torch.float32`. To overwrite it, set `torch_dtype` manually." + ) + return torch_dtype + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + **kwargs, + ): + replace_with_aqlm_linear( + model, + quantization_config=self.quantization_config, + linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + model._is_quantized_training_enabled = False + return model + + @property + def is_trainable(self, model: Optional["PreTrainedModel"] = None): + return False + + @property + def is_serializable(self): + return True diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index eb74af7a4a35c8..0ff7e718af20a9 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -53,6 +53,7 @@ from .utils import ( is_accelerate_available, is_apex_available, + is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_bitsandbytes_available, @@ -956,6 +957,13 @@ def require_apex(test_case): return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case) +def require_aqlm(test_case): + """ + Decorator marking a test that requires aqlm + """ + return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case) + + def require_bitsandbytes(test_case): """ Decorator for bits and bytes (bnb) dependency diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index a608304ac93cd3..4f69b629b22df0 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -105,6 +105,7 @@ get_torch_version, is_accelerate_available, is_apex_available, + is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_bitsandbytes_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 501d68b4929ee6..57b4e840414be0 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -74,6 +74,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") +_aqlm_available = _is_package_available("aqlm") _bitsandbytes_available = _is_package_available("bitsandbytes") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None @@ -570,6 +571,10 @@ def is_apex_available(): return _apex_available +def is_aqlm_available(): + return _aqlm_available + + def is_ninja_available(): r""" Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 358c0c71cf44c8..d2ab879f24ab61 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -38,6 +38,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" GPTQ = "gptq" AWQ = "awq" + AQLM = "aqlm" class AWQLinearVersion(str, Enum): @@ -731,3 +732,63 @@ def get_loading_attributes(self): loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict + + +@dataclass +class AqlmConfig(QuantizationConfigMixin): + """ + This is a wrapper class about `aqlm` parameters. + + Args: + in_group_size (`int`, *optional*, defaults to 8): + The group size along the input dimension. + out_group_size (`int`, *optional*, defaults to 1): + The group size along the output dimension. It's recommended to always use 1. + num_codebooks (`int`, *optional*, defaults to 1): + Number of codebooks for the Additive Quantization procedure. + nbits_per_codebook (`int`, *optional*, defaults to 16): + Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook. + linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*): + List of full paths of `nn.Linear` weight parameters that shall not be quantized. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + in_group_size: int = 8, + out_group_size: int = 1, + num_codebooks: int = 1, + nbits_per_codebook: int = 16, + linear_weights_not_to_quantize: Optional[List[str]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.AQLM + self.in_group_size = in_group_size + self.out_group_size = out_group_size + self.num_codebooks = num_codebooks + self.nbits_per_codebook = nbits_per_codebook + self.linear_weights_not_to_quantize = linear_weights_not_to_quantize + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.in_group_size, int): + raise ValueError("in_group_size must be a float") + if not isinstance(self.out_group_size, int): + raise ValueError("out_group_size must be a float") + if not isinstance(self.num_codebooks, int): + raise ValueError("num_codebooks must be a float") + if not isinstance(self.nbits_per_codebook, int): + raise ValueError("nbits_per_codebook must be a float") + + if self.linear_weights_not_to_quantize is not None and not isinstance( + self.linear_weights_not_to_quantize, list + ): + raise ValueError("linear_weights_not_to_quantize must be a list of strings") + + if self.linear_weights_not_to_quantize is None: + self.linear_weights_not_to_quantize = [] diff --git a/tests/quantization/aqlm_integration/__init__.py b/tests/quantization/aqlm_integration/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py new file mode 100644 index 00000000000000..6a5cefea2fb177 --- /dev/null +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -0,0 +1,183 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM +from transformers.testing_utils import ( + require_accelerate, + require_aqlm, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) +from transformers.utils import is_accelerate_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +@require_torch_gpu +class AqlmConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object + """ + quantization_config = AqlmConfig() + config_to_dict = quantization_config.to_dict() + + for key in config_to_dict: + self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) + + def test_from_dict(self): + """ + Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict + """ + dict = { + "in_group_size": 32, + "num_codebooks": 8, + "nbits_per_codebook": 8, + "linear_weights_not_to_quantize": ["lm_head.weight"], + } + quantization_config = AqlmConfig.from_dict(dict) + + self.assertEqual(dict["in_group_size"], quantization_config.in_group_size) + self.assertEqual(dict["num_codebooks"], quantization_config.num_codebooks) + self.assertEqual(dict["nbits_per_codebook"], quantization_config.nbits_per_codebook) + self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize) + + +@slow +@require_torch_gpu +@require_aqlm +@require_accelerate +class AqlmTest(unittest.TestCase): + model_name = "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch" + + input_text = "Hello my name is" + + EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am currently a sophomore and am majoring in Psychology. I am" + + device_map = "cuda" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, + device_map=cls.device_map, + ) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_quantized_model_conversion(self): + """ + Simple test that checks if the quantized model has been converted properly + """ + from aqlm import QuantizedLinear + + from transformers.integrations import replace_with_aqlm_linear + + model_id = "facebook/opt-350m" + config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5") + quantization_config = AqlmConfig() + + with init_empty_weights(): + model = OPTForCausalLM(config) + + nb_linears = 0 + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + nb_linears += 1 + + model, _ = replace_with_aqlm_linear(model, quantization_config=quantization_config) + nb_aqlm_linear = 0 + for module in model.modules(): + if isinstance(module, QuantizedLinear): + nb_aqlm_linear += 1 + + self.assertEqual(nb_linears, nb_aqlm_linear) + + # Try with `linear_weights_not_to_quantize` + with init_empty_weights(): + model = OPTForCausalLM(config) + + model, _ = replace_with_aqlm_linear( + model, quantization_config=quantization_config, linear_weights_not_to_quantize=["lm_head.weight"] + ) + nb_aqlm_linear = 0 + for module in model.modules(): + if isinstance(module, QuantizedLinear): + nb_aqlm_linear += 1 + + self.assertEqual(nb_linears - 1, nb_aqlm_linear) + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=40) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_raise_if_non_quantized(self): + model_id = "facebook/opt-125m" + quantization_config = AqlmConfig(bits=4) + + with self.assertRaises(ValueError): + _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=40) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto") + + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=40) + + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)