Skip to content

Commit

Permalink
[ML] fix: Add type instance var to MonitorInputData (Azure#32281)
Browse files Browse the repository at this point in the history
* refactor: Use MonintorInputDataType instead of magic constants

* refactor: input_type -> type

* docs: Document type instance var
  • Loading branch information
kdestin authored Sep 28, 2023
1 parent 4ac2a37 commit 973d6e5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
9 changes: 9 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ class MonitorTargetTasks(str, Enum, metaclass=CaseInsensitiveEnumMeta):
QUESTION_ANSWERING = "QuestionAnswering"


class MonitorInputDataType(str, Enum, metaclass=CaseInsensitiveEnumMeta):
#: An input data with a fixed window size.
STATIC = "Static"
#: An input data which trailing relatively to the monitor's current run.
TRAILING = "Trailing"
#: An input data with tabular format which doesn't require preprocessing.
FIXED = "Fixed"


class FADColumnNames(str, Enum, metaclass=CaseInsensitiveEnumMeta):
PREDICTION = "prediction"
PREDICTION_PROBABILITY = "prediction_probability"
Expand Down
35 changes: 26 additions & 9 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_monitoring/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
from azure.ai.ml.constants._monitoring import MonitorDatasetContext
from azure.ai.ml.constants._monitoring import MonitorDatasetContext, MonitorInputDataType


@experimental
class MonitorInputData(RestTranslatableMixin):
"""Monitor input data.
:keyword type: Specifies the type of monitoring input data.
:paramtype type: MonitorInputDataType
:keyword input_dataset: Input data used by the monitor
:paramtype input_dataset: Optional[~azure.ai.ml.Input]
:keyword dataset_context: The context of the input dataset. Accepted values are "model_inputs",
Expand All @@ -40,32 +42,37 @@ class MonitorInputData(RestTranslatableMixin):
def __init__(
self,
*,
input_type: str = None,
type: MonitorInputDataType = None,
data_context: MonitorDatasetContext = None,
target_columns: Dict = None,
job_type: str = None,
uri: str = None,
):
self.input_type = input_type
self.type = type
self.data_context = data_context
self.target_columns = target_columns
self.job_type = job_type
self.uri = uri

@classmethod
def _from_rest_object(cls, obj: RestMonitorInputBase) -> Optional["MonitorInputData"]:
if obj.input_data_type == "Fixed":
if obj.input_data_type == MonitorInputDataType.FIXED:
return FixedInputData._from_rest_object(obj)
if obj.input_data_type == "Trailing":
if obj.input_data_type == MonitorInputDataType.TRAILING:
return TrailingInputData._from_rest_object(obj)
if obj.input_data_type == "Static":
if obj.input_data_type == MonitorInputDataType.STATIC:
return StaticInputData._from_rest_object(obj)

return None


@experimental
class FixedInputData(MonitorInputData):
"""
:ivar type: Specifies the type of monitoring input data. Set automatically to "Fixed" for this class.
:var type: MonitorInputDataType
"""

def __init__(
self,
*,
Expand All @@ -75,7 +82,7 @@ def __init__(
uri: str = None,
):
super().__init__(
input_type="Fixed",
type=MonitorInputDataType.FIXED,
data_context=data_context,
target_columns=target_columns,
job_type=job_type,
Expand All @@ -102,6 +109,11 @@ def _from_rest_object(cls, obj: RestFixedInputData) -> "FixedInputData":

@experimental
class TrailingInputData(MonitorInputData):
"""
:ivar type: Specifies the type of monitoring input data. Set automatically to "Trailing" for this class.
:var type: MonitorInputDataType
"""

def __init__(
self,
*,
Expand All @@ -114,7 +126,7 @@ def __init__(
pre_processing_component_id: str = None,
):
super().__init__(
input_type="Trailing",
type=MonitorInputDataType.TRAILING,
data_context=data_context,
target_columns=target_columns,
job_type=job_type,
Expand Down Expand Up @@ -150,6 +162,11 @@ def _from_rest_object(cls, obj: RestTrailingInputData) -> "TrailingInputData":

@experimental
class StaticInputData(MonitorInputData):
"""
:ivar type: Specifies the type of monitoring input data. Set automatically to "Static" for this class.
:var type: MonitorInputDataType
"""

def __init__(
self,
*,
Expand All @@ -162,7 +179,7 @@ def __init__(
window_end: str = None,
):
super().__init__(
input_type="Static",
type=MonitorInputDataType.STATIC,
data_context=data_context,
target_columns=target_columns,
job_type=job_type,
Expand Down

0 comments on commit 973d6e5

Please sign in to comment.