Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read embeddings config during bootstrap/load and instance creation #387

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 62 additions & 40 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import alog

# Local
from caikit_nlp.modules.text_embedding.utils import env_val_to_bool, env_val_to_int
from caikit_nlp.modules.text_embedding.utils import env_val_to_bool

logger = alog.use_channel("TXT_EMB")
error = error_handler.get(logger)
Expand Down Expand Up @@ -99,19 +99,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument

SentenceTransformer = SentenceTransformerNotAvailable

embedding_cfg = get_config().get("embedding", {})

AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast"))
IPEX = env_val_to_bool(val=embedding_cfg.get("ipex"))
PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile"))
RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0)
BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0)
NO_IMPLICIT_TRUNCATION = env_val_to_bool(
val=embedding_cfg.get("implicit_truncation_errors", True)
)
DEVICE = embedding_cfg.get("device", "")
TRUST_REMOTE_CODE = embedding_cfg.get("trust_remote_code")

RT = TypeVar("RT") # return type


Expand Down Expand Up @@ -146,8 +133,6 @@ class TruncatedTokensTuple(NamedTuple):
],
)
class EmbeddingModule(ModuleBase):
# Retry count if enabled to try again (was for thread contention errors)
RETRY_COUNT = max(RETRIES, 0) # Ensure non-negative, before using in loop!

_ARTIFACTS_PATH_KEY = "artifacts_path"
_ARTIFACTS_PATH_DEFAULT = "artifacts"
Expand All @@ -159,42 +144,72 @@ def __init__(
super().__init__()
self.model = model

# Read config/env settings that are needed at run_* time.
embedding_cfg = get_config().get("embedding", {})

self.autocast = env_val_to_bool(embedding_cfg.get("autocast"))
self.no_implicit_truncation = env_val_to_bool(
embedding_cfg.get("implicit_truncation_errors", True)
)

self.batch_size = embedding_cfg.get("batch_size", 0)
error.type_check("<NLP83816537E>", int, EMBEDDING_BATCH_SIZE=self.batch_size)

# Retry count if enabled to try again (was for thread contention errors)
retries = embedding_cfg.get("retries", 0)
error.type_check("<NLP41910524E>", int, EMBEDDING_RETRIES=retries)
self.retry_count = max(
retries, 0
) # Ensure non-negative, before using in loop! (treat <0 as zero)

@classmethod
def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
def load(
cls, model_path: Union[str, ModuleConfig], *args, **kwargs
) -> "EmbeddingModule":
"""Load model

Args:
model_path: str
Path to the config dir under the model_id (where the config.yml lives)
model_path (Union[str, ModuleConfig]): Path to saved model or
evaline-ju marked this conversation as resolved.
Show resolved Hide resolved
in-memory ModuleConfig

Returns:
EmbeddingModule
Instance of this class built from the model.
"""

config = ModuleConfig.load(model_path)
artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY)
error.dir_check("<NLP19403057E>", config.model_path)

artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY)
error.value_check(
"<NLP07391618E>",
artifacts_path,
ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"),
)

artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path))
artifacts_path = os.path.abspath(
os.path.join(config.model_path, artifacts_path)
)
error.dir_check("<NLP34197772E>", artifacts_path)

ipex = cls._get_ipex(IPEX)
device = cls._select_device(ipex, DEVICE)
# Read config/env settings that are needed at load time.
embedding_cfg = get_config().get("embedding", {})

autocast = env_val_to_bool(embedding_cfg.get("autocast"))
pt2_compile = env_val_to_bool(embedding_cfg.get("pt2_compile"))
trust_remote_code = env_val_to_bool(embedding_cfg.get("trust_remote_code"))
ipex = cls._get_ipex(env_val_to_bool(embedding_cfg.get("ipex")))
device = cls._select_device(ipex, embedding_cfg.get("device", ""))

model = SentenceTransformerWithTruncate(
model_name_or_path=artifacts_path,
device=device,
trust_remote_code=TRUST_REMOTE_CODE,
trust_remote_code=trust_remote_code,
)
model.eval() # required for IPEX at least
if device is not None:
model.to(torch.device(device))
model = EmbeddingModule._optimize(model, ipex, device, AUTOCAST, PT2_COMPILE)
model = EmbeddingModule._optimize(model, ipex, device, autocast, pt2_compile)
return cls(model)

