Skip to content

Commit

Permalink
Replace all trainer smart_open+get_artifact_link with get_artifact_ha…
Browse files Browse the repository at this point in the history
…ndle

GitOrigin-RevId: d192abacedea7430175f4e9920530cc80c32feb6
  • Loading branch information
mikeknep committed Aug 11, 2023
1 parent e037974 commit 25aeaa3
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 27 deletions.
7 changes: 2 additions & 5 deletions src/gretel_trainer/benchmark/sdk_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from typing import Any

import smart_open

from gretel_client.projects.jobs import (
ACTIVE_STATES,
END_STATES,
Expand Down Expand Up @@ -41,9 +39,8 @@ def run_evaluate(
job_status = await_job(run_identifier, evaluate_model, "evaluation", wait)
if job_status in END_STATES and job_status != Status.COMPLETED:
raise BenchmarkException("Evaluate failed")
return json.loads(
smart_open.open(evaluate_model.get_artifact_link("report_json")).read()
)
with evaluate_model.get_artifact_handle("report_json") as report:
return json.loads(report.read())


def _make_evaluate_config(run_identifier: str) -> dict:
Expand Down
5 changes: 2 additions & 3 deletions src/gretel_trainer/relational/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def read_report_json_data(model: Model, report_path: Path) -> Optional[dict]:

def _get_report_json(model: Model) -> Optional[dict]:
try:
return json.loads(
smart_open.open(model.get_artifact_link("report_json")).read()
)
with model.get_artifact_handle("report_json") as report:
return json.loads(report.read())
except:
logger.warning("Failed to fetch model evaluation report JSON.")
return None
Expand Down
6 changes: 3 additions & 3 deletions src/gretel_trainer/relational/tasks/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def _write_results(self, job: Job, table: str) -> None:

destpath = self.out_dir / filename

with smart_open.open(
job.get_artifact_link(artifact_name), "rb"
) as src, smart_open.open(str(destpath), "wb") as dest:
with job.get_artifact_handle(artifact_name) as src, smart_open.open(
str(destpath), "wb"
) as dest:
shutil.copyfileobj(src, dest)
self.result_filepaths.append(destpath)
5 changes: 1 addition & 4 deletions src/gretel_trainer/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from typing import List, Optional, Tuple, Union

import pandas as pd
import smart_open

from gretel_client.projects import Project
from gretel_client.projects.jobs import ACTIVE_STATES
Expand Down Expand Up @@ -213,9 +212,7 @@ def _update_job_status(self):
report = current_model.peek_report()

if report is None:
with smart_open.open(
current_model.get_artifact_link("report_json")
) as fin:
with current_model.get_artifact_handle("report_json") as fin:
report = json.loads(fin.read())

sqs = report["synthetic_data_quality_score"]["score"]
Expand Down
9 changes: 7 additions & 2 deletions tests/benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,17 @@ def project():


@pytest.fixture()
def evaluate_report_path():
def evaluate_report_handle():
report = {"synthetic_data_quality_score": {"score": 95}}
with tempfile.NamedTemporaryFile() as f:
with open(f.name, "w") as j:
json.dump(report, j)
yield f.name

ctxmgr = Mock()
ctxmgr.__enter__ = Mock(return_value=f)
ctxmgr.__exit__ = Mock(return_value=False)

yield ctxmgr


@pytest.fixture()
Expand Down
20 changes: 10 additions & 10 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ class SharedDictLstm(GretelModel):
}


def test_run_with_gretel_dataset(working_dir, project, evaluate_report_path, iris):
def test_run_with_gretel_dataset(working_dir, project, evaluate_report_handle, iris):
evaluate_model = Mock(
status=Status.COMPLETED,
)
evaluate_model.get_artifact_link.return_value = evaluate_report_path
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
project.create_model_obj.side_effect = [evaluate_model]

session = compare(
Expand All @@ -107,11 +107,11 @@ def test_run_with_gretel_dataset(working_dir, project, evaluate_report_path, iri
assert result["SQS"] == 95


def test_run_with_custom_csv_dataset(working_dir, project, evaluate_report_path, df):
def test_run_with_custom_csv_dataset(working_dir, project, evaluate_report_handle, df):
evaluate_model = Mock(
status=Status.COMPLETED,
)
evaluate_model.get_artifact_link.return_value = evaluate_report_path
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
project.create_model_obj.side_effect = [evaluate_model]

with tempfile.NamedTemporaryFile() as f:
Expand All @@ -137,11 +137,11 @@ def test_run_with_custom_csv_dataset(working_dir, project, evaluate_report_path,
assert result["SQS"] == 95


def test_run_with_custom_psv_dataset(working_dir, project, evaluate_report_path, df):
def test_run_with_custom_psv_dataset(working_dir, project, evaluate_report_handle, df):
evaluate_model = Mock(
status=Status.COMPLETED,
)
evaluate_model.get_artifact_link.return_value = evaluate_report_path
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
project.create_model_obj.side_effect = [evaluate_model]

with tempfile.NamedTemporaryFile() as f:
Expand All @@ -168,12 +168,12 @@ def test_run_with_custom_psv_dataset(working_dir, project, evaluate_report_path,


def test_run_with_custom_dataframe_dataset(
working_dir, project, evaluate_report_path, df
working_dir, project, evaluate_report_handle, df
):
evaluate_model = Mock(
status=Status.COMPLETED,
)
evaluate_model.get_artifact_link.return_value = evaluate_report_path
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle
project.create_model_obj.side_effect = [evaluate_model]

dataset = create_dataset(df, datatype="tabular", name="pets")
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_run_with_custom_dataframe_dataset(

@pytest.mark.parametrize("benchmark_model", [GretelLSTM, TailoredActgan])
def test_run_happy_path_gretel_sdk(
benchmark_model, working_dir, iris, project, evaluate_report_path
benchmark_model, working_dir, iris, project, evaluate_report_handle
):
record_handler = Mock(
status=Status.COMPLETED,
Expand All @@ -221,7 +221,7 @@ def test_run_happy_path_gretel_sdk(
evaluate_model = Mock(
status=Status.COMPLETED,
)
evaluate_model.get_artifact_link.return_value = evaluate_report_path
evaluate_model.get_artifact_handle.return_value = evaluate_report_handle

project.create_model_obj.side_effect = [model, evaluate_model]

Expand Down

0 comments on commit 25aeaa3

Please sign in to comment.