Skip to content

Commit

Permalink
Merge pull request #353 from flaviabeo/embeddings_model_info
Browse files Browse the repository at this point in the history
Expose model information for embeddings service
  • Loading branch information
gkumbhat authored May 31, 2024
2 parents 98578c9 + 756b1e1 commit 75cd9ea
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
16 changes: 15 additions & 1 deletion caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Standard
from collections.abc import Sized
from enum import Enum, auto
from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
import importlib
import os
import time
Expand Down Expand Up @@ -178,6 +178,20 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":

return cls(model)

@property
def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument
"""Helper property to return public metadata about a specific Model. This
function is separate from `metdata` as that contains the entire ModelConfig
which might not want to be shared/exposed.
Returns:
Dict[str, str]: A dictionary of this models's public metadata
"""
return {
"max_seq_length": cls.model.max_seq_length,
"sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(),
}

@classmethod
def _get_ipex(cls, ipex_flag):
"""Get IPEX optimization library if enabled and available, else return False
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0",
"caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0",
"caikit-tgis-backend>=0.1.27,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
Expand Down
15 changes: 15 additions & 0 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,21 @@ def test_save_load_and_run():
_assert_is_expected_embedding_result(result)


def test_public_model_info():
"""Check if we can get model info successfully"""
model_id = "model_id"
with tempfile.TemporaryDirectory(suffix="-1st") as model_dir:
model_path = os.path.join(model_dir, model_id)
BOOTSTRAPPED_MODEL.save(model_path)
new_model = EmbeddingModule.load(model_path)

result = new_model.public_model_info
assert "max_seq_length" in result
assert "sentence_embedding_dimension" in result
assert type(result["max_seq_length"]) is int
assert type(result["sentence_embedding_dimension"]) is int


@pytest.mark.parametrize(
"model_path", ["", " ", " " * 100], ids=["empty", "space", "spaces"]
)
Expand Down

0 comments on commit 75cd9ea

Please sign in to comment.