Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Pipelines] feat: support data binding expression inside entity runsettings #29559

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from marshmallow import INCLUDE, Schema

from ... import MpiDistribution, PyTorchDistribution, TensorFlowDistribution
from ..._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
from ..._schema import PathAwareSchema
from ..._schema.core.fields import DistributionField
from ...entities import CommandJobLimits, JobResourceConfiguration
Expand Down Expand Up @@ -106,12 +105,11 @@ def _from_rest_object_to_init_params(cls, obj):
obj = InternalBaseNode._from_rest_object_to_init_params(obj)

if "resources" in obj and obj["resources"]:
resources = RestJobResourceConfiguration.from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(resources)
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])
wangchao1230 marked this conversation as resolved.
Show resolved Hide resolved

# handle limits
if "limits" in obj and obj["limits"]:
obj["limits"] = CommandJobLimits()._from_rest_object(obj["limits"])
obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"])
return obj


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from marshmallow import Schema, fields

from azure.ai.ml._schema.core.fields import DataBindingStr, NestedField, UnionField
from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField
from azure.ai.ml._schema.core.schema import PathAwareSchema

DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported"
Expand All @@ -31,6 +31,15 @@ def _add_data_binding_to_field(field, attrs_to_skip, schema_stack):
elif isinstance(field, fields.List):
# handle list
field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack)
elif isinstance(field, ExperimentalField):
field = ExperimentalField(
_add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack),
data_key=field.data_key,
attribute=field.attribute,
dump_only=field.dump_only,
required=field.required,
allow_none=field.allow_none,
)
elif isinstance(field, NestedField):
# handle nested field
support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack)
Expand Down
38 changes: 13 additions & 25 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,12 @@
from marshmallow import RAISE, fields
from marshmallow.exceptions import ValidationError
from marshmallow.fields import _T, Field, Nested
from marshmallow.utils import (
FieldInstanceResolutionError,
from_iso_datetime,
resolve_field_instance,
)
from marshmallow.utils import FieldInstanceResolutionError, from_iso_datetime, resolve_field_instance

from azure.ai.ml._schema.core.schema import PathAwareSchema
from azure.ai.ml._utils._arm_id_utils import (
AMLVersionedArmId,
is_ARM_id_for_resource,
parse_name_label,
parse_name_version,
)
from azure.ai.ml._utils._experimental import _is_warning_cached
from azure.ai.ml._utils.utils import (
is_data_binding_expression,
is_valid_node_name,
load_file,
load_yaml,
)
from azure.ai.ml.constants._common import (
from ..._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource, parse_name_label, parse_name_version
from ..._utils._experimental import _is_warning_cached
from ..._utils.utils import is_data_binding_expression, is_valid_node_name, load_file, load_yaml
from ...constants._common import (
ARM_ID_PREFIX,
AZUREML_RESOURCE_PROVIDER,
BASE_PATH_CONTEXT_KEY,
Expand All @@ -47,16 +32,15 @@
FILE_PREFIX,
INTERNAL_REGISTRY_URI_FORMAT,
LOCAL_COMPUTE_TARGET,
SERVERLESS_COMPUTE,
LOCAL_PATH,
REGISTRY_URI_FORMAT,
RESOURCE_ID_FORMAT,
SERVERLESS_COMPUTE,
AzureMLResourceType,
)
from azure.ai.ml.entities._job.pipeline._attr_dict import (
try_get_non_arbitrary_attr_for_potential_attr_dict,
)
from azure.ai.ml.exceptions import ValidationException
from ...entities._job.pipeline._attr_dict import try_get_non_arbitrary_attr_for_potential_attr_dict
from ...exceptions import ValidationException
from ..core.schema import PathAwareSchema

module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -783,6 +767,10 @@ def __init__(self, experimental_field: fields.Field, **kwargs):
'"experimental_field" must be subclasses or ' "instances of marshmallow.base.FieldABC."
) from error