@property
Expand Down Expand Up @@ -310,16 +325,16 @@ def _optimize(model, ipex, device, autocast, pt2_compile):

def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT:
first_exception = None
for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed)
for count in range(1 + self.retry_count): # try once plus retries (if needed)
try:
return fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
if first_exception is None:
first_exception = e
if self.RETRY_COUNT > 0:
if self.retry_count > 0:
warn_msg = f"Try {count + 1}: {fn} failed due to: {e}"
logger.warning("<NLP54902271W>", warn_msg, exc_info=True)
if count + 1 < self.RETRY_COUNT:
if count + 1 < self.retry_count:
time.sleep(0.1 * (count * 2))

# If above return did not happen, raise the first exception
Expand All @@ -334,16 +349,17 @@ def _encode_with_retry(
"""All encode calls should use this for consistent param adding and retry loop"""

# Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE
if BATCH_SIZE > 0:
if self.batch_size > 0:
if kwargs is None:
kwargs = {}
if "batch_size" not in kwargs:
kwargs["batch_size"] = BATCH_SIZE
kwargs["batch_size"] = self.batch_size

if isinstance(self.model, SentenceTransformerWithTruncate):
kwargs[
"implicit_truncation_errors"
] = NO_IMPLICIT_TRUNCATION # config/env overrides default
] = self.no_implicit_truncation # config/env overrides default
kwargs["autocast"] = self.autocast # config/env overrides default
return self._with_retry(self.model.encode, *args, **kwargs)

# Else...
Expand All @@ -357,6 +373,8 @@ def _encode_with_retry(
del kwargs["return_token_count"]
if "implicit_truncation_errors" in kwargs:
del kwargs["implicit_truncation_errors"]
if "autocast" in kwargs:
del kwargs["autocast"]
return self._with_retry(self.model.encode, *args, **kwargs)

@EmbeddingTask.taskmethod()
Expand Down Expand Up @@ -718,19 +736,21 @@ def add_query(q):
)

