-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e4c85a6
commit 6f37648
Showing
6 changed files
with
185 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
...df_wrapper_params/token_classification/error_prediction_containing_only_unknown_fields.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from pathlib import PurePosixPath | ||
from exasol_udf_mock_python.connection import Connection | ||
from tests.unit_tests.udf_wrapper_params.token_classification.make_data_row_functions import make_input_row, \ | ||
make_output_row, make_input_row_with_span, make_output_row_with_span, bucketfs_conn, \ | ||
text_docid, text_start, text_end, agg_strategy_simple, make_model_output_for_one_input_row, sub_dir, model_name | ||
|
||
|
||
class ErrorPredictionOnlyContainsUnknownFields: | ||
""" | ||
""" | ||
expected_model_counter = 1 | ||
batch_size = 2 | ||
data_size = 5 | ||
n_entities = 3 | ||
|
||
text_data = "error_result_contains_only_unknown fields" | ||
|
||
input_data = make_input_row(text_data=text_data) * data_size | ||
output_data = make_output_row(text_data=text_data, | ||
score=None, start=None, end=None, word=None, entity=None, | ||
error_msg="Traceback") * n_entities * data_size | ||
|
||
work_with_span_input_data = make_input_row_with_span(text_data=text_data) * data_size | ||
work_with_span_output_data = [(bucketfs_conn, sub_dir, model_name, | ||
text_docid, text_start, text_end, agg_strategy_simple, | ||
None, None, None, None, None, None, | ||
"Traceback")] * n_entities * data_size | ||
|
||
|
||
number_complete_batches = data_size // batch_size | ||
number_remaining_data_entries_in_last_batch = data_size % batch_size | ||
model_output_row_wrong_keys = [[{"unknown key": "some value", "diff unknown key": i}] for i in range(n_entities)] | ||
tokenizer_model_output_df_model1 = [model_output_row_wrong_keys * batch_size] * \ | ||
number_complete_batches + \ | ||
[model_output_row_wrong_keys * number_remaining_data_entries_in_last_batch] | ||
tokenizer_models_output_df = [tokenizer_model_output_df_model1] | ||
|
||
tmpdir_name = "_".join(("/tmpdir", __qualname__)) | ||
base_cache_dir1 = PurePosixPath(tmpdir_name, bucketfs_conn) | ||
bfs_connections = { | ||
bucketfs_conn: Connection(address=f"file://{base_cache_dir1}") | ||
} |
44 changes: 44 additions & 0 deletions
44
..._tests/udf_wrapper_params/token_classification/error_prediction_missing_expected_field.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from pathlib import PurePosixPath | ||
from exasol_udf_mock_python.connection import Connection | ||
from tests.unit_tests.udf_wrapper_params.token_classification.make_data_row_functions import make_input_row, \ | ||
make_output_row, make_input_row_with_span, make_output_row_with_span, bucketfs_conn, \ | ||
text_docid, text_start, text_end, agg_strategy_simple, make_model_output_for_one_input_row, sub_dir, model_name | ||
|
||
|
||
class ErrorPredictionMissingExpectedFields: | ||
""" | ||
""" | ||
expected_model_counter = 1 | ||
batch_size = 2 | ||
data_size = 5 | ||
n_entities = 3 | ||
|
||
text_data = "error_result_missing_field_'word'" #todo do we want tests for different combinations? seems like a lot of work | ||
# todo do we want tests for multiple models? multiple inputs where one works and one does not? how many test cases are to many test cases? | ||
# todo these should be moved to the base model tests together with the others | ||
|
||
input_data = make_input_row(text_data=text_data) * data_size | ||
output_data = make_output_row(text_data=text_data, score=None, error_msg="Traceback") * n_entities * data_size | ||
|
||
work_with_span_input_data = make_input_row_with_span(text_data=text_data) * data_size | ||
work_with_span_output_data = make_output_row_with_span(score=None, | ||
error_msg="Traceback") * n_entities * data_size | ||
|
||
|
||
number_complete_batches = data_size // batch_size | ||
number_remaining_data_entries_in_last_batch = data_size % batch_size | ||
|
||
model_output_row_missing_key = [[model_output_row[0].pop("score")] | ||
for model_output_row in make_model_output_for_one_input_row(number_entities=n_entities)] | ||
|
||
tokenizer_model_output_df_model1 = [model_output_row_missing_key * batch_size] * \ | ||
number_complete_batches + \ | ||
[model_output_row_missing_key * number_remaining_data_entries_in_last_batch] | ||
tokenizer_models_output_df = [tokenizer_model_output_df_model1] | ||
|
||
tmpdir_name = "_".join(("/tmpdir", __qualname__)) | ||
base_cache_dir1 = PurePosixPath(tmpdir_name, bucketfs_conn) | ||
bfs_connections = { | ||
bucketfs_conn: Connection(address=f"file://{base_cache_dir1}") | ||
} |
36 changes: 36 additions & 0 deletions
36
tests/unit_tests/udf_wrapper_params/token_classification/prediction_returns_empty_result.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from pathlib import PurePosixPath | ||
from exasol_udf_mock_python.connection import Connection | ||
from tests.unit_tests.udf_wrapper_params.token_classification.make_data_row_functions import make_input_row, \ | ||
make_output_row, make_input_row_with_span, make_output_row_with_span, bucketfs_conn, \ | ||
text_docid, text_start, text_end, agg_strategy_simple, make_model_output_for_one_input_row, sub_dir, model_name | ||
|
||
|
||
class PredictionReturnsEmptyResult: | ||
""" | ||
""" | ||
expected_model_counter = 1 | ||
batch_size = 4 | ||
data_size = 5 | ||
n_entities = 3 | ||
|
||
text_data = "error_result_empty" | ||
# todo throws error but meassage could be better | ||
# TODO mention in docu if result is empty row not in output | ||
input_data = make_input_row() * data_size + \ | ||
make_input_row(text_data=text_data) * data_size | ||
output_data = make_output_row() * n_entities * data_size # Result of input #2 is empty, so the row does not appear in the output | ||
|
||
work_with_span_input_data = make_input_row_with_span() * data_size + \ | ||
make_input_row_with_span(text_data=text_data) * data_size | ||
work_with_span_output_data = make_output_row_with_span() * n_entities * data_size # Result of input #2 is empty, so the row does not appear in the output | ||
|
||
# error on pred -> only one output per input | ||
tokenizer_model_output_df_model1 = make_model_output_for_one_input_row(number_entities=n_entities) * data_size | ||
tokenizer_models_output_df = [tokenizer_model_output_df_model1] | ||
|
||
tmpdir_name = "_".join(("/tmpdir", __qualname__)) | ||
base_cache_dir1 = PurePosixPath(tmpdir_name, bucketfs_conn) | ||
bfs_connections = { | ||
bucketfs_conn: Connection(address=f"file://{base_cache_dir1}") | ||
} |
37 changes: 37 additions & 0 deletions
37
tests/unit_tests/udf_wrapper_params/token_classification/result_contains_additional_keys.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from pathlib import PurePosixPath | ||
from exasol_udf_mock_python.connection import Connection | ||
from tests.unit_tests.udf_wrapper_params.token_classification.make_data_row_functions import make_input_row, \ | ||
make_output_row, make_input_row_with_span, make_output_row_with_span, bucketfs_conn, \ | ||
text_docid, text_start, text_end, agg_strategy_simple, make_model_output_for_one_input_row, sub_dir, model_name | ||
|
||
# todo do we wan to throw in this case? or just ignore additional results? | ||
|
||
class ErrorPredictionContainsAdditionalFields: | ||
""" | ||
""" | ||
expected_model_counter = 1 | ||
batch_size = 2 | ||
data_size = 2 | ||
n_entities = 3 | ||
|
||
text_data = "result contains additional keys" | ||
|
||
#todod these are not filled out | ||
input_data = make_input_row(text_data=text_data) * data_size | ||
output_data = make_output_row(text_data=text_data, error_msg="Traceback") * n_entities * data_size | ||
|
||
work_with_span_input_data = make_input_row_with_span(text_data=text_data) * data_size | ||
work_with_span_output_data = make_output_row_with_span(error_msg="Traceback") * n_entities * data_size | ||
|
||
model_output_rows = make_model_output_for_one_input_row(number_entities=n_entities) | ||
model_output_row_wrong_keys = [model_output_row[0].update({"unknown key": "some value", "diff unknown key": 1}) | ||
for model_output_row in model_output_rows] | ||
tokenizer_model_output_df_model1 = [model_output_row_wrong_keys * data_size] | ||
tokenizer_models_output_df = [tokenizer_model_output_df_model1] | ||
|
||
tmpdir_name = "_".join(("/tmpdir", __qualname__)) | ||
base_cache_dir1 = PurePosixPath(tmpdir_name, bucketfs_conn) | ||
bfs_connections = { | ||
bucketfs_conn: Connection(address=f"file://{base_cache_dir1}") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters