From 8466f7ffef22b8e93e8b8f6a6e4177808a35689a Mon Sep 17 00:00:00 2001 From: MarleneKress79789 Date: Fri, 22 Nov 2024 11:51:10 +0100 Subject: [PATCH] revert accidentally committed file --- .../udfs/test_token_classification_udf.py | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/integration_tests/without_db/udfs/test_token_classification_udf.py b/tests/integration_tests/without_db/udfs/test_token_classification_udf.py index 81f94416..46fbe2e5 100644 --- a/tests/integration_tests/without_db/udfs/test_token_classification_udf.py +++ b/tests/integration_tests/without_db/udfs/test_token_classification_udf.py @@ -7,8 +7,6 @@ from exasol_transformers_extension.udfs.models.token_classification_udf import \ TokenClassificationUDF -from exasol_transformers_extension.utils.model_specification import ModelSpecification -from tests.fixtures.model_fixture_utils import prepare_model_for_local_bucketfs from tests.integration_tests.without_db.udfs.matcher import Result, ShapeMatcher, NewColumnsEmptyMatcher, \ ErrorMessageMatcher, NoErrorMessageMatcher, ColumnsMatcher @@ -127,31 +125,27 @@ def test_token_classification_udf( ("on GPU with single input with max aggregation", 0, 1, "max") ]) def test_token_classification_udf_with_span( - description, device_id, n_rows, agg,tmpdir_factory, + description, device_id, n_rows, agg, prepare_token_classification_model_for_local_bucketfs): if device_id is not None and not torch.cuda.is_available(): pytest.skip(f"There is no available device({device_id}) " f"to execute the test") - model_spec = ModelSpecification("guishe/nuner-v2_fewnerd_fine_super", "token-classification") - bucketfs_path = prepare_model_for_local_bucketfs(model_spec, tmpdir_factory) - bucketfs_base_path = prepare_token_classification_model_for_local_bucketfs bucketfs_conn_name = "bucketfs_connection" - bucketfs_connection = create_mounted_bucketfs_connection(bucketfs_path) - text_data = "Foreign governments may be spying on your smartphone notifications, senator says. Washington (CNN) — Foreign governments have reportedly attempted to spy on iPhone and Android users through the mobile app notifications they receive on their smartphones - and the US government has forced Apple and Google to keep quiet about it, according to a top US senator. Through legal demands sent to the tech giants, governments have allegedly tried to force Apple and Google to turn over sensitive information that could include the contents of a notification - such as previews of a text message displayed on a lock screen, or an update about app activity, Oregon Democratic Sen. Ron Wyden said in a new report. Wyden''s report reflects the latest example of long-running tensions between tech companies and governments over law enforcement demands, which have stretched on for more than a decade. Governments around the world have particularly battled with tech companies over encryption, which provides critical protections to users and businesses while in some cases preventing law enforcement from pursuing investigations into messages sent over the internet.'" - text_data2 = "This is a test." + bucketfs_connection = create_mounted_bucketfs_connection(bucketfs_base_path) + batch_size = 2 sample_data = [( None, bucketfs_conn_name, model_params.sub_dir, - 'guishe/nuner-v2_fewnerd_fine_super', - text_data2, + model_params.token_model_specs.model_name, + model_params.text_data * (i + 1), i, 0, len(model_params.text_data), - "simple" + agg ) for i in range(n_rows)] columns = [ 'device_id', @@ -183,15 +177,12 @@ def test_token_classification_udf_with_span( sequence_classifier.run(ctx) result_df = ctx.get_emitted()[0][0] - with pd.option_context('display.max_rows', None, 'display.max_columns', None): # more options can be specified also - - print(result_df) new_columns = \ ['entity_covered_text', 'entity_type', 'score', 'entity_docid', 'entity_char_begin', 'entity_char_end', 'error_message'] result = Result(result_df) - assert (False and + assert ( result == ColumnsMatcher(columns=old_columns, new_columns=new_columns) and result == NoErrorMessageMatcher() ) @@ -226,7 +217,7 @@ def test_token_classification_udf_with_multiple_aggregation_strategies( 'device_id', 'bucketfs_conn', 'sub_dir', - 'guishe/nuner-v2_fewnerd_fine_super', + 'model_name', 'text_data', 'aggregation_strategy' ]