Skip to content

Commit

Permalink
fix: support data binding expression for resources.xxx (#29559)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh authored Mar 27, 2023
1 parent 9da2f42 commit 98dd8ec
Show file tree
Hide file tree
Showing 18 changed files with 314 additions and 140 deletions.
6 changes: 2 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from marshmallow import INCLUDE, Schema

from ... import MpiDistribution, PyTorchDistribution, TensorFlowDistribution
from ..._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
from ..._schema import PathAwareSchema
from ..._schema.core.fields import DistributionField
from ...entities import CommandJobLimits, JobResourceConfiguration
Expand Down Expand Up @@ -106,12 +105,11 @@ def _from_rest_object_to_init_params(cls, obj):
obj = InternalBaseNode._from_rest_object_to_init_params(obj)

if "resources" in obj and obj["resources"]:
resources = RestJobResourceConfiguration.from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(resources)
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

# handle limits
if "limits" in obj and obj["limits"]:
obj["limits"] = CommandJobLimits()._from_rest_object(obj["limits"])
obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"])
return obj


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from marshmallow import Schema, fields

from azure.ai.ml._schema.core.fields import DataBindingStr, NestedField, UnionField
from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField
from azure.ai.ml._schema.core.schema import PathAwareSchema

DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported"
Expand All @@ -31,6 +31,15 @@ def _add_data_binding_to_field(field, attrs_to_skip, schema_stack):
elif isinstance(field, fields.List):
# handle list
field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack)
elif isinstance(field, ExperimentalField):
field = ExperimentalField(
_add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack),
data_key=field.data_key,
attribute=field.attribute,
dump_only=field.dump_only,
required=field.required,
allow_none=field.allow_none,
)
elif isinstance(field, NestedField):
# handle nested field
support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack)
Expand Down
38 changes: 13 additions & 25 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,12 @@
from marshmallow import RAISE, fields
from marshmallow.exceptions import ValidationError
from marshmallow.fields import _T, Field, Nested
from marshmallow.utils import (
FieldInstanceResolutionError,
from_iso_datetime,
resolve_field_instance,
)
from marshmallow.utils import FieldInstanceResolutionError, from_iso_datetime, resolve_field_instance

from azure.ai.ml._schema.core.schema import PathAwareSchema
from azure.ai.ml._utils._arm_id_utils import (
AMLVersionedArmId,
is_ARM_id_for_resource,
parse_name_label,
parse_name_version,
)
from azure.ai.ml._utils._experimental import _is_warning_cached
from azure.ai.ml._utils.utils import (
is_data_binding_expression,
is_valid_node_name,
load_file,
load_yaml,
)
from azure.ai.ml.constants._common import (
from ..._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource, parse_name_label, parse_name_version
from ..._utils._experimental import _is_warning_cached
from ..._utils.utils import is_data_binding_expression, is_valid_node_name, load_file, load_yaml
from ...constants._common import (
ARM_ID_PREFIX,
AZUREML_RESOURCE_PROVIDER,
BASE_PATH_CONTEXT_KEY,
Expand All @@ -47,16 +32,15 @@
FILE_PREFIX,
INTERNAL_REGISTRY_URI_FORMAT,
LOCAL_COMPUTE_TARGET,
SERVERLESS_COMPUTE,
LOCAL_PATH,
REGISTRY_URI_FORMAT,
RESOURCE_ID_FORMAT,
SERVERLESS_COMPUTE,
AzureMLResourceType,
)
from azure.ai.ml.entities._job.pipeline._attr_dict import (
try_get_non_arbitrary_attr_for_potential_attr_dict,
)
from azure.ai.ml.exceptions import ValidationException
from ...entities._job.pipeline._attr_dict import try_get_non_arbitrary_attr_for_potential_attr_dict
from ...exceptions import ValidationException
from ..core.schema import PathAwareSchema

module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -783,6 +767,10 @@ def __init__(self, experimental_field: fields.Field, **kwargs):
'"experimental_field" must be subclasses or ' "instances of marshmallow.base.FieldABC."
) from error

