Skip to content

Commit

Permalink
[ML][Pipelines]Fix if_else CLI error and add related tests (Azure#28252)
Browse files Browse the repository at this point in the history
* add test for if_else cli

* add tests
  • Loading branch information
D-W- authored Jan 11, 2023
1 parent 83fca5e commit a17fed2
Show file tree
Hide file tree
Showing 18 changed files with 3,478 additions and 473 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _post_load_pipeline_jobs(context, data: dict) -> dict:
from azure.ai.ml.entities._builders.parallel_for import ParallelFor
from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
from azure.ai.ml.entities._builders.condition_node import ConditionNode

# parse inputs/outputs
data = parse_inputs_outputs(data)
Expand All @@ -107,6 +108,10 @@ def _post_load_pipeline_jobs(context, data: dict) -> dict:
loaded_data=job_instance,
)
jobs[key] = job_instance
elif job_instance.get("type") == ControlFlowType.IF_ELSE:
# Convert to if-else node.
job_instance = ConditionNode._create_instance_from_schema_dict(loaded_data=job_instance)
jobs[key] = job_instance
elif job_instance.get("type") == ControlFlowType.DO_WHILE:
# Convert to do-while node.
job_instance = DoWhile._create_instance_from_schema_dict(pipeline_jobs=jobs, loaded_data=job_instance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict

from azure.ai.ml._schema import PathAwareSchema
from azure.ai.ml._utils.utils import is_data_binding_expression
from azure.ai.ml.constants._component import ControlFlowType
from azure.ai.ml.entities._builders import BaseNode
from azure.ai.ml.entities._builders.control_flow_node import ControlFlowNode
Expand Down Expand Up @@ -35,6 +36,11 @@ def _create_schema_for_validation(cls, context) -> PathAwareSchema: # pylint: d
def _from_rest_object(cls, obj: dict) -> "ConditionNode":
return cls(**obj)

@classmethod
def _create_instance_from_schema_dict(cls, loaded_data: Dict) -> "ConditionNode":
"""Create a condition node instance from schema parsed dict."""
return cls(**loaded_data)

def _to_dict(self) -> Dict:
return self._dump_for_validation()

Expand Down Expand Up @@ -63,16 +69,37 @@ def _validate_params(self, raise_error=True) -> MutableValidationResult:
f"with value 'True', got {output_definition.is_control}",
)

error_msg = "{!r} of dsl.condition node must be an instance of " f"{BaseNode} or {AutoMLJob}," "got {!r}."
if self.true_block is not None and not isinstance(self.true_block, (BaseNode, AutoMLJob)):
# check if condition is valid binding
if isinstance(self.condition, str) and not is_data_binding_expression(
self.condition, ["parent"], is_singular=False):
error_tail = "for example, ${{parent.jobs.xxx.outputs.output}}"
validation_result.append_error(
yaml_path="condition",
message=f"'condition' of dsl.condition has invalid binding expression: {self.condition}, {error_tail}",
)

error_msg = "{!r} of dsl.condition node must be an instance of " \
f"{BaseNode}, {AutoMLJob} or {str}," "got {!r}."
if self.true_block is not None and not isinstance(self.true_block, (BaseNode, AutoMLJob, str)):
validation_result.append_error(
yaml_path="true_block", message=error_msg.format("true_block", type(self.true_block))
)
if self.false_block is not None and not isinstance(self.false_block, (BaseNode, AutoMLJob)):
if self.false_block is not None and not isinstance(self.false_block, (BaseNode, AutoMLJob, str)):
validation_result.append_error(
yaml_path="false_block", message=error_msg.format("false_block", type(self.false_block))
)

# check if true/false block is valid binding
for name, block in {"true_block": self.true_block, "false_block": self.false_block}.items():
if block is None or not isinstance(block, str):
continue
error_tail = "for example, ${{parent.jobs.xxx}}"
if not is_data_binding_expression(block, ["parent", "jobs"], is_singular=False):
validation_result.append_error(
yaml_path=name,
message=f"'{name}' of dsl.condition has invalid binding expression: {block}, {error_tail}",
)

if self.true_block is None and self.false_block is None:
validation_result.append_error(
yaml_path="true_block",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_dsl_condition_pipeline(self, client: MLClient):
compute="cpu-cluster",
)
def condition_pipeline():
result = basic_component(str_param="abc", int_param=1)
result = basic_component()