@property
def experimental_field(self):
return self._experimental_field

# This sets the parent for the schema and also handles nesting.
def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
Expand Down
50 changes: 30 additions & 20 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/component_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,52 @@

from marshmallow import INCLUDE, ValidationError, fields, post_dump, post_load, pre_dump, validates

from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema
from azure.ai.ml._schema.component import (
from ..._schema.assets.environment import AnonymousEnvironmentSchema
from ..._schema.component import (
AnonymousCommandComponentSchema,
AnonymousDataTransferCopyComponentSchema,
AnonymousImportComponentSchema,
AnonymousParallelComponentSchema,
AnonymousSparkComponentSchema,
AnonymousDataTransferCopyComponentSchema,
ComponentFileRefField,
DataTransferCopyComponentFileRefField,
ImportComponentFileRefField,
ParallelComponentFileRefField,
SparkComponentFileRefField,
DataTransferCopyComponentFileRefField,
)
from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, RegistryStr, UnionField
from azure.ai.ml._schema.core.schema import PathAwareSchema
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
from azure.ai.ml._schema.job.input_output_entry import OutputSchema, DatabaseSchema, FileSystemSchema
from azure.ai.ml._schema.job.input_output_fields_provider import InputsField
from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr
from azure.ai.ml._schema.spark_resource_configuration import SparkResourceConfigurationSchema
from azure.ai.ml._utils.utils import is_data_binding_expression
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml.constants._component import NodeType, DataTransferTaskType
from azure.ai.ml.entities._inputs_outputs import Input

from ..._utils.utils import is_data_binding_expression
from ...constants._common import AzureMLResourceType
from ...constants._component import DataTransferTaskType, NodeType
from ...entities._inputs_outputs import Input
from ...entities._job.pipeline._attr_dict import _AttrDict
from ...exceptions import ValidationException
from .._sweep.parameterized_sweep import ParameterizedSweepSchema
from .._utils.data_binding_expression import support_data_binding_expression_for_fields
from ..core.fields import ComputeField, StringTransformedEnum, TypeSensitiveUnionField
from ..core.fields import (
ArmVersionedStr,
ComputeField,
NestedField,
RegistryStr,
StringTransformedEnum,
TypeSensitiveUnionField,
UnionField,
)
from ..core.schema import PathAwareSchema
from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema
from ..job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
from ..job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema
from ..job.input_output_fields_provider import InputsField
from ..job.job_limits import CommandJobLimitsSchema
from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
from ..job.services import (
JobServiceSchema,
SshJobServiceSchema,
JupyterLabJobServiceSchema,
VsCodeJobServiceSchema,
SshJobServiceSchema,
TensorBoardJobServiceSchema,
VsCodeJobServiceSchema,
)
from ..pipeline.pipeline_job_io import OutputBindingStr
from ..spark_resource_configuration import SparkResourceConfigurationSchema

module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,7 +82,11 @@ def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint:
"""Support serializing unknown fields for pipeline node."""
if isinstance(original_data, _AttrDict):
user_setting_attr_dict = original_data._get_attrs()
data.update(user_setting_attr_dict)
# TODO: dump _AttrDict values to serializable data like dict instead of original object
# skip fields that are already serialized
for key, value in user_setting_attr_dict.items():
if key not in data:
data[key] = value
return data

# an alternative would be set schema property to be load_only, but sub-schemas like CommandSchema usually also
Expand Down
10 changes: 3 additions & 7 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

from azure.ai.ml._restclient.v2023_02_01_preview.models import CommandJob as RestCommandJob
from azure.ai.ml._restclient.v2023_02_01_preview.models import JobBase
from azure.ai.ml._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
from azure.ai.ml._restclient.v2023_02_01_preview.models import QueueSettings as RestQueueSettings
from azure.ai.ml._schema.core.fields import NestedField, UnionField
from azure.ai.ml._schema.job.command_job import CommandJobSchema
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
Expand Down Expand Up @@ -44,8 +42,8 @@
from azure.ai.ml.entities._job.job_limits import CommandJobLimits
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
from azure.ai.ml.entities._job.job_service import (
JobServiceBase,
JobService,
JobServiceBase,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
Expand Down Expand Up @@ -591,8 +589,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj = BaseNode._from_rest_object_to_init_params(obj)

if "resources" in obj and obj["resources"]:
resources = RestJobResourceConfiguration.from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(resources)
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

# services, sweep won't have services
if "services" in obj and obj["services"]:
Expand All @@ -614,8 +611,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj["identity"] = _BaseJobIdentityConfiguration._load(obj["identity"])

if "queue_settings" in obj and obj["queue_settings"]:
queue_settings = RestQueueSettings.from_dict(obj["queue_settings"])
obj["queue_settings"] = QueueSettings._from_rest_object(queue_settings)
obj["queue_settings"] = QueueSettings._from_rest_object(obj["queue_settings"])

return obj

Expand Down
20 changes: 10 additions & 10 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
from typing import Dict, List, Optional, Union

from marshmallow import Schema
from azure.ai.ml.constants._common import ARM_ID_PREFIX
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml.entities._component.component import Component
from azure.ai.ml.entities._component.parallel_component import ParallelComponent
from azure.ai.ml.entities._inputs_outputs import Input, Output
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob
from azure.ai.ml.entities._job.parallel.parallel_task import ParallelTask
from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings

from ..._schema import PathAwareSchema
from ...constants._common import ARM_ID_PREFIX
from ...constants._component import NodeType
from .._component.component import Component
from .._component.parallel_component import ParallelComponent
from .._inputs_outputs import Input, Output
from .._job.job_resource_configuration import JobResourceConfiguration
from .._job.parallel.parallel_job import ParallelJob
from .._job.parallel.parallel_task import ParallelTask
from .._job.parallel.retry_settings import RetrySettings
from .._job.pipeline._io import NodeOutput
from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, validate_attribute_type
from .base_node import BaseNode
Expand Down Expand Up @@ -355,7 +355,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj["task"].environment = task_env[len(ARM_ID_PREFIX) :]

if "resources" in obj and obj["resources"]:
obj["resources"] = JobResourceConfiguration._from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

if "partition_keys" in obj and obj["partition_keys"]:
obj["partition_keys"] = json.dumps(obj["partition_keys"])
Expand Down
12 changes: 3 additions & 9 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@

from marshmallow import INCLUDE, Schema

from ..._restclient.v2023_02_01_preview.models import IdentityConfiguration
from ..._restclient.v2023_02_01_preview.models import JobBase as JobBaseData
from ..._restclient.v2023_02_01_preview.models import SparkJob as RestSparkJob
from ..._restclient.v2023_02_01_preview.models import SparkJobEntry as RestSparkJobEntry
from ..._restclient.v2023_02_01_preview.models import SparkResourceConfiguration as RestSparkResourceConfiguration
from ..._schema import NestedField, PathAwareSchema, UnionField
from ..._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
from ..._schema.job.parameterized_spark import CONF_KEY_MAP, SparkConfSchema
Expand Down Expand Up @@ -299,16 +296,13 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj = super()._from_rest_object_to_init_params(obj)

if "resources" in obj and obj["resources"]:
resources = RestSparkResourceConfiguration.from_dict(obj["resources"])
obj["resources"] = SparkResourceConfiguration._from_rest_object(resources)
obj["resources"] = SparkResourceConfiguration._from_rest_object(obj["resources"])

if "identity" in obj and obj["identity"]:
identity = IdentityConfiguration.from_dict(obj["identity"])
obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(identity)
obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"])

if "entry" in obj and obj["entry"]:
entry = RestSparkJobEntry.from_dict(obj["entry"])
obj["entry"] = SparkJobEntry._from_rest_object(entry)
obj["entry"] = SparkJobEntry._from_rest_object(obj["entry"])
if "conf" in obj and obj["conf"]:
identify_schema = UnionField(
[
Expand Down
18 changes: 12 additions & 6 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Dict, List, Optional, Union

from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata
from azure.ai.ml._restclient.v2022_12_01_preview.models import ConnectionAuthType
from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration
from azure.ai.ml._restclient.v2022_01_01_preview.models import ManagedIdentity as RestWorkspaceConnectionManagedIdentity
from azure.ai.ml._restclient.v2022_01_01_preview.models import (
Expand All @@ -24,6 +23,8 @@
from azure.ai.ml._restclient.v2022_01_01_preview.models import (
UsernamePassword as RestWorkspaceConnectionUsernamePassword,
)
from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration
from azure.ai.ml._restclient.v2022_10_01.models import (
AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials,
)
Expand All @@ -32,7 +33,6 @@
CertificateDatastoreCredentials as RestCertificateDatastoreCredentials,
)
from azure.ai.ml._restclient.v2022_10_01.models import CertificateDatastoreSecrets, CredentialsType
from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
from azure.ai.ml._restclient.v2022_10_01.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials
from azure.ai.ml._restclient.v2022_10_01.models import SasDatastoreCredentials as RestSasDatastoreCredentials
from azure.ai.ml._restclient.v2022_10_01.models import SasDatastoreSecrets as RestSasDatastoreSecrets
Expand All @@ -42,16 +42,16 @@
from azure.ai.ml._restclient.v2022_10_01.models import (
ServicePrincipalDatastoreSecrets as RestServicePrincipalDatastoreSecrets,
)
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration
from azure.ai.ml._restclient.v2022_12_01_preview.models import ConnectionAuthType
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
)
from azure.ai.ml._restclient.v2023_02_01_preview.models import AmlToken as RestAmlToken
from azure.ai.ml._restclient.v2023_02_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration
from azure.ai.ml._restclient.v2023_02_01_preview.models import IdentityConfigurationType
from azure.ai.ml._restclient.v2023_02_01_preview.models import ManagedIdentity as RestJobManagedIdentity
from azure.ai.ml._restclient.v2023_02_01_preview.models import ManagedServiceIdentity as RestRegistryManagedIdentity
from azure.ai.ml._restclient.v2023_02_01_preview.models import UserIdentity as RestUserIdentity
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
)
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal
from azure.ai.ml.constants._common import CommonYamlFields, IdentityType
from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin
Expand Down Expand Up @@ -327,12 +327,18 @@ def __init__(self):

@classmethod
def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "Identity":
if obj is None:
return None
mapping = {
IdentityConfigurationType.AML_TOKEN: AmlTokenConfiguration,
IdentityConfigurationType.MANAGED: ManagedIdentityConfiguration,
IdentityConfigurationType.USER_IDENTITY: UserIdentityConfiguration,
}

if isinstance(obj, dict):
# TODO: support data binding expression
obj = RestJobIdentityConfiguration.from_dict(obj)

identity_class = mapping.get(obj.identity_type, None)
if identity_class:
# pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from azure.ai.ml._restclient.v2022_12_01_preview.models import SweepJobLimits as RestSweepJobLimits
from azure.ai.ml._utils.utils import from_iso_duration_format, is_data_binding_expression, to_iso_duration_format
from azure.ai.ml.constants import JobType
from azure.ai.ml.entities._job.pipeline._io import PipelineInput
from azure.ai.ml.entities._mixins import RestTranslatableMixin

module_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,7 +43,7 @@ def __init__(self, *, timeout: Union[int, str, None] = None):
self.timeout = timeout

def _to_rest_object(self) -> RestCommandJobLimits:
if isinstance(self.timeout, PipelineInput):
if is_data_binding_expression(self.timeout):
return RestCommandJobLimits(timeout=self.timeout)
return RestCommandJobLimits(timeout=to_iso_duration_format(self.timeout))

Expand Down
Loading