Skip to content

Commit

Permalink
#264 fixed unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsimb committed Oct 16, 2024
1 parent 58f1370 commit ece981e
Show file tree
Hide file tree
Showing 12 changed files with 15 additions and 17 deletions.
7 changes: 3 additions & 4 deletions exasol_transformers_extension/udfs/models/base_model_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pandas as pd
import numpy as np
import transformers
from exasol.python_extension_common.connections.bucketfs_location import (
create_bucketfs_location_from_conn_object)
import exasol.python_extension_common.connections.bucketfs_location as bfs_loc

from exasol_transformers_extension.deployment import constants
from exasol_transformers_extension.utils import device_management, dataframe_operations
Expand Down Expand Up @@ -186,8 +185,8 @@ def check_cache(self, model_df: pd.DataFrame) -> None:
current_model_specification = BucketFSModelSpecification(model_name, self.task_type, bucketfs_conn, sub_dir)

if self.model_loader.current_model_specification != current_model_specification:
bucketfs_location = \
create_bucketfs_location_from_conn_object(self.exa.get_connection(bucketfs_conn))
bucketfs_location = bfs_loc.create_bucketfs_location_from_conn_object(
self.exa.get_connection(bucketfs_conn))

self.model_loader.clear_device_memory()
self.model_loader.set_current_model_specification(current_model_specification)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Tuple

import transformers
from exasol.python_extension_common.connections.bucketfs_location import (
create_bucketfs_location_from_conn_object)
import exasol.python_extension_common.connections.bucketfs_location as bfs_loc

from exasol_transformers_extension.utils.bucketfs_model_specification import \
BucketFSModelSpecificationFactory
Expand Down Expand Up @@ -64,7 +63,7 @@ def _download_model(self, ctx) -> Tuple[str, str]:

# create bucketfs location
bfs_conn_obj = self._exa.get_connection(bfs_conn)
bucketfs_location = create_bucketfs_location_from_conn_object(bfs_conn_obj)
bucketfs_location = bfs_loc.create_bucketfs_location_from_conn_object(bfs_conn_obj)

# download base model and tokenizer into the model path
with self._huggingface_hub_bucketfs_model_transfer.create(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_base_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def setup_tests_and_run(bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):
("all given", "test_bucketfs_con_name", Connection(address=f"file:///test"),
"test_subdir", "test_model")
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_model_downloader_all_parameters(mock_local_path, mock_create_loc, description,
bucketfs_conn_name, bucketfs_conn, sub_dir, model_name):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_filling_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_mock_metadata(udf_wrapper):
ErrorOnPredictionSingleModelMultipleBatch,
ErrorOnPredictionMultipleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_filling_mask(mock_local_path, mock_create_loc, params):

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_model_downloader_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def udf_wrapper():
('without token', '', None, False),
('with token', 'conn_name', Connection(address="", password="valid"), "valid"),
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
def test_model_downloader(mock_create_loc, description, count, token_conn_name, token_conn_obj, expected_token):

mock_tokenizer_factory: Union[ModelFactoryProtocol, MagicMock] = create_autospec(ModelFactoryProtocol)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def create_mock_metadata(udf_wrapper):
ErrorOnPredictionSingleModelMultipleBatch,
ErrorOnPredictionMultipleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_question_answering(mock_local_path, mock_create_loc, params):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create_mock_metadata(udf_wrapper):
ErrorNotCachedSingleModelMultipleBatch,
ErrorOnPredictionSingleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_sequence_classification_single_text(mock_local_path, mock_create_loc, params):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def create_mock_metadata(udf_wrapper):
ErrorNotCachedSingleModelMultipleBatch,
ErrorOnPredictionSingleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_sequence_classification_text_pair(mock_local_path, mock_create_loc, params):

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def create_mock_metadata(udf_wrapper):
ErrorOnPredictionSingleModelMultipleBatch,
ErrorOnPredictionMultipleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_text_generation(mock_local_path, mock_create_loc, params):

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def create_mock_metadata(udf_wrapper):
ErrorOnPredictionMultipleModelMultipleBatch,
ErrorOnPredictionSingleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_token_classification(mock_local_path, mock_create_loc, params):

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def create_mock_metadata(udf_wrapper):
ErrorOnPredictionMultipleModelMultipleBatch,
ErrorOnPredictionSingleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_translation(mock_local_path, mock_create_loc, params):

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/udfs/test_zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create_mock_metadata(udf_wrapper):
ErrorOnPredictionMultipleModelMultipleBatch,
ErrorOnPredictionSingleModelMultipleBatch
])
@patch('exasol_transformers_extension.utils.bucketfs_operations.create_bucketfs_location_from_conn_object')
@patch('exasol.python_extension_common.connections.bucketfs_location.create_bucketfs_location_from_conn_object')
@patch('exasol_transformers_extension.utils.bucketfs_operations.get_local_bucketfs_path')
def test_zero_shot(mock_local_path, mock_create_loc, params):

Expand Down

0 comments on commit ece981e

Please sign in to comment.