node1 = hello_world_component_no_paths(component_in_number=1)
node2 = hello_world_component_no_paths(component_in_number=2)
Expand Down Expand Up @@ -89,10 +89,6 @@ def condition_pipeline():
},
"result": {
"_source": "REMOTE.WORKSPACE.COMPONENT",
"inputs": {
"int_param": {"job_input_type": "literal", "value": "1"},
"str_param": {"job_input_type": "literal", "value": "abc"},
},
"name": "result",
"type": "command",
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from typing import Callable

import pytest

from azure.ai.ml.exceptions import ValidationException
from devtools_testutils import AzureRecordedTestCase, is_live
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD

from azure.ai.ml import MLClient, load_job
from azure.ai.ml._schema.pipeline import pipeline_job
from azure.ai.ml._utils.utils import load_yaml
from azure.ai.ml.entities._builders import Command, Pipeline
from azure.ai.ml.entities._builders.do_while import DoWhile
from azure.ai.ml.entities._builders.parallel_for import ParallelFor

from .._util import _PIPELINE_JOB_TIMEOUT_SECOND
from .test_pipeline_job import assert_job_cancel
from test_utilities.utils import omit_with_wildcard

omit_fields = [
"name",
"properties.display_name",
"properties.settings",
"properties.jobs.*._source",
"properties.jobs.*.componentId",
]


@pytest.fixture()
Expand Down Expand Up @@ -44,6 +54,86 @@ class TestConditionalNodeInPipeline(AzureRecordedTestCase):
pass


class TestIfElse(TestConditionalNodeInPipeline):
def test_happy_path_if_else(self, client: MLClient, randstr: Callable[[], str]) -> None:
params_override = [{"name": randstr('name')}]
my_job = load_job(
"./tests/test_configs/pipeline_jobs/control_flow/if_else/simple_pipeline.yml",
params_override=params_override,
)
created_pipeline = assert_job_cancel(my_job, client)

pipeline_job_dict = created_pipeline._to_rest_object().as_dict()

pipeline_job_dict = omit_with_wildcard(pipeline_job_dict, *omit_fields)
assert pipeline_job_dict["properties"]["jobs"] == {
'conditionnode': {'condition': '${{parent.jobs.result.outputs.output}}',
'false_block': '${{parent.jobs.node1}}',
'true_block': '${{parent.jobs.node2}}',
'type': 'if_else'},
'node1': {'inputs': {'component_in_number': {'job_input_type': 'literal',
'value': '1'}},
'name': 'node1',
'type': 'command'},
'node2': {'inputs': {'component_in_number': {'job_input_type': 'literal',
'value': '2'}},
'name': 'node2',
'type': 'command'},
'result': {'name': 'result', 'type': 'command'}
}

def test_if_else_one_branch(self, client: MLClient, randstr: Callable[[], str]) -> None:
params_override = [{"name": randstr('name')}]
my_job = load_job(
"./tests/test_configs/pipeline_jobs/control_flow/if_else/one_branch.yml",
params_override=params_override,
)
created_pipeline = assert_job_cancel(my_job, client)

pipeline_job_dict = created_pipeline._to_rest_object().as_dict()

pipeline_job_dict = omit_with_wildcard(pipeline_job_dict, *omit_fields)
assert pipeline_job_dict["properties"]["jobs"] == {
'conditionnode': {'condition': '${{parent.jobs.result.outputs.output}}',
'true_block': '${{parent.jobs.node1}}',
'type': 'if_else'},
'node1': {'inputs': {'component_in_number': {'job_input_type': 'literal',
'value': '1'}},
'name': 'node1',
'type': 'command'},
'result': {'name': 'result', 'type': 'command'}
}

def test_if_else_literal_condition(self, client: MLClient, randstr: Callable[[], str]) -> None:
params_override = [{"name": randstr('name')}]
my_job = load_job(
"./tests/test_configs/pipeline_jobs/control_flow/if_else/literal_condition.yml",
params_override=params_override,
)
created_pipeline = assert_job_cancel(my_job, client)

pipeline_job_dict = created_pipeline._to_rest_object().as_dict()

pipeline_job_dict = omit_with_wildcard(pipeline_job_dict, *omit_fields)
assert pipeline_job_dict["properties"]["jobs"] == {
'conditionnode': {'condition': True,
'true_block': '${{parent.jobs.node1}}',
'type': 'if_else'},
'node1': {'inputs': {'component_in_number': {'job_input_type': 'literal',
'value': '1'}},
'name': 'node1',
'type': 'command'}
}

def test_if_else_invalid_case(self, client: MLClient, randstr: Callable[[], str]) -> None:
my_job = load_job(
"./tests/test_configs/pipeline_jobs/control_flow/if_else/invalid_binding.yml",
)
with pytest.raises(ValidationException) as e:
my_job._validate(raise_error=True)
assert '"path": "jobs.conditionnode.true_block",' in str(e.value)
assert "'true_block' of dsl.condition has invalid binding expression:" in str(e.value)

class TestDoWhile(TestConditionalNodeInPipeline):
def test_pipeline_with_do_while_node(self, client: MLClient, randstr: Callable[[], str]) -> None:
params_override = [{"name": randstr('name')}]
Expand Down
Loading

0 comments on commit a17fed2

Please sign in to comment.