Skip to content

Commit

Permalink
fix: support data binding expression for resources.xxx
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Mar 24, 2023
1 parent 1f9a176 commit b1a19d9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 13 deletions.
6 changes: 2 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from marshmallow import INCLUDE, Schema

from ... import MpiDistribution, PyTorchDistribution, TensorFlowDistribution
from ..._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
from ..._schema import PathAwareSchema
from ..._schema.core.fields import DistributionField
from ...entities import CommandJobLimits, JobResourceConfiguration
Expand Down Expand Up @@ -106,12 +105,11 @@ def _from_rest_object_to_init_params(cls, obj):
obj = InternalBaseNode._from_rest_object_to_init_params(obj)

if "resources" in obj and obj["resources"]:
resources = RestJobResourceConfiguration.from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(resources)
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

# handle limits
if "limits" in obj and obj["limits"]:
obj["limits"] = CommandJobLimits()._from_rest_object(obj["limits"])
obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"])
return obj


Expand Down
4 changes: 2 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from enum import Enum
from typing import Dict, List, Optional, Union

from marshmallow import Schema
from azure.ai.ml.constants._common import ARM_ID_PREFIX
from azure.ai.ml.constants._component import NodeType
from azure.ai.ml.entities._component.component import Component
Expand All @@ -21,6 +20,7 @@
from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob
from azure.ai.ml.entities._job.parallel.parallel_task import ParallelTask
from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings
from marshmallow import Schema

from ..._schema import PathAwareSchema
from .._job.pipeline._io import NodeOutput
Expand Down Expand Up @@ -355,7 +355,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
obj["task"].environment = task_env[len(ARM_ID_PREFIX) :]

if "resources" in obj and obj["resources"]:
obj["resources"] = JobResourceConfiguration._from_dict(obj["resources"])
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])

if "partition_keys" in obj and obj["partition_keys"]:
obj["partition_keys"] = json.dumps(obj["partition_keys"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
import logging
from typing import Any, Dict, Optional, List
from typing import Any, Dict, List, Optional

from azure.ai.ml._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
from azure.ai.ml.constants._job.job import JobComputePropertyFields
Expand Down Expand Up @@ -146,16 +146,12 @@ def _to_rest_object(self) -> RestJobResourceConfiguration:
shm_size=self.shm_size,
)

@classmethod
def _from_dict(cls, dct: dict):
"""Convert a dict to an Input object."""
obj = cls(**dict(dct.items()))
return obj

@classmethod
def _from_rest_object(cls, obj: Optional[RestJobResourceConfiguration]) -> Optional["JobResourceConfiguration"]:
if obj is None:
return None
if isinstance(obj, dict):
return cls(**obj)
return JobResourceConfiguration(
locations=obj.locations,
instance_count=obj.instance_count,
Expand Down

0 comments on commit b1a19d9

Please sign in to comment.