From 89c82dce4fec6cf6dbae756accb1e40871fed3f9 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 11 Feb 2025 17:44:40 -0500 Subject: [PATCH] feat: support listing all for `llama stack list-providers` For ease of reading, sort the output rows by type. Signed-off-by: Ihar Hrachyshka --- llama_stack/cli/stack/list_providers.py | 26 +++++++++++++++++++------ llama_stack/cli/table.py | 7 ++++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index bd152c9800..bfe11aa2c7 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -21,15 +21,19 @@ def __init__(self, subparsers: argparse._SubParsersAction): self._add_arguments() self.parser.set_defaults(func=self._run_providers_list_cmd) - def _add_arguments(self): + @property + def providable_apis(self): from llama_stack.distribution.distribution import providable_apis - api_values = [api.value for api in providable_apis()] + return [api.value for api in providable_apis()] + + def _add_arguments(self): self.parser.add_argument( "api", type=str, - choices=api_values, - help="API to list providers for (one of: {})".format(api_values), + choices=self.providable_apis, + nargs="?", + help="API to list providers for. List all if not specified.", ) def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: @@ -37,20 +41,29 @@ def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: from llama_stack.distribution.distribution import Api, get_provider_registry all_providers = get_provider_registry() - providers_for_api = all_providers[Api(args.api)] + if args.api: + providers = [(args.api, all_providers[Api(args.api)])] + else: + providers = [(k.value, prov) for k, prov in all_providers.items()] + + providers = [p for api, p in providers if api in self.providable_apis] # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ + "API Type", "Provider Type", "PIP Package Dependencies", ] rows = [] - for spec in providers_for_api.values(): + + specs = [spec for p in providers for spec in p.values()] + for spec in specs: if spec.is_sample: continue rows.append( [ + spec.api.value, spec.provider_type, ",".join(spec.pip_packages), ] @@ -59,4 +72,5 @@ def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: rows, headers, separate_rows=True, + sort_by=(0, 1), ) diff --git a/llama_stack/cli/table.py b/llama_stack/cli/table.py index 50f54852bc..847719f817 100644 --- a/llama_stack/cli/table.py +++ b/llama_stack/cli/table.py @@ -6,6 +6,7 @@ import re import textwrap +from typing import Iterable from termcolor import cprint @@ -39,11 +40,15 @@ def wrap(text, width): return "\n".join(lines) -def print_table(rows, headers=None, separate_rows: bool = False): +def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()): def itemlen(item): return max([len(line) for line in strip_ansi_colors(item).split("\n")]) rows = [[x or "" for x in row] for row in rows] + + if sort_by: + rows.sort(key=lambda x: tuple(x[i] for i in sort_by)) + if not headers: col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)] else: