Skip to content

Commit

Permalink
Merge pull request #3450 from BUAADreamer/mllm
Browse files Browse the repository at this point in the history
Add Multimodal LLM Finetuning
  • Loading branch information
hiyouga authored Apr 25, 2024
2 parents fcfbd8c + 7f3bd35 commit c20f750
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 38 deletions.
27 changes: 21 additions & 6 deletions data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@
"tools": "tools"
}
},
"mllm_demo": {
"file_name": "mllm_demo.json",
"file_sha1": "b6709b23657d5c42a701f1c5574f3a6edaa40a20",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"example": {
"script_url": "example_dataset",
"columns": {
Expand Down Expand Up @@ -185,6 +200,7 @@
"ultrachat_200k": {
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
},
Expand All @@ -193,8 +209,7 @@
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
},
"formatting": "sharegpt"
}
},
"agent_instruct": {
"hf_hub_url": "THUDM/AgentInstruct",
Expand All @@ -204,6 +219,7 @@
"lmsys_chat": {
"hf_hub_url": "lmsys/lmsys-chat-1m",
"ms_hub_url": "AI-ModelScope/lmsys-chat-1m",
"formatting": "sharegpt",
"columns": {
"messages": "conversation"
},
Expand All @@ -212,8 +228,7 @@
"content_tag": "content",
"user_tag": "human",
"assistant_tag": "assistant"
},
"formatting": "sharegpt"
}
},
"evol_instruct": {
"hf_hub_url": "WizardLM/WizardLM_evol_instruct_V2_196k",
Expand Down Expand Up @@ -340,7 +355,7 @@
"history": "history"
}
},
"orca_dpo_de" : {
"orca_dpo_de": {
"hf_hub_url": "mayflowergmbh/intel_orca_dpo_pairs_de",
"ranking": true
},
Expand Down Expand Up @@ -414,4 +429,4 @@
},
"folder": "python"
}
}
}
Binary file added data/images/1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/images/2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/images/3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 71 additions & 0 deletions data/mllm_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[
{
"messages": [
{
"content": "Who are they?<image>",
"role": "user"
},
{
"content": "They're Kane and Gretzka from Bayern Munich.",
"role": "assistant"
},
{
"content": "What are they doing?",
"role": "user"
},
{
"content": "They are celebrating on the soccer field",
"role": "assistant"
}
],
"images": [
"images/1.jpg"
]
},
{
"messages": [
{
"content": "Who is he?<image>",
"role": "user"
},
{
"content": "He's Thomas Muller from Bayern Munich.",
"role": "assistant"
},
{
"content": "Why is he on the ground?",
"role": "user"
},
{
"content": "Because he's sliding on his knees to celebrate.",
"role": "assistant"
}
],
"images": [
"images/2.jpg"
]
},
{
"messages": [
{
"content": "Please describe this image<image>",
"role": "user"
},
{
"content": "Chinese astronaut Gui Haichao is giving a speech.",
"role": "assistant"
},
{
"content": "What has he accomplished?",
"role": "user"
},
{
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.",
"role": "assistant"
}
],
"images": [
"images/3.jpg"
]
}
]
32 changes: 32 additions & 0 deletions examples/mllm/sft_llava.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft_mm \
--do_train \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--dataset mllm_instruct_example \
--dataset_dir data \
--template default \
--finetuning_type lora \
--lora_target all \
--output_dir saves/llava-1.5-7b/lora/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 3 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 1 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 100 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--bf16
31 changes: 24 additions & 7 deletions src/llmtuner/data/aligner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union

Expand All @@ -13,8 +14,10 @@
from .parser import DatasetAttr


def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
Expand Down Expand Up @@ -44,12 +47,19 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)

return outputs


def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
tag_mapping = {
dataset_attr.user_tag: Role.USER.value,
dataset_attr.assistant_tag: Role.ASSISTANT.value,
Expand Down Expand Up @@ -84,6 +94,11 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)

return outputs

Expand All @@ -96,12 +111,13 @@ def align_dataset(
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "..."
tools: "...",
images: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr)
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)

column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
Expand All @@ -114,6 +130,7 @@ def align_dataset(
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
}
)
kwargs = {}
Expand Down
9 changes: 5 additions & 4 deletions src/llmtuner/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import os
from typing import TYPE_CHECKING, Literal, Union
from typing import TYPE_CHECKING, Literal, Optional, Union

from datasets import load_dataset, load_from_disk

Expand All @@ -16,7 +16,7 @@

if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer

from ..hparams import DataArguments, ModelArguments
Expand Down Expand Up @@ -115,11 +115,12 @@ def load_single_dataset(


def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
Expand Down Expand Up @@ -149,7 +150,7 @@ def get_dataset(

with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
Expand Down
3 changes: 2 additions & 1 deletion src/llmtuner/data/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class DatasetAttr:
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
system: Optional[str] = None
images: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
Expand Down Expand Up @@ -105,7 +106,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")

if "columns" in dataset_info[name]:
column_names = ["system"]
column_names = ["system", "images"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
Expand Down
Loading

0 comments on commit c20f750

Please sign in to comment.