From 9f6c55ee1d41c0f9d9f547fc9837d51dd823c4b4 Mon Sep 17 00:00:00 2001 From: Matt Kornfield Date: Tue, 8 Aug 2023 08:45:05 -0700 Subject: [PATCH] Use get_artifact_handle instead of smart_open directly --- requirements.txt | 2 +- src/gretel_trainer/relational/sdk_extras.py | 9 ++++----- tests/relational/test_multi_table_restore.py | 11 ++++++++++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2062012f..0ee5a4fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ boto3~=1.20 dask[dataframe]==2023.5.1 -gretel-client>=0.16.0 +gretel-client>=0.16.12 jinja2~=3.1 networkx~=3.0 numpy~=1.20 diff --git a/src/gretel_trainer/relational/sdk_extras.py b/src/gretel_trainer/relational/sdk_extras.py index e9550bad..b636cf59 100644 --- a/src/gretel_trainer/relational/sdk_extras.py +++ b/src/gretel_trainer/relational/sdk_extras.py @@ -52,11 +52,10 @@ def download_file_artifact( artifact_name: str, out_path: Union[str, Path], ) -> None: - download_link = gretel_object.get_artifact_link(artifact_name) try: - with smart_open.open(download_link, "rb") as src, smart_open.open( - out_path, "wb" - ) as dest: + with gretel_object.get_artifact_handle( + artifact_name + ) as src, smart_open.open(out_path, "wb") as dest: shutil.copyfileobj(src, dest) except: logger.warning(f"Failed to download `{artifact_name}`") @@ -80,7 +79,7 @@ def sqs_score_from_full_report(self, report: dict[str, Any]) -> Optional[int]: return field_dict["value"] def get_record_handler_data(self, record_handler: RecordHandler) -> pd.DataFrame: - with smart_open.open(record_handler.get_artifact_link("data"), "rb") as data: + with record_handler.get_artifact_handle("data") as data: return pd.read_csv(data) def start_job_if_possible( diff --git a/tests/relational/test_multi_table_restore.py b/tests/relational/test_multi_table_restore.py index 756d5ec5..31e7262e 100644 --- a/tests/relational/test_multi_table_restore.py +++ b/tests/relational/test_multi_table_restore.py @@ -148,7 +148,7 @@ def create_backup( def get_local_name(artifact_id): local_name = None - for key, pointers in ARTIFACTS.items(): + for _, pointers in ARTIFACTS.items(): if pointers["artifact_id"] == artifact_id: local_name = pointers["local_name"] if local_name is None: @@ -163,6 +163,13 @@ def get_artifact_link(artifact_id): return get_artifact_link +def make_mock_get_artifact_handle(setup_path: Path): + def get_artifact_handle(artifact_id): + return open(setup_path / get_local_name(artifact_id), "rb") + + return get_artifact_handle + + def make_mock_download_tar_artifact(setup_path: Path, working_path: Path): def download_tar_artifact(project, name, out_path): local_name = get_local_name(name) @@ -193,6 +200,7 @@ def make_mock_model( model.model_id = name model.data_source = ARTIFACTS[f"train_{name}"]["artifact_id"] model.get_artifact_link = make_mock_get_artifact_link(setup_path) + model.get_artifact_handle = make_mock_get_artifact_handle(setup_path) model.get_record_handler.return_value = record_handler return model @@ -329,6 +337,7 @@ def configure_mocks( models: dict[str, Mock] = {}, ) -> None: project.get_artifact_link = make_mock_get_artifact_link(setup_path) + project.get_artifact_handle = make_mock_get_artifact_handle(setup_path) project.get_model = make_mock_get_model(models) download_tar_artifact.side_effect = make_mock_download_tar_artifact( setup_path,