From a989c6c6eb6266a929bbfb4baf16b4e190f7e733 Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Tue, 30 Jan 2024 01:43:40 +0100 Subject: [PATCH] Don't allow passing `load_in_8bit` and `load_in_4bit` at the same time (#28266) * Update quantization_config.py * Style * Protect from setting directly * add tests * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/utils/quantization_config.py | 28 +++++++++++++++++-- tests/quantization/bnb/test_4bit.py | 15 ++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 3684dcc76fce..f4c91dbf4d94 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -212,8 +212,12 @@ def __init__( **kwargs, ): self.quant_method = QuantizationMethod.BITS_AND_BYTES - self.load_in_8bit = load_in_8bit - self.load_in_4bit = load_in_4bit + + if load_in_4bit and load_in_8bit: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + + self._load_in_8bit = load_in_8bit + self._load_in_4bit = load_in_4bit self.llm_int8_threshold = llm_int8_threshold self.llm_int8_skip_modules = llm_int8_skip_modules self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload @@ -232,6 +236,26 @@ def __init__( self.post_init() + @property + def load_in_4bit(self): + return self._load_in_4bit + + @load_in_4bit.setter + def load_in_4bit(self, value: bool): + if self.load_in_8bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_4bit = value + + @property + def load_in_8bit(self): + return self._load_in_8bit + + @load_in_8bit.setter + def load_in_8bit(self, value: bool): + if self.load_in_4bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_8bit = value + def post_init(self): r""" Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 5e034e49f9a9..4c33270af674 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -648,3 +648,18 @@ class GPTSerializationTest(BaseSerializationTest): """ model_name = "gpt2-xl" + + +@require_bitsandbytes +@require_accelerate +@require_torch_gpu +@slow +class Bnb4BitTestBasicConfigTest(unittest.TestCase): + def test_load_in_4_and_8_bit_fails(self): + with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"): + AutoModelForCausalLM.from_pretrained("facebook/opt-125m", load_in_4bit=True, load_in_8bit=True) + + def test_set_load_in_8_bit(self): + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"): + quantization_config.load_in_8bit = True