Skip to content

Commit

Permalink
ENH [AutoQuantizer]: enhance trainer + not supported quant methods (#…
Browse files Browse the repository at this point in the history
…28991)

* enhance trainer + not support quant methods

* remove all old logic

* add version
  • Loading branch information
younesbelkada authored Feb 14, 2024
1 parent 1d12b8b commit 164bdef
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 6 deletions.
12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4190,6 +4190,18 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):

logger.warning_once(warn_string)

@property
def _is_quantized_training_enabled(self):
logger.warning(
"`_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead",
FutureWarning,
)

if not hasattr(self, "hf_quantizer"):
return False

return self.hf_quantizer.is_trainable


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down
1 change: 0 additions & 1 deletion src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_after_weight_loading`.
"""
model._is_quantized_training_enabled = self.is_trainable
return self._process_model_after_weight_loading(model, **kwargs)

@abstractmethod
Expand Down
1 change: 0 additions & 1 deletion src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def _process_model_before_weight_loading(

# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable
return model
Expand Down
1 change: 0 additions & 1 deletion src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def create_quantized_param(
unexpected_keys.remove(fp16_statistics_key)

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable
return model
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def __init__(
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)
_quantization_method_supports_training = (
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)

# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model):
Expand All @@ -428,10 +431,11 @@ def __init__(
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
" for more details"
)
elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False):
elif _is_quantized_and_base_model and not _quantization_method_supports_training:
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
" but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
)

self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
Expand Down

0 comments on commit 164bdef

Please sign in to comment.