Skip to content

Commit

Permalink
[Trainer] Refactor trainer + bnb logic (huggingface#26248)
Browse files Browse the repository at this point in the history
* refactor trainer + bnb logic

* remove logger.info

* oops
  • Loading branch information
younesbelkada authored and EduardoPach committed Nov 18, 2023
1 parent 9132353 commit 4a5059f
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,19 +402,23 @@ def __init__(
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)

_is_peft_model = is_peft_available() and isinstance(model, PeftModel)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)

# At this stage the model is already loaded
if getattr(model, "is_quantized", False) and not getattr(model, "_hf_peft_config_loaded", False):
if getattr(model, "_is_quantized_training_enabled", False):
logger.info(
"The model is quantized. To train this model you need to add additional modules"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
" check the examples in https://github.com/huggingface/peft for more details."
)
else:
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
)
if _is_quantized_and_base_model and not _is_peft_model:
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
" for more details"
)
elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False):
raise ValueError(
"The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
)

# Setup Sharded DDP training
self.sharded_ddp = None
Expand Down

0 comments on commit 4a5059f

Please sign in to comment.