Skip to content

Commit

Permalink
Type Fix for dsl folder (Azure#33644)
Browse files Browse the repository at this point in the history
* fix dsl _utils.py

* fix pipeline_decorator

* fix misc

* fix pipeline_component_builder

* fix misc - 2

* update dynamic

* fix misc - 3

* fix misc - 4

* add type ignore and bug item number

* update pyproject.toml - 1

* remove cast

* unexclude dsl folder
  • Loading branch information
CoderKevinZhang authored and sofiar-msft committed Feb 16, 2024
1 parent 6f1ca39 commit bcc1830
Show file tree
Hide file tree
Showing 15 changed files with 171 additions and 130 deletions.
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from azure.ai.ml.dsl._pipeline_decorator import pipeline

Expand Down
13 changes: 7 additions & 6 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_component_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

# pylint: disable=protected-access

from typing import Callable, Mapping
from typing import Any, Callable, List, Mapping

from azure.ai.ml.dsl._dynamic import KwParameter, create_kw_function_from_parameters
from azure.ai.ml.entities import Component as ComponentEntity
from azure.ai.ml.entities._builders import Command
from azure.ai.ml.entities._component.datatransfer_component import DataTransferImportComponent


def get_dynamic_input_parameter(inputs: Mapping):
def get_dynamic_input_parameter(inputs: Mapping) -> List:
"""Return the dynamic parameter of the definition's input ports.
:param inputs: The mapping of input names to input objects.
Expand All @@ -31,7 +31,7 @@ def get_dynamic_input_parameter(inputs: Mapping):
]


def get_dynamic_source_parameter(source):
def get_dynamic_source_parameter(source: Any) -> List:
"""Return the dynamic parameter of the definition's source port.
:param source: The source object.
Expand All @@ -49,7 +49,7 @@ def get_dynamic_source_parameter(source):
]


def to_component_func(entity: ComponentEntity, component_creation_func) -> Callable[..., Command]:
def to_component_func(entity: ComponentEntity, component_creation_func: Callable) -> Callable[..., Command]:
"""Convert a ComponentEntity to a callable component function.
:param entity: The ComponentEntity to convert.
Expand Down Expand Up @@ -97,6 +97,7 @@ def to_component_func(entity: ComponentEntity, component_creation_func) -> Calla
flattened_group_keys=flattened_group_keys,
)

dynamic_func._func_calling_example = example
dynamic_func._has_parameters = bool(all_params)
# Bug Item number: 2883188
dynamic_func._func_calling_example = example # type: ignore
dynamic_func._has_parameters = bool(all_params) # type: ignore
return dynamic_func
11 changes: 9 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_do_while.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Dict, Optional, Union

from azure.ai.ml.entities._builders import Command
from azure.ai.ml.entities._builders.do_while import DoWhile
from azure.ai.ml.entities._builders.pipeline import Pipeline
from azure.ai.ml.entities._inputs_outputs import Output
from azure.ai.ml.entities._job.pipeline._io import NodeOutput


def do_while(body, mapping, max_iteration_count: int, condition=None):
def do_while(
body: Union[Pipeline, Command], mapping: Dict, max_iteration_count: int, condition: Optional[Output] = None
) -> DoWhile:
"""Build a do_while node by specifying the loop body, output-input mapping, and termination condition.
.. remarks::
Expand Down Expand Up @@ -63,7 +70,7 @@ def pipeline_with_do_while_node():
)
do_while_node.set_limits(max_iteration_count=max_iteration_count)

def _infer_and_update_body_input_from_mapping():
def _infer_and_update_body_input_from_mapping() -> None:
# pylint: disable=protected-access
for source_output, body_input in mapping.items():
# handle case that mapping key is a NodeOutput
Expand Down
37 changes: 21 additions & 16 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import types
from inspect import Parameter, Signature
from typing import Callable, Dict, Sequence
from typing import Any, Callable, Dict, Sequence, cast