@classmethod
def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule":
def bootstrap(cls, *args, **kwargs) -> "EmbeddingModule":
"""Bootstrap a sentence-transformers model

Args:
model_name_or_path: str
Model name (Hugging Face hub) or path to model to load.
kwargs are passed to SentenceTransformer(**kwargs)
"""
return cls(
model=SentenceTransformer(
model_name_or_path=model_name_or_path,
trust_remote_code=TRUST_REMOTE_CODE,

if "trust_remote_code" not in kwargs:
# Read config/env settings that are needed at bootstrap time.
embedding_cfg = get_config().get("embedding", {})
kwargs["trust_remote_code"] = env_val_to_bool(
embedding_cfg.get("trust_remote_code")
)
)

return cls(model=SentenceTransformer(*args, **kwargs))

def save(self, model_path: str, *args, **kwargs):
"""Save model using config in model_path
Expand Down Expand Up @@ -1056,6 +1076,7 @@ def encode(
truncate_input_tokens: int = 0,
return_token_count: bool = False,
implicit_truncation_errors: bool = True,
autocast: bool = False,
) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
Expand Down Expand Up @@ -1083,6 +1104,7 @@ def encode(
:param return_token_count: If true, a tuple is returned to add the input token count.
:param implicit_truncation_errors: If true (default) implicit truncation throws an error.
If false, the model default behavior or used.
:param autocast: If true (not default) run with torch.cpu.amp.autocast()

:return:
If return_token_count is False, the embedding is returned as a numpy matrix.
Expand Down Expand Up @@ -1171,7 +1193,7 @@ def encode(

features = batch_to_device(features, device)

if AUTOCAST:
if autocast:
with torch.no_grad(), torch.cpu.amp.autocast():
out_features = self.forward(features)
embeddings = out_features["sentence_embedding"]
Expand Down
8 changes: 0 additions & 8 deletions caikit_nlp/modules/text_embedding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,3 @@ def env_val_to_bool(val):

# For testing env vars for values that mean false (else True!)
return str(val).lower().strip() not in ("no", "n", "false", "0", "f", "off", "")


def env_val_to_int(val, default):
"""Returns the integer value of env var or default value if None or invalid integer"""
try:
return int(val)
except (TypeError, ValueError):
return default
61 changes: 43 additions & 18 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Token,
TokenizationResults,
)
import aconfig

# Local
from caikit_nlp.modules.text_embedding import EmbeddingModule, utils
Expand Down Expand Up @@ -249,10 +250,22 @@ def test_save_type_checks(model_path):
BOOTSTRAPPED_MODEL.save(model_path)


def test_load_without_model_path():
"""Test coverage for the error message when config has no model_path"""
match = "stat: path should be string, bytes, os.PathLike or integer, not NoneType"
with pytest.raises(TypeError, match=match):
EmbeddingModule.load(ModuleConfig({}))


def test_load_without_artifacts():
"""Test coverage for the error message when config has no artifacts to load"""
with pytest.raises(ValueError):
EmbeddingModule.load(ModuleConfig({}))
with tempfile.TemporaryDirectory(suffix="-load") as model_dir:
config_yml_path = os.path.join(model_dir, "config.yml")
with open(config_yml_path, "a") as f:
f.write("module_id: foo")
match = "value check failed: Model config missing 'artifacts_path'"
with pytest.raises(ValueError, match=match):
EmbeddingModule.load(ModuleConfig({}).load(model_dir))


def test_run_embedding_type_check(loaded_model):
Expand Down Expand Up @@ -856,7 +869,7 @@ def fn():
def test__with_retry_fail_fail(loaded_model, monkeypatch):
"""fn needs a few tries, tries twice and fails."""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 1) # less than 3 tries
monkeypatch.setattr(loaded_model, "retry_count", 1) # less than 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop
Expand All @@ -880,7 +893,7 @@ def fail_fail_win():
def test__with_retry_fail_fail_win(loaded_model, monkeypatch):
"""fn needs a few tries, logs, loops and succeeds"""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 6) # test needs at least 3 tries
monkeypatch.setattr(loaded_model, "retry_count", 6) # test needs at least 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop
Expand Down Expand Up @@ -915,21 +928,33 @@ def test_env_val_to_bool():
assert utils.env_val_to_bool(" tRuE ")


def test_env_val_to_int():
def test_config_val_to_int():
conf = aconfig.Config(
{
"zero": 0,
"zero_str": "0",
"false": False,
"number_str": "456",
"number_str2": " 456 ",
"true": True,
"non_int_str": "non int str",
}
)
expected_default = 12345
assert expected_default == utils.env_val_to_int(None, expected_default)
assert expected_default == utils.env_val_to_int("", expected_default)
assert expected_default == utils.env_val_to_int(" ", expected_default)
assert expected_default == utils.env_val_to_int(" ss ", expected_default)
assert expected_default == utils.env_val_to_int(" sss ", expected_default)
assert expected_default == utils.env_val_to_int(" ssss ", expected_default)

assert 0 == utils.env_val_to_int(0, expected_default)
assert 0 == utils.env_val_to_int("0", expected_default)
assert 0 == utils.env_val_to_int(False, expected_default)
assert 456 == utils.env_val_to_int("456", expected_default)
assert 456 == utils.env_val_to_int(" 456 ", expected_default)
assert 1 == utils.env_val_to_int(True, expected_default)
assert expected_default == conf.get("bogus", expected_default)

assert 0 == conf.get("zero", expected_default)
assert 0 == int(conf.get("zero_str", expected_default))
assert 0 == int(conf.get("false", expected_default))
assert 456 == int(conf.get("number_str", expected_default))
assert 456 == int(conf.get("number_str2", expected_default))
assert 1 == conf.get("true", expected_default)
assert 1 == int(conf.get("true", expected_default))

assert "non int str" == conf.get("non_int_str", 123) # default not used
# Using a bad config (e.g., some non-integer string) with int() will raise ValueError
with pytest.raises(ValueError):
int(conf.get("non_int_str", 123)) # default not used, int("non int str") raises


@pytest.mark.parametrize(
Expand Down
Loading