Skip to content

Commit

Permalink
feat: support setting path for node output directly
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Oct 9, 2022
1 parent 8d20078 commit c4fc0bf
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
14 changes: 9 additions & 5 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/pipeline/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def path(self) -> str:

@path.setter
def path(self, path):
# For un-configured input/output, we build a default data entry for them.
self._build_default_data()
if hasattr(self._data, "path"):
self._data.path = path
else:
Expand Down Expand Up @@ -361,7 +363,9 @@ def is_control(self) -> str:
def _build_default_data(self):
"""Build default data when output not configured."""
if self._data is None:
self._data = Output()
# _meta will be None when node._component is not a Component object
# so we just leave the type inference work to backend
self._data = Output(type=None)

def _build_data(self, data, key=None):
"""Build output data according to assigned input, eg: node.outputs.key = data"""
Expand Down Expand Up @@ -593,15 +597,13 @@ def _validate_inputs(cls, inputs):

def __getattr__(self, name: K) -> V:
if name not in self:
# pylint: disable=unnecessary-comprehension
raise UnexpectedAttributeError(keyword=name, keywords=[key for key in self])
raise UnexpectedAttributeError(keyword=name, keywords=list(self))
return super().__getitem__(name)

def __getitem__(self, item: K) -> V:
# We raise this exception instead of KeyError
if item not in self:
# pylint: disable=unnecessary-comprehension
raise UnexpectedKeywordError(func_name="ParameterGroup", keyword=item, keywords=[key for key in self])
raise UnexpectedKeywordError(func_name="ParameterGroup", keyword=item, keywords=list(self))
return super().__getitem__(item)

# For Jupyter Notebook auto-completion
Expand Down Expand Up @@ -649,6 +651,8 @@ def __setattr__(self, key: str, value: Union[Data, Output]):
if isinstance(value, Output):
mode = value.mode
value = Output(type=value.type, path=value.path, mode=mode)
if key not in self:
raise UnexpectedAttributeError(keyword=key, keywords=list(self))
original_output = self.__getattr__(key) # Note that an exception will be raised if the keyword is invalid.
original_output._data = original_output._build_data(value)

Expand Down
13 changes: 13 additions & 0 deletions sdk/ml/azure-ai-ml/tests/internal/unittests/test_pipeline_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,16 @@ def pipeline_func():
if key.startswith("data_"):
expected_inputs[key] = {"job_input_type": "mltable", "uri": "azureml:scope_tsv:1"}
assert rest_obj.properties.jobs["node"]["inputs"] == expected_inputs

def test_pipeline_with_setting_node_output_directly(self) -> None:
component_dir = Path(__file__).parent.parent.parent / "test_configs" / "internal" / "command-component"
copy_func = load_component(component_dir / "command-linux/copy/component.yaml")

copy_file = copy_func(
input_dir=None,
file_names=None,
)

copy_file.outputs.output_dir.path = "path_on_datastore"
assert copy_file.outputs.output_dir.path == "path_on_datastore"
assert copy_file.outputs.output_dir.type == "path"
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytest_mock import MockFixture
from test_utilities.utils import verify_entity_load_and_dump

from azure.ai.ml import MLClient, load_job
from azure.ai.ml import MLClient, load_job, Output
from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData as FebRestJob
from azure.ai.ml._restclient.v2022_10_01_preview.models import JobBase as RestJob
from azure.ai.ml._schema.automl import AutoMLRegressionSchema
Expand All @@ -26,7 +26,7 @@
from azure.ai.ml.entities._job.automl.nlp import TextClassificationJob, TextClassificationMultilabelJob, TextNerJob
from azure.ai.ml.entities._job.automl.tabular import ClassificationJob, ForecastingJob, RegressionJob
from azure.ai.ml.entities._job.pipeline._io import PipelineInput, _GroupAttrDict
from azure.ai.ml.exceptions import ValidationException
from azure.ai.ml.exceptions import ValidationException, UnexpectedAttributeError

from .._util import _PIPELINE_JOB_TIMEOUT_SECOND

Expand Down Expand Up @@ -1454,3 +1454,37 @@ def test_comment_in_pipeline(self) -> None:
rest_pipeline_dict = pipeline_job._to_rest_object().as_dict()["properties"]
assert pipeline_dict["jobs"]["hello_world_component"]["comment"] == "arbitrary string"
assert rest_pipeline_dict["jobs"]["hello_world_component"]["comment"] == "arbitrary string"

def test_pipeline_node_default_output(self):
test_path = "./tests/test_configs/pipeline_jobs/helloworld_pipeline_job_with_component_output.yml"
pipeline: PipelineJob = load_job(source=test_path)

test_output_path = "azureml://datastores/workspaceblobstore/paths/azureml/ps_copy_component/outputs/output_dir"

# pipeline level output
pipeline_output = pipeline.outputs["job_out_path_2"]
assert pipeline_output.mode == "upload"

# node level output
pipeline.jobs["hello_world_component_1"].outputs["component_out_path_1"].path = test_output_path

# normal output from component
node_output = pipeline.jobs["hello_world_component_1"].outputs["component_out_path_1"]
assert node_output.path == test_output_path
assert node_output.mode == "mount"

# data-binding-expression
node_output = pipeline.jobs["merge_component_outputs"].outputs["component_out_path_1"]
with pytest.raises(ValidationException, match="<class '.*'> does not support setting path."):
node_output.path = test_output_path

# non-existent output
with pytest.raises(
UnexpectedAttributeError,
match="Got an unexpected attribute 'component_out_path_non', "
"valid attributes: 'component_out_path_1', "
"'component_out_path_2', 'component_out_path_3'."
):
pipeline.jobs["hello_world_component_1"].outputs["component_out_path_non"] = Output(
path=test_output_path, mode="upload"
)

0 comments on commit c4fc0bf

Please sign in to comment.