Skip to content

Commit

Permalink
Merge pull request #387 from markstur/embedding_get_config
Browse files Browse the repository at this point in the history
Read embeddings config during bootstrap/load and instance creation
  • Loading branch information
evaline-ju authored Sep 5, 2024
2 parents 97ae2bf + 8488582 commit fbe5637
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 66 deletions.
102 changes: 62 additions & 40 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
import alog

# Local
from caikit_nlp.modules.text_embedding.utils import env_val_to_bool, env_val_to_int
from caikit_nlp.modules.text_embedding.utils import env_val_to_bool

logger = alog.use_channel("TXT_EMB")
error = error_handler.get(logger)
Expand Down Expand Up @@ -99,19 +99,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument

SentenceTransformer = SentenceTransformerNotAvailable

embedding_cfg = get_config().get("embedding", {})

AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast"))
IPEX = env_val_to_bool(val=embedding_cfg.get("ipex"))
PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile"))
RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0)
BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0)
NO_IMPLICIT_TRUNCATION = env_val_to_bool(
val=embedding_cfg.get("implicit_truncation_errors", True)
)
DEVICE = embedding_cfg.get("device", "")
TRUST_REMOTE_CODE = embedding_cfg.get("trust_remote_code")

RT = TypeVar("RT") # return type


Expand Down Expand Up @@ -146,8 +133,6 @@ class TruncatedTokensTuple(NamedTuple):
],
)
class EmbeddingModule(ModuleBase):
# Retry count if enabled to try again (was for thread contention errors)
RETRY_COUNT = max(RETRIES, 0) # Ensure non-negative, before using in loop!

_ARTIFACTS_PATH_KEY = "artifacts_path"
_ARTIFACTS_PATH_DEFAULT = "artifacts"
Expand All @@ -159,42 +144,72 @@ def __init__(
super().__init__()
self.model = model

# Read config/env settings that are needed at run_* time.
embedding_cfg = get_config().get("embedding", {})

self.autocast = env_val_to_bool(embedding_cfg.get("autocast"))
self.no_implicit_truncation = env_val_to_bool(
embedding_cfg.get("implicit_truncation_errors", True)
)

self.batch_size = embedding_cfg.get("batch_size", 0)
error.type_check("<NLP83816537E>", int, EMBEDDING_BATCH_SIZE=self.batch_size)

# Retry count if enabled to try again (was for thread contention errors)
retries = embedding_cfg.get("retries", 0)
error.type_check("<NLP41910524E>", int, EMBEDDING_RETRIES=retries)
self.retry_count = max(
retries, 0
) # Ensure non-negative, before using in loop! (treat <0 as zero)

@classmethod
def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
def load(
cls, model_path: Union[str, ModuleConfig], *args, **kwargs
) -> "EmbeddingModule":
"""Load model
Args:
model_path: str
Path to the config dir under the model_id (where the config.yml lives)
model_path (Union[str, ModuleConfig]): Path to saved model or
in-memory ModuleConfig
Returns:
EmbeddingModule
Instance of this class built from the model.
"""

config = ModuleConfig.load(model_path)
artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY)
error.dir_check("<NLP19403057E>", config.model_path)

artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY)
error.value_check(
"<NLP07391618E>",
artifacts_path,
ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"),
)

artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path))
artifacts_path = os.path.abspath(
os.path.join(config.model_path, artifacts_path)
)
error.dir_check("<NLP34197772E>", artifacts_path)

ipex = cls._get_ipex(IPEX)
device = cls._select_device(ipex, DEVICE)
# Read config/env settings that are needed at load time.
embedding_cfg = get_config().get("embedding", {})

autocast = env_val_to_bool(embedding_cfg.get("autocast"))
pt2_compile = env_val_to_bool(embedding_cfg.get("pt2_compile"))
trust_remote_code = env_val_to_bool(embedding_cfg.get("trust_remote_code"))
ipex = cls._get_ipex(env_val_to_bool(embedding_cfg.get("ipex")))
device = cls._select_device(ipex, embedding_cfg.get("device", ""))

