Skip to content

Commit

Permalink
PLAT 1588: relational buffering
Browse files Browse the repository at this point in the history
GitOrigin-RevId: d5225c3e822b539e045c849f49509e05887161f5
  • Loading branch information
mikeknep committed Mar 1, 2024
1 parent dfbceb7 commit 2b634c2
Show file tree
Hide file tree
Showing 15 changed files with 356 additions and 422 deletions.
42 changes: 25 additions & 17 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,13 @@
from gretel_trainer.relational.strategies.ancestral import AncestralStrategy
from gretel_trainer.relational.strategies.independent import IndependentStrategy
from gretel_trainer.relational.table_evaluation import TableEvaluation
from gretel_trainer.relational.task_runner import run_task
from gretel_trainer.relational.tasks import (
ClassifyTask,
SyntheticsEvaluateTask,
SyntheticsRunTask,
SyntheticsTrainTask,
TransformsRunTask,
TransformsTrainTask,
)
from gretel_trainer.relational.task_runner import run_task, TaskContext
from gretel_trainer.relational.tasks.classify import ClassifyTask
from gretel_trainer.relational.tasks.synthetics_evaluate import SyntheticsEvaluateTask
from gretel_trainer.relational.tasks.synthetics_run import SyntheticsRunTask
from gretel_trainer.relational.tasks.synthetics_train import SyntheticsTrainTask
from gretel_trainer.relational.tasks.transforms_run import TransformsRunTask
from gretel_trainer.relational.tasks.transforms_train import TransformsTrainTask
from gretel_trainer.relational.workflow_state import (
Classify,
SyntheticsRun,
Expand Down Expand Up @@ -451,7 +449,7 @@ def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None:
classify=self._classify,
data_sources=classify_data_sources,
all_rows=all_rows,
multitable=self,
ctx=self._new_task_context(),
output_handler=self._output_handler,
)
run_task(task, self._extended_sdk)
Expand Down Expand Up @@ -487,7 +485,7 @@ def transform_v2(
self._setup_transforms_train_state(configs)
task = TransformsTrainTask(
transforms_train=self._transforms_train,
multitable=self,
ctx=self._new_task_context(),
)
run_task(task, self._extended_sdk)

Expand Down Expand Up @@ -518,7 +516,7 @@ def train_transforms(
self._setup_transforms_train_state(configs)
task = TransformsTrainTask(
transforms_train=self._transforms_train,
multitable=self,
ctx=self._new_task_context(),
)
run_task(task, self._extended_sdk)

Expand Down Expand Up @@ -593,7 +591,7 @@ def run_transforms(

task = TransformsRunTask(
record_handlers=transforms_record_handlers,
multitable=self,
ctx=self._new_task_context(),
)
run_task(task, self._extended_sdk)

Expand Down Expand Up @@ -677,7 +675,7 @@ def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None:

task = SyntheticsTrainTask(
synthetics_train=self._synthetics_train,
multitable=self,
ctx=self._new_task_context(),
)
run_task(task, self._extended_sdk)

Expand Down Expand Up @@ -828,7 +826,9 @@ def generate(
synthetics_train=self._synthetics_train,
subdir=run_subdir,
output_handler=self._output_handler,
multitable=self,
ctx=self._new_task_context(),
rel_data=self.relational_data,
strategy=self._strategy,
)
run_task(task, self._extended_sdk)

Expand Down Expand Up @@ -894,11 +894,10 @@ def generate(
synthetics_evaluate_task = SyntheticsEvaluateTask(
individual_evaluate_models=individual_evaluate_models,
cross_table_evaluate_models=cross_table_evaluate_models,
project=self._project,
subdir=run_subdir,
output_handler=self._output_handler,
evaluations=self._evaluations,
multitable=self,
ctx=self._new_task_context(),
)
run_task(synthetics_evaluate_task, self._extended_sdk)

Expand Down Expand Up @@ -993,6 +992,15 @@ def create_relational_report(self, run_identifier: str, filepath: str) -> None:
html_content = ReportRenderer().render(presenter)
report.write(html_content)

def _new_task_context(self) -> TaskContext:
return TaskContext(
in_flight_jobs=0,
refresh_interval=self._refresh_interval,
project=self._project,
extended_sdk=self._extended_sdk,
backup=self._backup,
)

def _validate_synthetics_config(self, config_dict: dict[str, Any]) -> None:
"""
Validates that the provided config (in dict form)
Expand Down
38 changes: 17 additions & 21 deletions src/gretel_trainer/relational/sdk_extras.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
import shutil

from contextlib import suppress
from pathlib import Path
from typing import Any, Optional, Union
from typing import Optional, Union

import pandas as pd

Expand All @@ -12,11 +11,12 @@
from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project
from gretel_client.projects.records import RecordHandler
from gretel_client.rest import ApiException
from gretel_trainer.relational.core import MultiTableException

logger = logging.getLogger(__name__)

MAX_PROJECT_ARTIFACTS = 10_000
MAX_IN_FLIGHT_JOBS = 10


class ExtendedGretelSDK:
Expand Down Expand Up @@ -65,12 +65,6 @@ def download_file_artifact(
logger.warning(f"Failed to download `{artifact_name}`")
return False

def sqs_score_from_full_report(self, report: dict[str, Any]) -> Optional[int]:
with suppress(KeyError):
for field_dict in report["summary"]:
if field_dict["field"] == "synthetic_data_quality_score":
return field_dict["value"]

def get_record_handler_data(self, record_handler: RecordHandler) -> pd.DataFrame:
with record_handler.get_artifact_handle("data") as data:
return pd.read_csv(data)
Expand All @@ -81,25 +75,27 @@ def start_job_if_possible(
table_name: str,
action: str,
project: Project,
number_of_artifacts: int,
) -> None:
if job.data_source is None or self._room_in_project(
project, number_of_artifacts
):
in_flight_jobs: int,
) -> int:
if in_flight_jobs < MAX_IN_FLIGHT_JOBS:
self._log_start(table_name, action)
job.submit()
try:
job.submit()
return 1
except ApiException as ex:
if "Maximum number of" in str(ex):
self._log_waiting(table_name, action)
return 0
else:
raise
else:
self._log_waiting(table_name, action)

def _room_in_project(self, project: Project, count: int) -> bool:
if self._hybrid:
return True
return len(project.artifacts) + count <= MAX_PROJECT_ARTIFACTS
return 0

def _log_start(self, table_name: str, action: str) -> None:
logger.info(f"Starting {action} for `{table_name}`.")

def _log_waiting(self, table_name: str, action: str) -> None:
logger.info(
f"Maximum concurrent relational jobs reached. Deferring start of `{table_name}` {action}."
f"Maximum concurrent jobs reached. Deferring start of `{table_name}` {action}."
)
49 changes: 31 additions & 18 deletions src/gretel_trainer/relational/task_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging

from collections import defaultdict
from typing import Protocol
from dataclasses import dataclass
from typing import Callable, Protocol

import gretel_trainer.relational.tasks.common as common

from gretel_client.projects.jobs import END_STATES, Job, Status
from gretel_client.projects.projects import Project
Expand All @@ -12,28 +15,39 @@
logger = logging.getLogger(__name__)


class Task(Protocol):
def action(self, job: Job) -> str:
...
@dataclass
class TaskContext:
in_flight_jobs: int
refresh_interval: int
project: Project
extended_sdk: ExtendedGretelSDK
backup: Callable[[], None]

def maybe_start_job(self, job: Job, table_name: str, action: str) -> None:
self.in_flight_jobs += self.extended_sdk.start_job_if_possible(
job=job,
table_name=table_name,
action=action,
project=self.project,
in_flight_jobs=self.in_flight_jobs,
)


class Task(Protocol):
@property
def table_collection(self) -> list[str]:
def ctx(self) -> TaskContext:
...

@property
def artifacts_per_job(self) -> int:
def action(self, job: Job) -> str:
...

@property
def project(self) -> Project:
def table_collection(self) -> list[str]:
...

def more_to_do(self) -> bool:
...

def wait(self) -> None:
...

def is_finished(self, table: str) -> bool:
...

Expand Down Expand Up @@ -64,20 +78,16 @@ def run_task(task: Task, extended_sdk: ExtendedGretelSDK) -> None:
if first_pass:
first_pass = False
else:
task.wait()
common.wait(task.ctx.refresh_interval)

for table_name in task.table_collection:
if task.is_finished(table_name):
continue

job = task.get_job(table_name)
if extended_sdk.get_job_id(job) is None:
extended_sdk.start_job_if_possible(
job=job,
table_name=table_name,
action=task.action(job),
project=task.project,
number_of_artifacts=task.artifacts_per_job,
task.ctx.maybe_start_job(
job=job, table_name=table_name, action=task.action(job)
)
continue

Expand All @@ -86,12 +96,15 @@ def run_task(task: Task, extended_sdk: ExtendedGretelSDK) -> None:
)

if refresh_attempts[table_name] >= MAX_REFRESH_ATTEMPTS:
task.ctx.in_flight_jobs -= 1
task.handle_lost_contact(table_name, job)
continue

if status == Status.COMPLETED:
task.ctx.in_flight_jobs -= 1
task.handle_completed(table_name, job)
elif status in END_STATES:
task.ctx.in_flight_jobs -= 1
task.handle_failed(table_name, job)
else:
task.handle_in_progress(table_name, job)
Expand Down
6 changes: 0 additions & 6 deletions src/gretel_trainer/relational/tasks/__init__.py

This file was deleted.

40 changes: 10 additions & 30 deletions src/gretel_trainer/relational/tasks/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from gretel_client.projects.artifact_handlers import open_artifact
from gretel_client.projects.jobs import Job
from gretel_client.projects.models import Model
from gretel_client.projects.projects import Project
from gretel_client.projects.records import RecordHandler
from gretel_trainer.relational.output_handler import OutputHandler
from gretel_trainer.relational.task_runner import TaskContext
from gretel_trainer.relational.workflow_state import Classify


Expand All @@ -17,13 +17,13 @@ def __init__(
classify: Classify,
data_sources: dict[str, str],
all_rows: bool,
multitable: common._MultiTable,
ctx: TaskContext,
output_handler: OutputHandler,
):
self.classify = classify
self.data_sources = data_sources
self.all_rows = all_rows
self.multitable = multitable
self.ctx = ctx
self.output_handler = output_handler
self.classify_record_handlers: dict[str, RecordHandler] = {}
self.completed_models = []
Expand All @@ -41,14 +41,6 @@ def action(self, job: Job) -> str:
else:
return "classification"

@property
def project(self) -> Project:
return self.multitable._project

@property
def artifacts_per_job(self) -> int:
return 1

@property
def table_collection(self) -> list[str]:
return list(self.classify.models.keys())
Expand All @@ -65,13 +57,6 @@ def more_to_do(self) -> bool:
else:
return any_unfinished_models

def wait(self) -> None:
if self.all_rows:
duration = self.multitable._refresh_interval
else:
duration = 15
common.wait(duration)

@property
def _finished_models(self) -> list[str]:
return self.completed_models + self.failed_models
Expand Down Expand Up @@ -103,41 +88,36 @@ def handle_completed(self, table: str, job: Job) -> None:
data_source=self.data_sources[table]
)
self.classify_record_handlers[table] = record_handler
self.multitable._extended_sdk.start_job_if_possible(
job=record_handler,
table_name=table,
action=self.action(record_handler),
project=self.project,
number_of_artifacts=self.artifacts_per_job,
self.ctx.maybe_start_job(
job=record_handler, table_name=table, action=self.action(job)
)
elif isinstance(job, RecordHandler):
self.completed_record_handlers.append(table)
common.log_success(table, self.action(job))
self._write_results(job=job, table=table)
common.cleanup(sdk=self.multitable._extended_sdk, project=self.project, job=job)
common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job)

def handle_failed(self, table: str, job: Job) -> None:
if isinstance(job, Model):
self.failed_models.append(table)
elif isinstance(job, RecordHandler):
self.failed_record_handlers.append(table)
common.log_failed(table, self.action(job))
common.cleanup(sdk=self.multitable._extended_sdk, project=self.project, job=job)
common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job)

def handle_lost_contact(self, table: str, job: Job) -> None:
if isinstance(job, Model):
self.failed_models.append(table)
elif isinstance(job, RecordHandler):
self.failed_record_handlers.append(table)
common.log_lost_contact(table)
common.cleanup(sdk=self.multitable._extended_sdk, project=self.project, job=job)
common.cleanup(sdk=self.ctx.extended_sdk, project=self.ctx.project, job=job)

def handle_in_progress(self, table: str, job: Job) -> None:
action = self.action(job)
common.log_in_progress(table, job.status, action)
common.log_in_progress(table, job.status, self.action(job))

def each_iteration(self) -> None:
self.multitable._backup()
self.ctx.backup()

def _write_results(self, job: Job, table: str) -> None:
if isinstance(job, Model):
Expand Down
Loading

0 comments on commit 2b634c2

Please sign in to comment.