diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 777c3d8602d..f75c0d078bd 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -40,6 +40,7 @@ MAX_CPU_BATCH_SIZE, MINIMIZE, MODEL_ECD, + MODEL_LLM, TEST, TRAINING, USED_TOKENS, @@ -68,6 +69,7 @@ from ludwig.utils import time_utils from ludwig.utils.batch_size_tuner import BatchSizeEvaluator from ludwig.utils.checkpoint_utils import Checkpoint, CheckpointManager +from ludwig.utils.config_utils import get_quantization from ludwig.utils.data_utils import load_json from ludwig.utils.defaults import default_random_seed from ludwig.utils.fs_utils import path_exists @@ -1133,19 +1135,19 @@ def train( # For a full explanation of this 8-bit workaround, see https://github.com/ludwig-ai/ludwig/pull/3606 # TODO (jeffkinnison): Determine why `SCB` and `CB` are deleted from parameter state - if ( - hasattr(self.model.config_obj, "quantization") - and self.model.config_obj.quantization - and self.model.config_obj.quantization.bits == 8 - ): + quantization = get_quantization(self.model.config_obj) + uses_quantization = bool(quantization) if not isinstance(quantization, list) else any(quantization) + if uses_quantization and 8 in quantization: # If the model was previously placed on GPU, 8-bit parameter state will be updated with several # matrices containing quantization information. These are recorded matrices are recorded in the # training checkpoint state dicts, but do not necessarily exist in the parameter object, leading # to a RuntimeError in `load_state_dict`. Explicitly call `model.cuda()` to make sure the # matrices are part of model state. This workaround is necessary because the matrices are # deleted during the model's forward pass. - if self.model.model.device.type == "cuda": + if self.model.config_obj.model_type == MODEL_LLM and self.model.model.device.type == "cuda": self.model.model.cuda() + elif self.model.config_obj.model_type == MODEL_ECD and self.model.device.type == "cuda": + self.model.cuda() _, unexpected_keys = self.model.load_state_dict(state_dict, strict=False) only_weights_format_keys = ["weights_format" in k for k in unexpected_keys] diff --git a/ludwig/utils/config_utils.py b/ludwig/utils/config_utils.py index 7eb4b9b5cb8..e890fb39cfc 100644 --- a/ludwig/utils/config_utils.py +++ b/ludwig/utils/config_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Set, Union +from typing import Any, Dict, List, Set, Union from ludwig.api_annotations import DeveloperAPI from ludwig.constants import ( @@ -142,3 +142,43 @@ def config_uses_llm(config: Union[Dict[str, Any], ModelConfig]) -> bool: raise ValueError(f"Invalid config cannot be checked for LLM usage. Config: {config}") return uses_llm + + +def get_quantization(config: Union[Dict[str, Any], ModelConfig]) -> Union[int, List[int], None]: + """Get the quantization specified in a config at any level. + + Args: + config: Ludwig config object or dictionary + + Returns: + For LLM models, the value of quantization.bits or None if it is not specified. + For ECD and GBM models, the list of values of quantization.bits for each encoder. If the encoder does not + support quantization or no quantization config is specified, the list entry is None. + """ + if isinstance(config, ModelConfig): + if config.model_type == MODEL_LLM: + return config.quantization.bits if config.quantization else None + else: + quantization_bits = [] + for feature in config.input_features: + try: + quantization = feature.encoder.quantization.bits + except AttributeError: + quantization = None + quantization_bits.append(quantization) + return quantization_bits + elif isinstance(config, dict) and config: + if config.get(MODEL_TYPE, MODEL_ECD) == MODEL_LLM: + return config.get("quantization", {}).get("bits") + elif INPUT_FEATURES in config: + quantization_bits = [] + for feature in config.get(INPUT_FEATURES, []): + quantization_bits.append(feature.get(ENCODER, {}).get("quantization", {}).get("bits")) + return quantization_bits + else: + raise ValueError( + "Invalid config cannot be checked for quantization because it has no input features." + f"Config: {config}" + ) + else: + raise ValueError(f"Invalid config cannot be checked for quantization. Config: {config}") diff --git a/tests/ludwig/utils/test_config_utils.py b/tests/ludwig/utils/test_config_utils.py index 348abcb6667..a38d23415e2 100644 --- a/tests/ludwig/utils/test_config_utils.py +++ b/tests/ludwig/utils/test_config_utils.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional +import copy +from typing import Any, Dict, List, Optional, Union import pytest @@ -20,7 +21,7 @@ from ludwig.schema.encoders.utils import get_encoder_cls from ludwig.schema.features.preprocessing.text import TextPreprocessingConfig from ludwig.schema.model_config import ModelConfig -from ludwig.utils.config_utils import config_uses_llm +from ludwig.utils.config_utils import config_uses_llm, get_quantization @pytest.mark.parametrize( @@ -84,11 +85,6 @@ def llm_config_dict() -> Dict[str, Any]: } -@pytest.fixture(scope="module") -def llm_config_object(llm_config_dict: Dict[str, Any]) -> ModelConfig: - return ModelConfig.from_dict(llm_config_dict) - - @pytest.fixture(scope="module") def ecd_config_dict_llm_encoder() -> Dict[str, Any]: return { @@ -104,11 +100,6 @@ def ecd_config_dict_llm_encoder() -> Dict[str, Any]: } -@pytest.fixture(scope="module") -def ecd_config_object_llm_encoder(ecd_config_dict_llm_encoder: Dict[str, Any]) -> ModelConfig: - return ModelConfig.from_dict(ecd_config_dict_llm_encoder) - - @pytest.fixture(scope="module") def ecd_config_dict_llm_encoder_multiple_features() -> Dict[str, Any]: return { @@ -125,13 +116,6 @@ def ecd_config_dict_llm_encoder_multiple_features() -> Dict[str, Any]: } -@pytest.fixture(scope="module") -def ecd_config_object_llm_encoder_multiple_features( - ecd_config_dict_llm_encoder_multiple_features: Dict[str, Any] -) -> ModelConfig: - return ModelConfig.from_dict(ecd_config_dict_llm_encoder_multiple_features) - - @pytest.fixture(scope="module") def ecd_config_dict_no_llm_encoder() -> Dict[str, Any]: return { @@ -141,11 +125,6 @@ def ecd_config_dict_no_llm_encoder() -> Dict[str, Any]: } -@pytest.fixture(scope="module") -def ecd_config_object_no_llm_encoder(ecd_config_dict_no_llm_encoder: Dict[str, Any]) -> ModelConfig: - return ModelConfig.from_dict(ecd_config_dict_no_llm_encoder) - - @pytest.fixture(scope="module") def ecd_config_dict_no_text_features() -> Dict[str, Any]: return { @@ -155,11 +134,6 @@ def ecd_config_dict_no_text_features() -> Dict[str, Any]: } -@pytest.fixture(scope="module") -def ecd_config_object_no_text_features(ecd_config_dict_no_text_features: Dict[str, Any]) -> ModelConfig: - return ModelConfig.from_dict(ecd_config_dict_no_text_features) - - @pytest.fixture(scope="module") def gbm_config_dict() -> Dict[str, Any]: return { @@ -169,11 +143,6 @@ def gbm_config_dict() -> Dict[str, Any]: } -@pytest.fixture(scope="module") -def gbm_config_object(gbm_config_dict: Dict[str, Any]) -> ModelConfig: - return ModelConfig.from_dict(gbm_config_dict) - - @pytest.fixture(scope="module") def gbm_config_dict_no_text_features() -> Dict[str, Any]: return { @@ -183,38 +152,27 @@ def gbm_config_dict_no_text_features() -> Dict[str, Any]: } -@pytest.fixture(scope="module") -def gbm_config_object_no_text_features(gbm_config_dict_no_text_features: Dict[str, Any]) -> ModelConfig: - return ModelConfig.from_dict(gbm_config_dict_no_text_features) - - @pytest.mark.parametrize( "config,expectation", [ # LLM configurations ("llm_config_dict", True), - ("llm_config_object", True), # LLM encoder configurations ("ecd_config_dict_llm_encoder", True), - ("ecd_config_object_llm_encoder", True), # LLM encoder configurations, multiple features ("ecd_config_dict_llm_encoder_multiple_features", True), - ("ecd_config_object_llm_encoder_multiple_features", True), # ECD configuration with text feature and non-LLM encoder ("ecd_config_dict_no_llm_encoder", False), - ("ecd_config_object_no_llm_encoder", False), # ECD configuration with no text features ("ecd_config_dict_no_text_features", False), - ("ecd_config_object_no_text_features", False), # GBM configuration with text feature. "tf_idf" is the only valid text encoder ("gbm_config_dict", False), - ("gbm_config_object", False), # GBM configuration with no text features ("gbm_config_dict_no_text_features", False), - ("gbm_config_object_no_text_features", False), ], ) -def test_is_or_uses_llm(config, expectation, request): +@pytest.mark.parametrize("config_type", ["dict", "object"]) +def test_is_or_uses_llm(config: Dict[str, Any], expectation: bool, config_type, request): """Test LLM detection on a variety of configs. Configs that use an LLM anywhere should return True, otherwise False. @@ -224,6 +182,8 @@ def test_is_or_uses_llm(config, expectation, request): request: pytest `request` fixture """ config = request.getfixturevalue(config) + if config_type == "object": + config = ModelConfig.from_dict(config) assert config_uses_llm(config) == expectation @@ -238,3 +198,164 @@ def test_is_or_uses_llm_invalid_input(invalid_config): """ with pytest.raises(ValueError): config_uses_llm(invalid_config) + + +@pytest.fixture(scope="module") +def quantization_4bit_config() -> Dict[str, Any]: + return {"quantization": {"bits": 4}} + + +@pytest.fixture(scope="module") +def quantization_8bit_config() -> Dict[str, Any]: + return {"quantization": {"bits": 8}} + + +@pytest.fixture(scope="module") +def llm_config_dict_4bit(llm_config_dict: Dict[str, Any], quantization_4bit_config: Dict[str, Any]) -> Dict[str, Any]: + config = copy.deepcopy(llm_config_dict) + config.update(quantization_4bit_config) + return config + + +@pytest.fixture(scope="module") +def llm_config_dict_8bit(llm_config_dict: Dict[str, Any], quantization_8bit_config: Dict[str, Any]) -> Dict[str, Any]: + config = copy.deepcopy(llm_config_dict) + config.update(quantization_8bit_config) + return config + + +@pytest.fixture(scope="module") +def ecd_config_dict_llm_encoder_4bit( + ecd_config_dict_llm_encoder: Dict[str, Any], quantization_4bit_config: Dict[str, Any] +) -> Dict[str, Any]: + config = copy.deepcopy(ecd_config_dict_llm_encoder) + config[INPUT_FEATURES][0][ENCODER].update(quantization_4bit_config) + return config + + +@pytest.fixture(scope="module") +def ecd_config_dict_llm_encoder_8bit( + ecd_config_dict_llm_encoder: Dict[str, Any], quantization_8bit_config: Dict[str, Any] +) -> Dict[str, Any]: + config = copy.deepcopy(ecd_config_dict_llm_encoder) + config[INPUT_FEATURES][0][ENCODER].update(quantization_8bit_config) + return config + + +@pytest.mark.parametrize( + "config,expectation", + [ + # LLM configurations + ("llm_config_dict", None), + ("llm_config_dict_4bit", 4), + ("llm_config_dict_8bit", 8), + # LLM encoder configurations with one feature + ("ecd_config_dict_llm_encoder", [None]), + ("ecd_config_dict_llm_encoder_4bit", [4]), + ("ecd_config_dict_llm_encoder_8bit", [8]), + # GBM configuration with text feature. "tf_idf" is the only valid text encoder + ("gbm_config_dict", [None]), + # GBM configuration with no text features + ("gbm_config_dict_no_text_features", [None]), + ], +) +@pytest.mark.parametrize("config_type", ["dict", "object"]) +def test_get_quantization( + config: Dict[str, Any], expectation: Union[int, List[int], None, List[None]], config_type: str, request +): + """Test get_quantization with LLM and single-feature ECD/GBM configs. + + Args: + config: The configuration to test + expectation: The expected quantization + config_type: Whether to test the config as a dict or object + request: pytest builtin fixture + """ + config = request.getfixturevalue(config) + if config_type == "object": + config = ModelConfig.from_dict(config) + assert get_quantization(config) == expectation + + +TEST_FEATURE_CONFIGS = [ + ( + { + TYPE: BINARY, + }, + None, + ), + ( + { + TYPE: TEXT, + }, + None, + ), + ({TYPE: TEXT, ENCODER: {TYPE: MODEL_LLM, BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM"}}, None), + ( + { + TYPE: TEXT, + ENCODER: { + TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + "quantization": {"bits": 4}, + }, + }, + 4, + ), + ( + { + TYPE: TEXT, + ENCODER: { + TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + "quantization": {"bits": 8}, + }, + }, + 8, + ), +] + +TEST_FEATURE_CONFIGS_IDS = [BINARY, TEXT, MODEL_LLM, f"{MODEL_LLM}-4bit", f"{MODEL_LLM}-8bit"] + + +@pytest.mark.parametrize("feature1,quantization1", TEST_FEATURE_CONFIGS, ids=TEST_FEATURE_CONFIGS_IDS) +@pytest.mark.parametrize("feature2,quantization2", TEST_FEATURE_CONFIGS, ids=TEST_FEATURE_CONFIGS_IDS) +@pytest.mark.parametrize("config_type", ["dict", "object"]) +def test_get_quantization_multiple_features( + ecd_config_dict_llm_encoder_multiple_features: Dict[str, Any], + feature1: Dict[str, Any], + quantization1: int, + feature2: Dict[str, Any], + quantization2: int, + config_type: str, +): + """Test get_quantization with multiple features. + + Args: + ecd_config_dict_llm_encoder_multiple_features: Baseline config to add features to. + feature1: First input feature config dict + quantization1: First input feature expected quantization + feature2: Second input feature config dict + quantization2: Second input feature expected quantization + config_type: Whether to test the config as a dict or object + """ + config = copy.deepcopy(ecd_config_dict_llm_encoder_multiple_features) + feature1 = dict(name="in1", **feature1) + feature2 = dict(name="in2", **feature2) + config[INPUT_FEATURES] = [feature1, feature2] + + if config_type == "object": + config = ModelConfig.from_dict(config) + + assert get_quantization(config) == [quantization1, quantization2] + + +@pytest.mark.parametrize("invalid_config", [1, 1.0, "foo", True, False, None, [], {}, {"foo": "bar"}]) +def test_get_quantization_invalid_input(invalid_config): + """Test get_quantization with invalid configs. These should always raise a ValueError. + + Args: + invalid_config: The invalid config to test + """ + with pytest.raises(ValueError): + get_quantization(invalid_config)