Skip to content

Commit

Permalink
minor fix on report
Browse files Browse the repository at this point in the history
  • Loading branch information
sixianyi0721 committed Jan 27, 2025
1 parent 54cb729 commit 72d9624
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions tests/client-sdk/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
from termcolor import cprint


def featured_models_repo_names():
def featured_models():
models = [
*llama3_instruct_models(),
*llama3_1_instruct_models(),
*llama3_2_instruct_models(),
*llama3_3_instruct_models(),
*safety_models(),
]
return [model.huggingface_repo for model in models if not model.variant]
return {model.huggingface_repo: model for model in models if not model.variant}


SUPPORTED_MODELS = {
Expand Down Expand Up @@ -97,9 +97,10 @@ def __init__(self, report_path: Optional[str] = None):
if not config_path.exists():
raise ValueError(f"Config file {config_path} does not exist")
self.output_path = Path(config_path.parent / "report.md")
self.distro_name = None
elif os.environ.get("LLAMA_STACK_BASE_URL"):
url = get_env_or_fail("LLAMA_STACK_BASE_URL")
self.image_name = urlparse(url).netloc
self.distro_name = urlparse(url).netloc
if report_path is None:
raise ValueError(
"Report path must be provided when LLAMA_STACK_BASE_URL is set"
Expand Down Expand Up @@ -128,34 +129,34 @@ def pytest_runtest_logreport(self, report):

def pytest_sessionfinish(self, session):
report = []
report.append(f"# Report for {self.image_name} distribution")
report.append(f"# Report for {self.distro_name} distribution")
report.append("\n## Supported Models")

header = f"| Model Descriptor | {self.image_name} |"
header = f"| Model Descriptor | {self.distro_name} |"
dividor = "|:---|:---|"

report.append(header)
report.append(dividor)

rows = []
if self.image_name in SUPPORTED_MODELS:
if self.distro_name in SUPPORTED_MODELS:
for model in all_registered_models():
if (
"Instruct" not in model.core_model_id.value
and "Guard" not in model.core_model_id.value
) or (model.variant):
continue
row = f"| {model.core_model_id.value} |"
if model.core_model_id.value in SUPPORTED_MODELS[self.image_name]:
if model.core_model_id.value in SUPPORTED_MODELS[self.distro_name]:
row += " ✅ |"
else:
row += " ❌ |"
rows.append(row)
else:
supported_models = {m.identifier for m in self.client.models.list()}
for model in featured_models_repo_names():
row = f"| {model} |"
if model in supported_models:
for hf_name, model in featured_models().items():
row = f"| {model.core_model_id.value} |"
if hf_name in supported_models:
row += " ✅ |"
else:
row += " ❌ |"
Expand Down Expand Up @@ -224,8 +225,8 @@ def pytest_runtest_makereport(self, item, call):

if self.client is None and "llama_stack_client" in item.funcargs:
self.client = item.funcargs["llama_stack_client"]
self.image_name = (
self.image_name or self.client.async_client.config.image_name
self.distro_name = (
self.distro_name or self.client.async_client.config.image_name
)

def _print_result_icon(self, result):
Expand Down

0 comments on commit 72d9624

Please sign in to comment.