Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti committed Nov 17, 2023
1 parent 1a5dde8 commit 55f168e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
1 change: 1 addition & 0 deletions caikit_nlp/modules/text_generation/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,5 +268,6 @@ def get_lora_config(tuning_type, tuning_config, base_model) -> LoraConfig:
log.info("<NLP61012781I>", f"Parameters used: {config_kwargs}")
config_params = _filter_params_for_prompt_config(tuning_config, config_kwargs)
output_model_types = _get_output_types(tuning_config, base_model)
del config_params["output_model_types"]
lora_config = LoraConfig(task_type=task_type, **config_params)
return task_type, output_model_types, lora_config, tuning_type
45 changes: 43 additions & 2 deletions tests/modules/text_generation/test_peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from unittest.mock import Mock

# Third Party
from peft import PromptTuningConfig
from peft import LoraConfig, PromptTuningConfig
import pytest

# Local
from caikit_nlp.data_model import TuningConfig
from caikit_nlp.data_model import LoraTuningConfig, TuningConfig
from caikit_nlp.modules.text_generation import TextGeneration
from caikit_nlp.modules.text_generation.peft_config import (
TuningType,
get_lora_config,
get_peft_config,
resolve_base_model,
)
Expand Down Expand Up @@ -74,6 +75,46 @@ def test_get_peft_config(train_kwargs, dummy_model, request):
assert peft_config.prompt_tuning_init_text == tuning_config.prompt_tuning_init_text


@pytest.mark.parametrize(
"train_kwargs,dummy_model",
[
(
"seq2seq_lm_train_kwargs",
"seq2seq_lm_dummy_model",
),
("causal_lm_train_kwargs", "causal_lm_dummy_model"),
],
)
def test_get_lora_config(train_kwargs, dummy_model, request):
# Fixtures can't be called directly or passed to mark parametrize;
# Currently, passing the fixture by name and retrieving it through
# the request is the 'right' way to do this.
train_kwargs = request.getfixturevalue(train_kwargs)
dummy_model = request.getfixturevalue(dummy_model)

# Define some sample values for testing
tuning_type = TuningType.LORA
tuning_config = LoraTuningConfig(r=8, lora_alpha=8, lora_dropout=0.0)
dummy_resource = train_kwargs["base_model"]

# Call the function being tested
task_type, output_model_types, lora_config, tuning_type = get_lora_config(
tuning_type, tuning_config, dummy_resource
)

# Add assertions to validate the behavior of the function
assert task_type == dummy_resource.TASK_TYPE
assert output_model_types == dummy_resource.PROMPT_OUTPUT_TYPES
assert tuning_type == TuningType.LORA

# Validation for type & important fields in the peft config
assert isinstance(lora_config, LoraConfig)
assert lora_config.task_type == dummy_resource.TASK_TYPE
assert lora_config.r == tuning_config.r
assert lora_config.lora_alpha == tuning_config.lora_alpha
assert lora_config.lora_dropout == tuning_config.lora_dropout


def test_resolve_model_with_invalid_path_raises():
"""Test passing invalid path to resolve_model function raises"""

Expand Down

0 comments on commit 55f168e

Please sign in to comment.