@property
def experimental_field(self):
return self._experimental_field

# This sets the parent for the schema and also handles nesting.
def _bind_to_schema(self, field_name, schema):
super()._bind_to_schema(field_name, schema)
Expand Down
50 changes: 30 additions & 20 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/component_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,52 @@

from marshmallow import INCLUDE, ValidationError, fields, post_dump, post_load, pre_dump, validates

from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema
from azure.ai.ml._schema.component import (
from ..._schema.assets.environment import AnonymousEnvironmentSchema
from ..._schema.component import (
AnonymousCommandComponentSchema,
AnonymousDataTransferCopyComponentSchema,
AnonymousImportComponentSchema,
AnonymousParallelComponentSchema,
AnonymousSparkComponentSchema,
AnonymousDataTransferCopyComponentSchema,
ComponentFileRefField,
DataTransferCopyComponentFileRefField,
ImportComponentFileRefField,
ParallelComponentFileRefField,
SparkComponentFileRefField,
DataTransferCopyComponentFileRefField,
)
from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, RegistryStr, UnionField
from azure.ai.ml._schema.core.schema import PathAwareSchema
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
from azure.ai.ml._schema.job.input_output_entry import OutputSchema, DatabaseSchema, FileSystemSchema
from azure.ai.ml._schema.job.input_output_fields_provider import InputsField
from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr
from azure.ai.ml._schema.spark_resource_configuration import SparkResourceConfigurationSchema
from azure.ai.ml._utils.utils import is_data_binding_expression
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml.constants._component import NodeType, DataTransferTaskType
from azure.ai.ml.entities._inputs_outputs import Input

from ..._utils.utils import is_data_binding_expression
from ...constants._common import AzureMLResourceType
from ...constants._component import DataTransferTaskType, NodeType
from ...entities._inputs_outputs import Input
from ...entities._job.pipeline._attr_dict import _AttrDict
from ...exceptions import ValidationException
from .._sweep.parameterized_sweep import ParameterizedSweepSchema
from .._utils.data_binding_expression import support_data_binding_expression_for_fields
from ..core.fields import ComputeField, StringTransformedEnum, TypeSensitiveUnionField
from ..core.fields import (
ArmVersionedStr,
ComputeField,
NestedField,
RegistryStr,
StringTransformedEnum,
TypeSensitiveUnionField,
UnionField,
)
from ..core.schema import PathAwareSchema
from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema
from ..job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
from ..job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema
from ..job.input_output_fields_provider import InputsField
from ..job.job_limits import CommandJobLimitsSchema
from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
from ..job.services import (
JobServiceSchema,
SshJobServiceSchema,
JupyterLabJobServiceSchema,
VsCodeJobServiceSchema,
SshJobServiceSchema,
TensorBoardJobServiceSchema,
VsCodeJobServiceSchema,
)
from ..pipeline.pipeline_job_io import OutputBindingStr
from ..spark_resource_configuration import SparkResourceConfigurationSchema

module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,7 +82,11 @@ def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint:
"""Support serializing unknown fields for pipeline node."""
if isinstance(original_data, _AttrDict):
user_setting_attr_dict = original_data._get_attrs()
data.update(user_setting_attr_dict)
# TODO: dump _AttrDict values to serializable data like dict instead of original object
# skip fields that are already serialized
for key, value in user_setting_attr_dict.items():
if key not in data:
data[key] = value
return data

# an alternative would be set schema property to be load_only, but sub-schemas like CommandSchema usually also
Expand Down
10 changes: 3 additions & 7 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

from azure.ai.ml._restclient.v2023_02_01_preview.models import CommandJob as RestCommandJob
from azure.ai.ml._restclient.v2023_02_01_preview.models import JobBase
from azure.ai.ml._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
from azure.ai.ml._restclient.v2023_02_01_preview.models import QueueSettings as RestQueueSettings
from azure.ai.ml._schema.core.fields import NestedField, UnionField
from azure.ai.ml._schema.job.command_job import CommandJobSchema
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
Expand Down Expand Up @@ -44,8 +42,8 @@
from azure.ai.ml.entities._job.job_limits import CommandJobLimits
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
from azure.ai.ml.entities._job.job_service import (
JobServiceBase,
JobService,
JobServiceBase,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
Expand Down Expand Up @@ -591,8 +589,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj = BaseNode._from_rest_object_to_init_params(obj)

if "resources" in obj and obj["resources"]:
resources = RestJobResourceConfiguration.from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(resources)
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

# services, sweep won't have services
if "services" in obj and obj["services"]:
Expand All @@ -614,8 +611,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj["identity"] = _BaseJobIdentityConfiguration._load(obj["identity"])

if "queue_settings" in obj and obj["queue_settings"]:
queue_settings = RestQueueSettings.from_dict(obj["queue_settings"])
obj["queue_settings"] = QueueSettings._from_rest_object(queue_settings)
obj["queue_settings"] = QueueSettings._from_rest_object(obj["queue_settings"])

return obj

Expand Down
20 changes: 10 additions & 10 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
from typing import Dict, List, Optional, Union

from marshmallow import Schema
from azure.ai.ml.constants._common import ARM_ID_PREFIX
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml.entities._component.component import Component
from azure.ai.ml.entities._component.parallel_component import ParallelComponent
from azure.ai.ml.entities._inputs_outputs import Input, Output
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob
from azure.ai.ml.entities._job.parallel.parallel_task import ParallelTask
from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings

from ..._schema import PathAwareSchema
from ...constants._common import ARM_ID_PREFIX
from ...constants._component import NodeType
from .._component.component import Component
from .._component.parallel_component import ParallelComponent
from .._inputs_outputs import Input, Output
from .._job.job_resource_configuration import JobResourceConfiguration
from .._job.parallel.parallel_job import ParallelJob
from .._job.parallel.parallel_task import ParallelTask
from .._job.parallel.retry_settings import RetrySettings
from .._job.pipeline._io import NodeOutput
from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, validate_attribute_type
from .base_node import BaseNode
Expand Down Expand Up @@ -355,7 +355,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj["task"].environment = task_env[len(ARM_ID_PREFIX) :]

if "resources" in obj and obj["resources"]:
obj["resources"] = JobResourceConfiguration._from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

if "partition_keys" in obj and obj["partition_keys"]:
obj["partition_keys"] = json.dumps(obj["partition_keys"])
Expand Down
12 changes: 3 additions & 9 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@

from marshmallow import INCLUDE, Schema

from ..._restclient.v2023_02_01_preview.models import IdentityConfiguration
from ..._restclient.v2023_02_01_preview.models import JobBase as JobBaseData
from ..._restclient.v2023_02_01_preview.models import SparkJob as RestSparkJob
from ..._restclient.v2023_02_01_preview.models import SparkJobEntry as RestSparkJobEntry
from ..._restclient.v2023_02_01_preview.models import SparkResourceConfiguration as RestSparkResourceConfiguration
from ..._schema import NestedField, PathAwareSchema, UnionField
from ..._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
from ..._schema.job.parameterized_spark import CONF_KEY_MAP, SparkConfSchema
Expand Down Expand Up @@ -299,16 +296,13 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj = super()._from_rest_object_to_init_params(obj)

if "resources" in obj and obj["resources"]:
resources = RestSparkResourceConfiguration.from_dict(obj["resources"])
obj["resources"] = SparkResourceConfiguration._from_rest_object(resources)
obj["resources"] = SparkResourceConfiguration._from_rest_object(obj["resources"])

if "identity" in obj and obj["identity"]:
identity = IdentityConfiguration.from_dict(obj["identity"])
obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(identity)
obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"])

if "entry" in obj and obj["entry"]:
entry = RestSparkJobEntry.from_dict(obj["entry"])
obj["entry"] = SparkJobEntry._from_rest_object(entry)
obj["entry"] = SparkJobEntry._from_rest_object(obj["entry"])
if "conf" in obj and obj["conf"]:
identify_schema = UnionField(
[
Expand Down
18 changes: 12 additions & 6 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Dict, List, Optional, Union

from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata
from azure.ai.ml._restclient.v2022_12_01_preview.models import ConnectionAuthType
from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration
from azure.ai.ml._restclient.v2022_01_01_preview.models import ManagedIdentity as RestWorkspaceConnectionManagedIdentity
from azure.ai.ml._restclient.v2022_01_01_preview.models import (
Expand All @@ -24,6 +23,8 @@
from azure.ai.ml._restclient.v2022_01_01_preview.models import (
UsernamePassword as RestWorkspaceConnectionUsernamePassword,
)
from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration
from azure.ai.ml._restclient.v2022_10_01.models import (
AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials,
)
Expand All @@ -32,7 +33,6 @@
CertificateDatastoreCredentials as RestCertificateDatastoreCredentials,
)
from azure.ai.ml._restclient.v2022_10_01.models import CertificateDatastoreSecrets, CredentialsType
from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
from azure.ai.ml._restclient.v2022_10_01.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials
from azure.ai.ml._restclient.v2022_10_01.models import SasDatastoreCredentials as RestSasDatastoreCredentials
from azure.ai.ml._restclient.v2022_10_01.models import SasDatastoreSecrets as RestSasDatastoreSecrets
Expand All @@ -42,16 +42,16 @@
from azure.ai.ml._restclient.v2022_10_01.models import (
ServicePrincipalDatastoreSecrets as RestServicePrincipalDatastoreSecrets,
)
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration
from azure.ai.ml._restclient.v2022_12_01_preview.models import ConnectionAuthType
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
)
from azure.ai.ml._restclient.v2023_02_01_preview.models import AmlToken as RestAmlToken
from azure.ai.ml._restclient.v2023_02_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration
from azure.ai.ml._restclient.v2023_02_01_preview.models import IdentityConfigurationType
from azure.ai.ml._restclient.v2023_02_01_preview.models import ManagedIdentity as RestJobManagedIdentity
from azure.ai.ml._restclient.v2023_02_01_preview.models import ManagedServiceIdentity as RestRegistryManagedIdentity
from azure.ai.ml._restclient.v2023_02_01_preview.models import UserIdentity as RestUserIdentity
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
)
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal
from azure.ai.ml.constants._common import CommonYamlFields, IdentityType
from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin
Expand Down Expand Up @@ -327,12 +327,18 @@ def __init__(self):

@classmethod
def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "Identity":
if obj is None:
return None
mapping = {
IdentityConfigurationType.AML_TOKEN: AmlTokenConfiguration,
IdentityConfigurationType.MANAGED: ManagedIdentityConfiguration,
IdentityConfigurationType.USER_IDENTITY: UserIdentityConfiguration,
}

if isinstance(obj, dict):
# TODO: support data binding expression
obj = RestJobIdentityConfiguration.from_dict(obj)

identity_class = mapping.get(obj.identity_type, None)
if identity_class:
# pylint: disable=protected-access
Expand Down
3 changes: 1 addition & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from azure.ai.ml._restclient.v2022_12_01_preview.models import SweepJobLimits as RestSweepJobLimits
from azure.ai.ml._utils.utils import from_iso_duration_format, is_data_binding_expression, to_iso_duration_format
from azure.ai.ml.constants import JobType
from azure.ai.ml.entities._job.pipeline._io import PipelineInput
from azure.ai.ml.entities._mixins import RestTranslatableMixin

module_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,7 +43,7 @@ def __init__(self, *, timeout: Union[int, str, None] = None):
self.timeout = timeout

def _to_rest_object(self) -> RestCommandJobLimits:
if isinstance(self.timeout, PipelineInput):
if is_data_binding_expression(self.timeout):
return RestCommandJobLimits(timeout=self.timeout)
return RestCommandJobLimits(timeout=to_iso_duration_format(self.timeout))

Expand Down
Loading

0 comments on commit 98dd8ec

Please sign in to comment.