Skip to content

Commit

Permalink
FIX / Quantization: Add extra validation for bnb config (huggingface#…
Browse files Browse the repository at this point in the history
…31135)

add validation for bnb config
  • Loading branch information
younesbelkada authored May 30, 2024
1 parent 2b9e252 commit 5e5c4d6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ def __init__(
if bnb_4bit_quant_storage is None:
self.bnb_4bit_quant_storage = torch.uint8
elif isinstance(bnb_4bit_quant_storage, str):
if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
raise ValueError(
"`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
)
self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
elif isinstance(bnb_4bit_quant_storage, torch.dtype):
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
Expand Down
7 changes: 7 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,13 @@ def test_fp32_4bit_conversion(self):
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small", load_in_4bit=True, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)

def test_bnb_4bit_wrong_config(self):
r"""
Test whether creating a bnb config with unsupported values leads to errors.
"""
with self.assertRaises(ValueError):
_ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")


@require_bitsandbytes
@require_accelerate
Expand Down

0 comments on commit 5e5c4d6

Please sign in to comment.