Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#124: Better error messages in case of missing parameters #148

Merged
merged 7 commits into from
Nov 2, 2023
Merged
93 changes: 46 additions & 47 deletions tests/unit_tests/udfs/test_base_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,30 @@
import re


class regex_matcher: #this one feels more correct
"""Assert that a given string meets some expectations."""
def __init__(self, pattern, flags=0):
self._regex = re.compile(pattern, flags)

def is_in(self, actual):
return bool(self._regex.search(actual))

def __repr__(self):
return self._regex.pattern


class regex_matcher2: #this one looks nicer when used and gives better error messages
"""Assert that a given string meets some expectations."""
def __init__(self, some_sting):
self._string = some_sting

def __contains__(self, pattern):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The eq in the linked example has a reason. This gives you a good diff. You used it then

assert actual_value == regex_matcher("abc.*")

return bool(re.search(pattern, self._string))

def __repr__(self):
return self._string


def create_mock_metadata() -> MockMetaData:
def udf_wrapper():
pass
Expand Down Expand Up @@ -42,12 +66,7 @@ def udf_wrapper():
return meta


@pytest.mark.parametrize(["description", "bucketfs_conn_name", "bucketfs_conn",
"sub_dir", "model_name"], [
("all given", "test_bucketfs_con_name", Connection(address=f"file:///test"),
"test_subdir", "test_model")
])
def test_model_downloader_all_parameters(description, bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):
def test_setup(description, bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):
mock_base_model_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
mock_tokenizer_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)

Expand All @@ -72,12 +91,23 @@ def test_model_downloader_all_parameters(description, bucketfs_conn_name, bucket
'',
None)
mock_ctx = create_mock_udf_context(input_data, mock_meta)

udf = DummyImplementationUDF(exa=mock_exa,
base_model=mock_base_model_factory,
tokenizer=mock_tokenizer_factory)
udf.run(mock_ctx)
res = mock_ctx.output
return res, mock_meta



@pytest.mark.parametrize(["description", "bucketfs_conn_name", "bucketfs_conn",
"sub_dir", "model_name"], [
("all given", "test_bucketfs_con_name", Connection(address=f"file:///test"),
"test_subdir", "test_model")
])
def test_model_downloader_all_parameters(description, bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):

res, mock_meta = test_setup(description, bucketfs_conn_name, bucketfs_conn, sub_dir, model_name)
# check if no errors
assert res[0][-1] is None and len(res[0]) == len(mock_meta.output_columns)

Expand All @@ -95,47 +125,16 @@ def test_model_downloader_all_parameters(description, bucketfs_conn_name, bucket
"test_subdir", None)
])
def test_model_downloader_missing_parameters(description, bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):
mock_base_model_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
mock_tokenizer_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)

mock_bucketfs_factory: Union[BucketFSFactory, MagicMock] = create_autospec(BucketFSFactory)
mock_bucketfs_locations = [Mock()]
mock_cast(mock_bucketfs_factory.create_bucketfs_location).side_effect = mock_bucketfs_locations

input_data = [
(
1,
model_name,
sub_dir,
bucketfs_conn_name,
''
),
(
1,
model_name,
sub_dir,
bucketfs_conn_name,
''
)
]
mock_meta = create_mock_metadata()
mock_exa = create_mock_exa_environment(
[bucketfs_conn_name],
[bucketfs_conn],
mock_meta,
'',
None)
mock_ctx = create_mock_udf_context(input_data, mock_meta)

udf = DummyImplementationUDF(exa=mock_exa,
base_model=mock_base_model_factory.side_effect,
tokenizer=mock_tokenizer_factory.side_effect)
res, mock_meta = test_setup(description, bucketfs_conn_name, bucketfs_conn, sub_dir, model_name)

udf.run(mock_ctx)
res = mock_ctx.output
error_field = res[0][-1]
pattern = f"For each model model_name, bucketfs_conn and sub_dir need to be provided. Found model_name = " \
f"{model_name}, bucketfs_conn = .*, sub_dir = {sub_dir}."
expected_error = regex_matcher(f"For each model model_name, bucketfs_conn and sub_dir need to be provided."
f" Found model_name = {model_name}, bucketfs_conn = .*, sub_dir = {sub_dir}.")
assert expected_error.is_in(error_field)

expected_error2 =f"For each model model_name, bucketfs_conn and sub_dir need to be provided. " \
f"Found model_name = {model_name}, bucketfs_conn = .*, sub_dir = {sub_dir}."
error_field_matcher = regex_matcher2(error_field)
assert expected_error2 in error_field_matcher

assert re.search(pattern, error_field)
assert error_field is not None and len(res[0]) == len(mock_meta.output_columns)