from azure.ai.ml.entities import Component
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, UnexpectedKeywordError, ValidationException
Expand All @@ -26,7 +26,9 @@ class KwParameter(Parameter):
:type _optional: bool
"""

def __init__(self, name, default, annotation=Parameter.empty, _type="str", _optional=False) -> None:
def __init__(
self, name: str, default: Any, annotation: Any = Parameter.empty, _type: str = "str", _optional: bool = False
) -> None:
super().__init__(name, Parameter.KEYWORD_ONLY, default=default, annotation=annotation)
self._type = _type
self._optional = _optional
Expand Down Expand Up @@ -54,23 +56,25 @@ def _replace_function_name(func: types.FunctionType, new_name: str) -> types.Fun
else:
# Before python<3.8, replace is not available, we can only initialize the code as following.
# https://github.com/python/cpython/blob/v3.7.8/Objects/codeobject.c#L97
code = types.CodeType(

# Bug Item number: 2881688
code = types.CodeType( # type: ignore
code_template.co_argcount,
code_template.co_kwonlyargcount,
code_template.co_nlocals,
code_template.co_stacksize,
code_template.co_flags,
code_template.co_code,
code_template.co_consts,
code_template.co_code, # type: ignore
code_template.co_consts, # type: ignore
code_template.co_names,
code_template.co_varnames,
code_template.co_filename,
code_template.co_filename, # type: ignore
new_name, # Use the new name for the new code object.
code_template.co_firstlineno,
code_template.co_lnotab,
code_template.co_firstlineno, # type: ignore
code_template.co_lnotab, # type: ignore
# The following two values are required for closures.
code_template.co_freevars,
code_template.co_cellvars,
code_template.co_freevars, # type: ignore
code_template.co_cellvars, # type: ignore
)
# Initialize a new function with the code object and the new name, see the following ref for more details.
# https://github.com/python/cpython/blob/4901fe274bc82b95dc89bcb3de8802a3dfedab32/Objects/clinic/funcobject.c.h#L30
Expand All @@ -89,7 +93,7 @@ def _replace_function_name(func: types.FunctionType, new_name: str) -> types.Fun


# pylint: disable-next=docstring-missing-param
def _assert_arg_valid(kwargs: dict, keys: list, func_name: str):
def _assert_arg_valid(kwargs: dict, keys: list, func_name: str) -> None:
"""Assert the arg keys are all in keys."""
# pylint: disable=protected-access
# validate component input names
Expand All @@ -114,7 +118,7 @@ def _assert_arg_valid(kwargs: dict, keys: list, func_name: str):
kwargs[lower2original_parameter_names[key.lower()]] = kwargs.pop(key)


def _update_dct_if_not_exist(dst: Dict, src: Dict):
def _update_dct_if_not_exist(dst: Dict, src: Dict) -> None:
"""Computes the union of `src` and `dst`, in-place within `dst`
If a key exists in `dst` and `src` the value in `dst` is preserved
Expand Down Expand Up @@ -162,17 +166,18 @@ def create_kw_function_from_parameters(
)
default_kwargs = {p.name: p.default for p in parameters}

def f(**kwargs):
def f(**kwargs: Any) -> Any:
# We need to make sure all keys of kwargs are valid.
# Merge valid group keys with original keys.
_assert_arg_valid(kwargs, [*list(default_kwargs.keys()), *flattened_group_keys], func_name=func_name)
# We need to put the default args to the kwargs before invoking the original function.
_update_dct_if_not_exist(kwargs, default_kwargs)
return func(**kwargs)

f = _replace_function_name(f, func_name)
f = _replace_function_name(cast(types.FunctionType, f), func_name)
# Set the signature so jupyter notebook could have param hint by calling inspect.signature()
f.__signature__ = Signature(parameters)
# Bug Item number: 2883223
f.__signature__ = Signature(parameters) # type: ignore
# Set doc/name/module to make sure help(f) shows following expected result.
# Expected help(f):
#
Expand All @@ -183,5 +188,5 @@ def f(**kwargs):
f.__doc__ = documentation # Set documentation to update FUNC_DOC in help.
# Set module = None to avoid showing the sentence `in module 'azure.ai.ml.component._dynamic' in help.`
# See https://github.com/python/cpython/blob/2145c8c9724287a310bc77a2760d4f1c0ca9eb0c/Lib/pydoc.py#L1757
f.__module__ = None
f.__module__ = None # type: ignore
return f
12 changes: 6 additions & 6 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_fl_scatter_gather_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import importlib

