Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change available_models to return List[Model], previously List[str] #16968

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
import aiohttp
import json

from llama_index.core.bridge.pydantic import BaseModel


class Model(BaseModel):
id: str


class NVIDIAClient:
def __init__(
Expand All @@ -58,14 +64,14 @@ def _get_headers(self, stream: bool) -> Dict[str, str]:
headers["accept"] = "text/event-stream" if stream else "application/json"
return headers

def get_model_details(self) -> List[str]:
def get_model_details(self) -> List[Model]:
"""
Get model details.

Returns:
List of models
"""
return list(NVIDIA_MULTI_MODAL_MODELS.keys())
return [Model(id=model) for model in NVIDIA_MULTI_MODAL_MODELS]

def request(
self,
Expand Down Expand Up @@ -198,7 +204,7 @@ def metadata(self) -> MultiModalLLMMetadata:
)

@property
def available_models(self):
def available_models(self) -> List[Model]:
return self._client.get_model_details()

def _get_credential_kwargs(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ license = "MIT"
name = "llama-index-multi-modal-llms-nvidia"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"
version = "0.2.0"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def test_available_models() -> None:
models = NVIDIAMultiModal().available_models
assert models
assert isinstance(models, list)
assert all(isinstance(model, str) for model in models)
assert all(isinstance(model.id, str) for model in models)
Loading