Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(vlm): handle legacy conversation data format and check image in data #2018

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 120 additions & 11 deletions src/axolotl/utils/collators/mm_chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""
Collators for multi-modal chat messages and packing
"""

from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union

from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
Expand Down Expand Up @@ -30,8 +32,8 @@ def __post_init__(self):
raise ValueError("Packing is currently not supported.")

def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.

return self.__class__.process_rows(
Expand All @@ -46,22 +48,129 @@ def process_rows(examples, processor, chat_template, max_images, length_only=Fal
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point

def _preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.

Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'

Args:
examples: list of conversation dictionaries

Returns:
dict in OpenAI format with 'messages' key

Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}

winglian marked this conversation as resolved.
Show resolved Hide resolved
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)

def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]

# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}

processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)

# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))

else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)

return processed_examples

def _process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.

Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)

Returns:
Either None (if no images) or List[Image objects] (if all examples have images)

Raises:
ValueError: If there's a mix of None and non-None images
"""

def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images

images = [get_image(example) for example in examples]

# Count None and non-None images
none_count = sum(1 for img in images if img is None)

# All images are None
if none_count == len(images):
return None

# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)

# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]

return images

# Preprocess the examples
examples = _preprocess(examples)

# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]

if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
images = _process_images(examples, max_images=max_images)

# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
Expand Down
116 changes: 116 additions & 0 deletions tests/e2e/test_llama_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
E2E tests for lora llama
"""

import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestLlamaVision(unittest.TestCase):
"""
Test case for Llama Vision models
"""

@with_temp_dir
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
winglian marked this conversation as resolved.
Show resolved Hide resolved
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

@with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "axolotl-ai-co/llava-instruct-mix-vsft-small",
"type": "chat_template",
"split": "train",
"field_messages": "messages",
},
],
"num_epochs": 1,
NanoCode012 marked this conversation as resolved.
Show resolved Hide resolved
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
1 change: 1 addition & 0 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_lora(self, temp_dir):
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
}
)
normalize_config(cfg)
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_optimi_adamw(self, temp_dir):
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/test_relora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_relora(self, temp_dir):
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
Expand Down