Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI Implementation with Click #2107

Merged
merged 30 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6e97854
Initial CLI implementation with click package
Nov 27, 2024
034f2bf
Adding fetch command for pulling examples and deepspeed configs
Nov 27, 2024
2992b11
Automating default options for CliArgs classes
Nov 27, 2024
ebd8e45
Mimicking existing no config behavior
Nov 27, 2024
02e432e
bugfix in choose_config
Nov 27, 2024
f4109b1
Updating fetch to sync instead of re-download
Nov 27, 2024
f0790b7
bugfix
Nov 27, 2024
422d91b
isort fix
Nov 27, 2024
b11fba4
fixing yaml isort order
Dec 2, 2024
0016e1f
pre-commit fixes
Dec 2, 2024
49d6cda
simplifying argument parsing -- pass through kwargs to do_cli
Dec 2, 2024
44caf3a
make accelerate launch default for non-preprocess commands
Dec 2, 2024
276198f
fixing arg handling
Dec 2, 2024
8a39b7f
testing None placeholder approach
Dec 2, 2024
e5c2a51
removing hacky --use-gpu argument to preprocess command
Dec 2, 2024
507513e
Adding brief README documentation for CLI
Dec 6, 2024
fb5431b
remove (New)
Dec 2, 2024
ac05e29
Initial CLI pytest tests
Dec 3, 2024
363b5a0
progress on CLI pytest
djsaunde Dec 3, 2024
59abdc6
adding inference CLI tests; cleanup
Dec 3, 2024
5d18989
Refactor train CLI tests to remove various mocking
Dec 4, 2024
a8a7819
Major CLI test refator; adding remaining CLI codepath test coverage
Dec 4, 2024
ff0ddf1
pytest fixes
Dec 4, 2024
fe244cf
remove integration markers
Dec 4, 2024
8f31101
parallelizing examples, deepspeed config downloads; rename test to ma…
Dec 4, 2024
1cd2647
moving cli pytest due to isolation issues; cleanup
Dec 5, 2024
d5f49a9
testing fixes; various minor improvements
Dec 5, 2024
5df0b2f
fix
Dec 5, 2024
cafc208
tests fix
Dec 5, 2024
996bf34
Update tests/cli/conftest.py
djsaunde Dec 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/tests-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ jobs:

- name: Run tests
run: |
pytest --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
djsaunde marked this conversation as resolved.
Show resolved Hide resolved

- name: cleanup pip cache
run: |
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ jobs:

- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
djsaunde marked this conversation as resolved.
Show resolved Hide resolved

- name: cleanup pip cache
run: |
Expand Down Expand Up @@ -123,7 +124,8 @@ jobs:

- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
djsaunde marked this conversation as resolved.
Show resolved Hide resolved

- name: cleanup pip cache
run: |
Expand Down
41 changes: 40 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,46 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```

### Axolotl CLI

If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:

```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml

# finetune lora
axolotl train examples/openllama-3b/lora.yml

# inference
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out"

# gradio
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out" --gradio

# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```

We've also added a new command for fetching `examples` and `deepspeed_configs` to your
local machine. This will come in handy when installing `axolotl` from PyPI.

```bash
# Fetch example YAML files (stores in "examples/" folder)
axolotl fetch examples

# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
axolotl fetch deepspeed_configs

# Optionally, specify a destination folder
axolotl fetch examples --dest path/to/folder
```

## Badge ❤🏷️

Building something cool with Axolotl? Consider adding a badge to your model card.
Expand Down Expand Up @@ -206,7 +246,6 @@ Thanks to all of our contributors to date. Help drive open source AI progress fo
❌: not supported
❓: untested


## Advanced Setup

### Environment
Expand Down
5 changes: 3 additions & 2 deletions cicd/cicd.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash
set -e

pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
1 change: 1 addition & 0 deletions outputs
djsaunde marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def parse_requirements():
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
entry_points={
"console_scripts": [
"axolotl=axolotl.cli.main:main",
],
},
extras_require={
"flash-attn": [
"flash-attn==2.7.0.post2",
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def choose_config(path: Path):

if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
djsaunde marked this conversation as resolved.
Show resolved Hide resolved
return str(yaml_files[0])

print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
Expand All @@ -391,7 +391,7 @@ def choose_config(path: Path):
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
CLI to run inference on a trained model
"""
from pathlib import Path
from typing import Union

import fire
import transformers
Expand All @@ -16,7 +17,7 @@
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
Expand Down
231 changes: 231 additions & 0 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional

import click

from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
)
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig


@click.group()
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

from axolotl.cli.preprocess import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.train import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(
config: str,
accelerate: bool,
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli

if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.merge_sharded_fsdp_weights import do_cli

do_cli(config=config, **kwargs)

djsaunde marked this conversation as resolved.
Show resolved Hide resolved

@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing the LoRA model to merge",
)
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir

from axolotl.cli.merge_lora import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.

Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
"""
fetch_from_github(f"{directory}/", dest)


def main():
cli()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path
from typing import Union

import fire
import transformers
Expand All @@ -11,7 +12,7 @@
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
Expand Down
Loading
Loading