diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 66de85397ab3..606144dfdbfb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2578,11 +2578,11 @@ def from_pretrained( if quantization_method_from_config == QuantizationMethod.GPTQ: quantization_config = GPTQConfig.from_dict(config.quantization_config) config.quantization_config = quantization_config - logger.info( - f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to " - "requirements of `auto-gptq` to enable model quantization " - ) - torch_dtype = torch.float16 + if torch_dtype is None: + torch_dtype = torch.float16 + else: + logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.") + quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict()) if (