Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH [AutoQuantizer]: enhance trainer + not supported quant methods #28991

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4190,6 +4190,18 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):

logger.warning_once(warn_string)

@property
def _is_quantized_training_enabled(self):
logger.warning(
"`_is_quantized_training_enabled` is going to be deprecated in a future version. Please use `model.hf_quantizer.is_trainable` instead",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Message should have a specific version listed here for deprecation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok perfect, I will set it on the next 2 minor releases

FutureWarning,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and safe :) Tbh, with private methods we can probably get away with no deprecation warning as it's not public.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in my first commits I just removed that private attribute, but after giving it some thoughts I realised maybe better to go for that option just to be on the safe zone


if not hasattr(self, "hf_quantizer"):
return False

return self.hf_quantizer.is_trainable


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down
1 change: 0 additions & 1 deletion src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
kwargs (`dict`, *optional*):
The keyword arguments that are passed along `_process_model_after_weight_loading`.
"""
model._is_quantized_training_enabled = self.is_trainable
return self._process_model_after_weight_loading(model, **kwargs)

@abstractmethod
Expand Down
1 change: 0 additions & 1 deletion src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def _process_model_before_weight_loading(

# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable
return model
Expand Down
1 change: 0 additions & 1 deletion src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def create_quantized_param(
unexpected_keys.remove(fp16_statistics_key)

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = self.is_trainable
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable
return model
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def __init__(
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)
_quantization_method_supports_training = (
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)

# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model):
Expand All @@ -428,10 +431,11 @@ def __init__(
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
" for more details"
)
elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False):
elif _is_quantized_and_base_model and not _quantization_method_supports_training:
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
" but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
)

self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
Expand Down
Loading