Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for the llm-finetuning blog #12

Merged
merged 12 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ This repository contains a variety of Determined examples that are not actively
## Blog posts
| Example | Description |
|:---------------------------------------:|:----------------------------------------------------------------------------:|
| [Python SDK demo](blog/python_sdk_demo) | Example usage of the Determined Python SDK to run and administer experiments |
| [LLM Finetuning](blog/llm-finetuning) | Finetuning the TinyLlama-1.1B Model on Text-to-SQL. |
| [Python SDK demo](blog/python_sdk_demo) | Example usage of the Determined Python SDK to run and administer experiments. |

## Computer Vision

Expand Down
2 changes: 2 additions & 0 deletions blog/llm-finetuning/.detignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
text-to-sql*
checkpoints
4 changes: 4 additions & 0 deletions blog/llm-finetuning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__pycache__
.DS_STORE
text-to-sql*
checkpoints
40 changes: 40 additions & 0 deletions blog/llm-finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# LLM Finetuning using HuggingFace + Determined

In this demo, we finetune the [TinyLlama-1.1B-Chat](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4) on a [text-to-SQL dataset](https://huggingface.co/datasets/Clinton/Text-to-sql-v1). We ran this on two 80 GB A100 GPUs.

To get started, first install Determined on your local machine:
```bash
pip install determined
```

Then finetune:
```bash
det e create distributed.yaml .
```

Change configuration options in `distributed.yaml`. Some important options are:
- `slots_per_trial`: the number of GPUs to use.
- `dataset_subset`: the difficulty subset to train on.
- `per_device_train_batch_size`: the batch size per GPU.


Test your model's generation capabilities:

```bash
python test_model.py --exp_id <exp_id> --dataset_subset <dataset_subset>
```

Where
- `<exp_id>` is the id of your finetuning experiment in the Determined UI.
- `<dataset_subset>` is one of "easy", "medium", or "hard".

To test the pretrained model (not finetuned), leave out `--exp_id`. For example:

```bash
python test_model.py --dataset_subset easy
```

## Contributors

- [Kevin Musgrave](https://github.com/KevinMusgrave)
- [Agnieszka Ciborowska](https://github.com/aciborowska)
31 changes: 31 additions & 0 deletions blog/llm-finetuning/chat_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
CHAT_ML_TEMPLATE = """
{% for message in messages %}
{% if message['role'] == 'user' %}
{{'<|im_start|>user\n' + message['content'].strip() + '<|im_end|>' }}
{% elif message['role'] == 'system' %}
{{'<|im_start|>system\n' + message['content'].strip() + '<|im_end|>' }}
{% elif message['role'] == 'assistant' %}
{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}
{% endif %}
{% endfor %}
"""

ASSISTANT_PROMPT = "<|im_start|>assistant\n"

EOS_TOKEN = "<|im_end|>"


def get_chat_format(element):
system_prompt = (
"You are a helpful programmer assistant that excels at SQL. "
"When prompted with a task and a definition of an SQL table, you "
"respond with a SQL query to retrieve information from the table. "
"Don't explain your reasoning, only provide the SQL query."
)
user_prompt = "Task: {instruction}\nSQL table: {input}\nSQL query: "

return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt.format_map(element)},
{"role": "assistant", "content": element["response"]},
]
68 changes: 68 additions & 0 deletions blog/llm-finetuning/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import datasets


def add_length_column(dataset):
df = dataset.to_pandas()
df["total_length"] = 0
for column_name in ["instruction", "input", "response"]:
num_words = df[column_name].astype(str).str.split().apply(len)
df["total_length"] += num_words

return df


def filter_by_total_length(df, difficulty, number_of_samples):
if difficulty == "easy":
return df[df["total_length"].between(10, 100)].iloc[:number_of_samples]
elif difficulty == "medium":
return df[df["total_length"].between(101, 200)].iloc[:number_of_samples]
elif difficulty == "hard":
return df[df["total_length"].between(201, 800)].iloc[:number_of_samples]


def get_dataset_subset_name(difficulty):
return f"text-to-sql-v1-{difficulty}"


def create_and_save_datasets(
df, difficulty, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1
):
seed = 123
# remove total_length column because we don't need it anymore
df = df.drop(columns=["total_length"])
dataset = datasets.Dataset.from_pandas(df, preserve_index=False)

# split into training and "the rest"
train_valtest = dataset.train_test_split(train_size=train_ratio, seed=seed)

# split "the rest" into validation and testing
val_test = train_valtest["test"].train_test_split(
test_size=test_ratio / (test_ratio + val_ratio), seed=seed
)

dataset = datasets.DatasetDict(
{
"train": train_valtest["train"],
"valid": val_test["train"],
"test": val_test["test"],
}
)
dataset_name = get_dataset_subset_name(difficulty)
dataset.save_to_disk(dataset_name)
return dataset


def load_dataset(difficulty):
return datasets.load_from_disk(get_dataset_subset_name(difficulty))


def load_or_create_dataset(difficulty, num_samples=10000):
try:
return load_dataset(difficulty)
except FileNotFoundError:
dataset = datasets.load_dataset("Clinton/Text-to-sql-v1")
dataset = dataset["train"]
dataset = dataset.remove_columns(["text", "source"])
df = add_length_column(dataset)
df = filter_by_total_length(df, difficulty, num_samples)
return create_and_save_datasets(df, difficulty)
33 changes: 33 additions & 0 deletions blog/llm-finetuning/distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Text-to-SQL
debug: false
environment:
environment_variables:
- NCCL_DEBUG=INFO
resources:
slots_per_trial: 2
searcher:
name: single
max_length:
batches: 5000
metric: eval_accuracy
smaller_is_better: false
hyperparameters:
model: "TinyLlama/TinyLlama-1.1B-Chat-v0.4"
dataset_subset: "easy"
training_args:
output_dir: "/tmp/llm_finetuning"
max_steps: 5000
per_device_train_batch_size: 1
KevinMusgrave marked this conversation as resolved.
Show resolved Hide resolved
per_device_eval_batch_size: 4
fp16: true
evaluation_strategy: "steps"
eval_steps: 1000
logging_strategy: "steps"
logging_steps: 100
save_strategy: "steps"
save_steps: 1000
learning_rate: 1e-5
entrypoint: >-
python -m determined.launch.torch_distributed
python finetune.py
max_restarts: 0
93 changes: 93 additions & 0 deletions blog/llm-finetuning/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import determined as det
import evaluate
from determined.transformers import DetCallback
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from trl import DataCollatorForCompletionOnlyLM

from chat_format import ASSISTANT_PROMPT, CHAT_ML_TEMPLATE, EOS_TOKEN, get_chat_format
from dataset_utils import load_or_create_dataset


def get_model_and_tokenizer(model_name):
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, eos_token=EOS_TOKEN)
tokenizer.chat_template = CHAT_ML_TEMPLATE
return model, tokenizer


