Skip to content

Commit

Permalink
rework based on review, fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
fdroessler committed Aug 29, 2023
1 parent d9d3d67 commit ca86f4f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
20 changes: 13 additions & 7 deletions kedro_azureml/datasets/asset_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

from kedro_azureml.client import _get_azureml_client
from kedro_azureml.config import AzureMLConfig
from kedro_azureml.datasets.pipeline_dataset import AzureMLPipelineDataSet

AzureMLDataAssetType = Literal["uri_file", "uri_folder"]
Expand Down Expand Up @@ -114,6 +115,15 @@ def __init__(
f"the dataset definition."
)

@property
def azure_config(self) -> AzureMLConfig:
"""AzureML config to be used by the dataset."""
return self._azureml_config

@azure_config.setter
def azure_config(self, azure_config: AzureMLConfig) -> None:
self._azureml_config = azure_config

@property
def path(self) -> str:
# For local runs we want to replicate the folder structure of the remote dataset.
Expand Down Expand Up @@ -201,14 +211,10 @@ def _load(self) -> Any:
def _save(self, data: Any) -> None:
self._construct_dataset().save(data)

def as_local(self, azure_config, download: bool):
self._azureml_config = azure_config
self._local_run = True
if download:
self._download = True
def as_local_intermediate(self):
self._download = False
# for local runs we want the data to be saved as a "local version"
else:
self._version = Version("local", "local")
self._version = Version("local", "local")

def as_remote(self):
self._version = None
Expand Down
6 changes: 3 additions & 3 deletions kedro_azureml/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def after_context_created(self, context) -> None:
def after_catalog_created(self, catalog):
for dataset_name, dataset in catalog._data_sets.items():
if isinstance(dataset, AzureMLAssetDataSet):
dataset.as_local(self.azure_config, download=True)
dataset.azure_config = self.azure_config
catalog.add(dataset_name, dataset, replace=True)

@hook_impl
Expand All @@ -37,8 +37,8 @@ def before_pipeline_run(self, run_params, pipeline, catalog):
# when running locally using an AzureMLAssetDataSet
# as an intermediate dataset we don't want download
# but still set to run local with a local version.
download = dataset_name in pipeline.inputs()
dataset.as_local(self.azure_config, download=download)
if dataset_name not in pipeline.inputs():
dataset.as_local_intermediate()
# when running remotely we still want to provide information
# from the azureml config for getting the dataset version during
# remote runs
Expand Down
3 changes: 2 additions & 1 deletion tests/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def test_hook_after_context_created(
azureml_local_run_hook.before_pipeline_run(
run_params, dummy_pipeline, multi_catalog
)
# if local execution
if runner == SequentialRunner.__name__:
assert multi_catalog.datasets.input_data._download is True
assert multi_catalog.datasets.input_data._local_run is True
assert (
multi_catalog.datasets.input_data._azureml_config
== azureml_local_run_hook.azure_config
)
assert multi_catalog.datasets.i2._download is True
assert multi_catalog.datasets.i2._download is False
assert multi_catalog.datasets.i2._local_run is True
assert multi_catalog.datasets.i2._version == Version("local", "local")
else:
Expand Down

0 comments on commit ca86f4f

Please sign in to comment.