From 348ea4b8c0fa6e46b59b5660ef2c50cf534fd8d4 Mon Sep 17 00:00:00 2001 From: Gretel Team Date: Wed, 6 Mar 2024 12:18:34 -0600 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 49a15cf0dfff4f12cc865ea2e2ed7be14bbf448b --- requirements.txt | 2 +- src/gretel_trainer/benchmark/executor.py | 7 ++++++- src/gretel_trainer/benchmark/session.py | 4 ++++ src/gretel_trainer/relational/__init__.py | 10 ---------- tests/benchmark/test_benchmark.py | 12 ++++++++++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/requirements.txt b/requirements.txt index 24e9f1fc..d35e5c4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ boto3~=1.20 dask[dataframe]==2023.5.1 -gretel-client>=0.17.7 +gretel-client>=0.17.5 jinja2~=3.1 networkx~=3.0 numpy~=1.20 diff --git a/src/gretel_trainer/benchmark/executor.py b/src/gretel_trainer/benchmark/executor.py index 2ad1099b..c3ddd594 100644 --- a/src/gretel_trainer/benchmark/executor.py +++ b/src/gretel_trainer/benchmark/executor.py @@ -1,7 +1,7 @@ import logging from enum import Enum -from typing import Optional, Protocol +from typing import Callable, Optional, Protocol from gretel_client.projects.models import Model from gretel_client.projects.projects import Project @@ -66,11 +66,13 @@ def __init__( run_identifier: str, evaluate_project: Project, config: BenchmarkConfig, + snapshot: Callable[[], None], ): self.strategy = strategy self.run_identifier = run_identifier self.evaluate_project = evaluate_project self.config = config + self.snapshot = snapshot self.status = Status.NotStarted self.exception: Optional[Exception] = None @@ -81,10 +83,13 @@ def run(self) -> None: self._maybe_skip() if self.status.can_proceed: self._train() + self.snapshot() if self.status.can_proceed: self._generate() + self.snapshot() if self.status.can_proceed: self._evaluate() + self.snapshot() def get_report_score(self, key: str) -> Optional[int]: if self.evaluate_report_json is None: diff --git a/src/gretel_trainer/benchmark/session.py b/src/gretel_trainer/benchmark/session.py index 95d339e8..747a1f45 100644 --- a/src/gretel_trainer/benchmark/session.py +++ b/src/gretel_trainer/benchmark/session.py @@ -209,11 +209,13 @@ def _setup_gretel_run( trainer_project_index=trainer_project_index, artifact_key=artifact_key, ) + snapshot_dest = str(self._config.working_dir / f"{run_key.identifier}_result_data.csv") executor = Executor( strategy=strategy, run_identifier=run_identifier, evaluate_project=self._project, config=self._config, + snapshot=lambda: self.export_results(snapshot_dest), ) self._gretel_executors[run_key] = executor @@ -231,11 +233,13 @@ def _setup_custom_run( config=self._config, artifact_key=artifact_key, ) + snapshot_dest = str(self._config.working_dir / f"{run_key.identifier}_result_data.csv") executor = Executor( strategy=strategy, run_identifier=run_identifier, evaluate_project=self._project, config=self._config, + snapshot=lambda: self.export_results(snapshot_dest), ) self._custom_executors[run_key] = executor diff --git a/src/gretel_trainer/relational/__init__.py b/src/gretel_trainer/relational/__init__.py index 979b25bb..810d51f8 100644 --- a/src/gretel_trainer/relational/__init__.py +++ b/src/gretel_trainer/relational/__init__.py @@ -1,5 +1,3 @@ -import logging - import gretel_trainer.relational.log from gretel_trainer.relational.connectors import ( @@ -14,11 +12,3 @@ from gretel_trainer.relational.extractor import ExtractorConfig from gretel_trainer.relational.log import set_log_level from gretel_trainer.relational.multi_table import MultiTable - -logger = logging.getLogger(__name__) - -logger.warn( - "Relational Trainer is deprecated, and will be removed in the next Trainer release. " - "To transform and synthesize relational data, use Gretel Workflows. " - "Visit the docs to learn more: https://docs.gretel.ai/create-synthetic-data/workflows-and-connectors" -) diff --git a/tests/benchmark/test_benchmark.py b/tests/benchmark/test_benchmark.py index 103e6968..adeb85ce 100644 --- a/tests/benchmark/test_benchmark.py +++ b/tests/benchmark/test_benchmark.py @@ -247,9 +247,11 @@ def test_run_happy_path_gretel_sdk( assert result["Generate time (sec)"] == 15 assert result["Total time (sec)"] == 45 - # The synthetic data is written to the working directory working_dir_contents = os.listdir(working_dir) - assert len(working_dir_contents) == 1 + assert len(working_dir_contents) == 2 + print(working_dir_contents) + + # The synthetic data is written to the working directory filename = f"synth_{model_name}-iris.csv" assert filename in working_dir_contents df = pd.read_csv(f"{working_dir}/{filename}") @@ -257,6 +259,12 @@ def test_run_happy_path_gretel_sdk( df, pd.DataFrame(data={"synthetic": [1, 2], "data": [3, 4]}) ) + # Snapshot results are written to the working directory + filename = f"{model_name}-iris_result_data.csv" + assert filename in working_dir_contents + df = pd.read_csv(f"{working_dir}/{filename}") + pdtest.assert_frame_equal(df, session.results) + def test_sdk_model_failure(working_dir, iris, project): model = Mock(