model = SentenceTransformerWithTruncate(
model_name_or_path=artifacts_path,
device=device,
trust_remote_code=TRUST_REMOTE_CODE,
trust_remote_code=trust_remote_code,
)
model.eval() # required for IPEX at least
if device is not None:
model.to(torch.device(device))
model = EmbeddingModule._optimize(model, ipex, device, AUTOCAST, PT2_COMPILE)
model = EmbeddingModule._optimize(model, ipex, device, autocast, pt2_compile)
return cls(model)

@property
Expand Down Expand Up @@ -310,16 +325,16 @@ def _optimize(model, ipex, device, autocast, pt2_compile):

def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT:
first_exception = None
for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed)
for count in range(1 + self.retry_count): # try once plus retries (if needed)
try:
return fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
if first_exception is None:
first_exception = e
if self.RETRY_COUNT > 0:
if self.retry_count > 0:
warn_msg = f"Try {count + 1}: {fn} failed due to: {e}"
logger.warning("<NLP54902271W>", warn_msg, exc_info=True)
if count + 1 < self.RETRY_COUNT:
if count + 1 < self.retry_count:
time.sleep(0.1 * (count * 2))

# If above return did not happen, raise the first exception
Expand All @@ -334,16 +349,17 @@ def _encode_with_retry(
"""All encode calls should use this for consistent param adding and retry loop"""

# Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE
if BATCH_SIZE > 0:
if self.batch_size > 0:
if kwargs is None:
kwargs = {}
if "batch_size" not in kwargs:
kwargs["batch_size"] = BATCH_SIZE
kwargs["batch_size"] = self.batch_size

if isinstance(self.model, SentenceTransformerWithTruncate):
kwargs[
"implicit_truncation_errors"
] = NO_IMPLICIT_TRUNCATION # config/env overrides default
] = self.no_implicit_truncation # config/env overrides default
kwargs["autocast"] = self.autocast # config/env overrides default
return self._with_retry(self.model.encode, *args, **kwargs)

# Else...
Expand All @@ -357,6 +373,8 @@ def _encode_with_retry(
del kwargs["return_token_count"]
if "implicit_truncation_errors" in kwargs:
del kwargs["implicit_truncation_errors"]
if "autocast" in kwargs:
del kwargs["autocast"]
return self._with_retry(self.model.encode, *args, **kwargs)

@EmbeddingTask.taskmethod()
Expand Down Expand Up @@ -718,19 +736,21 @@ def add_query(q):
)

