diff --git a/examples/cohere/lora.yml b/examples/cohere/lora.yml new file mode 100644 index 000000000..2ac15880b --- /dev/null +++ b/examples/cohere/lora.yml @@ -0,0 +1,66 @@ +base_model: CohereForAI/c4ai-command-r-v01 +model_type: CohereForCausalLM +tokenizer_type: CohereTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index c1eb3127d..8714a1a57 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -1,4 +1,5 @@ """multipack patching for v2 of sample packing""" + import importlib import transformers @@ -18,6 +19,7 @@ "gemma", "gemmoe", "starcoder2", + "cohere", ] @@ -56,6 +58,10 @@ def patch_for_multipack(model_type, model_name=None): patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") elif model_type == "jamba": patch_remote(model_name, ".configuration_jamba", ".modeling_jamba") + elif model_type == "cohere": + transformers.models.cohere.modeling_cohere._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) def patch_remote(model_name, config_name, modeling_name):