From a5671c4959aec7b7637ea8a777a28c03103f90f8 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Tue, 1 Aug 2023 15:22:30 -0500 Subject: [PATCH] Export helper as fixture instead of importing from test package (#147) --- tests/benchmark/mocks.py | 63 ------ tests/benchmark/test_benchmark.py | 68 +++++- tests/relational/conftest.py | 10 +- tests/relational/test_backup.py | 3 +- .../test_relational_data_with_json.py | 204 +++++++++--------- tests/relational/test_train_synthetics.py | 5 +- 6 files changed, 179 insertions(+), 174 deletions(-) delete mode 100644 tests/benchmark/mocks.py diff --git a/tests/benchmark/mocks.py b/tests/benchmark/mocks.py deleted file mode 100644 index e1d6f1ec..00000000 --- a/tests/benchmark/mocks.py +++ /dev/null @@ -1,63 +0,0 @@ -import pandas as pd -from gretel_client.projects.models import read_model_config - -from gretel_trainer.benchmark.core import Dataset -from gretel_trainer.benchmark.gretel.models import GretelModel - - -class DoNothingModel: - def train(self, source: Dataset, **kwargs) -> None: - pass - - def generate(self, **kwargs) -> pd.DataFrame: - return pd.DataFrame() - - -class FailsToTrain: - def train(self, source: Dataset, **kwargs) -> None: - raise Exception("failed") - - def generate(self, **kwargs) -> pd.DataFrame: - return pd.DataFrame() - - -class FailsToGenerate: - def train(self, source: Dataset, **kwargs) -> None: - pass - - def generate(self, **kwargs) -> pd.DataFrame: - raise Exception("failed") - - -class TailoredActgan(GretelModel): - @property - def config(self): - c = read_model_config("synthetics/tabular-actgan") - c["models"][0]["actgan"]["params"]["epochs"] = 100 - return c - - -class SharedDictLstm(GretelModel): - config = { - "schema_version": "1.0", - "name": "tabular-lstm", - "models": [ - { - "synthetics": { - "data_source": "__tmp__", - "params": { - "epochs": "auto", - "vocab_size": "auto", - "learning_rate": "auto", - "batch_size": "auto", - "rnn_units": "auto", - }, - "generate": {"num_records": 5000}, - "privacy_filters": { - "outliers": "auto", - "similarity": "auto", - }, - } - } - ], - } diff --git a/tests/benchmark/test_benchmark.py b/tests/benchmark/test_benchmark.py index 3fe4fb3b..9fb80683 100644 --- a/tests/benchmark/test_benchmark.py +++ b/tests/benchmark/test_benchmark.py @@ -7,6 +7,7 @@ import pandas.testing as pdtest import pytest from gretel_client.projects.jobs import Status +from gretel_client.projects.models import read_model_config from gretel_trainer.benchmark import ( BenchmarkConfig, @@ -17,13 +18,66 @@ create_dataset, launch, ) -from tests.benchmark.mocks import ( - DoNothingModel, - FailsToGenerate, - FailsToTrain, - SharedDictLstm, - TailoredActgan, -) +from gretel_trainer.benchmark.core import Dataset +from gretel_trainer.benchmark.gretel.models import GretelModel + + +class DoNothingModel: + def train(self, source: Dataset, **kwargs) -> None: + pass + + def generate(self, **kwargs) -> pd.DataFrame: + return pd.DataFrame() + + +class FailsToTrain: + def train(self, source: Dataset, **kwargs) -> None: + raise Exception("failed") + + def generate(self, **kwargs) -> pd.DataFrame: + return pd.DataFrame() + + +class FailsToGenerate: + def train(self, source: Dataset, **kwargs) -> None: + pass + + def generate(self, **kwargs) -> pd.DataFrame: + raise Exception("failed") + + +class TailoredActgan(GretelModel): + @property + def config(self): + c = read_model_config("synthetics/tabular-actgan") + c["models"][0]["actgan"]["params"]["epochs"] = 100 + return c + + +class SharedDictLstm(GretelModel): + config = { + "schema_version": "1.0", + "name": "tabular-lstm", + "models": [ + { + "synthetics": { + "data_source": "__tmp__", + "params": { + "epochs": "auto", + "vocab_size": "auto", + "learning_rate": "auto", + "batch_size": "auto", + "rnn_units": "auto", + }, + "generate": {"num_records": 5000}, + "privacy_filters": { + "outliers": "auto", + "similarity": "auto", + }, + } + } + ], + } def test_run_with_gretel_dataset(working_dir, project, evaluate_report_path, iris): diff --git a/tests/relational/conftest.py b/tests/relational/conftest.py index e62117fc..a18ceaa2 100644 --- a/tests/relational/conftest.py +++ b/tests/relational/conftest.py @@ -2,7 +2,7 @@ import sqlite3 import tempfile from pathlib import Path -from typing import Generator +from typing import Callable, Generator from unittest.mock import Mock, patch import pandas as pd @@ -32,8 +32,12 @@ def static_suffix(request): yield make_suffix -# Doesn't work well as a fixture due to the need for an input param -def get_invented_table_suffix(make_suffix_execution_number: int): +@pytest.fixture() +def get_invented_table_suffix() -> Callable[[int], str]: + return _get_invented_table_suffix + + +def _get_invented_table_suffix(make_suffix_execution_number: int): return f"invented_{str(make_suffix_execution_number)}" diff --git a/tests/relational/test_backup.py b/tests/relational/test_backup.py index 01ab7f7f..48917a8f 100644 --- a/tests/relational/test_backup.py +++ b/tests/relational/test_backup.py @@ -11,7 +11,6 @@ BackupSyntheticsTrain, BackupTransformsTrain, ) -from tests.relational.conftest import get_invented_table_suffix def test_backup_relational_data(trips): @@ -37,7 +36,7 @@ def test_backup_relational_data(trips): assert BackupRelationalData.from_relational_data(trips) == expected -def test_backup_relational_data_with_json(documents): +def test_backup_relational_data_with_json(documents, get_invented_table_suffix): purchases_root_invented_table = f"purchases_{get_invented_table_suffix(1)}" purchases_data_years_invented_table = f"purchases_{get_invented_table_suffix(2)}" diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index 5c1df56f..4c6c4ffc 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -8,13 +8,17 @@ from gretel_trainer.relational.core import ForeignKey, RelationalData, Scope from gretel_trainer.relational.json import generate_unique_table_name, get_json_columns -from tests.relational.conftest import get_invented_table_suffix -purchases_root_invented_table = f"purchases_{get_invented_table_suffix(1)}" -purchases_data_years_invented_table = f"purchases_{get_invented_table_suffix(2)}" -bball_root_invented_table = f"bball_{get_invented_table_suffix(1)}" -bball_suspensions_invented_table = f"bball_{get_invented_table_suffix(2)}" -bball_teams_invented_table = f"bball_{get_invented_table_suffix(3)}" + +@pytest.fixture +def invented_tables(get_invented_table_suffix) -> dict[str, str]: + return { + "purchases_root": f"purchases_{get_invented_table_suffix(1)}", + "purchases_data_years": f"purchases_{get_invented_table_suffix(2)}", + "bball_root": f"bball_{get_invented_table_suffix(1)}", + "bball_suspensions": f"bball_{get_invented_table_suffix(2)}", + "bball_teams": f"bball_{get_invented_table_suffix(3)}", + } @pytest.fixture @@ -55,9 +59,9 @@ def test_list_json_cols(documents, bball): } -def test_json_columns_produce_invented_flattened_tables(documents): +def test_json_columns_produce_invented_flattened_tables(documents, invented_tables): pdtest.assert_frame_equal( - documents.get_table_data(purchases_root_invented_table), + documents.get_table_data(invented_tables["purchases_root"]), pd.DataFrame( data={ "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], @@ -72,7 +76,7 @@ def test_json_columns_produce_invented_flattened_tables(documents): ) pdtest.assert_frame_equal( - documents.get_table_data(purchases_data_years_invented_table), + documents.get_table_data(invented_tables["purchases_data_years"]), pd.DataFrame( data={ "content": [2023, 2023, 2022, 2020, 2019, 2021], @@ -85,17 +89,17 @@ def test_json_columns_produce_invented_flattened_tables(documents): check_dtype=False, # Without this, test fails asserting dtype mismatch in `content` field (object vs. int) ) - assert documents.get_foreign_keys(purchases_data_years_invented_table) == [ + assert documents.get_foreign_keys(invented_tables["purchases_data_years"]) == [ ForeignKey( - table_name=purchases_data_years_invented_table, + table_name=invented_tables["purchases_data_years"], columns=["purchases~id"], - parent_table_name=purchases_root_invented_table, + parent_table_name=invented_tables["purchases_root"], parent_columns=["~PRIMARY_KEY_ID~"], ) ] -def test_list_tables_accepts_various_scopes(documents): +def test_list_tables_accepts_various_scopes(documents, invented_tables): # PUBLIC reflects the user's source assert set(documents.list_all_tables(Scope.PUBLIC)) == { "users", @@ -107,21 +111,21 @@ def test_list_tables_accepts_various_scopes(documents): assert set(documents.list_all_tables(Scope.MODELABLE)) == { "users", "payments", - purchases_root_invented_table, - purchases_data_years_invented_table, + invented_tables["purchases_root"], + invented_tables["purchases_data_years"], } # EVALUATABLE is similar to MODELABLE, but omits invented child tables—we only evaluate the root invented table assert set(documents.list_all_tables(Scope.EVALUATABLE)) == { "users", "payments", - purchases_root_invented_table, + invented_tables["purchases_root"], } # INVENTED returns only tables invented from source tables with JSON assert set(documents.list_all_tables(Scope.INVENTED)) == { - purchases_root_invented_table, - purchases_data_years_invented_table, + invented_tables["purchases_root"], + invented_tables["purchases_data_years"], } # ALL returns every table name, including both source-with-JSON tables and those invented from such tables @@ -129,8 +133,8 @@ def test_list_tables_accepts_various_scopes(documents): "users", "purchases", "payments", - purchases_root_invented_table, - purchases_data_years_invented_table, + invented_tables["purchases_root"], + invented_tables["purchases_data_years"], } # Default scope is MODELABLE @@ -139,59 +143,59 @@ def test_list_tables_accepts_various_scopes(documents): ) -def test_get_modelable_table_names(documents): +def test_get_modelable_table_names(documents, invented_tables): # Given a source-with-JSON name, returns the tables invented from that source assert set(documents.get_modelable_table_names("purchases")) == { - purchases_root_invented_table, - purchases_data_years_invented_table, + invented_tables["purchases_root"], + invented_tables["purchases_data_years"], } # Invented tables are modelable - assert documents.get_modelable_table_names(purchases_root_invented_table) == [ - purchases_root_invented_table - ] - assert documents.get_modelable_table_names(purchases_data_years_invented_table) == [ - purchases_data_years_invented_table + assert documents.get_modelable_table_names(invented_tables["purchases_root"]) == [ + invented_tables["purchases_root"] ] + assert documents.get_modelable_table_names( + invented_tables["purchases_data_years"] + ) == [invented_tables["purchases_data_years"]] # Unknown tables return empty list assert documents.get_modelable_table_names("nonsense") == [] -def test_get_modelable_names_ignores_empty_mapped_tables(bball): +def test_get_modelable_names_ignores_empty_mapped_tables(bball, invented_tables): # The `suspensions` column in the source data contained empty lists for all records. # The normalization process transforms that into a standalone, empty table. # We need to hold onto that table name to support denormalizing back to the original # source data shape. It is therefore present when listing ALL tables... assert set(bball.list_all_tables(Scope.ALL)) == { "bball", - bball_root_invented_table, - bball_suspensions_invented_table, - bball_teams_invented_table, + invented_tables["bball_root"], + invented_tables["bball_suspensions"], + invented_tables["bball_teams"], } # ...and the producer metadata is aware of it... assert set(bball.get_producer_metadata("bball").table_names) == { - bball_root_invented_table, - bball_suspensions_invented_table, - bball_teams_invented_table, + invented_tables["bball_root"], + invented_tables["bball_suspensions"], + invented_tables["bball_teams"], } # ...BUT most clients only care about invented tables that can be modeled # (i.e. that contain data), so the empty table does not appear in these contexts: assert set(bball.get_modelable_table_names("bball")) == { - bball_root_invented_table, - bball_teams_invented_table, + invented_tables["bball_root"], + invented_tables["bball_teams"], } assert set(bball.list_all_tables()) == { - bball_root_invented_table, - bball_teams_invented_table, + invented_tables["bball_root"], + invented_tables["bball_teams"], } -def test_invented_json_column_names_documents(documents): +def test_invented_json_column_names_documents(documents, invented_tables): # The root invented table adds columns for dictionary properties lifted from nested JSON objects - assert documents.get_table_columns(purchases_root_invented_table) == [ + assert documents.get_table_columns(invented_tables["purchases_root"]) == [ "~PRIMARY_KEY_ID~", "id", "user_id", @@ -202,7 +206,7 @@ def test_invented_json_column_names_documents(documents): # JSON lists lead to invented child tables. These tables store the original content, # a new primary key, a foreign key back to the parent, and the original array index - assert documents.get_table_columns(purchases_data_years_invented_table) == [ + assert documents.get_table_columns(invented_tables["purchases_data_years"]) == [ "~PRIMARY_KEY_ID~", "purchases~id", "content", @@ -210,9 +214,9 @@ def test_invented_json_column_names_documents(documents): ] -def test_invented_json_column_names_bball(bball): +def test_invented_json_column_names_bball(bball, invented_tables): # If the source table does not have a primary key defined, one is created on the root invented table - assert bball.get_table_columns(bball_root_invented_table) == [ + assert bball.get_table_columns(invented_tables["bball_root"]) == [ "~PRIMARY_KEY_ID~", "name", "age", @@ -221,11 +225,11 @@ def test_invented_json_column_names_bball(bball): ] -def test_set_some_primary_key_to_none(static_suffix, documents): +def test_set_some_primary_key_to_none(static_suffix, documents, invented_tables): # The producer table has a single column primary key, # so the root invented table has a composite key that includes the source PK and an invented column assert documents.get_primary_key("purchases") == ["id"] - assert documents.get_primary_key(purchases_root_invented_table) == [ + assert documents.get_primary_key(invented_tables["purchases_root"]) == [ "id", "~PRIMARY_KEY_ID~", ] @@ -242,25 +246,25 @@ def test_set_some_primary_key_to_none(static_suffix, documents): documents.set_primary_key(table="purchases", primary_key=None) assert len(documents.list_all_tables(Scope.ALL)) == 5 assert documents.get_primary_key("purchases") == [] - assert documents.get_primary_key(purchases_root_invented_table) == [ + assert documents.get_primary_key(invented_tables["purchases_root"]) == [ "~PRIMARY_KEY_ID~" ] - assert documents.get_foreign_keys(purchases_data_years_invented_table) == [ + assert documents.get_foreign_keys(invented_tables["purchases_data_years"]) == [ ForeignKey( - table_name=purchases_data_years_invented_table, + table_name=invented_tables["purchases_data_years"], columns=["purchases~id"], - parent_table_name=purchases_root_invented_table, + parent_table_name=invented_tables["purchases_root"], parent_columns=["~PRIMARY_KEY_ID~"], ) ] assert documents.get_foreign_keys("payments") == original_payments_fks -def test_set_none_primary_key_to_some_value(static_suffix, bball): +def test_set_none_primary_key_to_some_value(static_suffix, bball, invented_tables): # The producer table has no primary key, # so the root invented table has a single invented key column assert bball.get_primary_key("bball") == [] - assert bball.get_primary_key(bball_root_invented_table) == ["~PRIMARY_KEY_ID~"] + assert bball.get_primary_key(invented_tables["bball_root"]) == ["~PRIMARY_KEY_ID~"] # Setting a None primary key to some column puts us in the correct state assert len(bball.list_all_tables(Scope.ALL)) == 4 @@ -272,30 +276,30 @@ def test_set_none_primary_key_to_some_value(static_suffix, bball): bball.set_primary_key(table="bball", primary_key="name") assert len(bball.list_all_tables(Scope.ALL)) == 4 assert bball.get_primary_key("bball") == ["name"] - assert bball.get_primary_key(bball_root_invented_table) == [ + assert bball.get_primary_key(invented_tables["bball_root"]) == [ "name", "~PRIMARY_KEY_ID~", ] - assert bball.get_foreign_keys(bball_suspensions_invented_table) == [ + assert bball.get_foreign_keys(invented_tables["bball_suspensions"]) == [ ForeignKey( - table_name=bball_suspensions_invented_table, + table_name=invented_tables["bball_suspensions"], columns=["bball~id"], - parent_table_name=bball_root_invented_table, + parent_table_name=invented_tables["bball_root"], parent_columns=["~PRIMARY_KEY_ID~"], ) ] -def test_foreign_keys(documents): +def test_foreign_keys(documents, invented_tables): # Foreign keys from the source-with-JSON table are present on the root invented table assert documents.get_foreign_keys("purchases") == documents.get_foreign_keys( - purchases_root_invented_table + invented_tables["purchases_root"] ) # The root invented table name is used in the ForeignKey assert documents.get_foreign_keys("purchases") == [ ForeignKey( - table_name=purchases_root_invented_table, + table_name=invented_tables["purchases_root"], columns=["user_id"], parent_table_name="users", parent_columns=["id"], @@ -303,11 +307,11 @@ def test_foreign_keys(documents): ] # Invented children point to invented parents - assert documents.get_foreign_keys(purchases_data_years_invented_table) == [ + assert documents.get_foreign_keys(invented_tables["purchases_data_years"]) == [ ForeignKey( - table_name=purchases_data_years_invented_table, + table_name=invented_tables["purchases_data_years"], columns=["purchases~id"], - parent_table_name=purchases_root_invented_table, + parent_table_name=invented_tables["purchases_root"], parent_columns=["~PRIMARY_KEY_ID~"], ) ] @@ -317,7 +321,7 @@ def test_foreign_keys(documents): ForeignKey( table_name="payments", columns=["purchase_id"], - parent_table_name=purchases_root_invented_table, + parent_table_name=invented_tables["purchases_root"], parent_columns=["id"], ) ] @@ -345,10 +349,12 @@ def test_foreign_keys(documents): table="purchases", constrained_columns=["user_id"] ) assert documents.get_foreign_keys("purchases") == [] - assert documents.get_foreign_keys(purchases_root_invented_table) == [] + assert documents.get_foreign_keys(invented_tables["purchases_root"]) == [] -def test_update_data_with_existing_json_to_new_json(static_suffix, documents): +def test_update_data_with_existing_json_to_new_json( + static_suffix, documents, invented_tables +): new_purchases_jsonl = """ {"id": 1, "user_id": 1, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} {"id": 2, "user_id": 2, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} @@ -369,7 +375,7 @@ def test_update_data_with_existing_json_to_new_json(static_suffix, documents): assert len(documents.list_all_tables(Scope.MODELABLE)) == 4 expected = { - purchases_root_invented_table: pd.DataFrame( + invented_tables["purchases_root"]: pd.DataFrame( data={ "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], "id": [1, 2, 3, 4, 5, 6], @@ -393,7 +399,7 @@ def test_update_data_with_existing_json_to_new_json(static_suffix, documents): ], } ), - purchases_data_years_invented_table: pd.DataFrame( + invented_tables["purchases_data_years"]: pd.DataFrame( data={ "content": [1999, 1999, 1999, 1998, 1998, 1998], "array~order": [0, 0, 0, 0, 0, 0], @@ -404,14 +410,14 @@ def test_update_data_with_existing_json_to_new_json(static_suffix, documents): } pdtest.assert_frame_equal( - documents.get_table_data(purchases_root_invented_table), - expected[purchases_root_invented_table], + documents.get_table_data(invented_tables["purchases_root"]), + expected[invented_tables["purchases_root"]], check_like=True, ) pdtest.assert_frame_equal( - documents.get_table_data(purchases_data_years_invented_table), - expected[purchases_data_years_invented_table], + documents.get_table_data(invented_tables["purchases_data_years"]), + expected[invented_tables["purchases_data_years"]], check_like=True, check_dtype=False, # Without this, test fails asserting dtype mismatch in `content` field (object vs. int) ) @@ -421,7 +427,7 @@ def test_update_data_with_existing_json_to_new_json(static_suffix, documents): ForeignKey( table_name="payments", columns=["purchase_id"], - parent_table_name=purchases_root_invented_table, + parent_table_name=invented_tables["purchases_root"], parent_columns=["id"], ) ] @@ -456,7 +462,7 @@ def test_update_data_existing_json_to_no_json(documents): ] -def test_update_data_existing_flat_to_json(static_suffix, documents): +def test_update_data_existing_flat_to_json(static_suffix, documents, invented_tables): # Build up a RelationalData instance that basically mirrors documents, # but purchases is flat to start and thus there are no RelationalJson instances flat_purchases_df = pd.DataFrame( @@ -498,22 +504,24 @@ def test_update_data_existing_flat_to_json(static_suffix, documents): assert set(rel_data.list_all_tables(Scope.ALL)) == { "users", "purchases", - purchases_root_invented_table, - purchases_data_years_invented_table, + invented_tables["purchases_root"], + invented_tables["purchases_data_years"], "payments", } # the original purchases table is no longer flat, nor (therefore) MODELABLE assert set(rel_data.list_all_tables(Scope.MODELABLE)) == { "users", - purchases_root_invented_table, - purchases_data_years_invented_table, + invented_tables["purchases_root"], + invented_tables["purchases_data_years"], "payments", } assert rel_data.get_foreign_keys("payments") == [ ForeignKey( table_name="payments", columns=["purchase_id"], - parent_table_name=purchases_root_invented_table, # The foreign key now points to the root invented table + parent_table_name=invented_tables[ + "purchases_root" + ], # The foreign key now points to the root invented table parent_columns=["id"], ) ] @@ -521,7 +529,7 @@ def test_update_data_existing_flat_to_json(static_suffix, documents): # Simulates output tables from MultiTable transforms or synthetics, which will only include the MODELABLE tables @pytest.fixture() -def mt_output_tables(): +def mt_output_tables(invented_tables): return { "users": pd.DataFrame( data={ @@ -536,7 +544,7 @@ def mt_output_tables(): "purchase_id": [1, 2, 3, 4], } ), - purchases_root_invented_table: pd.DataFrame( + invented_tables["purchases_root"]: pd.DataFrame( data={ "~PRIMARY_KEY_ID~": [0, 1, 2, 3], "id": [1, 2, 3, 4], @@ -546,7 +554,7 @@ def mt_output_tables(): "data>details>color": ["blue", "yellow", "pink", "orange"], } ), - purchases_data_years_invented_table: pd.DataFrame( + invented_tables["purchases_data_years"]: pd.DataFrame( data={ "content": [2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007], "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5, 6, 7], @@ -603,15 +611,17 @@ def test_restoring_output_tables_to_original_shape(documents, mt_output_tables): pdtest.assert_frame_equal(df, expected[t]) -def test_restore_with_incomplete_tableset(documents, mt_output_tables): +def test_restore_with_incomplete_tableset(documents, mt_output_tables, invented_tables): without_invented_root = { - k: v for k, v in mt_output_tables.items() if k != purchases_root_invented_table + k: v + for k, v in mt_output_tables.items() + if k != invented_tables["purchases_root"] } without_invented_child = { k: v for k, v in mt_output_tables.items() - if k != purchases_data_years_invented_table + if k != invented_tables["purchases_data_years"] } restored_without_invented_root = documents.restore(without_invented_root) @@ -672,9 +682,9 @@ def test_restore_with_incomplete_tableset(documents, mt_output_tables): ) -def test_restore_with_empty_tables(bball): +def test_restore_with_empty_tables(bball, invented_tables): synthetic_bball_output_tables = { - bball_root_invented_table: pd.DataFrame( + invented_tables["bball_root"]: pd.DataFrame( data={ "name": ["Jimmy Butler"], "age": [33], @@ -683,7 +693,7 @@ def test_restore_with_empty_tables(bball): "~PRIMARY_KEY_ID~": [0], } ), - bball_teams_invented_table: pd.DataFrame( + invented_tables["bball_teams"]: pd.DataFrame( data={ "content": ["Bulls", "Timberwolves", "Sixers", "Heat"], "array~order": [0, 1, 2, 3], @@ -703,7 +713,7 @@ def test_restore_with_empty_tables(bball): assert jimmy["suspensions"] == [] -def test_flatten_and_restore_all_sorts_of_json(tmpdir): +def test_flatten_and_restore_all_sorts_of_json(tmpdir, get_invented_table_suffix): json = """ [ { @@ -843,7 +853,7 @@ def test_only_lists_edge_case(tmpdir): assert rel_data.list_all_tables(Scope.ALL) == [] -def test_lists_of_lists(tmpdir): +def test_lists_of_lists(tmpdir, get_invented_table_suffix): # Enough flat data in the source to create a root invented table. # Upping the complexity by making the special value a list of lists, # but not to fear: we can handle this correctly. @@ -892,7 +902,7 @@ def test_lists_of_lists(tmpdir): ) -def test_mix_of_dict_and_list_cols(tmpdir): +def test_mix_of_dict_and_list_cols(tmpdir, get_invented_table_suffix): df = pd.DataFrame( data={ "id": [1, 2], @@ -922,7 +932,7 @@ def test_mix_of_dict_and_list_cols(tmpdir): ] -def test_all_tables_are_present_in_debug_summary(documents): +def test_all_tables_are_present_in_debug_summary(documents, invented_tables): assert documents.debug_summary() == { "foreign_key_count": 4, "max_depth": 2, @@ -943,7 +953,7 @@ def test_all_tables_are_present_in_debug_summary(documents): "foreign_keys": [ { "columns": ["purchase_id"], - "parent_table_name": purchases_root_invented_table, + "parent_table_name": invented_tables["purchases_root"], "parent_columns": ["id"], } ], @@ -964,12 +974,12 @@ def test_all_tables_are_present_in_debug_summary(documents): "invented_table_details": { "table_type": "producer", "json_to_table_mappings": { - "purchases": purchases_root_invented_table, - "purchases^data>years": purchases_data_years_invented_table, + "purchases": invented_tables["purchases_root"], + "purchases^data>years": invented_tables["purchases_data_years"], }, }, }, - purchases_root_invented_table: { + invented_tables["purchases_root"]: { "column_count": 6, "primary_key": ["id", "~PRIMARY_KEY_ID~"], "foreign_key_count": 1, @@ -986,14 +996,14 @@ def test_all_tables_are_present_in_debug_summary(documents): "json_breadcrumb_path": "purchases", }, }, - purchases_data_years_invented_table: { + invented_tables["purchases_data_years"]: { "column_count": 4, "primary_key": ["~PRIMARY_KEY_ID~"], "foreign_key_count": 1, "foreign_keys": [ { "columns": ["purchases~id"], - "parent_table_name": purchases_root_invented_table, + "parent_table_name": invented_tables["purchases_root"], "parent_columns": ["~PRIMARY_KEY_ID~"], } ], diff --git a/tests/relational/test_train_synthetics.py b/tests/relational/test_train_synthetics.py index 2d1ef919..fac0bcef 100644 --- a/tests/relational/test_train_synthetics.py +++ b/tests/relational/test_train_synthetics.py @@ -5,7 +5,6 @@ from gretel_trainer.relational.core import MultiTableException from gretel_trainer.relational.multi_table import MultiTable -from tests.relational.conftest import get_invented_table_suffix # The assertions in this file are concerned with setting up the synthetics train @@ -206,7 +205,9 @@ def test_train_synthetics_multiple_calls_additive(ecom, tmpdir): assert set(mt._synthetics_train.models.keys()) == {"products", "users"} -def test_train_synthetics_models_for_dbs_with_invented_tables(documents, tmpdir): +def test_train_synthetics_models_for_dbs_with_invented_tables( + documents, tmpdir, get_invented_table_suffix +): mt = MultiTable(documents, project_display_name=tmpdir) mt.train_synthetics()