Skip to content

Commit

Permalink
change available_models to return List[Model], previously List[str] (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf authored Nov 15, 2024
1 parent 018eaca commit a10af2c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
import aiohttp
import json

from llama_index.core.bridge.pydantic import BaseModel


class Model(BaseModel):
id: str


class NVIDIAClient:
def __init__(
Expand All @@ -58,14 +64,14 @@ def _get_headers(self, stream: bool) -> Dict[str, str]:
headers["accept"] = "text/event-stream" if stream else "application/json"
return headers

def get_model_details(self) -> List[str]:
def get_model_details(self) -> List[Model]:
"""
Get model details.
Returns:
List of models
"""
return list(NVIDIA_MULTI_MODAL_MODELS.keys())
return [Model(id=model) for model in NVIDIA_MULTI_MODAL_MODELS]

def request(
self,
Expand Down Expand Up @@ -198,7 +204,7 @@ def metadata(self) -> MultiModalLLMMetadata:
)

@property
def available_models(self):
def available_models(self) -> List[Model]:
return self._client.get_model_details()

def _get_credential_kwargs(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ license = "MIT"
name = "llama-index-multi-modal-llms-nvidia"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"
version = "0.2.0"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def test_available_models() -> None:
models = NVIDIAMultiModal().available_models
assert models
assert isinstance(models, list)
assert all(isinstance(model, str) for model in models)
assert all(isinstance(model.id, str) for model in models)

0 comments on commit a10af2c

Please sign in to comment.