From 38b7d3b93e7c88700db134548866a587fdef7d10 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 10:13:07 -0500 Subject: [PATCH 01/12] wip qlora + fsdp fixes --- requirements.txt | 1 + src/axolotl/core/trainer_builder.py | 111 ++++++++++++++++ src/axolotl/utils/models.py | 192 +++++++++++++++++++++++++++- 3 files changed, 299 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 718896783b..a0d88d215f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,3 +40,4 @@ gcsfs # adlfs trl>=0.7.9 +fastcore>=1.5.29 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e051d4e69c..b4f5d1e755 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -4,10 +4,12 @@ """ import abc +import functools import importlib import importlib.util import logging import math +import os import sys from abc import abstractmethod from dataclasses import dataclass, field @@ -17,7 +19,17 @@ import torch import transformers +from accelerate import FullyShardedDataParallelPlugin +from accelerate.utils import str_to_bool from datasets import Dataset +from peft import PrefixEncoder, PromptEmbedding, PromptEncoder +from torch import nn +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp.wrap import ( + _or_policy, + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -26,6 +38,11 @@ TrainerCallback, TrainingArguments, ) +from transformers.models.llama.modeling_llama import ( + LLAMA_ATTENTION_CLASSES, + LlamaDecoderLayer, + LlamaMLP, +) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer @@ -193,6 +210,57 @@ class AxolotlTrainingArguments(TrainingArguments): ) +# FIXME, this should be some sort of generator based on the model arch +# This checks for lora layers (has weight and requires_grad) +def get_wrapping_policy(custom_policy: bool = False): + if custom_policy: + + def lambda_policy_fn(module): + # LORA trainable layers. + return isinstance(module, nn.Sequential) and all( + m.weight.requires_grad for m in module + ) + + else: + + def lambda_policy_fn(module): + return ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ) + + def self_attn_policy_fn(module): + # Check module name is self_attn. + return isinstance(module, tuple(LLAMA_ATTENTION_CLASSES.values())) + + def mlp_policy_fn(module): + # Check module name is self_attn. + return isinstance(module, LlamaMLP) + + lambda_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn + ) + self_attn_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn + ) + mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn) + transformer_layer_name = LlamaDecoderLayer + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + transformer_layer_name, + ), + ) + policies = [lambda_policy, transformer_wrap_policy] + if custom_policy: + policies.extend([self_attn_policy, mlp_policy]) + return functools.partial(_or_policy, policies=policies) + + class AxolotlTrainer(Trainer): """ Extend the base Trainer for axolotl helpers @@ -468,6 +536,49 @@ def push_to_hub(self, *args, **kwargs) -> str: return super().push_to_hub(*args, **kwargs) + @wraps(Trainer.create_accelerator_and_postprocess) + def create_accelerator_and_postprocess(self): + rank = int(os.environ.get("LOCAL_RANK", 0)) + res = super().create_accelerator_and_postprocess() + sync_module_states = ( + str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1 + ) + + mp_policy = None + amp = os.environ["ACCELERATE_MIXED_PRECISION"] + if amp == "fp16": + mp_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) + elif amp == "bf16": + mp_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) + + # If somehow we figure out how we want to parameterize we want to autocast buffers... + # mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32) + # load_param_skip_names = ['inv_freq'] + + if self.is_fsdp_enabled: + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=get_wrapping_policy(False), + use_orig_params=False, + limit_all_gathers=True, + param_init_fn=lambda module: module.to_empty( + device=torch.device("cuda"), recurse=False + ) + if (rank != 0 and sync_module_states) + else None, + mixed_precision_policy=mp_policy, + ) + self.accelerator.state.fsdp_plugin = fsdp_plugin + + return res + class AxolotlMambaTrainer(AxolotlTrainer): """ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5407245ac6..8cb75ffd4c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -2,12 +2,16 @@ import logging import math import os -from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 +from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401 import addict import bitsandbytes as bnb +import safetensors import torch import transformers +from accelerate import init_empty_weights +from bitsandbytes.nn import Linear4bit, Params4bit +from fastcore.parallel import parallel from peft import ( LoftQConfig, PeftConfig, @@ -16,6 +20,7 @@ prepare_model_for_kbit_training, ) from peft.tuners.lora import QuantLinear +from torch import Tensor, nn from transformers import ( # noqa: F401 AddedToken, AutoConfig, @@ -27,6 +32,7 @@ PreTrainedTokenizerBase, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( @@ -262,6 +268,113 @@ def load_tokenizer(cfg): return tokenizer +def replace_linear( + model: nn.Module, + linear_replacement: Type[nn.Module], + quant_config: Union[dict, None] = None, + skip_modules=None, + **kwargs, +): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + """ + if skip_modules is None: + skip_modules = ["lm_head"] + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear( + module, linear_replacement, quant_config, skip_modules, **kwargs + ) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + if issubclass(linear_replacement, Linear4bit): + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + **kwargs, + ) + else: + raise ValueError( + f"Unsupported linear replacement: {type(linear_replacement)}" + ) + return model + + +def load_and_quantize( + module: nn.Module, + name: str, + value: Tensor, + device: torch.device = None, + dtype: torch.dtype = None, + skip_names: List[str] = [], + is_meta_rank: bool = False, + low_memory: bool = True, + verbose: bool = False, + quant_method: str = "bnb", +): + """ + Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. + + Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True. + """ + + def place_on_device(value): + if is_meta_rank: + device = "meta" + elif low_memory: + device = "cpu" + else: + device = "cuda" + return value.to(device=device, dtype=dtype) + + if any([skip_name in name for skip_name in skip_names]): + if verbose: + print(f"Skipping {name} because it is in skip_names") + return + + module_key, _, value_key = name.rpartition(".") + try: + submodule = module.get_submodule(module_key) + except AttributeError as e: + print(f"Module {module_key} not found:\n{e}") + return + + try: + if quant_method == "bnb": + param = submodule.get_parameter(value_key) + if isinstance(param, Params4bit): + # With `sync_module_states=True`, a meta device Params4bit needs to be the same + # shape as the quantized Params4bit with an initialized quant_state. However, + # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This + # workaround quantizes Params4bit to initialize quant_state on all ranks, then + # replaces Params4bit's data with a meta tensor to free memory on non-rank 0. + value = type(param)( + value.to(device=device, dtype=dtype).data, **param.__dict__ + ).cuda(device) + if is_meta_rank: + value = type(param)(value.data.to("meta"), **value.__dict__) + elif low_memory: + value = type(param)(value.data.to("cpu"), **value.__dict__) + else: + value = type(param)(place_on_device(value).data) + + except AttributeError: + # it's a buffer + value = place_on_device(value) + pass + + setattr(submodule, value_key, value) + + def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, @@ -394,7 +507,7 @@ def load_model( if max_memory is not None: # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map, init_empty_weights + from accelerate import infer_auto_device_map with init_empty_weights(): model_canvas = AutoModelForCausalLM.from_config(model_config) @@ -498,6 +611,73 @@ def load_model( try: if ( + model_config.model_type == "llama" + and cfg.adapter == "qlora" + and cfg.fsdp is not None + ): + if cfg.bf16 or cfg.bfloat16: + torch_dtype, compute_dtype = torch.bfloat16, torch.bfloat16 + elif cfg.fp16 or cfg.float16: + torch_dtype, compute_dtype = torch.float32, torch.float16 + else: + torch_dtype, compute_dtype = torch.float32, torch.float16 + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(model_config) + model.model = replace_linear( + model.model, + Linear4bit, + compute_dtype=compute_dtype, + quant_type="nf4", + quant_storage=torch_dtype, + ) + + model.is_loaded_in_4bit = True + + # Grab the safetensors files that hold the weights + try: + idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME) + files, _ = hub.get_checkpoint_shard_files(base_model, idx) + except OSError: + try: + # This means the model doesn't have a model.safetensors.index.json because it is not sharded + files = [] + files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME)) + except OSError as e: + # This means the model probably doesn't have a safetensors file + raise e + + # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly + # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage + def load_and_quantize_parallel(name_param, model, **kwargs): + name, param = name_param + load_and_quantize(model, name, param, **kwargs) + + param_count = sum((p.numel() for n, p in model.named_parameters())) + for filename in files: + weights = safetensors.torch.load_file(filename) + quant_method = "bnb" + devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) + left = int(os.cpu_count() / torch.cuda.device_count()) + right = int( + 8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9)) + ) + n_workers = min(left, right) + parallel( + load_and_quantize_parallel, + weights.items(), + n_workers=n_workers, + threadpool=True, + model=model, + dtype=torch_dtype, + device=cfg.local_rank, + skip_names=[], + is_meta_rank=(cfg.local_rank != 0), + verbose=False, + quant_method=quant_method, + ) + + elif ( model_config.model_type == "llama" and not cfg.trust_remote_code and not cfg.gptq @@ -613,7 +793,9 @@ def load_model( LOG.exception(err) raise err - if isinstance(model, (PeftModel, PeftModelForCausalLM)): + qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" and model_config.model_type == "llama" + + if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: model = model.merge_and_unload() embeddings_len = ( @@ -692,7 +874,7 @@ def load_model( if cfg.adapter == "lora" and loftq_bits: skip_prepare_model_for_kbit_training = True - if cfg.adapter in ["lora", "qlora"]: + if cfg.adapter in ["lora", "qlora"] and not qlora_fsdp: if cfg.gradient_checkpointing: model.gradient_checkpointing_enable() if ( @@ -721,7 +903,7 @@ def load_model( # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: _, lora_config = load_lora(model, cfg, inference=False, config_only=True) - else: + elif not qlora_fsdp: model, lora_config = load_adapter(model, cfg, cfg.adapter) if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): From 43a766a8708b0470dc1f77b77c444e8a79eef4e6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 10:59:18 -0500 Subject: [PATCH 02/12] more fixes --- requirements.txt | 2 +- src/axolotl/utils/bench.py | 2 +- src/axolotl/utils/models.py | 17 ++++++++++++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index a0d88d215f..191948a400 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ packaging==23.2 peft==0.9.0 transformers==4.38.2 tokenizers==0.15.0 -bitsandbytes>=0.41.1 +bitsandbytes>=0.43.0 accelerate==0.26.1 deepspeed==0.13.1 pydantic==2.6.3 diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index c039e790a1..11c25160da 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -24,9 +24,9 @@ def wrapper(*args, **kwargs): or not torch.cuda.is_available() or device == "auto" or torch.device(device).type == "cpu" + or torch.device(device).type == "meta" ): return default_value - return func(*args, **kwargs) return wrapper diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8cb75ffd4c..8241123302 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -623,6 +623,7 @@ def load_model( torch_dtype, compute_dtype = torch.float32, torch.float16 with init_empty_weights(): + LOG.info("Loading model with empty weights.") model = AutoModelForCausalLM.from_config(model_config) model.model = replace_linear( model.model, @@ -793,7 +794,9 @@ def load_and_quantize_parallel(name_param, model, **kwargs): LOG.exception(err) raise err - qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" and model_config.model_type == "llama" + qlora_fsdp = ( + cfg.fsdp and cfg.adapter == "qlora" and model_config.model_type == "llama" + ) if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: model = model.merge_and_unload() @@ -874,7 +877,10 @@ def load_and_quantize_parallel(name_param, model, **kwargs): if cfg.adapter == "lora" and loftq_bits: skip_prepare_model_for_kbit_training = True - if cfg.adapter in ["lora", "qlora"] and not qlora_fsdp: + if qlora_fsdp: + skip_prepare_model_for_kbit_training = True + + if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: model.gradient_checkpointing_enable() if ( @@ -906,7 +912,12 @@ def load_and_quantize_parallel(name_param, model, **kwargs): elif not qlora_fsdp: model, lora_config = load_adapter(model, cfg, cfg.adapter) - if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): + if ( + cfg.ddp + and not load_in_8bit + and not (cfg.rl and cfg.load_in_4bit) + and not qlora_fsdp + ): # TODO revaldate this conditional model.to(f"cuda:{cfg.local_rank}") From 3773ecdf2042b67b494532e5f21e4d62ed3586e0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 11:12:40 -0500 Subject: [PATCH 03/12] make sure to load the lora :facepalm: --- src/axolotl/utils/models.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8241123302..dd81365dfe 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -2,6 +2,7 @@ import logging import math import os +import types from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401 import addict @@ -909,7 +910,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: _, lora_config = load_lora(model, cfg, inference=False, config_only=True) - elif not qlora_fsdp: + else: model, lora_config = load_adapter(model, cfg, cfg.adapter) if ( @@ -1006,6 +1007,26 @@ def find_all_linear_names(model): return list(lora_module_names) +def setup_quantized_meta_for_peft(model: nn.Module): + """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" + + def temp_to_method(self, *args, **kwargs): + return self + + for param in model.parameters(): + if isinstance(param, Params4bit): + param.quant_state._orig_to = param.quant_state.to + param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) + + +def setup_quantized_peft_meta_for_training(model: nn.Module): + """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" + for param in model.parameters(): + if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): + param.quant_state.to = param.quant_state._orig_to + param.quant_state._orig_to = None + + def load_lora(model, cfg, inference=False, config_only=False): # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] @@ -1042,6 +1063,9 @@ def load_lora(model, cfg, inference=False, config_only=False): if config_only: return None, lora_config + if cfg.fsdp and cfg.adapter == "qlora": + setup_quantized_meta_for_peft(model) + if cfg.lora_model_dir: LOG.debug("Loading pretrained PEFT - LoRA") model_kwargs: Any = {} @@ -1057,6 +1081,9 @@ def load_lora(model, cfg, inference=False, config_only=False): else: model = get_peft_model(model, lora_config) - model.print_trainable_parameters() + if int(os.environ.get("LOCAL_RANK", 0)) == 0: + model.print_trainable_parameters() + else: + setup_quantized_peft_meta_for_training(model) return model, lora_config From b32de13ce10b7c50531d844af481ac7fff38f7d7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 11:19:38 -0500 Subject: [PATCH 04/12] only setup quantized meta on non-zero rank: --- src/axolotl/utils/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index dd81365dfe..bd42b3268a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1063,7 +1063,9 @@ def load_lora(model, cfg, inference=False, config_only=False): if config_only: return None, lora_config - if cfg.fsdp and cfg.adapter == "qlora": + rank = int(os.environ.get("LOCAL_RANK", 0)) + + if cfg.fsdp and cfg.adapter == "qlora" and rank != 0: setup_quantized_meta_for_peft(model) if cfg.lora_model_dir: @@ -1081,7 +1083,7 @@ def load_lora(model, cfg, inference=False, config_only=False): else: model = get_peft_model(model, lora_config) - if int(os.environ.get("LOCAL_RANK", 0)) == 0: + if rank == 0: model.print_trainable_parameters() else: setup_quantized_peft_meta_for_training(model) From af885de9fec1aa2b8a551a34823eedbeefe91596 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 11:31:03 -0500 Subject: [PATCH 05/12] only run setup_quantized_peft_meta_for_training for qlora+fsdp --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bd42b3268a..faabda0a49 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1085,7 +1085,7 @@ def load_lora(model, cfg, inference=False, config_only=False): if rank == 0: model.print_trainable_parameters() - else: + elif cfg.fsdp and cfg.adapter == "qlora": setup_quantized_peft_meta_for_training(model) return model, lora_config From dfcfe695a5471f55500790dae83d7ebe93043c83 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 11:46:44 -0500 Subject: [PATCH 06/12] more fixes for qlora+fsdp --- src/axolotl/utils/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index faabda0a49..1e66f34a0f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -617,7 +617,7 @@ def load_model( and cfg.fsdp is not None ): if cfg.bf16 or cfg.bfloat16: - torch_dtype, compute_dtype = torch.bfloat16, torch.bfloat16 + torch_dtype, compute_dtype = torch.float32, torch.bfloat16 elif cfg.fp16 or cfg.float16: torch_dtype, compute_dtype = torch.float32, torch.float16 else: @@ -895,7 +895,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype or cfg.flash_attention: + if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp: LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: From 6a8224193da3e27904eb7b6cfec4b499e23bf4f5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 11:53:00 -0500 Subject: [PATCH 07/12] chore: lint --- src/axolotl/utils/models.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1e66f34a0f..b37ba69d11 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,4 +1,6 @@ """Module for models and model loading""" +# pylint: disable=too-many-lines + import logging import math import os @@ -297,7 +299,9 @@ def replace_linear( if isinstance(module, torch.nn.Linear) and name not in skip_modules: if issubclass(linear_replacement, Linear4bit): - model._modules[name] = linear_replacement( + model._modules[ # pylint: disable=protected-access + name + ] = linear_replacement( module.in_features, module.out_features, module.bias is not None, @@ -316,7 +320,7 @@ def load_and_quantize( value: Tensor, device: torch.device = None, dtype: torch.dtype = None, - skip_names: List[str] = [], + skip_names: Optional[List[str]] = None, is_meta_rank: bool = False, low_memory: bool = True, verbose: bool = False, @@ -328,6 +332,9 @@ def load_and_quantize( Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True. """ + if skip_names is None: + skip_names = [] + def place_on_device(value): if is_meta_rank: device = "meta" @@ -337,7 +344,7 @@ def place_on_device(value): device = "cuda" return value.to(device=device, dtype=dtype) - if any([skip_name in name for skip_name in skip_names]): + if any(skip_name in name for skip_name in skip_names): if verbose: print(f"Skipping {name} because it is in skip_names") return @@ -345,8 +352,8 @@ def place_on_device(value): module_key, _, value_key = name.rpartition(".") try: submodule = module.get_submodule(module_key) - except AttributeError as e: - print(f"Module {module_key} not found:\n{e}") + except AttributeError as exc: + print(f"Module {module_key} not found:\n{exc}") return try: @@ -371,7 +378,6 @@ def place_on_device(value): except AttributeError: # it's a buffer value = place_on_device(value) - pass setattr(submodule, value_key, value) @@ -645,9 +651,9 @@ def load_model( # This means the model doesn't have a model.safetensors.index.json because it is not sharded files = [] files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME)) - except OSError as e: + except OSError as exc: # This means the model probably doesn't have a safetensors file - raise e + raise exc # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage @@ -1010,12 +1016,14 @@ def find_all_linear_names(model): def setup_quantized_meta_for_peft(model: nn.Module): """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" - def temp_to_method(self, *args, **kwargs): + def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument return self for param in model.parameters(): if isinstance(param, Params4bit): - param.quant_state._orig_to = param.quant_state.to + param.quant_state._orig_to = ( # pylint: disable=protected-access + param.quant_state.to + ) param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) @@ -1023,8 +1031,10 @@ def setup_quantized_peft_meta_for_training(model: nn.Module): """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" for param in model.parameters(): if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): - param.quant_state.to = param.quant_state._orig_to - param.quant_state._orig_to = None + param.quant_state.to = ( + param.quant_state._orig_to # pylint: disable=protected-access + ) + param.quant_state._orig_to = None # pylint: disable=protected-access def load_lora(model, cfg, inference=False, config_only=False): From ee76506a6edfb8a7db45d3e961a567f2573a5180 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 11:59:19 -0500 Subject: [PATCH 08/12] add example yml --- examples/llama-2/qlora-fsdp.yml | 70 +++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 examples/llama-2/qlora-fsdp.yml diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml new file mode 100644 index 0000000000..da6c06020f --- /dev/null +++ b/examples/llama-2/qlora-fsdp.yml @@ -0,0 +1,70 @@ +base_model: NousResearch/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 4 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: + - full_shard +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer +special_tokens: From 24340cd9e7f1a3425ca3749eaac3acebfcfb079a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 12:19:47 -0500 Subject: [PATCH 09/12] support mistral too --- src/axolotl/core/policies/__init__.py | 0 src/axolotl/core/policies/auto_wrap.py | 87 ++++++++++++++++++++++++++ src/axolotl/core/trainer_builder.py | 68 +------------------- src/axolotl/utils/models.py | 7 ++- 4 files changed, 95 insertions(+), 67 deletions(-) create mode 100644 src/axolotl/core/policies/__init__.py create mode 100644 src/axolotl/core/policies/auto_wrap.py diff --git a/src/axolotl/core/policies/__init__.py b/src/axolotl/core/policies/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/policies/auto_wrap.py b/src/axolotl/core/policies/auto_wrap.py new file mode 100644 index 0000000000..5d49facc2b --- /dev/null +++ b/src/axolotl/core/policies/auto_wrap.py @@ -0,0 +1,87 @@ +"""module for building the auto wrap policy for FSDP""" +import functools + +from peft import PrefixEncoder, PromptEmbedding, PromptEncoder +from torch import nn +from torch.distributed.fsdp.wrap import ( + _or_policy, + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) +from transformers.models.llama.modeling_llama import ( + LLAMA_ATTENTION_CLASSES, + LlamaDecoderLayer, + LlamaMLP, +) +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_ATTENTION_CLASSES, + MistralDecoderLayer, + MistralMLP, +) + +SUPPORTED_AUTO_WRAP_MODEL_TYPES = [ + "mistral", + "llama", +] + + +def get_wrapping_policy_factory(model_type): + if model_type == "llama": + attention_classes = LLAMA_ATTENTION_CLASSES + layer_to_wrap = LlamaDecoderLayer + model_mlp = LlamaMLP + elif model_type == "mistral": + attention_classes = MISTRAL_ATTENTION_CLASSES + layer_to_wrap = MistralDecoderLayer + model_mlp = MistralMLP + + def get_wrapping_policy(custom_policy: bool = False): + """This checks for lora layers (has weight and requires_grad)""" + if custom_policy: + + def lambda_policy_fn(module): + # LORA trainable layers. + return isinstance(module, nn.Sequential) and all( + m.weight.requires_grad for m in module + ) + + else: + + def lambda_policy_fn(module): + return ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ) + + def self_attn_policy_fn(module): + # Check module name is self_attn. + return isinstance(module, tuple(attention_classes.values())) + + def mlp_policy_fn(module): + # Check module name is self_attn. + return isinstance(module, model_mlp) + + lambda_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn + ) + self_attn_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn + ) + mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn) + transformer_layer_name = layer_to_wrap + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + transformer_layer_name, + ), + ) + policies = [lambda_policy, transformer_wrap_policy] + if custom_policy: + policies.extend([self_attn_policy, mlp_policy]) + return functools.partial(_or_policy, policies=policies) + + return get_wrapping_policy diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b4f5d1e755..b28c385728 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -4,7 +4,6 @@ """ import abc -import functools import importlib import importlib.util import logging @@ -22,14 +21,7 @@ from accelerate import FullyShardedDataParallelPlugin from accelerate.utils import str_to_bool from datasets import Dataset -from peft import PrefixEncoder, PromptEmbedding, PromptEncoder -from torch import nn from torch.distributed.fsdp import MixedPrecision -from torch.distributed.fsdp.wrap import ( - _or_policy, - lambda_auto_wrap_policy, - transformer_auto_wrap_policy, -) from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -38,15 +30,11 @@ TrainerCallback, TrainingArguments, ) -from transformers.models.llama.modeling_llama import ( - LLAMA_ATTENTION_CLASSES, - LlamaDecoderLayer, - LlamaMLP, -) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer +from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler @@ -210,57 +198,6 @@ class AxolotlTrainingArguments(TrainingArguments): ) -# FIXME, this should be some sort of generator based on the model arch -# This checks for lora layers (has weight and requires_grad) -def get_wrapping_policy(custom_policy: bool = False): - if custom_policy: - - def lambda_policy_fn(module): - # LORA trainable layers. - return isinstance(module, nn.Sequential) and all( - m.weight.requires_grad for m in module - ) - - else: - - def lambda_policy_fn(module): - return ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ) - - def self_attn_policy_fn(module): - # Check module name is self_attn. - return isinstance(module, tuple(LLAMA_ATTENTION_CLASSES.values())) - - def mlp_policy_fn(module): - # Check module name is self_attn. - return isinstance(module, LlamaMLP) - - lambda_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn - ) - self_attn_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn - ) - mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn) - transformer_layer_name = LlamaDecoderLayer - transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=( - PrefixEncoder, - PromptEncoder, - PromptEmbedding, - transformer_layer_name, - ), - ) - policies = [lambda_policy, transformer_wrap_policy] - if custom_policy: - policies.extend([self_attn_policy, mlp_policy]) - return functools.partial(_or_policy, policies=policies) - - class AxolotlTrainer(Trainer): """ Extend the base Trainer for axolotl helpers @@ -564,8 +501,9 @@ def create_accelerator_and_postprocess(self): # load_param_skip_names = ['inv_freq'] if self.is_fsdp_enabled: + wrapping_policy = get_wrapping_policy_factory(self.model.config.model_type) fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=get_wrapping_policy(False), + auto_wrap_policy=wrapping_policy(False), use_orig_params=False, limit_all_gathers=True, param_init_fn=lambda module: module.to_empty( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b37ba69d11..65b75dbe25 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -37,6 +37,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub +from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -618,7 +619,7 @@ def load_model( try: if ( - model_config.model_type == "llama" + model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES and cfg.adapter == "qlora" and cfg.fsdp is not None ): @@ -802,7 +803,9 @@ def load_and_quantize_parallel(name_param, model, **kwargs): raise err qlora_fsdp = ( - cfg.fsdp and cfg.adapter == "qlora" and model_config.model_type == "llama" + cfg.fsdp + and cfg.adapter == "qlora" + and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES ) if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: From 716133c35b8342fe4976175db840ffcf0e1574a4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 13:23:14 -0500 Subject: [PATCH 10/12] fix for model_type and add mixtral support too --- examples/mistral/mixtral-qlora-fsdp.yml | 74 +++++++++++++++++++++++++ src/axolotl/core/policies/auto_wrap.py | 60 +++++--------------- src/axolotl/core/trainer_builder.py | 4 +- 3 files changed, 90 insertions(+), 48 deletions(-) create mode 100644 examples/mistral/mixtral-qlora-fsdp.yml diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml new file mode 100644 index 0000000000..32db7073b7 --- /dev/null +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -0,0 +1,74 @@ +base_model: mistralai/Mixtral-8x7B-v0.1 +model_type: AutoModelForCausalLM +tokenizer_type: LlamaTokenizer +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +output_dir: ./qlora-out + +model_config: + output_router_logits: true + +adapter: qlora +lora_model_dir: + +sequence_len: 1024 +sample_packing: false +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock +special_tokens: diff --git a/src/axolotl/core/policies/auto_wrap.py b/src/axolotl/core/policies/auto_wrap.py index 5d49facc2b..d42b62ee08 100644 --- a/src/axolotl/core/policies/auto_wrap.py +++ b/src/axolotl/core/policies/auto_wrap.py @@ -2,73 +2,43 @@ import functools from peft import PrefixEncoder, PromptEmbedding, PromptEncoder -from torch import nn from torch.distributed.fsdp.wrap import ( _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy, ) -from transformers.models.llama.modeling_llama import ( - LLAMA_ATTENTION_CLASSES, - LlamaDecoderLayer, - LlamaMLP, -) -from transformers.models.mistral.modeling_mistral import ( - MISTRAL_ATTENTION_CLASSES, - MistralDecoderLayer, - MistralMLP, -) +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer SUPPORTED_AUTO_WRAP_MODEL_TYPES = [ - "mistral", "llama", + "mistral", + "mixtral", ] def get_wrapping_policy_factory(model_type): if model_type == "llama": - attention_classes = LLAMA_ATTENTION_CLASSES layer_to_wrap = LlamaDecoderLayer - model_mlp = LlamaMLP elif model_type == "mistral": - attention_classes = MISTRAL_ATTENTION_CLASSES layer_to_wrap = MistralDecoderLayer - model_mlp = MistralMLP + elif model_type == "mixtral": + layer_to_wrap = MixtralDecoderLayer - def get_wrapping_policy(custom_policy: bool = False): + def get_wrapping_policy(): """This checks for lora layers (has weight and requires_grad)""" - if custom_policy: - - def lambda_policy_fn(module): - # LORA trainable layers. - return isinstance(module, nn.Sequential) and all( - m.weight.requires_grad for m in module - ) - else: - - def lambda_policy_fn(module): - return ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ) - - def self_attn_policy_fn(module): - # Check module name is self_attn. - return isinstance(module, tuple(attention_classes.values())) - - def mlp_policy_fn(module): - # Check module name is self_attn. - return isinstance(module, model_mlp) + def lambda_policy_fn(module): + return ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ) lambda_policy = functools.partial( lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn ) - self_attn_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn - ) - mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn) transformer_layer_name = layer_to_wrap transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, @@ -80,8 +50,6 @@ def mlp_policy_fn(module): ), ) policies = [lambda_policy, transformer_wrap_policy] - if custom_policy: - policies.extend([self_attn_policy, mlp_policy]) return functools.partial(_or_policy, policies=policies) return get_wrapping_policy diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b28c385728..990d814d9f 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -501,9 +501,9 @@ def create_accelerator_and_postprocess(self): # load_param_skip_names = ['inv_freq'] if self.is_fsdp_enabled: - wrapping_policy = get_wrapping_policy_factory(self.model.config.model_type) + wrapping_policy = get_wrapping_policy_factory(self.args.model_type) fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=wrapping_policy(False), + auto_wrap_policy=wrapping_policy(), use_orig_params=False, limit_all_gathers=True, param_init_fn=lambda module: module.to_empty( From 5d072dd60ce7a2f353fcb0f09030c491342d67cd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 13:49:39 -0500 Subject: [PATCH 11/12] set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp --- src/axolotl/core/trainer_builder.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 990d814d9f..d11f0c6532 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -196,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=1e-6, metadata={"help": "loraplus learning rate for lora embedding layers."}, ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) class AxolotlTrainer(Trainer): @@ -477,6 +481,11 @@ def push_to_hub(self, *args, **kwargs) -> str: def create_accelerator_and_postprocess(self): rank = int(os.environ.get("LOCAL_RANK", 0)) res = super().create_accelerator_and_postprocess() + + if self.args.qlora is False: + return res + + # the rest of this method override is specific to fsdp + qlora (for now) sync_module_states = ( str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1 ) @@ -504,6 +513,7 @@ def create_accelerator_and_postprocess(self): wrapping_policy = get_wrapping_policy_factory(self.args.model_type) fsdp_plugin = FullyShardedDataParallelPlugin( auto_wrap_policy=wrapping_policy(), + cpu_offload=False, use_orig_params=False, limit_all_gathers=True, param_init_fn=lambda module: module.to_empty( @@ -836,6 +846,9 @@ def build(self, total_num_steps): if self.cfg.fsdp_config: training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) + if self.cfg.adapter == "qlora": + training_arguments_kwargs["qlora"] = True + # deepspeed if self.cfg.deepspeed: training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed From 1fa1a23c00d3c9429ded91ec6dcad004e7ae0cbe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 13:53:27 -0500 Subject: [PATCH 12/12] refactor for duplicate code --- src/axolotl/utils/models.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 65b75dbe25..36c9c17e35 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -617,12 +617,14 @@ def load_model( model_kwargs["attn_implementation"] = "eager" model_config._attn_implementation = "eager" # pylint: disable=protected-access + qlora_fsdp = ( + cfg.fsdp + and cfg.adapter == "qlora" + and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES + ) + try: - if ( - model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES - and cfg.adapter == "qlora" - and cfg.fsdp is not None - ): + if qlora_fsdp: if cfg.bf16 or cfg.bfloat16: torch_dtype, compute_dtype = torch.float32, torch.bfloat16 elif cfg.fp16 or cfg.float16: @@ -802,12 +804,6 @@ def load_and_quantize_parallel(name_param, model, **kwargs): LOG.exception(err) raise err - qlora_fsdp = ( - cfg.fsdp - and cfg.adapter == "qlora" - and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES - ) - if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: model = model.merge_and_unload()