Skip to content

Commit

Permalink
cleanup and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MarleneKress79789 committed Nov 14, 2024
1 parent fd06cd8 commit 2593dfa
Showing 1 changed file with 102 additions and 69 deletions.
171 changes: 102 additions & 69 deletions tests/unit_tests/udfs/test_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
from unittest.mock import patch, MagicMock, create_autospec

import pytest
import transformers
from exasol_udf_mock_python.column import Column
from exasol_udf_mock_python.mock_context import StandaloneMockContext
from exasol_udf_mock_python.mock_exa_environment import MockExaEnvironment
from exasol_udf_mock_python.mock_meta_data import MockMetaData
from transformers import Pipeline
from transformers import Pipeline, AutoModel

from exasol_transformers_extension.udfs.models.token_classification_udf import TokenClassificationUDF
from exasol_transformers_extension.utils.model_factory_protocol import ModelFactoryProtocol
from tests.unit_tests.udfs.output_matcher import Output, OutputMatcher
from tests.utils.mock_bucketfs_location import fake_bucketfs_location_from_conn_object, fake_local_bucketfs_path
from tests.utils.mock_cast import mock_cast

# test params:
from tests.unit_tests.udf_wrapper_params.token_classification.error_not_cached_multiple_model_multiple_batch import \
ErrorNotCachedMultipleModelMultipleBatch
from tests.unit_tests.udf_wrapper_params.token_classification.error_not_cached_single_model_multiple_batch import \
Expand Down Expand Up @@ -50,9 +54,6 @@
from tests.unit_tests.udf_wrapper_params.token_classification.single_model_single_batch_incomplete import \
SingleModelSingleBatchIncomplete

from tests.unit_tests.udfs.output_matcher import Output, OutputMatcher
from tests.utils.mock_bucketfs_location import (fake_bucketfs_location_from_conn_object, fake_local_bucketfs_path)
from tests.utils.mock_cast import mock_cast

def udf_wrapper_empty():
# placeholder to use for MockMetaData creation.
Expand Down Expand Up @@ -123,6 +124,66 @@ def create_mock_metadata():
)
return meta

# todo these functions should be reusable for the other unit tests. should we move them to a utils file or something?
def create_db_mocks(bfs_connection, model_input_data, mock_meta):
mock_ctx = StandaloneMockContext(inp=model_input_data, metadata=mock_meta)
mock_exa = MockExaEnvironment(
metadata=mock_meta,
connections=bfs_connection)
return mock_ctx, mock_exa

def create_mock_model_factorys(number_of_intended_used_models):
"""
Creates mocks for transformers.AutoModel and gives them to mocks a base_model_factory_mock as side_effect.
This way mock_base_model_factory can the return a mock_model when called by the udf.
In test cases where we expect the model loading to fail, we create only expected model, and then try loading
more which results in no model being returned triggering our exception.
mock_tokenizer_factory does not need to return anything for our tests.
"""
mock_tokenizer_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
mock_base_model_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol,
_name="mock_base_model_factory")
mock_models: List[Union[AutoModel, MagicMock]] = [
create_autospec(AutoModel) for i in range (0,number_of_intended_used_models)
]
mock_cast(mock_base_model_factory.from_pretrained).side_effect = mock_models

return mock_base_model_factory, mock_tokenizer_factory

def create_mock_pipeline_factory(tokenizer_models_output_df, number_of_intended_used_models):
"""
Creates a mock pipeline (Normally created form model and tokenizer, then called with the data and outputs results).
Ths mock gets a list of tokenizer_models_outputs as side_effect, enabling it to return them in order when called.
if the specific tokenizer_models_output is a non-valid result (outputs are None), we give a
list containing an Exception instead, so the mock can throw to test error_on_prediction.
This mock_pipeline is feed into a mock_pipeline_factory.
"""
mock_pipeline: List[Union[AutoModel, MagicMock]] = [
create_autospec(Pipeline, side_effect=tokenizer_models_output_df[i]) if tokenizer_models_output_df[i][0][0][0]["word"]
else [Exception("Traceback mock_pipeline is throwing an error intentionally")] # todo we could probably put this exception into the tokenizer_models_output_df in the params files instead
for i in range(0, number_of_intended_used_models)
]