@classmethod
def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule":
def bootstrap(cls, *args, **kwargs) -> "EmbeddingModule":
"""Bootstrap a sentence-transformers model
Args:
model_name_or_path: str
Model name (Hugging Face hub) or path to model to load.
kwargs are passed to SentenceTransformer(**kwargs)
"""
return cls(
model=SentenceTransformer(
model_name_or_path=model_name_or_path,
trust_remote_code=TRUST_REMOTE_CODE,

if "trust_remote_code" not in kwargs:
# Read config/env settings that are needed at bootstrap time.
embedding_cfg = get_config().get("embedding", {})
kwargs["trust_remote_code"] = env_val_to_bool(
embedding_cfg.get("trust_remote_code")
)
)

return cls(model=SentenceTransformer(*args, **kwargs))

def save(self, model_path: str, *args, **kwargs):
"""Save model using config in model_path
Expand Down Expand Up @@ -1056,6 +1076,7 @@ def encode(
truncate_input_tokens: int = 0,
return_token_count: bool = False,
implicit_truncation_errors: bool = True,
autocast: bool = False,
) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
Expand Down Expand Up @@ -1083,6 +1104,7 @@ def encode(
:param return_token_count: If true, a tuple is returned to add the input token count.
:param implicit_truncation_errors: If true (default) implicit truncation throws an error.
If false, the model default behavior or used.
:param autocast: If true (not default) run with torch.cpu.amp.autocast()
:return:
If return_token_count is False, the embedding is returned as a numpy matrix.
Expand Down Expand Up @@ -1171,7 +1193,7 @@ def encode(

features = batch_to_device(features, device)

if AUTOCAST:
if autocast:
with torch.no_grad(), torch.cpu.amp.autocast():
out_features = self.forward(features)
embeddings = out_features["sentence_embedding"]
Expand Down
8 changes: 0 additions & 8 deletions caikit_nlp/modules/text_embedding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,3 @@ def env_val_to_bool(val):

# For testing env vars for values that mean false (else True!)
return str(val).lower().strip() not in ("no", "n", "false", "0", "f", "off", "")


def env_val_to_int(val, default):
"""Returns the integer value of env var or default value if None or invalid integer"""
try:
return int(val)
except (TypeError, ValueError):
return default
61 changes: 43 additions & 18 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Token,
TokenizationResults,
)
import aconfig

# Local
from caikit_nlp.modules.text_embedding import EmbeddingModule, utils
Expand Down Expand Up @@ -249,10 +250,22 @@ def test_save_type_checks(model_path):
BOOTSTRAPPED_MODEL.save(model_path)


def test_load_without_model_path():
"""Test coverage for the error message when config has no model_path"""
match = "stat: path should be string, bytes, os.PathLike or integer, not NoneType"
with pytest.raises(TypeError, match=match):
EmbeddingModule.load(ModuleConfig({}))


def test_load_without_artifacts():
"""Test coverage for the error message when config has no artifacts to load"""
with pytest.raises(ValueError):
EmbeddingModule.load(ModuleConfig({}))
with tempfile.TemporaryDirectory(suffix="-load") as model_dir:
config_yml_path = os.path.join(model_dir, "config.yml")
with open(config_yml_path, "a") as f:
f.write("module_id: foo")
match = "value check failed: Model config missing 'artifacts_path'"
with pytest.raises(ValueError, match=match):
EmbeddingModule.load(ModuleConfig({}).load(model_dir))


def test_run_embedding_type_check(loaded_model):
Expand Down Expand Up @@ -856,7 +869,7 @@ def fn():
def test__with_retry_fail_fail(loaded_model, monkeypatch):
"""fn needs a few tries, tries twice and fails."""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 1) # less than 3 tries
monkeypatch.setattr(loaded_model, "retry_count", 1) # less than 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop
Expand All @@ -880,7 +893,7 @@ def fail_fail_win():
def test__with_retry_fail_fail_win(loaded_model, monkeypatch):
"""fn needs a few tries, logs, loops and succeeds"""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 6) # test needs at least 3 tries
monkeypatch.setattr(loaded_model, "retry_count", 6) # test needs at least 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop
Expand Down Expand Up @@ -915,21 +928,33 @@ def test_env_val_to_bool():
assert utils.env_val_to_bool(" tRuE ")


def test_env_val_to_int():
def test_config_val_to_int():
conf = aconfig.Config(
{
"zero": 0,
"zero_str": "0",
"false": False,
"number_str": "456",
"number_str2": " 456 ",
"true": True,
"non_int_str": "non int str",
}
)
expected_default = 12345
assert expected_default == utils.env_val_to_int(None, expected_default)
assert expected_default == utils.env_val_to_int("", expected_default)
assert expected_default == utils.env_val_to_int(" ", expected_default)
assert expected_default == utils.env_val_to_int(" ss ", expected_default)
assert expected_default == utils.env_val_to_int(" sss ", expected_default)
assert expected_default == utils.env_val_to_int(" ssss ", expected_default)

assert 0 == utils.env_val_to_int(0, expected_default)
assert 0 == utils.env_val_to_int("0", expected_default)
assert 0 == utils.env_val_to_int(False, expected_default)
assert 456 == utils.env_val_to_int("456", expected_default)
assert 456 == utils.env_val_to_int(" 456 ", expected_default)
assert 1 == utils.env_val_to_int(True, expected_default)
assert expected_default == conf.get("bogus", expected_default)

assert 0 == conf.get("zero", expected_default)
assert 0 == int(conf.get("zero_str", expected_default))
assert 0 == int(conf.get("false", expected_default))
assert 456 == int(conf.get("number_str", expected_default))
assert 456 == int(conf.get("number_str2", expected_default))
assert 1 == conf.get("true", expected_default)
assert 1 == int(conf.get("true", expected_default))

assert "non int str" == conf.get("non_int_str", 123) # default not used
# Using a bad config (e.g., some non-integer string) with int() will raise ValueError
with pytest.raises(ValueError):
int(conf.get("non_int_str", 123)) # default not used, int("non int str") raises


@pytest.mark.parametrize(
Expand Down

0 comments on commit fbe5637

Please sign in to comment.