From a10af2c3328432e1389f04597a6c2d0e9c9bbbb0 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 15 Nov 2024 16:48:06 -0500 Subject: [PATCH] change available_models to return List[Model], previously List[str] (#16968) --- .../llama_index/multi_modal_llms/nvidia/base.py | 12 +++++++++--- .../pyproject.toml | 2 +- .../tests/test_available_models.py | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/llama_index/multi_modal_llms/nvidia/base.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/llama_index/multi_modal_llms/nvidia/base.py index 02cea8ed17ed7..e14dfa616f510 100644 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/llama_index/multi_modal_llms/nvidia/base.py +++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/llama_index/multi_modal_llms/nvidia/base.py @@ -39,6 +39,12 @@ import aiohttp import json +from llama_index.core.bridge.pydantic import BaseModel + + +class Model(BaseModel): + id: str + class NVIDIAClient: def __init__( @@ -58,14 +64,14 @@ def _get_headers(self, stream: bool) -> Dict[str, str]: headers["accept"] = "text/event-stream" if stream else "application/json" return headers - def get_model_details(self) -> List[str]: + def get_model_details(self) -> List[Model]: """ Get model details. Returns: List of models """ - return list(NVIDIA_MULTI_MODAL_MODELS.keys()) + return [Model(id=model) for model in NVIDIA_MULTI_MODAL_MODELS] def request( self, @@ -198,7 +204,7 @@ def metadata(self) -> MultiModalLLMMetadata: ) @property - def available_models(self): + def available_models(self) -> List[Model]: return self._client.get_model_details() def _get_credential_kwargs(self) -> Dict[str, Any]: diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/pyproject.toml b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/pyproject.toml index 9555020026cb1..32d8121d7eaa6 100644 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/pyproject.toml +++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/pyproject.toml @@ -27,7 +27,7 @@ license = "MIT" name = "llama-index-multi-modal-llms-nvidia" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.1.0" +version = "0.2.0" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/tests/test_available_models.py b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/tests/test_available_models.py index 829622e872440..26d14ad038abf 100644 --- a/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/tests/test_available_models.py +++ b/llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-nvidia/tests/test_available_models.py @@ -8,4 +8,4 @@ def test_available_models() -> None: models = NVIDIAMultiModal().available_models assert models assert isinstance(models, list) - assert all(isinstance(model, str) for model in models) + assert all(isinstance(model.id, str) for model in models)