def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)


def main(training_args, det_callback, hparams):
model_name = hparams["model"]
model, tokenizer = get_model_and_tokenizer(model_name)

def tokenize(element):
formatted = tokenizer.apply_chat_template(
KevinMusgrave marked this conversation as resolved.
Show resolved Hide resolved
get_chat_format(element), tokenize=False
)
outputs = tokenizer(formatted)
return {
"input_ids": outputs["input_ids"],
"attention_mask": outputs["attention_mask"],
}

dataset = load_or_create_dataset(hparams["dataset_subset"])
for k in dataset.keys():
dataset[k] = dataset[k].map(tokenize)

response_template_ids = tokenizer.encode(ASSISTANT_PROMPT, add_special_tokens=False)
collator = DataCollatorForCompletionOnlyLM(
response_template_ids, tokenizer=tokenizer
)

bleu = evaluate.load("bleu")
acc = evaluate.load("accuracy")

def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics but we need to shift the labels
labels = labels[:, 1:]
preds = preds[:, :-1]
# -100 is a default value for ignore_index used by DataCollatorForCompletionOnlyLM
mask = labels == -100
labels[mask] = tokenizer.pad_token_id
preds[mask] = tokenizer.pad_token_id

decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
bleu_score = bleu.compute(predictions=decoded_preds, references=decoded_labels)
accuracy = acc.compute(predictions=preds[~mask], references=labels[~mask])

return {**bleu_score, **accuracy}

trainer = Trainer(
args=training_args,
model=model,
tokenizer=tokenizer,
data_collator=collator,
train_dataset=dataset["train"],
eval_dataset=dataset["valid"],
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
compute_metrics=compute_metrics,
)

trainer.add_callback(det_callback)
trainer.evaluate()
trainer.train()


if __name__ == "__main__":
info = det.get_cluster_info()
hparams = info.trial.hparams
distributed = det.core.DistributedContext.from_torch_distributed()
with det.core.init(distributed=distributed) as core_context:
training_args = TrainingArguments(**hparams["training_args"])
det_callback = DetCallback(core_context, training_args)
main(training_args, det_callback, hparams)
5 changes: 5 additions & 0 deletions blog/llm-finetuning/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transformers==4.36.2
datasets==2.16.1
evaluate==0.4.1
trl==0.7.9
scikit-learn==1.4.0
3 changes: 3 additions & 0 deletions blog/llm-finetuning/startup-hook.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
pip install --upgrade pip
pip install -r requirements.txt
51 changes: 51 additions & 0 deletions blog/llm-finetuning/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
KevinMusgrave marked this conversation as resolved.
Show resolved Hide resolved
import glob

from determined.experimental import client

from chat_format import ASSISTANT_PROMPT, EOS_TOKEN, get_chat_format
from dataset_utils import load_or_create_dataset
from finetune import get_model_and_tokenizer


def main(exp_id, dataset_subset):
if exp_id is None:
checkpoint_dir = "TinyLlama/TinyLlama-1.1B-Chat-v0.4"
else:
exp = client.get_experiment(exp_id)
checkpoint = exp.list_checkpoints(
max_results=1,
sort_by=client.CheckpointSortBy.SEARCHER_METRIC,
order_by=client.OrderBy.DESCENDING,
)[0]
checkpoint_dir = checkpoint.download(mode=client.DownloadMode.MASTER)
checkpoint_dir = glob.glob(f"{checkpoint_dir}/checkpoint-*")[0]

model, tokenizer = get_model_and_tokenizer(checkpoint_dir)
eos_token_id = tokenizer.get_vocab()[EOS_TOKEN]

dataset = load_or_create_dataset(dataset_subset)["test"]
element = dataset[0]
formatted = tokenizer.apply_chat_template(
get_chat_format(element)[:2],
tokenize=False,
)
formatted += ASSISTANT_PROMPT
print(formatted)

inputs = tokenizer(formatted, return_tensors="pt")
outputs = model.generate(**inputs, eos_token_id=eos_token_id, max_new_tokens=1000)
input_length = inputs["input_ids"].shape[1]
response = tokenizer.batch_decode(
outputs[:, input_length:], skip_special_tokens=True
)
print(f"\n\nCorrect response:\n{element['response']}")
print(f"\n\nLLM response:\n{response[0]}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--exp_id", type=int, default=None, required=False)
parser.add_argument("--dataset_subset", type=str, default="easy", required=False)
args = parser.parse_args()
main(args.exp_id, args.dataset_subset)