diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py index 32a645eff9b9..ebec3d3e32ed 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py @@ -7,7 +7,7 @@ from azure.ai.ml._schema import NestedField, StringTransformedEnum, UnionField from azure.ai.ml._schema.component.component import ComponentSchema from azure.ai.ml._schema.core.fields import ArmVersionedStr, CodeField -from azure.ai.ml.constants._common import AzureMLResourceType +from azure.ai.ml.constants._common import AzureMLResourceType, LABELLED_RESOURCE_NAME from .environment import InternalEnvironmentSchema from .input_output import ( @@ -17,6 +17,7 @@ InternalParameterSchema, InternalPrimitiveOutputSchema, ) +from ..._utils._arm_id_utils import parse_name_label class NodeType: @@ -77,7 +78,8 @@ class Meta: # type field is required for registration type = StringTransformedEnum( allowed_values=NodeType.all_values(), - casing_transform=lambda x: x.rsplit("@", 1)[0], + casing_transform=lambda x: parse_name_label(x)[0], + pass_original=True, ) # need to resolve as it can be a local field @@ -122,3 +124,10 @@ def simplify_input_output_port(self, data, original, **kwargs): # pylint:disabl del port_definition["mode"] return data + + @post_dump(pass_original=True) + def add_back_type_label(self, data, original, **kwargs): # pylint:disable=unused-argument, no-self-use + type_label = original._type_label # pylint:disable=protected-access + if type_label: + data["type"] = LABELLED_RESOURCE_NAME.format(data['type'], type_label) + return data diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py index 60cf248fe5bd..38eac72c9efa 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py @@ -26,6 +26,7 @@ from .environment import InternalEnvironment from .node import InternalBaseNode from .code import InternalCode, InternalComponentIgnoreFile +from ..._utils._arm_id_utils import parse_name_label from ...entities._job.distribution import DistributionConfiguration @@ -95,6 +96,7 @@ def __init__( launcher: Dict = None, **kwargs, ): + type, self._type_label = parse_name_label(type) super().__init__( name=name, version=version, diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py index b7e6dbcd78ff..2aff862cd78f 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py @@ -53,6 +53,7 @@ def __init__(self, **kwargs): # pop marshmallow unknown args to avoid warnings self.allowed_values = kwargs.pop("allowed_values", None) self.casing_transform = kwargs.pop("casing_transform", lambda x: x.lower()) + self.pass_original = kwargs.pop("pass_original", False) super().__init__(**kwargs) if isinstance(self.allowed_values, str): self.allowed_values = [self.allowed_values] @@ -70,12 +71,12 @@ def _serialize(self, value, attr, obj, **kwargs): if not value: return if isinstance(value, str) and self.casing_transform(value) in self.allowed_values: - return self.casing_transform(value) + return value if self.pass_original else self.casing_transform(value) raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}") def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, str) and self.casing_transform(value) in self.allowed_values: - return self.casing_transform(value) + return value if self.pass_original else self.casing_transform(value) raise ValidationError(f"Value {value!r} passed is not in set {self.allowed_values}") diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py index 6eb9fc5ec99f..9f35d889575f 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py @@ -39,6 +39,7 @@ ASSET_ID_FORMAT = "azureml://locations/{}/workspaces/{}/{}/{}/versions/{}" VERSIONED_RESOURCE_NAME = "{}:{}" LABELLED_RESOURCE_NAME = "{}@{}" +LABEL_SPLITTER = "@" PYTHON = "python" AML_TOKEN_YAML = "aml_token" AAD_TOKEN_YAML = "aad_token" diff --git a/sdk/ml/azure-ai-ml/tests/internal/unittests/test_component.py b/sdk/ml/azure-ai-ml/tests/internal/unittests/test_component.py index 43d4f935066d..c086bb43c543 100644 --- a/sdk/ml/azure-ai-ml/tests/internal/unittests/test_component.py +++ b/sdk/ml/azure-ai-ml/tests/internal/unittests/test_component.py @@ -189,12 +189,8 @@ def test_component_serialization(self, yaml_path): entity = load_component(yaml_path) expected_dict = copy.deepcopy(yaml_dict) - type_value = ( - expected_dict["type"].rsplit("@", 1)[0] - if expected_dict["type"].endswith("@1-legacy") - else expected_dict["type"] - ) - pydash.set_(expected_dict, "type", type_value) + + # Linux is the default value of os in InternalEnvironment if "environment" in expected_dict: expected_dict["environment"]["os"] = "Linux" @@ -222,6 +218,17 @@ def test_component_serialization(self, yaml_path): result = entity._validate() assert result._to_dict() == {"result": "Succeeded"} + @pytest.mark.parametrize( + "yaml_path,label", + [ + ("preview_command_component.yaml", "1-preview"), + ("legacy_distributed_component.yaml", "1-legacy"), + ] + ) + def test_command_mode_command_component(self, yaml_path: str, label: str): + component = load_component("./tests/test_configs/internal/command-mode/{}".format(yaml_path)) + assert component._to_rest_object().properties.component_spec["type"] == f"{component.type}@{label}" + def test_ipp_component_serialization(self): yaml_path = "./tests/test_configs/internal/ipp-component/spec.yaml" load_component(yaml_path) diff --git a/sdk/ml/azure-ai-ml/tests/test_configs/internal/command-mode/legacy_distributed_component.yaml b/sdk/ml/azure-ai-ml/tests/test_configs/internal/command-mode/legacy_distributed_component.yaml new file mode 100644 index 000000000000..208d122b5266 --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/test_configs/internal/command-mode/legacy_distributed_component.yaml @@ -0,0 +1,26 @@ +$schema: https://componentsdk.azureedge.net/jsonschema/DistributedComponent.json +name: microsoft.com.azureml.samples.mpi_example +version: 0.0.1 +display_name: MPI Example +type: DistributedComponent@1-legacy +inputs: + input_path: + type: path + description: The directory contains input data. + optional: false + string_parameter: + type: String + description: A parameter accepts a string value. + optional: true +outputs: + output_path: + type: path + description: The directory contains output data. +launcher: + type: mpi + additional_arguments: >- + python train.py --input-path {inputs.input_path} [--string-parameter {inputs.string_parameter}] + --output-path {outputs.output_path} +environment: + name: AzureML-Minimal + diff --git a/sdk/ml/azure-ai-ml/tests/test_configs/internal/command-mode/preview_command_component.yaml b/sdk/ml/azure-ai-ml/tests/test_configs/internal/command-mode/preview_command_component.yaml new file mode 100644 index 000000000000..e13e12e82a67 --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/test_configs/internal/command-mode/preview_command_component.yaml @@ -0,0 +1,10 @@ +$schema: https://componentsdk.azureedge.net/jsonschema/CommandComponent.json +name: ls_command +display_name: Ls Command +version: 0.0.1 +type: CommandComponent@1-preview +is_deterministic: true +command: >- + ls +environment: + name: AzureML-Designer