mock_pipeline_factory: Union[Pipeline, MagicMock] = create_autospec(Pipeline,
side_effect=mock_pipeline)
return mock_pipeline_factory

def assert_correct_number_of_results(result, output_columns, output_data):
assert len(result[0]) == len(output_columns), (f"Number of columns in result is {len(result[0])},"
f"not as expected {len(output_columns)}")
assert len(result) == len(output_data), (f"Number of lines in result is {len(result)}, "
f"not as expected {len(output_data)}")

def assert_result_matches_expected_output(result, expected_output_data, input_columns):
expected_output = Output(expected_output_data)
actual_output = Output(result)
n_input_columns = len(input_columns) - 1
assert OutputMatcher(actual_output, n_input_columns) == expected_output, ("OutputMatcher found expected_output_data and reult not matching:"
f"expected_output_data: \n"
f"{expected_output_data}\n"
f"actual_output_data: \n"
f"{actual_output}")


@pytest.mark.parametrize("params", [
SingleModelSingleBatchIncomplete,
Expand All @@ -149,37 +210,27 @@ def create_mock_metadata():
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_token_classification_with_span(mock_local_path, mock_create_loc, params):
"""
This test checks combinations of input data to determine correct output data. For this everything the udf uses in
the background is mocked, and given to the udf. we then check if the resulting output matches the expected output.
"""
mock_create_loc.side_effect = fake_bucketfs_location_from_conn_object
mock_local_path.side_effect = fake_local_bucketfs_path

mock_meta = create_mock_metadata_with_span()
input = params.work_with_span_input_data
mock_ctx = StandaloneMockContext(inp=input, metadata=mock_meta)
mock_exa = MockExaEnvironment(
metadata=mock_meta,
connections=params.bfs_connections)
model_input_data = params.work_with_span_input_data
bfs_connection = params.bfs_connections
expected_model_counter = params.expected_model_counter
tokenizer_models_output_df = params.tokenizer_models_output_df
batch_size = params.batch_size
expected_output_data = params.work_with_span_output_data

mock_base_model_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol,
_name="mock_base_model_factory")
number_of_intended_used_models = params.expected_model_counter
mock_models: List[Union[transformers.AutoModel, MagicMock]] = [
create_autospec(transformers.AutoModel) for i in range (0,number_of_intended_used_models)
]
print(mock_models)
mock_cast(mock_base_model_factory.from_pretrained).side_effect = mock_models

mock_tokenizer_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)

mock_pipeline: List[Union[transformers.AutoModel, MagicMock]] = [
create_autospec(Pipeline, side_effect=params.tokenizer_models_output_df[i]) if params.tokenizer_models_output_df[i][0][0][0]["word"]
else [Exception("Traceback mock_pipeline is throwing an error intentionally")]
for i in range(0, number_of_intended_used_models)
]
mock_meta = create_mock_metadata_with_span()
mock_ctx, mock_exa = create_db_mocks(bfs_connection, model_input_data, mock_meta)
mock_base_model_factory, mock_tokenizer_factory = create_mock_model_factorys(expected_model_counter)
mock_pipeline_factory = create_mock_pipeline_factory(tokenizer_models_output_df, expected_model_counter)

