From 272eeb16f8b556a39358693340eafa7276720bd5 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 28 Aug 2024 20:23:55 -0700 Subject: [PATCH 1/3] Read embeddings config during bootstrap/load and instance creation When CONFIG_FILES setting is not used, the caikit-nlp config values are not read from environment variables unless caikit has already imported caikit-nlp. The code was loading the embedding config too soon in this case so the env vars were ignored. This commit reads the config during bootstrap and load for settings that are needed at bootstrap/load time and also reads config during module init for runtime settings. In addition: * bootstrap() kwargs can be used (e.g. for trust_remote_code param) * load() was made to work with model_path as ModuleConfig (because str is deprecated and spammy) * utils/env_val_to_int was removed because the caikit config handles this well * utils/env_val_to_bool was NOT YET removed because the caikit config does not handle this quite as well Signed-off-by: Mark Sturdevant --- .../modules/text_embedding/embedding.py | 99 +++++++++++-------- caikit_nlp/modules/text_embedding/utils.py | 8 -- .../modules/text_embedding/test_embedding.py | 45 ++++++--- 3 files changed, 89 insertions(+), 63 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 97fcb7ac..6ef72c74 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -71,7 +71,7 @@ import alog # Local -from caikit_nlp.modules.text_embedding.utils import env_val_to_bool, env_val_to_int +from caikit_nlp.modules.text_embedding.utils import env_val_to_bool logger = alog.use_channel("TXT_EMB") error = error_handler.get(logger) @@ -99,19 +99,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument SentenceTransformer = SentenceTransformerNotAvailable -embedding_cfg = get_config().get("embedding", {}) - -AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast")) -IPEX = env_val_to_bool(val=embedding_cfg.get("ipex")) -PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile")) -RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0) -BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0) -NO_IMPLICIT_TRUNCATION = env_val_to_bool( - val=embedding_cfg.get("implicit_truncation_errors", True) -) -DEVICE = embedding_cfg.get("device", "") -TRUST_REMOTE_CODE = embedding_cfg.get("trust_remote_code") - RT = TypeVar("RT") # return type @@ -146,8 +133,6 @@ class TruncatedTokensTuple(NamedTuple): ], ) class EmbeddingModule(ModuleBase): - # Retry count if enabled to try again (was for thread contention errors) - RETRY_COUNT = max(RETRIES, 0) # Ensure non-negative, before using in loop! _ARTIFACTS_PATH_KEY = "artifacts_path" _ARTIFACTS_PATH_DEFAULT = "artifacts" @@ -159,13 +144,33 @@ def __init__( super().__init__() self.model = model + # Read config/env settings that are needed at run_* time. + embedding_cfg = get_config().get("embedding", {}) + + self.autocast = env_val_to_bool(embedding_cfg.get("autocast")) + self.no_implicit_truncation = env_val_to_bool( + embedding_cfg.get("implicit_truncation_errors", True) + ) + + self.batch_size = embedding_cfg.get("batch_size", 0) + error.type_check("", int, EMBEDDING_BATCH_SIZE=self.batch_size) + + # Retry count if enabled to try again (was for thread contention errors) + retries = embedding_cfg.get("retries", 0) + error.type_check("", int, EMBEDDING_RETRIES=retries) + self.retry_count = max( + retries, 0 + ) # Ensure non-negative, before using in loop! (treat <0 as zero) + @classmethod - def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": + def load( + cls, model_path: Union[str, ModuleConfig], *args, **kwargs + ) -> "EmbeddingModule": """Load model Args: - model_path: str - Path to the config dir under the model_id (where the config.yml lives) + model_path (Union[str, ModuleConfig]): Path to saved model or + in-memory ModuleConfig Returns: EmbeddingModule @@ -181,20 +186,29 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"), ) - artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path)) + artifacts_path = os.path.abspath( + os.path.join(config.model_path, artifacts_path) + ) error.dir_check("", artifacts_path) - ipex = cls._get_ipex(IPEX) - device = cls._select_device(ipex, DEVICE) + # Read config/env settings that are needed at load time. + embedding_cfg = get_config().get("embedding", {}) + + autocast = env_val_to_bool(embedding_cfg.get("autocast")) + pt2_compile = env_val_to_bool(embedding_cfg.get("pt2_compile")) + trust_remote_code = env_val_to_bool(embedding_cfg.get("trust_remote_code")) + ipex = cls._get_ipex(env_val_to_bool(embedding_cfg.get("ipex"))) + device = cls._select_device(ipex, embedding_cfg.get("device", "")) + model = SentenceTransformerWithTruncate( model_name_or_path=artifacts_path, device=device, - trust_remote_code=TRUST_REMOTE_CODE, + trust_remote_code=trust_remote_code, ) model.eval() # required for IPEX at least if device is not None: model.to(torch.device(device)) - model = EmbeddingModule._optimize(model, ipex, device, AUTOCAST, PT2_COMPILE) + model = EmbeddingModule._optimize(model, ipex, device, autocast, pt2_compile) return cls(model) @property @@ -310,16 +324,16 @@ def _optimize(model, ipex, device, autocast, pt2_compile): def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT: first_exception = None - for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed) + for count in range(1 + self.retry_count): # try once plus retries (if needed) try: return fn(*args, **kwargs) except Exception as e: # pylint: disable=broad-exception-caught if first_exception is None: first_exception = e - if self.RETRY_COUNT > 0: + if self.retry_count > 0: warn_msg = f"Try {count + 1}: {fn} failed due to: {e}" logger.warning("", warn_msg, exc_info=True) - if count + 1 < self.RETRY_COUNT: + if count + 1 < self.retry_count: time.sleep(0.1 * (count * 2)) # If above return did not happen, raise the first exception @@ -334,16 +348,17 @@ def _encode_with_retry( """All encode calls should use this for consistent param adding and retry loop""" # Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE - if BATCH_SIZE > 0: + if self.batch_size > 0: if kwargs is None: kwargs = {} if "batch_size" not in kwargs: - kwargs["batch_size"] = BATCH_SIZE + kwargs["batch_size"] = self.batch_size if isinstance(self.model, SentenceTransformerWithTruncate): kwargs[ "implicit_truncation_errors" - ] = NO_IMPLICIT_TRUNCATION # config/env overrides default + ] = self.no_implicit_truncation # config/env overrides default + kwargs["autocast"] = self.autocast # config/env overrides default return self._with_retry(self.model.encode, *args, **kwargs) # Else... @@ -357,6 +372,8 @@ def _encode_with_retry( del kwargs["return_token_count"] if "implicit_truncation_errors" in kwargs: del kwargs["implicit_truncation_errors"] + if "autocast" in kwargs: + del kwargs["autocast"] return self._with_retry(self.model.encode, *args, **kwargs) @EmbeddingTask.taskmethod() @@ -718,19 +735,21 @@ def add_query(q): ) @classmethod - def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": + def bootstrap(cls, *args, **kwargs) -> "EmbeddingModule": """Bootstrap a sentence-transformers model Args: - model_name_or_path: str - Model name (Hugging Face hub) or path to model to load. + kwargs are passed to SentenceTransformer(**kwargs) """ - return cls( - model=SentenceTransformer( - model_name_or_path=model_name_or_path, - trust_remote_code=TRUST_REMOTE_CODE, + + if "trust_remote_code" not in kwargs: + # Read config/env settings that are needed at bootstrap time. + embedding_cfg = get_config().get("embedding", {}) + kwargs["trust_remote_code"] = env_val_to_bool( + embedding_cfg.get("trust_remote_code") ) - ) + + return cls(model=SentenceTransformer(*args, **kwargs)) def save(self, model_path: str, *args, **kwargs): """Save model using config in model_path @@ -1056,6 +1075,7 @@ def encode( truncate_input_tokens: int = 0, return_token_count: bool = False, implicit_truncation_errors: bool = True, + autocast: bool = False, ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings @@ -1083,6 +1103,7 @@ def encode( :param return_token_count: If true, a tuple is returned to add the input token count. :param implicit_truncation_errors: If true (default) implicit truncation throws an error. If false, the model default behavior or used. + :param autocast: If true (not default) run with torch.cpu.amp.autocast() :return: If return_token_count is False, the embedding is returned as a numpy matrix. @@ -1171,7 +1192,7 @@ def encode( features = batch_to_device(features, device) - if AUTOCAST: + if autocast: with torch.no_grad(), torch.cpu.amp.autocast(): out_features = self.forward(features) embeddings = out_features["sentence_embedding"] diff --git a/caikit_nlp/modules/text_embedding/utils.py b/caikit_nlp/modules/text_embedding/utils.py index 39adfb82..377f6dac 100644 --- a/caikit_nlp/modules/text_embedding/utils.py +++ b/caikit_nlp/modules/text_embedding/utils.py @@ -22,11 +22,3 @@ def env_val_to_bool(val): # For testing env vars for values that mean false (else True!) return str(val).lower().strip() not in ("no", "n", "false", "0", "f", "off", "") - - -def env_val_to_int(val, default): - """Returns the integer value of env var or default value if None or invalid integer""" - try: - return int(val) - except (TypeError, ValueError): - return default diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 20a70fca..afbf58af 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -24,6 +24,7 @@ Token, TokenizationResults, ) +import aconfig # Local from caikit_nlp.modules.text_embedding import EmbeddingModule, utils @@ -856,7 +857,7 @@ def fn(): def test__with_retry_fail_fail(loaded_model, monkeypatch): """fn needs a few tries, tries twice and fails.""" - monkeypatch.setattr(loaded_model, "RETRY_COUNT", 1) # less than 3 tries + monkeypatch.setattr(loaded_model, "retry_count", 1) # less than 3 tries def generate_ints(): yield from range(9) # More than enough for retry loop @@ -880,7 +881,7 @@ def fail_fail_win(): def test__with_retry_fail_fail_win(loaded_model, monkeypatch): """fn needs a few tries, logs, loops and succeeds""" - monkeypatch.setattr(loaded_model, "RETRY_COUNT", 6) # test needs at least 3 tries + monkeypatch.setattr(loaded_model, "retry_count", 6) # test needs at least 3 tries def generate_ints(): yield from range(9) # More than enough for retry loop @@ -915,21 +916,33 @@ def test_env_val_to_bool(): assert utils.env_val_to_bool(" tRuE ") -def test_env_val_to_int(): +def test_config_val_to_int(): + conf = aconfig.Config( + { + "zero": 0, + "zero_str": "0", + "false": False, + "number_str": "456", + "number_str2": " 456 ", + "true": True, + "non_int": "oh-oh", + } + ) expected_default = 12345 - assert expected_default == utils.env_val_to_int(None, expected_default) - assert expected_default == utils.env_val_to_int("", expected_default) - assert expected_default == utils.env_val_to_int(" ", expected_default) - assert expected_default == utils.env_val_to_int(" ss ", expected_default) - assert expected_default == utils.env_val_to_int(" sss ", expected_default) - assert expected_default == utils.env_val_to_int(" ssss ", expected_default) - - assert 0 == utils.env_val_to_int(0, expected_default) - assert 0 == utils.env_val_to_int("0", expected_default) - assert 0 == utils.env_val_to_int(False, expected_default) - assert 456 == utils.env_val_to_int("456", expected_default) - assert 456 == utils.env_val_to_int(" 456 ", expected_default) - assert 1 == utils.env_val_to_int(True, expected_default) + assert expected_default == conf.get("bogus", expected_default) + + assert 0 == conf.get("zero", expected_default) + assert 0 == int(conf.get("zero_str", expected_default)) + assert 0 == int(conf.get("false", expected_default)) + assert 456 == int(conf.get("number_str", expected_default)) + assert 456 == int(conf.get("number_str2", expected_default)) + assert 1 == conf.get("true", expected_default) + assert 1 == int(conf.get("true", expected_default)) + + assert "oh-oh" == conf.get("non_int", 123) # default not used (got "uh-oh") + # Using a bad config (e.g., some non-integer string) with int() will raise ValueError + with pytest.raises(ValueError): + int(conf.get("non_int", 123)) # default not used, int("uh-oh") raises @pytest.mark.parametrize( From 90e0ccbffd569b1140018e056d97917257ab6bb8 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 4 Sep 2024 16:11:21 -0700 Subject: [PATCH 2/3] Make sure EmbeddingModule load() is not called without a model_path ModuleConfig should always have a model_path, but this check will make the error better if this gets called with an empty ModuleConfig. Signed-off-by: Mark Sturdevant --- caikit_nlp/modules/text_embedding/embedding.py | 3 ++- tests/modules/text_embedding/test_embedding.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 6ef72c74..06779d0d 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -178,8 +178,9 @@ def load( """ config = ModuleConfig.load(model_path) - artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) + error.dir_check("", config.model_path) + artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY) error.value_check( "", artifacts_path, diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index afbf58af..222341ad 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -250,10 +250,22 @@ def test_save_type_checks(model_path): BOOTSTRAPPED_MODEL.save(model_path) +def test_load_without_model_path(): + """Test coverage for the error message when config has no model_path""" + match = "stat: path should be string, bytes, os.PathLike or integer, not NoneType" + with pytest.raises(TypeError, match=match): + EmbeddingModule.load(ModuleConfig({})) + + def test_load_without_artifacts(): """Test coverage for the error message when config has no artifacts to load""" - with pytest.raises(ValueError): - EmbeddingModule.load(ModuleConfig({})) + with tempfile.TemporaryDirectory(suffix="-load") as model_dir: + config_yml_path = os.path.join(model_dir, "config.yml") + with open(config_yml_path, "a") as f: + f.write("module_id: foo") + match = "value check failed: Model config missing 'artifacts_path'" + with pytest.raises(ValueError, match=match): + EmbeddingModule.load(ModuleConfig({}).load(model_dir)) def test_run_embedding_type_check(loaded_model): From 8488582f70e8282d49b670c9fbaad54e6157916b Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 4 Sep 2024 17:56:39 -0700 Subject: [PATCH 3/3] Clarify test Signed-off-by: Mark Sturdevant --- tests/modules/text_embedding/test_embedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 222341ad..cbd22ee8 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -937,7 +937,7 @@ def test_config_val_to_int(): "number_str": "456", "number_str2": " 456 ", "true": True, - "non_int": "oh-oh", + "non_int_str": "non int str", } ) expected_default = 12345 @@ -951,10 +951,10 @@ def test_config_val_to_int(): assert 1 == conf.get("true", expected_default) assert 1 == int(conf.get("true", expected_default)) - assert "oh-oh" == conf.get("non_int", 123) # default not used (got "uh-oh") + assert "non int str" == conf.get("non_int_str", 123) # default not used # Using a bad config (e.g., some non-integer string) with int() will raise ValueError with pytest.raises(ValueError): - int(conf.get("non_int", 123)) # default not used, int("uh-oh") raises + int(conf.get("non_int_str", 123)) # default not used, int("non int str") raises @pytest.mark.parametrize(