From 2dd91d2a34b8ae5df751a764b32114d1e50e31ce Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Fri, 3 May 2024 08:45:50 -0300 Subject: [PATCH 1/2] Expose model information for embeddings service Signed-off-by: Flavia Beo --- caikit_nlp/modules/text_embedding/embedding.py | 16 +++++++++++++++- tests/modules/text_embedding/test_embedding.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index d4447ebe..c1eb7173 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -15,7 +15,7 @@ # Standard from collections.abc import Sized from enum import Enum, auto -from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union import importlib import os import time @@ -178,6 +178,20 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": return cls(model) + @property + def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument + """Helper property to return public metadata about a specific Model. This + function is separate from `metdata` as that contains the entire ModelConfig + which might not want to be shared/exposed. + + Returns: + Dict[str, str]: A dictionary of this models's public metadata + """ + return { + "max_seq_length": cls.model.max_seq_length, + "sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(), + } + @classmethod def _get_ipex(cls, ipex_flag): """Get IPEX optimization library if enabled and available, else return False diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index e625588b..d630b747 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -197,6 +197,21 @@ def test_save_load_and_run(): _assert_is_expected_embedding_result(result) +def test_public_model_info(): + """Check if we can get model info successfully""" + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + new_model = EmbeddingModule.load(model_path) + + result = new_model.public_model_info + assert "max_seq_length" in result + assert "sentence_embedding_dimension" in result + assert type(result["max_seq_length"]) is int + assert type(result["sentence_embedding_dimension"]) is int + + @pytest.mark.parametrize( "model_path", ["", " ", " " * 100], ids=["empty", "space", "spaces"] ) From 756b1e191f125ecee6f46ef94667a13132f3e0f4 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Wed, 29 May 2024 13:53:50 -0300 Subject: [PATCH 2/2] Bump lower caikit version Signed-off-by: Flavia Beo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bcb82db0..3fc25ee2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0", + "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", "caikit-tgis-backend>=0.1.27,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0",