Skip to content

Commit

Permalink
Major CLI test refator; adding remaining CLI codepath test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Saunders committed Dec 4, 2024
1 parent b27e2af commit 66428fd
Show file tree
Hide file tree
Showing 11 changed files with 524 additions and 305 deletions.
106 changes: 4 additions & 102 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,117 +3,18 @@
"""
# pylint: disable=redefined-outer-name
import dataclasses
import hashlib
import json
import subprocess # nosec B404
from pathlib import Path
from types import NoneType
from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin
from typing import Any, Optional, Type, Union, get_args, get_origin

import click
import requests
from pydantic import BaseModel

from axolotl.cli.utils import build_command, fetch_from_github
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig


def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
"""Build command list from base command and options."""
cmd = base_cmd.copy()

for key, value in options.items():
if value is None:
continue

key = key.replace("_", "-")

if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.extend([f"--{key}", str(value)])

return cmd


def fetch_from_github(dir_prefix: str, dest_dir: Optional[str] = None) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
dest_dir: Local destination directory
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"

# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)

# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}

if not files:
raise click.ClickException(f"No files found in {dir_prefix}")

# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest

# Keep track of processed files for summary
files_processed: Dict[str, List[str]] = {"new": [], "updated": [], "unchanged": []}

for file_path, remote_sha in files.items():
# Create full URLs and paths
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]

# Check if file exists and needs updating
if dest_file.exists():
# Git blob SHA is calculated with a header
with open(dest_file, "rb") as file:
content = file.read()

# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()

if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
files_processed["unchanged"].append(file_path)
continue

print(f"Updating {file_path}")
files_processed["updated"].append(file_path)
else:
print(f"Downloading {file_path}")
files_processed["new"].append(file_path)

# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)

# Download and save file
response = requests.get(raw_url, timeout=30)
response.raise_for_status()

with open(dest_file, "wb") as file:
file.write(response.content)

# Print summary
print("\nSync Summary:")
print(f"New files: {len(files_processed['new'])}")
print(f"Updated files: {len(files_processed['updated'])}")
print(f"Unchanged files: {len(files_processed['unchanged'])}")


@click.group()
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
Expand Down Expand Up @@ -176,6 +77,7 @@ def decorator(function):
@cli.command()
@click.argument("config", type=str)
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
Expand Down Expand Up @@ -243,7 +145,7 @@ def inference(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=str)
@click.option(
"--accelerate/--no-accelerate",
default=True,
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option("--model-dir", help="Directory containing model weights to shard")
Expand Down
4 changes: 0 additions & 4 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print("in do_cli")

# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
Expand All @@ -42,8 +40,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True
)

print(f"parsed_cli_args: {parsed_cli_args}")

if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
Expand Down
106 changes: 106 additions & 0 deletions src/axolotl/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
utility methods for axoltl CLI
"""
import hashlib
import json
from pathlib import Path
from typing import Any, Dict, List, Optional

import click
import requests


def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
"""Build command list from base command and options."""
cmd = base_cmd.copy()

for key, value in options.items():
if value is None:
continue

key = key.replace("_", "-")

if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.extend([f"--{key}", str(value)])

return cmd


def fetch_from_github(dir_prefix: str, dest_dir: Optional[str] = None) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
dest_dir: Local destination directory
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"

# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)

# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}

if not files:
raise click.ClickException(f"No files found in {dir_prefix}")

# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest

# Keep track of processed files for summary
files_processed: Dict[str, List[str]] = {"new": [], "updated": [], "unchanged": []}

for file_path, remote_sha in files.items():
# Create full URLs and paths
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]

# Check if file exists and needs updating
if dest_file.exists():
# Git blob SHA is calculated with a header
with open(dest_file, "rb") as file:
content = file.read()

# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()

if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
files_processed["unchanged"].append(file_path)
continue

print(f"Updating {file_path}")
files_processed["updated"].append(file_path)
else:
print(f"Downloading {file_path}")
files_processed["new"].append(file_path)

# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)

# Download and save file
response = requests.get(raw_url, timeout=30)
response.raise_for_status()

with open(dest_file, "wb") as file:
file.write(response.content)

# Print summary
print("\nSync Summary:")
print(f"New files: {len(files_processed['new'])}")
print(f"Updated files: {len(files_processed['updated'])}")
print(f"Unchanged files: {len(files_processed['unchanged'])}")
35 changes: 35 additions & 0 deletions tests/cli/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,45 @@
shared pytest fixtures for cli module
"""

import shutil
from pathlib import Path

import pytest
from click.testing import CliRunner

VALID_TEST_CONFIG = """
base_model: HuggingFaceTB/SmolLM2-135M
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
sequence_len: 2048
max_steps: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1e-3
special_tokens:
pad_token: <|end_of_text|>
"""


@pytest.fixture
def cli_runner():
return CliRunner()


@pytest.fixture
def config_path(tmp_path):
"""Creates a temporary config file"""
path = tmp_path / "config.yml"
path.write_text(VALID_TEST_CONFIG)

return path


@pytest.fixture(autouse=True)
def cleanup_model_out():
yield

# Clean up after the test
if Path("model-out").exists():
shutil.rmtree("model-out")
Loading

0 comments on commit 66428fd

Please sign in to comment.