Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Pipelines] dsl.pipeline support pass non_pipeline_parameters #26920

Merged
merged 12 commits into from
Oct 24, 2022
Prev Previous commit
Next Next commit
fix comment
lalala123123 committed Oct 20, 2022
commit ffac561b493ebcb228be9f93338ca41e89affa92
Original file line number Diff line number Diff line change
@@ -577,9 +577,21 @@ def check_parameter_type(f):
error_category=ErrorCategory.USER_ERROR,
)

def check_non_pipeline_parameters(f):
"""Check whether non_pipeline_parameters exist in pipeline builder."""
if f._pipeline_builder.non_pipeline_parameter_names:
msg = "Cannot register pipeline component {!r} with non_pipeline_parameters."
raise ValidationException(
message=msg.format(f.__name__),
no_personal_data_message=msg.format(""),
target=ErrorTarget.COMPONENT,
error_category=ErrorCategory.USER_ERROR,
)

if hasattr(component_func, "_is_mldesigner_component") and component_func._is_mldesigner_component:
return component_func.component
if hasattr(component_func, "_is_dsl_func") and component_func._is_dsl_func:
check_non_pipeline_parameters(component_func)
check_parameter_type(component_func)
if component_func._job_settings:
module_logger.warning(
15 changes: 15 additions & 0 deletions sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py
Original file line number Diff line number Diff line change
@@ -1061,6 +1061,21 @@ def pipeline_missing_type(
in e.value.message
)

@dsl.pipeline(non_pipeline_parameters=['param'])
def pipeline_with_non_pipeline_parameters(
required_input: Input,
required_param: str,
param: str,
):
default_optional_func(
required_input=required_input,
required_param=required_param,
)

with pytest.raises(ValidationException) as e:
client.components.create_or_update(pipeline_with_non_pipeline_parameters)
assert "Cannot register pipeline component 'pipeline_func' with non_pipeline_parameters." in e.value.message

def test_create_pipeline_component_by_dsl(self, caplog, client: MLClient):
default_optional_func = load_component(source=str(components_dir / "default_optional_component.yml"))

12 changes: 9 additions & 3 deletions sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py
Original file line number Diff line number Diff line change
@@ -1946,16 +1946,22 @@ def test_pipeline_with_non_pipeline_parameters(self):
component_func1 = load_component(source=component_yaml, params_override=[{"name": "component_name_1"}])
component_func2 = load_component(source=component_yaml, params_override=[{"name": "component_name_2"}])

@dsl.pipeline(non_pipeline_parameters=["other_params"])
def pipeline_func(job_in_number, job_in_path, other_params):
@dsl.pipeline(non_pipeline_parameters=["other_params", "is_add_component"])
def pipeline_func(job_in_number, job_in_path, other_params, is_add_component):
component_func1(component_in_number=job_in_number, component_in_path=job_in_path)
lalala123123 marked this conversation as resolved.
Show resolved Hide resolved
component_func2(component_in_number=other_params, component_in_path=job_in_path)
if is_add_component:
component_func2(component_in_number=other_params, component_in_path=job_in_path)

pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15)
pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, False)
assert len(pipeline.jobs) == 2
assert "other_params" not in pipeline.inputs
assert isinstance(pipeline.jobs[component_func1.name].inputs["component_in_number"]._data, PipelineInput)
assert pipeline.jobs[component_func2.name].inputs["component_in_number"]._data == 15

pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, True)
assert len(pipeline.jobs) == 3

def test_pipeline_with_invalid_non_pipeline_parameters(self):
lalala123123 marked this conversation as resolved.
Show resolved Hide resolved

@dsl.pipeline(non_pipeline_parameters=[123])