From c650606a3f5bff9c819fbb8f53b76f2164236aaf Mon Sep 17 00:00:00 2001 From: Ajay Patel Date: Wed, 24 Apr 2024 18:19:41 -0400 Subject: [PATCH] Fix coverage --- src/llms/hf_transformers.py | 10 ++++++++-- src/llms/petals.py | 4 ++-- src/tests/llms/test_llms.py | 22 ++++++++++++++++++++++ src/trainers/_train_hf_base.py | 16 ++++++++++++---- src/utils/device_utils.py | 6 ++++-- src/utils/hf_model_utils.py | 26 +++++++++++++++++++++----- src/utils/import_utils.py | 6 ++++-- 7 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/llms/hf_transformers.py b/src/llms/hf_transformers.py index 2245bad..5c64030 100644 --- a/src/llms/hf_transformers.py +++ b/src/llms/hf_transformers.py @@ -239,7 +239,11 @@ def model(self) -> PreTrainedModel: low_cpu_mem_usage=True, torch_dtype=self.dtype, attn_implementation=get_attn_implementation( - model_cls=auto_cls, model_kwargs=self.kwargs, optimize=True + model_name=self.model_name, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + model_kwargs=self.kwargs, + optimize=True, ), device_map=to_device_map, max_memory=to_device_map_max_memory, @@ -425,7 +429,9 @@ def _run_batch( # noqa: C901 **kwargs, ) if not use_pipeline: - if _is_petals(self) and "attention_mask" in model_inputs: + if ( + _is_petals(self) and "attention_mask" in model_inputs + ): # pragma: no cover del model_inputs["attention_mask"] outputs = model.generate(**model_inputs, **generation_kwargs) texts = cached_tokenizer.batch_decode( diff --git a/src/llms/petals.py b/src/llms/petals.py index d15d9df..0e36501 100644 --- a/src/llms/petals.py +++ b/src/llms/petals.py @@ -20,7 +20,7 @@ _ServerInferenceSession_step: None | Callable = None -def _is_batch_size_exception_func(e: BaseException) -> bool: +def _is_batch_size_exception_func(e: BaseException) -> bool: # pragma: no cover from hivemind.p2p.p2p_daemon_bindings.utils import P2PHandlerError return ( @@ -66,7 +66,7 @@ def _patched_on_request_failure( @cache -def _monkey_patch_ServerInferenceSession_step(): +def _monkey_patch_ServerInferenceSession_step(): # pragma: no cover from hivemind.p2p.p2p_daemon_bindings.utils import P2PHandlerError # noqa: F401 try: diff --git a/src/tests/llms/test_llms.py b/src/tests/llms/test_llms.py index 0cb4b9a..0db4ab8 100644 --- a/src/tests/llms/test_llms.py +++ b/src/tests/llms/test_llms.py @@ -1523,6 +1523,25 @@ def test_run(self, create_datadreamer): llm.unload_model() assert "model" not in llm.__dict__ and "tokenizer" not in llm.__dict__ + def test_sdpa(self, create_datadreamer): + with create_datadreamer(): + llm = HFTransformers("Qwen/Qwen1.5-0.5B-Chat") + generated_texts = llm.run( + [ + "Question: What color is the sky?\nSingle-Word Answer:", + "Question: What color are trees?\nSingle-Word Answer:", + ], + max_new_tokens=1, + temperature=0.0, + top_p=0.0, + n=1, + stop="Question:", + repetition_penalty=None, + logit_bias=None, + batch_size=2, + ) + assert generated_texts == ["Blue", "Green"] + def test_adapter_metadata(self, create_datadreamer): with create_datadreamer(): # Load LLaMa @@ -2962,10 +2981,12 @@ def chat_mocked(**kwargs): class TestPetals: pydantic_version = None + bitsandbytes_version = None @classmethod def setup_class(cls): cls.pydantic_version = importlib.metadata.version("pydantic") + cls.bitsandbytes_version = importlib.metadata.version("bitsandbytes") os.system("pip3 install petals==2.2.0") os.system("pip3 install 'pydantic>=1.10,<2.0'") _reload_pydantic() @@ -2973,6 +2994,7 @@ def setup_class(cls): @classmethod def teardown_class(cls): os.system(f"pip3 install pydantic=={cls.pydantic_version}") + os.system(f"pip3 install bitsandbytes=={cls.bitsandbytes_version}") _reload_pydantic() @pytest.mark.xfail # Petals network is unreliable currently diff --git a/src/trainers/_train_hf_base.py b/src/trainers/_train_hf_base.py index 21f6af1..abcdbd3 100644 --- a/src/trainers/_train_hf_base.py +++ b/src/trainers/_train_hf_base.py @@ -1108,7 +1108,11 @@ def _create_model( trust_remote_code=self.trust_remote_code, torch_dtype=self.dtype, attn_implementation=get_attn_implementation( - model_cls=self.auto_cls, model_kwargs=self.kwargs, optimize=True + model_name=self.model_name, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + model_kwargs=self.kwargs, + optimize=True, ), device_map=to_device_map, max_memory=to_device_map_max_memory, @@ -1143,7 +1147,7 @@ def _create_model( with ignore_transformers_warnings(): from peft import get_peft_model, prepare_model_for_kbit_training - if self.quantization_config: + if self.quantization_config: # pragma: no cover model = prepare_model_for_kbit_training(model) model = get_peft_model( model, validate_peft_config(model=model, peft_config=self.peft_config) @@ -1290,7 +1294,9 @@ def _load_model( trust_remote_code=self.trust_remote_code, torch_dtype=self.dtype, attn_implementation=get_attn_implementation( - model_cls=self.auto_cls, + model_name=self.model_name, + revision=self.revision, + trust_remote_code=self.trust_remote_code, model_kwargs=self.kwargs, optimize=with_optimizations, ), @@ -1316,7 +1322,9 @@ def _load_model( MODEL_DIR, torch_dtype=self.dtype, attn_implementation=get_attn_implementation( - model_cls=self.auto_cls, + model_name=self.model_name, + revision=self.revision, + trust_remote_code=self.trust_remote_code, model_kwargs=self.kwargs, optimize=with_optimizations, ), diff --git a/src/utils/device_utils.py b/src/utils/device_utils.py index dc006ac..6afad75 100644 --- a/src/utils/device_utils.py +++ b/src/utils/device_utils.py @@ -227,7 +227,7 @@ def model_to_device( revision=revision, trust_remote_code=trust_remote_code, quantization_config=quantization_config, - ): + ): # pragma: no cover if is_train: to_device_map = { "": get_global_rank() if isinstance(device, list) else device @@ -261,7 +261,9 @@ def model_to_device( to_device_map_max_memory = None else: to_device = "cpu" if isinstance(device, list) else device - else: + else: # pragma: no cover to_device_map = device_map to_device_map_max_memory = None + if to_device is None and to_device_map is None and to_device_map_max_memory is None: + to_device = "cpu" return to_device, to_device_map, to_device_map_max_memory diff --git a/src/utils/hf_model_utils.py b/src/utils/hf_model_utils.py index e6900d3..0a89f3d 100644 --- a/src/utils/hf_model_utils.py +++ b/src/utils/hf_model_utils.py @@ -1,6 +1,6 @@ from copy import copy from functools import cache -from typing import Any, Type +from typing import Any import torch @@ -8,6 +8,7 @@ from .import_utils import ignore_transformers_warnings with ignore_transformers_warnings(): + import transformers from transformers import ( AutoConfig, AutoTokenizer, @@ -233,7 +234,9 @@ def validate_quantization_config( dtype: None | str | torch.dtype, ): quantization_config = copy(quantization_config) - if getattr(quantization_config, "quant_method", None) == "bitsandbytes": + if ( + getattr(quantization_config, "quant_method", None) == "bitsandbytes" + ): # pragma: no cover quantization_config.bnb_4bit_compute_dtype = dtype # type:ignore[union-attr] quantization_config.bnb_4bit_quant_storage = dtype # type:ignore[union-attr] return quantization_config @@ -260,10 +263,23 @@ def peft_module_casting_to_dtype(model, dtype: None | str | torch.dtype): def get_attn_implementation( - model_cls: Type, model_kwargs: dict[str, Any], optimize: bool + model_name: str, + revision: None | str, + trust_remote_code: bool, + model_kwargs: dict[str, Any], + optimize: bool, ): + model_config = get_config( + model_name=model_name, revision=revision, trust_remote_code=trust_remote_code + ) + architecture = getattr(model_config, "architectures", ["CannotDetectArchitecture"])[ + 0 + ] attn_implementation = "eager" - if getattr(model_cls, "_supports_sdpa", False) and optimize: + if ( + getattr(getattr(transformers, architecture, object()), "_supports_sdpa", False) + and optimize + ): attn_implementation = "sdpa" return model_kwargs.get("attn_implementation", attn_implementation) @@ -297,6 +313,6 @@ def get_model_optional_kwargs( quantization_config: None | QuantizationConfigMixin | dict, ) -> dict[str, Any]: optional_kwargs = {} - if quantization_config is not None: + if quantization_config is not None: # pragma: no cover optional_kwargs["quantization_config"] = quantization_config return optional_kwargs diff --git a/src/utils/import_utils.py b/src/utils/import_utils.py index bfc594d..539eb43 100644 --- a/src/utils/import_utils.py +++ b/src/utils/import_utils.py @@ -77,7 +77,9 @@ def ignore_training_warnings(): ]: model_logger = logging.getLogger(model_logger_name) - class NoUseCacheIsIncompatibleWarningFilter(logging.Filter): + class NoUseCacheIsIncompatibleWarningFilter( + logging.Filter + ): # pragma: no cover def filter(self, record): return not record.getMessage().startswith( "`use_cache=True` is incompatible with gradient checkpointing" @@ -136,7 +138,7 @@ def ignore_faiss_warnings(): @contextlib.contextmanager -def ignore_hivemind_warnings(): +def ignore_hivemind_warnings(): # pragma: no cover with warnings.catch_warnings(): warnings.filterwarnings( "ignore",