Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 49a15cf0dfff4f12cc865ea2e2ed7be14bbf448b
  • Loading branch information
Gretel Team authored and mikeknep committed Mar 6, 2024
1 parent 69b78ab commit 348ea4b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 14 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
boto3~=1.20
dask[dataframe]==2023.5.1
gretel-client>=0.17.7
gretel-client>=0.17.5
jinja2~=3.1
networkx~=3.0
numpy~=1.20
Expand Down
7 changes: 6 additions & 1 deletion src/gretel_trainer/benchmark/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from enum import Enum
from typing import Optional, Protocol
from typing import Callable, Optional, Protocol

from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project
Expand Down Expand Up @@ -66,11 +66,13 @@ def __init__(
run_identifier: str,
evaluate_project: Project,
config: BenchmarkConfig,
snapshot: Callable[[], None],
):
self.strategy = strategy
self.run_identifier = run_identifier
self.evaluate_project = evaluate_project
self.config = config
self.snapshot = snapshot

self.status = Status.NotStarted
self.exception: Optional[Exception] = None
Expand All @@ -81,10 +83,13 @@ def run(self) -> None:
self._maybe_skip()
if self.status.can_proceed:
self._train()
self.snapshot()
if self.status.can_proceed:
self._generate()
self.snapshot()
if self.status.can_proceed:
self._evaluate()
self.snapshot()

def get_report_score(self, key: str) -> Optional[int]:
if self.evaluate_report_json is None:
Expand Down
4 changes: 4 additions & 0 deletions src/gretel_trainer/benchmark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,13 @@ def _setup_gretel_run(
trainer_project_index=trainer_project_index,
artifact_key=artifact_key,
)
snapshot_dest = str(self._config.working_dir / f"{run_key.identifier}_result_data.csv")
executor = Executor(
strategy=strategy,
run_identifier=run_identifier,
evaluate_project=self._project,
config=self._config,
snapshot=lambda: self.export_results(snapshot_dest),
)
self._gretel_executors[run_key] = executor

Expand All @@ -231,11 +233,13 @@ def _setup_custom_run(
config=self._config,
artifact_key=artifact_key,
)
snapshot_dest = str(self._config.working_dir / f"{run_key.identifier}_result_data.csv")
executor = Executor(
strategy=strategy,
run_identifier=run_identifier,
evaluate_project=self._project,
config=self._config,
snapshot=lambda: self.export_results(snapshot_dest),
)
self._custom_executors[run_key] = executor

Expand Down
10 changes: 0 additions & 10 deletions src/gretel_trainer/relational/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

import gretel_trainer.relational.log

from gretel_trainer.relational.connectors import (
Expand All @@ -14,11 +12,3 @@
from gretel_trainer.relational.extractor import ExtractorConfig
from gretel_trainer.relational.log import set_log_level
from gretel_trainer.relational.multi_table import MultiTable

logger = logging.getLogger(__name__)

logger.warn(
"Relational Trainer is deprecated, and will be removed in the next Trainer release. "
"To transform and synthesize relational data, use Gretel Workflows. "
"Visit the docs to learn more: https://docs.gretel.ai/create-synthetic-data/workflows-and-connectors"
)
12 changes: 10 additions & 2 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,24 @@ def test_run_happy_path_gretel_sdk(
assert result["Generate time (sec)"] == 15
assert result["Total time (sec)"] == 45

# The synthetic data is written to the working directory
working_dir_contents = os.listdir(working_dir)
assert len(working_dir_contents) == 1
assert len(working_dir_contents) == 2
print(working_dir_contents)

# The synthetic data is written to the working directory
filename = f"synth_{model_name}-iris.csv"
assert filename in working_dir_contents
df = pd.read_csv(f"{working_dir}/{filename}")
pdtest.assert_frame_equal(
df, pd.DataFrame(data={"synthetic": [1, 2], "data": [3, 4]})
)

# Snapshot results are written to the working directory
filename = f"{model_name}-iris_result_data.csv"
assert filename in working_dir_contents
df = pd.read_csv(f"{working_dir}/{filename}")
pdtest.assert_frame_equal(df, session.results)


def test_sdk_model_failure(working_dir, iris, project):
model = Mock(
Expand Down

0 comments on commit 348ea4b

Please sign in to comment.