Skip to content

Commit

Permalink
testing None placeholder approach
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Saunders committed Dec 2, 2024
1 parent b2d6ebe commit 93588a3
Showing 1 changed file with 59 additions and 40 deletions.
99 changes: 59 additions & 40 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
CLI definition for various axolotl commands
"""
# pylint: disable=redefined-outer-name
import dataclasses
import hashlib
import json
import os
import subprocess # nosec B404
from dataclasses import fields as dataclass_fields
from pathlib import Path
from types import NoneType
from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin
Expand Down Expand Up @@ -140,38 +140,47 @@ def get_click_type(python_type: Type) -> Any:
return type_map.get(python_type, str)


def generate_click_options(model: Union[Type[BaseModel], Type]):
"""Generate Click options from a Pydantic model or dataclass."""
def add_options_from_dataclass(config_class: Type[Any]):
"""Create Click options from the fields of a dataclass."""

def decorator(function):
# Handle Pydantic models
if isinstance(model, type) and issubclass(model, BaseModel):
for field_name, field in model.model_fields.items():
# Convert snake_case to kebab-case for CLI
cli_name = f"--{field_name.replace('_', '-')}"
field_type = get_click_type(field.annotation)

# Handle boolean flags specially
if field_type is bool:
function = click.option(
cli_name, is_flag=True, help=field.description
)(function)
else:
function = click.option(
cli_name, type=field_type, help=field.description
)(function)

# Handle dataclasses
elif hasattr(model, "__dataclass_fields__"):
for field in dataclass_fields(model):
cli_name = f"--{field.name.replace('_', '-')}"
field_type = get_click_type(field.type)

if field_type is bool:
function = click.option(cli_name, is_flag=True)(function)
else:
function = click.option(cli_name, type=field_type)(function)
for field in reversed(dataclasses.fields(config_class)):
option_name = f"--{field.name.replace('_', '-')}"

if field.type == bool:
function = click.option(
option_name,
is_flag=True,
default=None,
help=field.metadata.get("description"),
)(function)
else:
function = click.option(
option_name,
type=field.type,
default=None,
help=field.metadata.get("description"),
)(function)
return function

return decorator


def add_options_from_config(config_class: Type[BaseModel]):
"""Create Click options from the fields of a Pydantic model."""

def decorator(function):
for name, field in reversed(config_class.model_fields.items()):
option_name = f"--{name.replace('_', '-')}"

if field.annotation == bool:
function = click.option(
option_name, is_flag=True, default=None, help=field.description
)(function)
else:
function = click.option(
option_name, default=None, help=field.description
)(function)
return function

return decorator
Expand All @@ -185,9 +194,11 @@ def decorator(function):
default=False,
help="Allow GPU usage during preprocessing",
)
@generate_click_options(PreprocessCliArgs)
@add_options_from_dataclass(PreprocessCliArgs)
def preprocess(config: str, use_gpu: bool, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if not use_gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = ""

Expand All @@ -201,13 +212,15 @@ def preprocess(config: str, use_gpu: bool, **kwargs):
@click.option(
"--accelerate",
is_flag=True,
default=False,
default=True,
help="Use accelerate launch for multi-GPU training",
)
@generate_click_options(AxolotlInputConfig)
@generate_click_options(TrainerCliArgs)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
Expand All @@ -232,10 +245,12 @@ def train(config: str, accelerate: bool, **kwargs):
@click.option("--base-model", help="Path to base model for non-LoRA models")
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@generate_click_options(AxolotlInputConfig)
@generate_click_options(TrainerCliArgs)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(config: str, accelerate: bool, **kwargs):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
Expand All @@ -258,10 +273,12 @@ def inference(config: str, accelerate: bool, **kwargs):
)
@click.option("--model-dir", help="Directory containing model weights to shard")
@click.option("--save-dir", help="Directory to save sharded weights")
@generate_click_options(AxolotlInputConfig)
@generate_click_options(TrainerCliArgs)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
Expand All @@ -284,10 +301,12 @@ def shard(config: str, accelerate: bool, **kwargs):
)
@click.option("--model-dir", help="Directory containing sharded weights")
@click.option("--save-path", help="Path to save merged weights")
@generate_click_options(AxolotlInputConfig)
@generate_click_options(TrainerCliArgs)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = [
"accelerate",
Expand Down

0 comments on commit 93588a3

Please sign in to comment.