diff --git a/README.md b/README.md index 84c7ea1..adf5b98 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ This repository contains a variety of Determined examples that are not actively | [LLM Finetuning](blog/llm-finetuning) | Finetuning TinyLlama-1.1B on Text-to-SQL. | | [LLM Finetuning 2](blog/llm-finetuning-2) | Finetuning Mistral-7B on Text-to-SQL using LoRA and DeepSpeed. | | [LLM Finetuning 3](blog/llm-finetuning-3) | Finetuning Gemma-2B using DPO. | +| [LoRA Parameters](blog/lora-parameters) | Finding the best LoRA parameters. | | [Python SDK demo](blog/python_sdk_demo) | Example usage of the Determined Python SDK to run and administer experiments. | | [Tensor Parallelism](blog/tp) | Profiling tensor parallelism in PyTorch. | diff --git a/blog/lora-parameters/.detignore b/blog/lora-parameters/.detignore new file mode 100644 index 0000000..5e741f0 --- /dev/null +++ b/blog/lora-parameters/.detignore @@ -0,0 +1,2 @@ +text-to-sql* +checkpoints \ No newline at end of file diff --git a/blog/lora-parameters/.gitignore b/blog/lora-parameters/.gitignore new file mode 100644 index 0000000..d3f89f5 --- /dev/null +++ b/blog/lora-parameters/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +.DS_STORE +text-to-sql* +checkpoints +*.png \ No newline at end of file diff --git a/blog/lora-parameters/README.md b/blog/lora-parameters/README.md new file mode 100644 index 0000000..e5099c4 --- /dev/null +++ b/blog/lora-parameters/README.md @@ -0,0 +1,34 @@ +# Finding the best LoRA parameters + +We finetune [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) using [LoRA](https://arxiv.org/abs/2106.09685) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). We ran LoRA on two 40 GB A100 GPUs utilizing DeepSpeed. + +See our [blog post](https://www.determined.ai/blog/lora-parameters) for our experiment results. + +To get started, first install Determined on your local machine: +```bash +pip install determined +``` + +Then finetune with LoRA: +```bash +det e create lora.yaml . +``` + +You can view the actual training code in `finetune.py`. + + +## Configuration + +Change configuration options in `lora.yaml`. Some important options are: +- `slots_per_trial`: the number of GPUs to use. +- `dataset_subset`: the difficulty subset to train on. +- `per_device_train_batch_size`: the batch size per GPU. + + +DeepSpeed configuration files are in the `ds_configs` folder. + + +## Contributors + +- By [Sze Wai Yuen](https://github.com/szewaiyuen6) +- Built on `llm-finetuning` code by [Agnieszka Ciborowska](https://github.com/aciborowska) and [Kevin Musgrave](https://github.com/KevinMusgrave). \ No newline at end of file diff --git a/blog/lora-parameters/chat_format.py b/blog/lora-parameters/chat_format.py new file mode 100644 index 0000000..57ea591 --- /dev/null +++ b/blog/lora-parameters/chat_format.py @@ -0,0 +1,67 @@ +CHAT_ML_TEMPLATE = """ +{% for message in messages %} +{% if message['role'] == 'user' %} +{{'<|im_start|>user\n' + message['content'].strip() + '<|im_end|>' }} +{% elif message['role'] == 'system' %} +{{'<|im_start|>system\n' + message['content'].strip() + '<|im_end|>' }} +{% elif message['role'] == 'assistant' %} +{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }} +{% endif %} +{% endfor %} +""" + + +CHAT_ML_EOS_TOKEN = "<|im_end|>" + + +def get_chat_format(element, model_name, with_assistant_response=True): + system_prompt = ( + "You are a helpful programmer assistant that excels at SQL. " + "When prompted with a task and a definition of an SQL table, you " + "respond with a SQL query to retrieve information from the table. " + "Don't explain your reasoning, only provide the SQL query." + ) + + user_prompt = "Task: {instruction}\nSQL table: {input}\nSQL query: " + + if model_name == "mistralai/Mistral-7B-Instruct-v0.2": + user_prompt = f"{system_prompt}\n{user_prompt}" + output = [ + {"role": "user", "content": user_prompt.format_map(element)}, + ] + else: + output = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt.format_map(element)}, + ] + + if with_assistant_response: + output.append({"role": "assistant", "content": element["response"]}) + + return output + + +def set_special_tokens(tokenizer, model_name): + if model_name == "TinyLlama/TinyLlama-1.1B-Chat-v0.4": + tokenizer.chat_template = CHAT_ML_TEMPLATE + tokenizer.eos_token = CHAT_ML_EOS_TOKEN + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + +def get_assistant_prompt(model_name): + if model_name == "TinyLlama/TinyLlama-1.1B-Chat-v0.4": + return "<|im_start|>assistant\n" + else: + return "[/INST]" + + +def get_response_template_ids(tokenizer, model_name): + return tokenizer.encode(get_assistant_prompt(model_name), add_special_tokens=False) + + +def maybe_add_generation_prompt(x, model_name): + if model_name == "TinyLlama/TinyLlama-1.1B-Chat-v0.4": + return x + get_assistant_prompt(model_name) + else: + return x \ No newline at end of file diff --git a/blog/lora-parameters/dataset_utils.py b/blog/lora-parameters/dataset_utils.py new file mode 100644 index 0000000..f00e73f --- /dev/null +++ b/blog/lora-parameters/dataset_utils.py @@ -0,0 +1,69 @@ +import datasets +import pandas as pd + + +def add_length_column(dataset) -> pd.DataFrame: + df = dataset.to_pandas() + df["total_length"] = 0 + for column_name in ["instruction", "input", "response"]: + num_words = df[column_name].astype(str).str.split().apply(len) + df["total_length"] += num_words + + return df + + +def filter_by_total_length(df, difficulty, number_of_samples): + if difficulty == "easy": + return df[df["total_length"].between(10, 100)].iloc[:number_of_samples] + elif difficulty == "medium": + return df[df["total_length"].between(101, 200)].iloc[:number_of_samples] + elif difficulty == "hard": + return df[df["total_length"].between(201, 800)].iloc[:number_of_samples] + + +def get_dataset_subset_name(difficulty: str) -> str: + return f"text-to-sql-v1-{difficulty}" + + +def create_and_save_datasets( + df, difficulty, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1 +): + seed = 123 + # remove total_length column because we don't need it anymore + df = df.drop(columns=["total_length"]) + dataset = datasets.Dataset.from_pandas(df, preserve_index=False) + + # split into training and "the rest" + train_valtest = dataset.train_test_split(train_size=train_ratio, seed=seed) + + # split "the rest" into validation and testing + val_test = train_valtest["test"].train_test_split( + test_size=test_ratio / (test_ratio + val_ratio), seed=seed + ) + + dataset = datasets.DatasetDict( + { + "train": train_valtest["train"], + "valid": val_test["train"], + "test": val_test["test"], + } + ) + dataset_name = get_dataset_subset_name(difficulty) + dataset.save_to_disk(dataset_name) + return dataset + + +def load_dataset(difficulty): + return datasets.load_from_disk(get_dataset_subset_name(difficulty)) + + +def load_or_create_dataset(difficulty, num_samples=10000): + try: + return load_dataset(difficulty) + except FileNotFoundError: + dataset = datasets.load_dataset("Clinton/Text-to-sql-v1") + dataset = dataset["train"] + dataset = dataset.remove_columns(["text", "source"]) + df = add_length_column(dataset) + df = filter_by_total_length(df, difficulty, num_samples) + return create_and_save_datasets(df, difficulty) \ No newline at end of file diff --git a/blog/lora-parameters/ds_configs/ds_config_stage_3.json b/blog/lora-parameters/ds_configs/ds_config_stage_3.json new file mode 100644 index 0000000..9d36ec2 --- /dev/null +++ b/blog/lora-parameters/ds_configs/ds_config_stage_3.json @@ -0,0 +1,47 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto", + "total_num_steps": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto" + } \ No newline at end of file diff --git a/blog/lora-parameters/finetune.py b/blog/lora-parameters/finetune.py new file mode 100644 index 0000000..7485b00 --- /dev/null +++ b/blog/lora-parameters/finetune.py @@ -0,0 +1,217 @@ +import logging +import os +import random +import sys + +import datasets +import determined as det +import evaluate +import numpy as np +import torch +import transformers +from determined.transformers import DetCallback +from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model +from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer, + TrainingArguments) +from trl import DataCollatorForCompletionOnlyLM + +from chat_format import (get_chat_format, get_response_template_ids, + set_special_tokens) +from dataset_utils import load_or_create_dataset + +logger = logging.getLogger(__name__) + + +def get_tokenizer(model_name, model_commit_hash, hparams): + tokenizer = AutoTokenizer.from_pretrained( + model_name, + padding_side="right", + truncation_side="right", + revision=model_commit_hash, + token=hparams["hf_token"], + ) + set_special_tokens(tokenizer, model_name) + return tokenizer + + +def get_model_and_tokenizer(model_name, use_lora, hparams, inference=False, device_map="auto", model_commit_hash=None): + if inference: + if use_lora: + model = AutoPeftModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16, device_map=device_map, revision=model_commit_hash + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map=device_map, + revision=model_commit_hash, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + revision=model_commit_hash, + token=hparams["hf_token"], + ) + model.enable_input_require_grads() + + if use_lora: + r = hparams["r"] + lora_alpha = hparams["lora_alpha"] + peft_config = LoraConfig( + task_type="CAUSAL_LM", + inference_mode=False, + r=r, + lora_alpha=lora_alpha, + lora_dropout=hparams["lora_dropout"], + use_rslora=hparams["use_rslora"] + ) + + model = get_peft_model(model, peft_config) + + tokenizer = get_tokenizer(model_name, model_commit_hash=model_commit_hash, hparams=hparams) + return model, tokenizer + + +def get_tokenize_fn(tokenizer): + def fn(formatted): + return tokenizer(formatted, padding=True, truncation=True, max_length=2048) + + return fn + + +def preprocess_logits_for_metrics(logits, labels): + if isinstance(logits, tuple): + # Depending on the model and config, logits may contain extra tensors, + # like past_key_values, but logits always come first + logits = logits[0] + return logits.argmax(dim=-1) + + +def main(training_args, det_callback, hparams): + if "hf_token" in hparams: + import huggingface_hub + + huggingface_hub.login(token=hparams["hf_token"]) + + model_name = hparams["model"] + model_commit_hash = None + if "model_commit_hash" in hparams: + model_commit_hash = hparams["model_commit_hash"] + model, tokenizer = get_model_and_tokenizer(model_name, hparams["lora"], hparams=hparams, model_commit_hash=model_commit_hash) + tokenize_fn = get_tokenize_fn(tokenizer) + + def tokenize(element): + formatted = tokenizer.apply_chat_template( + get_chat_format(element, model_name), tokenize=False + ) + outputs = tokenize_fn(formatted) + return { + "input_ids": outputs["input_ids"], + "attention_mask": outputs["attention_mask"], + } + + dataset = load_or_create_dataset(hparams["dataset_subset"]) + column_names = list(dataset["train"].features) + for k in dataset.keys(): + dataset[k] = dataset[k].map(tokenize, remove_columns=column_names) + + response_template_ids = get_response_template_ids(tokenizer, model_name) + collator = DataCollatorForCompletionOnlyLM( + response_template_ids, tokenizer=tokenizer + ) + + bleu = evaluate.load("bleu") + acc = evaluate.load("accuracy") + + def compute_metrics(eval_preds): + preds, labels = eval_preds + # preds have the same shape as the labels, after the argmax(-1) has been calculated + # by preprocess_logits_for_metrics but we need to shift the labels + labels = labels[:, 1:] + preds = preds[:, :-1] + # -100 is a default value for ignore_index used by DataCollatorForCompletionOnlyLM + mask = labels == -100 + labels[mask] = tokenizer.pad_token_id + preds[mask] = tokenizer.pad_token_id + + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + + bleu_score = bleu.compute(predictions=decoded_preds, references=decoded_labels) + accuracy = acc.compute(predictions=preds[~mask], references=labels[~mask]) + + return {**bleu_score, **accuracy} + + trainer = Trainer( + args=training_args, + model=model, + tokenizer=tokenizer, + data_collator=collator, + train_dataset=dataset["train"], + eval_dataset=dataset["valid"], + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + compute_metrics=compute_metrics, + ) + + trainer.add_callback(det_callback) + trainer.train() + +def set_seed(seed: int = 42) -> None: + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PYTHONHASHSEED"] = str(seed) + print(f"Random seed set as {seed}") + + +if __name__ == "__main__": + # Setup logging + logging.basicConfig( + format=det.LOG_FORMAT, handlers=[logging.StreamHandler(sys.stdout)] + ) + log_level = logging.INFO + transformers.utils.logging.set_verbosity_info() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + info = det.get_cluster_info() + hparams = info.trial.hparams + + if "hf_token" in hparams: + hf_token = hparams["hf_token"] + import huggingface_hub + huggingface_hub.login(token=hparams["hf_token"]) + + if hparams["training_args"]["deepspeed"]: + hparams["training_args"]["deepspeed"] = "ds_configs/ds_config_stage_3.json" + + training_args = TrainingArguments(**hparams["training_args"]) + if training_args.deepspeed: + # Set env var for deepspeed distributed context + os.environ["LOCAL_SIZE"] = os.environ["LOCAL_WORLD_SIZE"] + os.environ["CROSS_RANK"] = str(int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"])) + os.environ["CROSS_SIZE"] = str(int(os.environ["WORLD_SIZE"]) // int(os.environ["LOCAL_WORLD_SIZE"])) + os.environ["CHIEF_IP"] = os.environ["DET_CHIEF_IP"] + distributed = det.core.DistributedContext.from_deepspeed() + else: + distributed = det.core.DistributedContext.from_torch_distributed() + + random_seed = 42 + + with det.core.init(distributed=distributed) as core_context: + set_seed(random_seed) + + det_callback = DetCallback( + core_context, + training_args, + ) + + main(training_args, det_callback, hparams) \ No newline at end of file diff --git a/blog/lora-parameters/lora.yaml b/blog/lora-parameters/lora.yaml new file mode 100644 index 0000000..e2caac5 --- /dev/null +++ b/blog/lora-parameters/lora.yaml @@ -0,0 +1,55 @@ +name: mistral lora hard +debug: false +environment: + environment_variables: + - NCCL_DEBUG=INFO + - NCCL_SOCKET_IFNAME=ens,eth,ib + image: + gpu: determinedai/environments:cuda-11.8-pytorch-2.0-gpu-95c7a14 + cpu: determinedai/environments:py-3.10-pytorch-2.0-cpu-03ae7d7 +resources: + slots_per_trial: 2 + resource_pool: # We used A100 40GB GPUs +workspace: +project: +searcher: + name: grid + max_length: + batches: 3000 + metric: eval_accuracy + smaller_is_better: false +hyperparameters: + model: "mistralai/Mistral-7B-Instruct-v0.2" + model_commit_hash: "99259002b41e116d28ccb2d04a9fbe22baed0c7f" + dataset_subset: "hard" + lora: true + r: + type: categorical + vals: [2, 8, 32, 128] + lora_alpha: + type: categorical + vals: [0.5, 1, 2, 8, 32, 128, 256, 512] + lora_dropout: + type: categorical + vals: [0.1] + hf_token: + training_args: + output_dir: "/tmp/llm_finetuning" + max_steps: 3000 + per_device_train_batch_size: 4 + per_device_eval_batch_size: 4 + bf16: true + evaluation_strategy: "steps" + eval_steps: 500 + logging_strategy: "steps" + logging_steps: 100 + save_strategy: "steps" + save_steps: 1000 + learning_rate: 1e-5 + deepspeed: true + gradient_checkpointing: true + use_rslora: false +entrypoint: >- + python -m determined.launch.torch_distributed + python finetune.py +max_restarts: 0 \ No newline at end of file diff --git a/blog/lora-parameters/requirements.txt b/blog/lora-parameters/requirements.txt new file mode 100644 index 0000000..c6cad1b --- /dev/null +++ b/blog/lora-parameters/requirements.txt @@ -0,0 +1,8 @@ +transformers==4.37.2 +datasets==2.17.0 +evaluate==0.4.1 +trl==0.7.10 +scikit-learn==1.4.0 +deepspeed==0.10.2 +peft==0.8.2 +huggingface_hub \ No newline at end of file diff --git a/blog/lora-parameters/startup-hook.sh b/blog/lora-parameters/startup-hook.sh new file mode 100644 index 0000000..c90323c --- /dev/null +++ b/blog/lora-parameters/startup-hook.sh @@ -0,0 +1,3 @@ +#!/bin/bash +pip install --upgrade pip +pip install -r requirements.txt \ No newline at end of file