Skip to content

Commit

Permalink
feat: support listing all for llama stack list-providers
Browse files Browse the repository at this point in the history
For ease of reading, sort the output rows by type.

Signed-off-by: Ihar Hrachyshka <[email protected]>
  • Loading branch information
booxter committed Feb 11, 2025
1 parent ab7f802 commit c493afe
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
39 changes: 25 additions & 14 deletions llama_stack/cli/stack/list_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the root directory of this source tree.

import argparse
import itertools

from llama_stack.cli.subcommand import Subcommand

Expand All @@ -21,23 +22,30 @@ def __init__(self, subparsers: argparse._SubParsersAction):
self._add_arguments()
self.parser.set_defaults(func=self._run_providers_list_cmd)

def _add_arguments(self):
@property
def providable_apis(self):
from llama_stack.distribution.distribution import providable_apis

api_values = [api.value for api in providable_apis()]
return [api.value for api in providable_apis()]

def _add_arguments(self):
self.parser.add_argument(
"api",
type=str,
choices=api_values,
help="API to list providers for (one of: {})".format(api_values),
choices=self.providable_apis,
nargs="?",
help="API to list providers for. List all if not specified.",
)

def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
from llama_stack.cli.table import print_table
from llama_stack.distribution.distribution import Api, get_provider_registry

all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)]
if args.api:
providers = [all_providers[Api(args.api)]]
else:
providers = [prov for k, prov in all_providers.items() if k.value in self.providable_apis]

# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
Expand All @@ -46,17 +54,20 @@ def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
]

rows = []
for spec in providers_for_api.values():
if spec.provider_type == "sample":
continue
rows.append(
[
spec.provider_type,
",".join(spec.pip_packages),
]
)

for provider in itertools.chain(p.values() for p in providers):
for spec in provider:
if spec.provider_type == "sample":
continue
rows.append(
[
spec.provider_type,
",".join(spec.pip_packages),
]
)
print_table(
rows,
headers,
separate_rows=True,
sort_by=0,
)
6 changes: 5 additions & 1 deletion llama_stack/cli/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@ def wrap(text, width):
return "\n".join(lines)


def print_table(rows, headers=None, separate_rows: bool = False):
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: int = -1):
def itemlen(item):
return max([len(line) for line in strip_ansi_colors(item).split("\n")])

rows = [[x or "" for x in row] for row in rows]

if sort_by >= 0:
rows = sorted(rows, key=lambda row: row[sort_by])

if not headers:
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
else:
Expand Down

0 comments on commit c493afe

Please sign in to comment.