Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support listing all for llama stack list-providers #1056

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions llama_stack/cli/stack/list_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,49 @@ def __init__(self, subparsers: argparse._SubParsersAction):
self._add_arguments()
self.parser.set_defaults(func=self._run_providers_list_cmd)

def _add_arguments(self):
@property
def providable_apis(self):
from llama_stack.distribution.distribution import providable_apis

api_values = [api.value for api in providable_apis()]
return [api.value for api in providable_apis()]

def _add_arguments(self):
self.parser.add_argument(
"api",
type=str,
choices=api_values,
help="API to list providers for (one of: {})".format(api_values),
choices=self.providable_apis,
nargs="?",
help="API to list providers for. List all if not specified.",
)

def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table
from llama_stack.distribution.distribution import Api, get_provider_registry

all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)]
if args.api:
providers = [(args.api, all_providers[Api(args.api)])]
else:
providers = [(k.value, prov) for k, prov in all_providers.items()]

providers = [p for api, p in providers if api in self.providable_apis]

# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"API Type",
"Provider Type",
"PIP Package Dependencies",
]

rows = []
for spec in providers_for_api.values():

specs = [spec for p in providers for spec in p.values()]
for spec in specs:
if spec.is_sample:
continue
rows.append(
[
spec.api.value,
spec.provider_type,
",".join(spec.pip_packages),
]
Expand All @@ -59,4 +72,5 @@ def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
rows,
headers,
separate_rows=True,
sort_by=(0, 1),
)
7 changes: 6 additions & 1 deletion llama_stack/cli/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import re
import textwrap
from typing import Iterable

from termcolor import cprint

Expand Down Expand Up @@ -39,11 +40,15 @@ def wrap(text, width):
return "\n".join(lines)


def print_table(rows, headers=None, separate_rows: bool = False):
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
def itemlen(item):
return max([len(line) for line in strip_ansi_colors(item).split("\n")])

rows = [[x or "" for x in row] for row in rows]

if sort_by:
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))

if not headers:
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
else:
Expand Down