Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use get_artifact_handle instead of smart_open directly #153

Merged
merged 1 commit into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
boto3~=1.20
dask[dataframe]==2023.5.1
gretel-client>=0.16.0
gretel-client>=0.16.12
jinja2~=3.1
networkx~=3.0
numpy~=1.20
Expand Down
9 changes: 4 additions & 5 deletions src/gretel_trainer/relational/sdk_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,10 @@ def download_file_artifact(
artifact_name: str,
out_path: Union[str, Path],
) -> None:
download_link = gretel_object.get_artifact_link(artifact_name)
try:
with smart_open.open(download_link, "rb") as src, smart_open.open(
out_path, "wb"
) as dest:
with gretel_object.get_artifact_handle(
artifact_name
) as src, smart_open.open(out_path, "wb") as dest:
shutil.copyfileobj(src, dest)
except:
logger.warning(f"Failed to download `{artifact_name}`")
Expand All @@ -80,7 +79,7 @@ def sqs_score_from_full_report(self, report: dict[str, Any]) -> Optional[int]:
return field_dict["value"]

def get_record_handler_data(self, record_handler: RecordHandler) -> pd.DataFrame:
with smart_open.open(record_handler.get_artifact_link("data"), "rb") as data:
with record_handler.get_artifact_handle("data") as data:
return pd.read_csv(data)

def start_job_if_possible(
Expand Down
11 changes: 10 additions & 1 deletion tests/relational/test_multi_table_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def create_backup(

def get_local_name(artifact_id):
local_name = None
for key, pointers in ARTIFACTS.items():
for _, pointers in ARTIFACTS.items():
if pointers["artifact_id"] == artifact_id:
local_name = pointers["local_name"]
if local_name is None:
Expand All @@ -163,6 +163,13 @@ def get_artifact_link(artifact_id):
return get_artifact_link


def make_mock_get_artifact_handle(setup_path: Path):
def get_artifact_handle(artifact_id):
return open(setup_path / get_local_name(artifact_id), "rb")

return get_artifact_handle


def make_mock_download_tar_artifact(setup_path: Path, working_path: Path):
def download_tar_artifact(project, name, out_path):
local_name = get_local_name(name)
Expand Down Expand Up @@ -193,6 +200,7 @@ def make_mock_model(
model.model_id = name
model.data_source = ARTIFACTS[f"train_{name}"]["artifact_id"]
model.get_artifact_link = make_mock_get_artifact_link(setup_path)
model.get_artifact_handle = make_mock_get_artifact_handle(setup_path)
model.get_record_handler.return_value = record_handler
return model

Expand Down Expand Up @@ -329,6 +337,7 @@ def configure_mocks(
models: dict[str, Mock] = {},
) -> None:
project.get_artifact_link = make_mock_get_artifact_link(setup_path)
project.get_artifact_handle = make_mock_get_artifact_handle(setup_path)
project.get_model = make_mock_get_model(models)
download_tar_artifact.side_effect = make_mock_download_tar_artifact(
setup_path,
Expand Down