diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2cc8dbbbe639f8..a6dc313fbaa172 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4190,6 +4190,18 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): logger.warning_once(warn_string) + @property + def _is_quantized_training_enabled(self): + logger.warning( + "`_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead", + FutureWarning, + ) + + if not hasattr(self, "hf_quantizer"): + return False + + return self.hf_quantizer.is_trainable + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 68adc3954df45d..345b19a14e3dc7 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -176,7 +176,6 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs): kwargs (`dict`, *optional*): The keyword arguments that are passed along `_process_model_after_weight_loading`. """ - model._is_quantized_training_enabled = self.is_trainable return self._process_model_after_weight_loading(model, **kwargs) @abstractmethod diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 7cc9ef6560e941..16745f756ca525 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -289,7 +289,6 @@ def _process_model_before_weight_loading( # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): - model._is_quantized_training_enabled = self.is_trainable model.is_loaded_in_4bit = True model.is_4bit_serializable = self.is_serializable return model diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 6428b13c250b19..d41a280f89a4f8 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -205,7 +205,6 @@ def create_quantized_param( unexpected_keys.remove(fp16_statistics_key) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): - model._is_quantized_training_enabled = self.is_trainable model.is_loaded_in_8bit = True model.is_8bit_serializable = self.is_serializable return model diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 905744a64ed4c6..f4a54ecc4dabbd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -420,6 +420,9 @@ def __init__( _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr( model, "_hf_peft_config_loaded", False ) + _quantization_method_supports_training = ( + getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable + ) # At this stage the model is already loaded if _is_quantized_and_base_model and not _is_peft_model(model): @@ -428,10 +431,11 @@ def __init__( " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" " for more details" ) - elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False): + elif _is_quantized_and_base_model and not _quantization_method_supports_training: raise ValueError( - "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" - " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " + f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}" + " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers" + f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}" ) self.is_fsdp_xla_enabled = args.fsdp_config["xla"]