Skip to content

Commit

Permalink
Fix coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
AjayP13 committed Apr 24, 2024
1 parent 68dfb8f commit c650606
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 17 deletions.
10 changes: 8 additions & 2 deletions src/llms/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,11 @@ def model(self) -> PreTrainedModel:
low_cpu_mem_usage=True,
torch_dtype=self.dtype,
attn_implementation=get_attn_implementation(
model_cls=auto_cls, model_kwargs=self.kwargs, optimize=True
model_name=self.model_name,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
model_kwargs=self.kwargs,
optimize=True,
),
device_map=to_device_map,
max_memory=to_device_map_max_memory,
Expand Down Expand Up @@ -425,7 +429,9 @@ def _run_batch( # noqa: C901
**kwargs,
)
if not use_pipeline:
if _is_petals(self) and "attention_mask" in model_inputs:
if (
_is_petals(self) and "attention_mask" in model_inputs
): # pragma: no cover
del model_inputs["attention_mask"]
outputs = model.generate(**model_inputs, **generation_kwargs)
texts = cached_tokenizer.batch_decode(
Expand Down
4 changes: 2 additions & 2 deletions src/llms/petals.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_ServerInferenceSession_step: None | Callable = None


def _is_batch_size_exception_func(e: BaseException) -> bool:
def _is_batch_size_exception_func(e: BaseException) -> bool: # pragma: no cover
from hivemind.p2p.p2p_daemon_bindings.utils import P2PHandlerError

return (
Expand Down Expand Up @@ -66,7 +66,7 @@ def _patched_on_request_failure(


@cache
def _monkey_patch_ServerInferenceSession_step():
def _monkey_patch_ServerInferenceSession_step(): # pragma: no cover
from hivemind.p2p.p2p_daemon_bindings.utils import P2PHandlerError # noqa: F401

try:
Expand Down
22 changes: 22 additions & 0 deletions src/tests/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,25 @@ def test_run(self, create_datadreamer):
llm.unload_model()
assert "model" not in llm.__dict__ and "tokenizer" not in llm.__dict__

def test_sdpa(self, create_datadreamer):
with create_datadreamer():
llm = HFTransformers("Qwen/Qwen1.5-0.5B-Chat")
generated_texts = llm.run(
[
"Question: What color is the sky?\nSingle-Word Answer:",
"Question: What color are trees?\nSingle-Word Answer:",
],
max_new_tokens=1,
temperature=0.0,
top_p=0.0,
n=1,
stop="Question:",
repetition_penalty=None,
logit_bias=None,
batch_size=2,
)
assert generated_texts == ["Blue", "Green"]

def test_adapter_metadata(self, create_datadreamer):
with create_datadreamer():
# Load LLaMa
Expand Down Expand Up @@ -2962,17 +2981,20 @@ def chat_mocked(**kwargs):

class TestPetals:
pydantic_version = None
bitsandbytes_version = None

@classmethod
def setup_class(cls):
cls.pydantic_version = importlib.metadata.version("pydantic")
cls.bitsandbytes_version = importlib.metadata.version("bitsandbytes")
os.system("pip3 install petals==2.2.0")
os.system("pip3 install 'pydantic>=1.10,<2.0'")
_reload_pydantic()

@classmethod
def teardown_class(cls):
os.system(f"pip3 install pydantic=={cls.pydantic_version}")
os.system(f"pip3 install bitsandbytes=={cls.bitsandbytes_version}")
_reload_pydantic()

@pytest.mark.xfail # Petals network is unreliable currently
Expand Down
16 changes: 12 additions & 4 deletions src/trainers/_train_hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,11 @@ def _create_model(
trust_remote_code=self.trust_remote_code,
torch_dtype=self.dtype,
attn_implementation=get_attn_implementation(
model_cls=self.auto_cls, model_kwargs=self.kwargs, optimize=True
model_name=self.model_name,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
model_kwargs=self.kwargs,
optimize=True,
),
device_map=to_device_map,
max_memory=to_device_map_max_memory,
Expand Down Expand Up @@ -1143,7 +1147,7 @@ def _create_model(
with ignore_transformers_warnings():
from peft import get_peft_model, prepare_model_for_kbit_training

if self.quantization_config:
if self.quantization_config: # pragma: no cover
model = prepare_model_for_kbit_training(model)
model = get_peft_model(
model, validate_peft_config(model=model, peft_config=self.peft_config)
Expand Down Expand Up @@ -1290,7 +1294,9 @@ def _load_model(
trust_remote_code=self.trust_remote_code,
torch_dtype=self.dtype,
attn_implementation=get_attn_implementation(
model_cls=self.auto_cls,
model_name=self.model_name,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
model_kwargs=self.kwargs,
optimize=with_optimizations,
),
Expand All @@ -1316,7 +1322,9 @@ def _load_model(
MODEL_DIR,
torch_dtype=self.dtype,
attn_implementation=get_attn_implementation(
model_cls=self.auto_cls,
model_name=self.model_name,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
model_kwargs=self.kwargs,
optimize=with_optimizations,
),
Expand Down
6 changes: 4 additions & 2 deletions src/utils/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def model_to_device(
revision=revision,
trust_remote_code=trust_remote_code,
quantization_config=quantization_config,
):
): # pragma: no cover
if is_train:
to_device_map = {
"": get_global_rank() if isinstance(device, list) else device
Expand Down Expand Up @@ -261,7 +261,9 @@ def model_to_device(
to_device_map_max_memory = None
else:
to_device = "cpu" if isinstance(device, list) else device
else:
else: # pragma: no cover
to_device_map = device_map
to_device_map_max_memory = None
if to_device is None and to_device_map is None and to_device_map_max_memory is None:
to_device = "cpu"
return to_device, to_device_map, to_device_map_max_memory
26 changes: 21 additions & 5 deletions src/utils/hf_model_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from copy import copy
from functools import cache
from typing import Any, Type
from typing import Any

import torch

from .arg_utils import Default, default_to
from .import_utils import ignore_transformers_warnings

with ignore_transformers_warnings():
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
Expand Down Expand Up @@ -233,7 +234,9 @@ def validate_quantization_config(
dtype: None | str | torch.dtype,
):
quantization_config = copy(quantization_config)
if getattr(quantization_config, "quant_method", None) == "bitsandbytes":
if (
getattr(quantization_config, "quant_method", None) == "bitsandbytes"
): # pragma: no cover
quantization_config.bnb_4bit_compute_dtype = dtype # type:ignore[union-attr]
quantization_config.bnb_4bit_quant_storage = dtype # type:ignore[union-attr]
return quantization_config
Expand All @@ -260,10 +263,23 @@ def peft_module_casting_to_dtype(model, dtype: None | str | torch.dtype):


def get_attn_implementation(
model_cls: Type, model_kwargs: dict[str, Any], optimize: bool
model_name: str,
revision: None | str,
trust_remote_code: bool,
model_kwargs: dict[str, Any],
optimize: bool,
):
model_config = get_config(
model_name=model_name, revision=revision, trust_remote_code=trust_remote_code
)
architecture = getattr(model_config, "architectures", ["CannotDetectArchitecture"])[
0
]
attn_implementation = "eager"
if getattr(model_cls, "_supports_sdpa", False) and optimize:
if (
getattr(getattr(transformers, architecture, object()), "_supports_sdpa", False)
and optimize
):
attn_implementation = "sdpa"
return model_kwargs.get("attn_implementation", attn_implementation)

Expand Down Expand Up @@ -297,6 +313,6 @@ def get_model_optional_kwargs(
quantization_config: None | QuantizationConfigMixin | dict,
) -> dict[str, Any]:
optional_kwargs = {}
if quantization_config is not None:
if quantization_config is not None: # pragma: no cover
optional_kwargs["quantization_config"] = quantization_config
return optional_kwargs
6 changes: 4 additions & 2 deletions src/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def ignore_training_warnings():
]:
model_logger = logging.getLogger(model_logger_name)

class NoUseCacheIsIncompatibleWarningFilter(logging.Filter):
class NoUseCacheIsIncompatibleWarningFilter(
logging.Filter
): # pragma: no cover
def filter(self, record):
return not record.getMessage().startswith(
"`use_cache=True` is incompatible with gradient checkpointing"
Expand Down Expand Up @@ -136,7 +138,7 @@ def ignore_faiss_warnings():


@contextlib.contextmanager
def ignore_hivemind_warnings():
def ignore_hivemind_warnings(): # pragma: no cover
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down

0 comments on commit c650606

Please sign in to comment.