diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index b9b67f8750..6f8a64ad85 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -1,8 +1,10 @@ """ Collators for multi-modal chat messages and packing """ + +from copy import deepcopy from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from PIL import Image from transformers import PreTrainedTokenizerBase, ProcessorMixin @@ -30,8 +32,8 @@ def __post_init__(self): raise ValueError("Packing is currently not supported.") def torch_call( - self, examples: List[Union[List[int], Any, Dict[str, Any]]] - ) -> Dict[str, Any]: + self, examples: list[Union[list[int], Any, dict[str, Any]]] + ) -> dict[str, Any]: # Handle dict or lists with proper padding and conversion to tensor. return self.__class__.process_rows( @@ -46,6 +48,120 @@ def process_rows(examples, processor, chat_template, max_images, length_only=Fal # *** This is COPIED from the trl example sft_vlm.py code *** # use this as a starting point + def _preprocess(examples: list[dict]) -> list[dict]: + """ + Preprocess conversation examples to ensure consistent format. + + Converts different conversation formats to OpenAI format with 'messages'. + Supports two formats: + 1. OpenAI format with 'messages' + 2. Legacy format with 'conversations' + + Args: + examples: list of conversation dictionaries + + Returns: + dict in OpenAI format with 'messages' key + + Raises: + ValueError: If the conversation format is not supported + """ + role_mapping = { + "human": "user", + "gpt": "assistant", + } + + def normalize_role(role: str) -> str: + """Normalize role names to OpenAI format. Default to original role if not found.""" + return role_mapping.get(role, role) + + def convert_legacy_format(example: dict) -> dict: + """Convert legacy 'conversations' format to OpenAI 'messages' format.""" + messages = [ + { + "role": normalize_role(convo["from"]), + "content": convo["value"], + } + for convo in example["conversations"] + ] + + # Create new dict without 'conversations' key + result = deepcopy(example) + result.pop("conversations") + return {"messages": messages, **result} + + processed_examples = [] + for example in examples: + # OpenAI format + if "messages" in example: + processed_examples.append(example) + + # Legacy format + elif "conversations" in example: + processed_examples.append(convert_legacy_format(example)) + + else: + raise ValueError( + "Only `messages` and `conversations` message keys are currently supported." + ) + + return processed_examples + + def _process_images(examples, max_images): + """ + Process images from examples, ensuring consistency in image presence and applying max_images limit. + + Args: + examples: List of dictionaries that may contain 'images' key + max_images: Maximum number of images to keep per example (0 means no limit) + + Returns: + Either None (if no images) or List[Image objects] (if all examples have images) + + Raises: + ValueError: If there's a mix of None and non-None images + """ + + def get_image(example): + if "images" not in example: + return None + images = example["images"] + if isinstance(images, str): + return Image.open(images) + return images + + images = [get_image(example) for example in examples] + + # Count None and non-None images + none_count = sum(1 for img in images if img is None) + + # All images are None + if none_count == len(images): + return None + + # Mix of None and non-None images + if none_count > 0: + raise ValueError( + "All images should be either None or not None. " + "Please provide images for all examples or None." + ) + + # Apply max_images limit if specified + if max_images > 0: + images = [ + ( + img_batch[:max_images] + if isinstance(img_batch, (list, tuple)) + else img_batch + ) + for img_batch in images + ] + + return images + + # Preprocess the examples + examples = _preprocess(examples) + # Get the texts and images, and apply the chat template texts = [ processor.apply_chat_template( @@ -53,15 +169,8 @@ def process_rows(examples, processor, chat_template, max_images, length_only=Fal ) for example in examples ] - images = [ - Image.open(example["images"]) - if isinstance(example["images"], str) - else example["images"] - for example in examples - ] - if max_images > 0: - images = [img_batch[:max_images] for img_batch in images] + images = _process_images(examples, max_images=max_images) # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py new file mode 100644 index 0000000000..1d583a3267 --- /dev/null +++ b/tests/e2e/test_llama_vision.py @@ -0,0 +1,116 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestLlamaVision(unittest.TestCase): + """ + Test case for Llama Vision models + """ + + @with_temp_dir + def test_lora_llama_vision_text_only_dataset(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "axolotl-ai-co/Llama-3.2-39M-Vision", + "processor_type": "AutoProcessor", + "skip_prepare_dataset": True, + "remove_unused_columns": False, + "sample_packing": False, + "sequence_len": 1024, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", + "val_set_size": 0, + "chat_template": "llama3_2_vision", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_safetensors": True, + "bf16": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + + @with_temp_dir + def test_lora_llama_vision_multimodal_dataset(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "axolotl-ai-co/Llama-3.2-39M-Vision", + "processor_type": "AutoProcessor", + "skip_prepare_dataset": True, + "remove_unused_columns": False, + "sample_packing": False, + "sequence_len": 1024, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", + "val_set_size": 0, + "chat_template": "llama3_2_vision", + "datasets": [ + { + "path": "axolotl-ai-co/llava-instruct-mix-vsft-small", + "type": "chat_template", + "split": "train", + "field_messages": "messages", + }, + ], + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_safetensors": True, + "bf16": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 4c6fdaaa91..d06be60b96 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -57,6 +57,7 @@ def test_lora(self, temp_dir): "learning_rate": 0.00001, "optimizer": "adamw_torch", "lr_scheduler": "cosine", + "max_steps": 20, } ) normalize_config(cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 119dd3d7cf..258cdb1c18 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -56,6 +56,7 @@ def test_optimi_adamw(self, temp_dir): "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "optimi_adamw", + "max_steps": 5, "lr_scheduler": "cosine", } ) diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py index 4ba130c9dc..4e428038c8 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/test_relora_llama.py @@ -57,6 +57,7 @@ def test_relora(self, temp_dir): "output_dir": temp_dir, "learning_rate": 0.00001, "optimizer": "adamw_torch", + "max_steps": 5, "lr_scheduler": "cosine", } )