# pylint: disable=protected-access
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.entities import CommandComponent, PipelineJob
from azure.ai.ml.entities._assets.federated_learning_silo import FederatedLearningSilo
from azure.ai.ml.entities._builders.fl_scatter_gather import FLScatterGather


def _check_for_import(package_name):
def _check_for_import(package_name: str) -> None:
try:
# pylint: disable=unused-import
importlib.import_module(package_name)
Expand All @@ -31,16 +31,16 @@ def fl_scatter_gather(
silo_configs: List[FederatedLearningSilo],
silo_component: Union[PipelineJob, CommandComponent],
aggregation_component: Union[PipelineJob, CommandComponent],
aggregation_compute: str = None,
aggregation_datastore: str = None,
aggregation_compute: Optional[str] = None,
aggregation_datastore: Optional[str] = None,
shared_silo_kwargs: Optional[Dict] = None,
aggregation_kwargs: Optional[Dict] = None,
silo_to_aggregation_argument_map: Optional[Dict] = None,
aggregation_to_silo_argument_map: Optional[Dict] = None,
max_iterations: int = 1,
_create_default_mappings_if_needed: bool = False,
**kwargs,
):
**kwargs: Any,
) -> FLScatterGather:
"""A federated learning scatter-gather subgraph node.
It's assumed that this will be used inside of a `@pipeline`-decorated function in order to create a subgraph which
Expand Down
32 changes: 18 additions & 14 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_group_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Attribute on customized group class to mark a value type as a group of inputs/outputs.
import _thread
import functools
from typing import Any, Callable, Dict, List, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

from azure.ai.ml import Input, Output
from azure.ai.ml.constants._component import IOConstants
Expand Down Expand Up @@ -145,12 +145,12 @@ def parent_func(params: ParentClass):

def _create_fn(
name: str,
args: List[str],
body: List[str],
args: Union[List, str],
body: Union[List, str],
*,
globals: Dict[str, Any] = None,
locals: Dict[str, Any] = None,
return_type: Type[T2],
globals: Optional[Dict[str, Any]] = None,
locals: Optional[Dict[str, Any]] = None,
return_type: Optional[Type[T2]],
) -> Callable[..., T2]:
"""To generate function in class.
Expand Down Expand Up @@ -188,9 +188,10 @@ def _create_fn(
txt = f" def {name}({args}){return_annotation}:\n{body}"
local_vars = ", ".join(locals.keys())
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
ns = {}
ns: Dict = {}
exec(txt, globals, ns) # pylint: disable=exec-used # nosec
return ns["__create_fn__"](**locals)
res: Callable = ns["__create_fn__"](**locals)
return res

def _create_init_fn( # pylint: disable=unused-argument
cls: Type[T], fields: Dict[str, Union[Annotation, Input, Output]]
Expand All @@ -207,7 +208,7 @@ def _create_init_fn( # pylint: disable=unused-argument

# Reference code link:
# https://github.com/python/cpython/blob/17b16e13bb444001534ed6fccb459084596c8bcf/Lib/dataclasses.py#L523
def _get_data_type_from_annotation(anno: Input):
def _get_data_type_from_annotation(anno: Any) -> Any:
if isinstance(anno, GroupInput):
return anno._group_class
# keep original annotation for Outputs
Expand All @@ -220,7 +221,7 @@ def _get_data_type_from_annotation(anno: Input):
# otherwise, keep original annotation
return anno

def _get_default(key):
def _get_default(key: str) -> Any:
# will set None as default value when default not exist so won't need to reorder the init params
val = fields[key]
if hasattr(val, "default"):
Expand Down Expand Up @@ -254,20 +255,22 @@ def _create_repr_fn(fields: Dict[str, Union[Annotation, Input, Output]]) -> Call
# https://github.com/python/cpython/blob/17b16e13bb444001534ed6fccb459084596c8bcf/Lib/dataclasses.py#L582
fn = _create_fn(
"__repr__",
("self",),
[
"self",
],
['return self.__class__.__qualname__ + f"(' + ", ".join([f"{k}={{self.{k}!r}}" for k in fields]) + ')"'],
return_type=str,
)

# This function's logic is copied from "recursive_repr" function in
# reprlib module to avoid dependency.
def _recursive_repr(user_function):
def _recursive_repr(user_function: Any) -> Any:
# Decorator to make a repr function return "..." for a recursive
# call.
repr_running = set()

@functools.wraps(user_function)
def wrapper(self):
def wrapper(self: Any) -> Any:
key = id(self), _thread.get_ident()
if key in repr_running:
return "..."
Expand All @@ -280,7 +283,8 @@ def wrapper(self):

return wrapper

return _recursive_repr(fn)
res: Callable = _recursive_repr(fn)
return res

def _process_class(cls: Type[T], all_fields: Dict[str, Union[Annotation, Input, Output]]) -> Type[T]:
setattr(cls, "__init__", _create_init_fn(cls, all_fields))
Expand Down
7 changes: 4 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_load_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

# pylint: disable=protected-access

from typing import Callable
from typing import Any, Callable

from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
from azure.ai.ml.entities._builders import Command
from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin


# pylint: disable=unused-argument
def to_component(*, job: ComponentTranslatableMixin, **kwargs) -> Callable[..., Command]:
def to_component(*, job: ComponentTranslatableMixin, **kwargs: Any) -> Callable[..., Command]:
"""Translate a job object to a component function, provided job should be able to translate to a component.
For example:
Expand Down Expand Up @@ -41,4 +41,5 @@ def to_component(*, job: ComponentTranslatableMixin, **kwargs) -> Callable[...,

# set default base path as "./". Because if code path is relative path and base path is None, will raise error when
# get arm id of Code
return job._to_component(context={BASE_PATH_CONTEXT_KEY: Path("./")})
res: Callable = job._to_component(context={BASE_PATH_CONTEXT_KEY: Path("./")})
return res
12 changes: 6 additions & 6 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_mldesigner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
original function/module names the same as before, otherwise mldesigner will be broken by this change.
"""

__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from azure.ai.ml.entities._component.component_factory import component_factory
from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function
from azure.ai.ml.entities._inputs_outputs import _get_param_with_standard_annotation
from azure.ai.ml._internal.entities import InternalComponent
from azure.ai.ml._internal.entities._additional_includes import InternalAdditionalIncludes
from azure.ai.ml._utils._asset_utils import get_ignore_file
from azure.ai.ml._utils.utils import try_enable_internal_components
from azure.ai.ml._internal.entities import InternalComponent
from azure.ai.ml.dsl._condition import condition
from azure.ai.ml.dsl._do_while import do_while
from azure.ai.ml.dsl._parallel_for import parallel_for, ParallelFor
from azure.ai.ml.dsl._group_decorator import group
from azure.ai.ml.dsl._parallel_for import ParallelFor, parallel_for
from azure.ai.ml.entities._component.component_factory import component_factory
from azure.ai.ml.entities._inputs_outputs import _get_param_with_standard_annotation
from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function

from ._constants import V1_COMPONENT_TO_NODE

Expand Down
4 changes: 2 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_overrides_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from typing import Mapping
from typing import Mapping, Optional


class OverrideDefinition(dict):
Expand All @@ -11,7 +11,7 @@ class OverrideDefinition(dict):

def get_override_definition_from_schema(
schema: str, # pylint: disable=unused-argument
) -> Mapping[str, OverrideDefinition]:
) -> Optional[Mapping[str, OverrideDefinition]]:
"""Ger override definition from a json schema.
:param schema: Json schema of component job.
Expand Down
Loading

0 comments on commit bcc1830

Please sign in to comment.