Skip to content

Commit

Permalink
ci: add mypy for static type checking
Browse files Browse the repository at this point in the history
- Enable mypy to run in the CI on a subset of the repository
- Fix a few mypy errors
- Run mypy from pre-commit

Signed-off-by: Sébastien Han <[email protected]>
  • Loading branch information
leseb committed Feb 20, 2025
1 parent fb6a3ef commit fc1dd9e
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 93 deletions.
27 changes: 15 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,26 @@ repos:
hooks:
- id: uv-export
args: [
"--frozen",
"--no-hashes",
"--no-emit-project",
"--frozen",
"--no-hashes",
"--no-emit-project",
"--output-file=requirements.txt"
]
files: ^pyproject\.toml$
- id: uv-sync

# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.14.0
# hooks:
# - id: mypy
# additional_dependencies:
# - types-requests
# - types-setuptools
# - pydantic
# args: [--ignore-missing-imports]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
- uv==0.6.2
- mypy
- pytest
- rich
- types-requests
- pydantic
pass_filenames: false

# - repo: https://github.com/jsh9/pydoclint
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
Expand Down
15 changes: 9 additions & 6 deletions llama_stack/apis/common/type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,18 @@ class DialogType(BaseModel):
name="ParamType",
)

"""
# TODO: recursive definition of ParamType in these containers
# will cause infinite recursion in OpenAPI generation script
# since we are going with ChatCompletionInputType and CompletionInputType
# we don't need to worry about ArrayType/ObjectType/UnionType for now
# ArrayType.model_rebuild()
# ObjectType.model_rebuild()
# UnionType.model_rebuild()
ArrayType.model_rebuild()
ObjectType.model_rebuild()
UnionType.model_rebuild()
# class CustomType(BaseModel):
# type: Literal["custom"] = "custom"
# validator_class: str
class CustomType(BaseModel):
pylint: disable=syntax-error
type: Literal["custom"] = "custom"
validator_class: str
"""
11 changes: 8 additions & 3 deletions llama_stack/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
# the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Callable, List, Optional, TypeVar
from typing import Any, Callable, List, Optional, Protocol, TypeVar

from .strong_typing.schema import json_schema_type, register_schema # noqa: F401

T = TypeVar("T")


@dataclass
class WebMethod:
Expand All @@ -21,6 +19,13 @@ class WebMethod:
method: Optional[str] = None


class HasWebMethod(Protocol):
__webmethod__: WebMethod


T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol


def webmethod(
route: Optional[str] = None,
method: Optional[str] = None,
Expand Down
6 changes: 3 additions & 3 deletions llama_stack/scripts/distro_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from functools import partial
from pathlib import Path
from typing import Iterator
from typing import Iterable

from rich.progress import Progress, SpinnerColumn, TextColumn

Expand Down Expand Up @@ -39,7 +39,7 @@ def changed_paths(self):
return self._changed_paths


def find_template_dirs(templates_dir: Path) -> Iterator[Path]:
def find_template_dirs(templates_dir: Path) -> Iterable[Path]:
"""Find immediate subdirectories in the templates folder."""
if not templates_dir.exists():
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
Expand Down Expand Up @@ -90,7 +90,7 @@ def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
return has_changes


def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]:
def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]:
try:
module_name = f"llama_stack.templates.{template_dir.name}"
module = importlib.import_module(module_name)
Expand Down
2 changes: 1 addition & 1 deletion llama_stack/scripts/run_client_sdk_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main(parser: argparse.ArgumentParser):
pytest_args,
"-s",
"-v",
REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH,
str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH),
]
)

Expand Down
23 changes: 23 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,26 @@ ignore = [
"B007",
"B008",
]

[tool.mypy]
mypy_path = ["llama_stack"]
packages = ["llama_stack"]
disable_error_code = []
warn_return_any = true
# # honor excludes by not following there through imports
follow_imports = "silent"
exclude = [
# As we fix more and more of these, we should remove them from the list
"llama_stack/providers",
"llama_stack/distribution",
"llama_stack/apis",
"llama_stack/cli",
"llama_stack/models",
"llama_stack/strong_typing",
"llama_stack/templates",
]

[[tool.mypy.overrides]]
# packages that lack typing annotations, do not have stubs, or are unavailable.
module = ["llama_models.*", "yaml", "fire"]
ignore_missing_imports = true
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fsspec==2025.2.0
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
huggingface-hub==0.28.1
huggingface-hub==0.29.0
idna==3.10
jinja2==3.1.5
jsonschema==4.23.0
Expand Down
Loading

0 comments on commit fc1dd9e

Please sign in to comment.