diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml new file mode 100644 index 0000000000..da6c06020f --- /dev/null +++ b/examples/llama-2/qlora-fsdp.yml @@ -0,0 +1,70 @@ +base_model: NousResearch/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +output_dir: ./qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 512 +sample_packing: false +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 4 +num_epochs: 4 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.00001 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: + - full_shard +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer +special_tokens: diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml new file mode 100644 index 0000000000..32db7073b7 --- /dev/null +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -0,0 +1,74 @@ +base_model: mistralai/Mixtral-8x7B-v0.1 +model_type: AutoModelForCausalLM +tokenizer_type: LlamaTokenizer +trust_remote_code: true + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.02 +output_dir: ./qlora-out + +model_config: + output_router_logits: true + +adapter: qlora +lora_model_dir: + +sequence_len: 1024 +sample_packing: false +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: paged_adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +weight_decay: 0.0 +fsdp: + - full_shard +fsdp_config: + fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock +special_tokens: diff --git a/requirements.txt b/requirements.txt index 718896783b..191948a400 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ packaging==23.2 peft==0.9.0 transformers==4.38.2 tokenizers==0.15.0 -bitsandbytes>=0.41.1 +bitsandbytes>=0.43.0 accelerate==0.26.1 deepspeed==0.13.1 pydantic==2.6.3 @@ -40,3 +40,4 @@ gcsfs # adlfs trl>=0.7.9 +fastcore>=1.5.29 diff --git a/src/axolotl/core/policies/__init__.py b/src/axolotl/core/policies/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/core/policies/auto_wrap.py b/src/axolotl/core/policies/auto_wrap.py new file mode 100644 index 0000000000..d42b62ee08 --- /dev/null +++ b/src/axolotl/core/policies/auto_wrap.py @@ -0,0 +1,55 @@ +"""module for building the auto wrap policy for FSDP""" +import functools + +from peft import PrefixEncoder, PromptEmbedding, PromptEncoder +from torch.distributed.fsdp.wrap import ( + _or_policy, + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer + +SUPPORTED_AUTO_WRAP_MODEL_TYPES = [ + "llama", + "mistral", + "mixtral", +] + + +def get_wrapping_policy_factory(model_type): + if model_type == "llama": + layer_to_wrap = LlamaDecoderLayer + elif model_type == "mistral": + layer_to_wrap = MistralDecoderLayer + elif model_type == "mixtral": + layer_to_wrap = MixtralDecoderLayer + + def get_wrapping_policy(): + """This checks for lora layers (has weight and requires_grad)""" + + def lambda_policy_fn(module): + return ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ) + + lambda_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn + ) + transformer_layer_name = layer_to_wrap + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + transformer_layer_name, + ), + ) + policies = [lambda_policy, transformer_wrap_policy] + return functools.partial(_or_policy, policies=policies) + + return get_wrapping_policy diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e051d4e69c..d11f0c6532 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -8,6 +8,7 @@ import importlib.util import logging import math +import os import sys from abc import abstractmethod from dataclasses import dataclass, field @@ -17,7 +18,10 @@ import torch import transformers +from accelerate import FullyShardedDataParallelPlugin +from accelerate.utils import str_to_bool from datasets import Dataset +from torch.distributed.fsdp import MixedPrecision from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -30,6 +34,7 @@ from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer +from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler @@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=1e-6, metadata={"help": "loraplus learning rate for lora embedding layers."}, ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) class AxolotlTrainer(Trainer): @@ -468,6 +477,56 @@ def push_to_hub(self, *args, **kwargs) -> str: return super().push_to_hub(*args, **kwargs) + @wraps(Trainer.create_accelerator_and_postprocess) + def create_accelerator_and_postprocess(self): + rank = int(os.environ.get("LOCAL_RANK", 0)) + res = super().create_accelerator_and_postprocess() + + if self.args.qlora is False: + return res + + # the rest of this method override is specific to fsdp + qlora (for now) + sync_module_states = ( + str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1 + ) + + mp_policy = None + amp = os.environ["ACCELERATE_MIXED_PRECISION"] + if amp == "fp16": + mp_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) + elif amp == "bf16": + mp_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) + + # If somehow we figure out how we want to parameterize we want to autocast buffers... + # mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32) + # load_param_skip_names = ['inv_freq'] + + if self.is_fsdp_enabled: + wrapping_policy = get_wrapping_policy_factory(self.args.model_type) + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=wrapping_policy(), + cpu_offload=False, + use_orig_params=False, + limit_all_gathers=True, + param_init_fn=lambda module: module.to_empty( + device=torch.device("cuda"), recurse=False + ) + if (rank != 0 and sync_module_states) + else None, + mixed_precision_policy=mp_policy, + ) + self.accelerator.state.fsdp_plugin = fsdp_plugin + + return res + class AxolotlMambaTrainer(AxolotlTrainer): """ @@ -787,6 +846,9 @@ def build(self, total_num_steps): if self.cfg.fsdp_config: training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config) + if self.cfg.adapter == "qlora": + training_arguments_kwargs["qlora"] = True + # deepspeed if self.cfg.deepspeed: training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index c039e790a1..11c25160da 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -24,9 +24,9 @@ def wrapper(*args, **kwargs): or not torch.cuda.is_available() or device == "auto" or torch.device(device).type == "cpu" + or torch.device(device).type == "meta" ): return default_value - return func(*args, **kwargs) return wrapper diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5407245ac6..36c9c17e35 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,13 +1,20 @@ """Module for models and model loading""" +# pylint: disable=too-many-lines + import logging import math import os -from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 +import types +from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401 import addict import bitsandbytes as bnb +import safetensors import torch import transformers +from accelerate import init_empty_weights +from bitsandbytes.nn import Linear4bit, Params4bit +from fastcore.parallel import parallel from peft import ( LoftQConfig, PeftConfig, @@ -16,6 +23,7 @@ prepare_model_for_kbit_training, ) from peft.tuners.lora import QuantLinear +from torch import Tensor, nn from transformers import ( # noqa: F401 AddedToken, AutoConfig, @@ -27,7 +35,9 @@ PreTrainedTokenizerBase, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub +from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -262,6 +272,117 @@ def load_tokenizer(cfg): return tokenizer +def replace_linear( + model: nn.Module, + linear_replacement: Type[nn.Module], + quant_config: Union[dict, None] = None, + skip_modules=None, + **kwargs, +): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + """ + if skip_modules is None: + skip_modules = ["lm_head"] + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear( + module, linear_replacement, quant_config, skip_modules, **kwargs + ) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + if issubclass(linear_replacement, Linear4bit): + model._modules[ # pylint: disable=protected-access + name + ] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + **kwargs, + ) + else: + raise ValueError( + f"Unsupported linear replacement: {type(linear_replacement)}" + ) + return model + + +def load_and_quantize( + module: nn.Module, + name: str, + value: Tensor, + device: torch.device = None, + dtype: torch.dtype = None, + skip_names: Optional[List[str]] = None, + is_meta_rank: bool = False, + low_memory: bool = True, + verbose: bool = False, + quant_method: str = "bnb", +): + """ + Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. + + Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True. + """ + + if skip_names is None: + skip_names = [] + + def place_on_device(value): + if is_meta_rank: + device = "meta" + elif low_memory: + device = "cpu" + else: + device = "cuda" + return value.to(device=device, dtype=dtype) + + if any(skip_name in name for skip_name in skip_names): + if verbose: + print(f"Skipping {name} because it is in skip_names") + return + + module_key, _, value_key = name.rpartition(".") + try: + submodule = module.get_submodule(module_key) + except AttributeError as exc: + print(f"Module {module_key} not found:\n{exc}") + return + + try: + if quant_method == "bnb": + param = submodule.get_parameter(value_key) + if isinstance(param, Params4bit): + # With `sync_module_states=True`, a meta device Params4bit needs to be the same + # shape as the quantized Params4bit with an initialized quant_state. However, + # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This + # workaround quantizes Params4bit to initialize quant_state on all ranks, then + # replaces Params4bit's data with a meta tensor to free memory on non-rank 0. + value = type(param)( + value.to(device=device, dtype=dtype).data, **param.__dict__ + ).cuda(device) + if is_meta_rank: + value = type(param)(value.data.to("meta"), **value.__dict__) + elif low_memory: + value = type(param)(value.data.to("cpu"), **value.__dict__) + else: + value = type(param)(place_on_device(value).data) + + except AttributeError: + # it's a buffer + value = place_on_device(value) + + setattr(submodule, value_key, value) + + def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, @@ -394,7 +515,7 @@ def load_model( if max_memory is not None: # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map, init_empty_weights + from accelerate import infer_auto_device_map with init_empty_weights(): model_canvas = AutoModelForCausalLM.from_config(model_config) @@ -496,8 +617,78 @@ def load_model( model_kwargs["attn_implementation"] = "eager" model_config._attn_implementation = "eager" # pylint: disable=protected-access + qlora_fsdp = ( + cfg.fsdp + and cfg.adapter == "qlora" + and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES + ) + try: - if ( + if qlora_fsdp: + if cfg.bf16 or cfg.bfloat16: + torch_dtype, compute_dtype = torch.float32, torch.bfloat16 + elif cfg.fp16 or cfg.float16: + torch_dtype, compute_dtype = torch.float32, torch.float16 + else: + torch_dtype, compute_dtype = torch.float32, torch.float16 + + with init_empty_weights(): + LOG.info("Loading model with empty weights.") + model = AutoModelForCausalLM.from_config(model_config) + model.model = replace_linear( + model.model, + Linear4bit, + compute_dtype=compute_dtype, + quant_type="nf4", + quant_storage=torch_dtype, + ) + + model.is_loaded_in_4bit = True + + # Grab the safetensors files that hold the weights + try: + idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME) + files, _ = hub.get_checkpoint_shard_files(base_model, idx) + except OSError: + try: + # This means the model doesn't have a model.safetensors.index.json because it is not sharded + files = [] + files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME)) + except OSError as exc: + # This means the model probably doesn't have a safetensors file + raise exc + + # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly + # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage + def load_and_quantize_parallel(name_param, model, **kwargs): + name, param = name_param + load_and_quantize(model, name, param, **kwargs) + + param_count = sum((p.numel() for n, p in model.named_parameters())) + for filename in files: + weights = safetensors.torch.load_file(filename) + quant_method = "bnb" + devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) + left = int(os.cpu_count() / torch.cuda.device_count()) + right = int( + 8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9)) + ) + n_workers = min(left, right) + parallel( + load_and_quantize_parallel, + weights.items(), + n_workers=n_workers, + threadpool=True, + model=model, + dtype=torch_dtype, + device=cfg.local_rank, + skip_names=[], + is_meta_rank=(cfg.local_rank != 0), + verbose=False, + quant_method=quant_method, + ) + + elif ( model_config.model_type == "llama" and not cfg.trust_remote_code and not cfg.gptq @@ -613,7 +804,7 @@ def load_model( LOG.exception(err) raise err - if isinstance(model, (PeftModel, PeftModelForCausalLM)): + if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: model = model.merge_and_unload() embeddings_len = ( @@ -692,6 +883,9 @@ def load_model( if cfg.adapter == "lora" and loftq_bits: skip_prepare_model_for_kbit_training = True + if qlora_fsdp: + skip_prepare_model_for_kbit_training = True + if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: model.gradient_checkpointing_enable() @@ -706,7 +900,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype or cfg.flash_attention: + if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp: LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: @@ -724,7 +918,12 @@ def load_model( else: model, lora_config = load_adapter(model, cfg, cfg.adapter) - if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): + if ( + cfg.ddp + and not load_in_8bit + and not (cfg.rl and cfg.load_in_4bit) + and not qlora_fsdp + ): # TODO revaldate this conditional model.to(f"cuda:{cfg.local_rank}") @@ -813,6 +1012,30 @@ def find_all_linear_names(model): return list(lora_module_names) +def setup_quantized_meta_for_peft(model: nn.Module): + """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" + + def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument + return self + + for param in model.parameters(): + if isinstance(param, Params4bit): + param.quant_state._orig_to = ( # pylint: disable=protected-access + param.quant_state.to + ) + param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) + + +def setup_quantized_peft_meta_for_training(model: nn.Module): + """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" + for param in model.parameters(): + if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): + param.quant_state.to = ( + param.quant_state._orig_to # pylint: disable=protected-access + ) + param.quant_state._orig_to = None # pylint: disable=protected-access + + def load_lora(model, cfg, inference=False, config_only=False): # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] @@ -849,6 +1072,11 @@ def load_lora(model, cfg, inference=False, config_only=False): if config_only: return None, lora_config + rank = int(os.environ.get("LOCAL_RANK", 0)) + + if cfg.fsdp and cfg.adapter == "qlora" and rank != 0: + setup_quantized_meta_for_peft(model) + if cfg.lora_model_dir: LOG.debug("Loading pretrained PEFT - LoRA") model_kwargs: Any = {} @@ -864,6 +1092,9 @@ def load_lora(model, cfg, inference=False, config_only=False): else: model = get_peft_model(model, lora_config) - model.print_trainable_parameters() + if rank == 0: + model.print_trainable_parameters() + elif cfg.fsdp and cfg.adapter == "qlora": + setup_quantized_peft_meta_for_training(model) return model, lora_config