diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 817f032635..6deb14e76a 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -2,11 +2,11 @@ CLI definition for various axolotl commands """ # pylint: disable=redefined-outer-name +import dataclasses import hashlib import json import os import subprocess # nosec B404 -from dataclasses import fields as dataclass_fields from pathlib import Path from types import NoneType from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin @@ -140,38 +140,47 @@ def get_click_type(python_type: Type) -> Any: return type_map.get(python_type, str) -def generate_click_options(model: Union[Type[BaseModel], Type]): - """Generate Click options from a Pydantic model or dataclass.""" +def add_options_from_dataclass(config_class: Type[Any]): + """Create Click options from the fields of a dataclass.""" def decorator(function): - # Handle Pydantic models - if isinstance(model, type) and issubclass(model, BaseModel): - for field_name, field in model.model_fields.items(): - # Convert snake_case to kebab-case for CLI - cli_name = f"--{field_name.replace('_', '-')}" - field_type = get_click_type(field.annotation) - - # Handle boolean flags specially - if field_type is bool: - function = click.option( - cli_name, is_flag=True, help=field.description - )(function) - else: - function = click.option( - cli_name, type=field_type, help=field.description - )(function) - - # Handle dataclasses - elif hasattr(model, "__dataclass_fields__"): - for field in dataclass_fields(model): - cli_name = f"--{field.name.replace('_', '-')}" - field_type = get_click_type(field.type) - - if field_type is bool: - function = click.option(cli_name, is_flag=True)(function) - else: - function = click.option(cli_name, type=field_type)(function) + for field in reversed(dataclasses.fields(config_class)): + option_name = f"--{field.name.replace('_', '-')}" + + if field.type == bool: + function = click.option( + option_name, + is_flag=True, + default=None, + help=field.metadata.get("description"), + )(function) + else: + function = click.option( + option_name, + type=field.type, + default=None, + help=field.metadata.get("description"), + )(function) + return function + + return decorator + +def add_options_from_config(config_class: Type[BaseModel]): + """Create Click options from the fields of a Pydantic model.""" + + def decorator(function): + for name, field in reversed(config_class.model_fields.items()): + option_name = f"--{name.replace('_', '-')}" + + if field.annotation == bool: + function = click.option( + option_name, is_flag=True, default=None, help=field.description + )(function) + else: + function = click.option( + option_name, default=None, help=field.description + )(function) return function return decorator @@ -185,9 +194,11 @@ def decorator(function): default=False, help="Allow GPU usage during preprocessing", ) -@generate_click_options(PreprocessCliArgs) +@add_options_from_dataclass(PreprocessCliArgs) def preprocess(config: str, use_gpu: bool, **kwargs): """Preprocess datasets before training.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if not use_gpu: os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -201,13 +212,15 @@ def preprocess(config: str, use_gpu: bool, **kwargs): @click.option( "--accelerate", is_flag=True, - default=False, + default=True, help="Use accelerate launch for multi-GPU training", ) -@generate_click_options(AxolotlInputConfig) -@generate_click_options(TrainerCliArgs) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) def train(config: str, accelerate: bool, **kwargs): """Train or fine-tune a model.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] if config: @@ -232,10 +245,12 @@ def train(config: str, accelerate: bool, **kwargs): @click.option("--base-model", help="Path to base model for non-LoRA models") @click.option("--gradio", is_flag=True, help="Launch Gradio interface") @click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode") -@generate_click_options(AxolotlInputConfig) -@generate_click_options(TrainerCliArgs) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) def inference(config: str, accelerate: bool, **kwargs): """Run inference with a trained model.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] if config: @@ -258,10 +273,12 @@ def inference(config: str, accelerate: bool, **kwargs): ) @click.option("--model-dir", help="Directory containing model weights to shard") @click.option("--save-dir", help="Directory to save sharded weights") -@generate_click_options(AxolotlInputConfig) -@generate_click_options(TrainerCliArgs) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) def shard(config: str, accelerate: bool, **kwargs): """Shard model weights.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"] if config: @@ -284,10 +301,12 @@ def shard(config: str, accelerate: bool, **kwargs): ) @click.option("--model-dir", help="Directory containing sharded weights") @click.option("--save-path", help="Path to save merged weights") -@generate_click_options(AxolotlInputConfig) -@generate_click_options(TrainerCliArgs) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): """Merge sharded FSDP model weights.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if accelerate: base_cmd = [ "accelerate",