mock_pipeline_factory: Union[Pipeline, MagicMock] = create_autospec(Pipeline,
side_effect=mock_pipeline)
udf = TokenClassificationUDF(exa=mock_exa,
batch_size=params.batch_size,
batch_size=batch_size,
base_model=mock_base_model_factory,
tokenizer=mock_tokenizer_factory,
pipeline=mock_pipeline_factory,
Expand All @@ -188,14 +239,9 @@ def test_token_classification_with_span(mock_local_path, mock_create_loc, params
udf.run(mock_ctx)
result = mock_ctx.output

assert len(result[0]) == len(mock_meta.output_columns)
assert len(result) == len(params.work_with_span_output_data)

expected_output = Output(params.work_with_span_output_data)
actual_output = Output(result)
n_input_columns = len(mock_meta.input_columns) - 1
assert (OutputMatcher(actual_output, n_input_columns) == expected_output and
len(mock_pipeline_factory.mock_calls) == params.expected_model_counter)
assert_correct_number_of_results(result, mock_meta.output_columns, expected_output_data)
assert_result_matches_expected_output(result, expected_output_data, mock_meta.input_columns)
assert len(mock_pipeline_factory.mock_calls) == expected_model_counter



Expand Down Expand Up @@ -223,47 +269,34 @@ def test_token_classification_with_span(mock_local_path, mock_create_loc, params
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_token_classification(mock_local_path, mock_create_loc, params):
"""
This test checks combinations of input data to determine correct output data. For this everything the udf uses in
the background is mocked, and given to the udf. we then check if the resulting output matches the expected output.
"""
mock_create_loc.side_effect = fake_bucketfs_location_from_conn_object
mock_local_path.side_effect = fake_local_bucketfs_path

mock_meta = create_mock_metadata()
input = params.input_data
mock_ctx = StandaloneMockContext(inp=input, metadata=mock_meta)
mock_exa = MockExaEnvironment(
metadata=mock_meta,
connections=params.bfs_connections)

mock_base_model_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol,
_name="mock_base_model_factory")
number_of_intended_used_models = params.expected_model_counter
mock_models: List[Union[transformers.AutoModel, MagicMock]] = [
create_autospec(transformers.AutoModel) for i in range (0,number_of_intended_used_models)
]
mock_cast(mock_base_model_factory.from_pretrained).side_effect = mock_models
model_input_data = params.input_data
bfs_connection = params.bfs_connections
expected_model_counter = params.expected_model_counter
tokenizer_models_output_df = params.tokenizer_models_output_df
batch_size = params.batch_size
expected_output_data = params.output_data

mock_tokenizer_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
mock_meta = create_mock_metadata()
mock_ctx, mock_exa = create_db_mocks(bfs_connection, model_input_data, mock_meta)
mock_base_model_factory, mock_tokenizer_factory = create_mock_model_factorys(expected_model_counter)
mock_pipeline_factory = create_mock_pipeline_factory(tokenizer_models_output_df, expected_model_counter)

mock_pipeline: List[Union[transformers.AutoModel, MagicMock]] = [
create_autospec(Pipeline, side_effect=params.tokenizer_models_output_df[i]) if params.tokenizer_models_output_df[i][0][0][0]["word"]
else [Exception("Traceback mock_pipeline is throwing an error intentionally")]
for i in range(0, number_of_intended_used_models)
]
mock_pipeline_factory: Union[Pipeline, MagicMock] = create_autospec(Pipeline,
side_effect=mock_pipeline)
udf = TokenClassificationUDF(exa=mock_exa,
batch_size=params.batch_size,
batch_size=batch_size,
base_model=mock_base_model_factory,
tokenizer=mock_tokenizer_factory,
pipeline=mock_pipeline_factory)

udf.run(mock_ctx)
result = mock_ctx.output

assert len(result[0]) == len(mock_meta.output_columns)
assert len(result) == len(params.work_with_span_output_data)

expected_output = Output(params.output_data)
actual_output = Output(result)
n_input_columns = len(mock_meta.input_columns) - 1
assert (OutputMatcher(actual_output, n_input_columns) == expected_output and
len(mock_pipeline_factory.mock_calls) == params.expected_model_counter)
assert_correct_number_of_results(result, mock_meta.output_columns, expected_output_data)
assert_result_matches_expected_output(result, expected_output_data, mock_meta.input_columns)
assert len(mock_pipeline_factory.mock_calls) == expected_model_counter

0 comments on commit 2593dfa

Please sign in to comment.