Skip to content

Commit

Permalink
Improve FunctionWrapper (#130)
Browse files Browse the repository at this point in the history
* Improve FunctionWrapper

* update

* update
  • Loading branch information
goodwanghan authored Jun 11, 2024
1 parent 61eed0a commit cdfd2bb
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 42 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ pip install triad

## Release History

### 0.9.7

* Make FunctionWrapper compare annotation origins by default

### 0.9.6

* Add `is_like` to Schema to compare similar schemas
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ addopts =
-vvv

[flake8]
ignore = E24,E203,W503
ignore = A005,E24,E203,W503
max-line-length = 88
format = pylint
exclude = .svc,CVS,.bzr,.hg,.git,__pycache__,venv,tests/*,docs/*
Expand Down
31 changes: 21 additions & 10 deletions tests/collections/test_function_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Any, Callable, Dict, Iterable, List, Optional
from __future__ import annotations

from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional

import pandas as pd
from pytest import raises
import sys

from triad.collections.function_wrapper import (
AnnotatedParam,
FunctionWrapper,
NoneParam,
OtherParam,
function_wrapper,
)
from triad.exceptions import InvalidOperationError
from triad import to_uuid
from triad.collections.function_wrapper import (AnnotatedParam,
FunctionWrapper, NoneParam,
OtherParam, function_wrapper)
from triad.exceptions import InvalidOperationError


class _Dummy:
Expand All @@ -38,6 +38,11 @@ class SeriesParam(AnnotatedParam):
pass


@MockFunctionWrapper.annotated_param(List[List[int]], "l")
class ListParam(AnnotatedParam):
pass


def test_registration():
with raises(InvalidOperationError):

Expand Down Expand Up @@ -113,13 +118,15 @@ def _parse_function(f, params_re, return_re):
_parse_function(f4, "^0x$", "d")
_parse_function(f6, "^d$", "n")
_parse_function(f7, "^yz$", "n")
if sys.version_info >= (3, 9):
_parse_function(f8, "^l$", "n")


def f1(a: pd.DataFrame, b: pd.Series) -> None:
pass


def f2(e: int, a, b: int, c):
def f2(e: "int", a, b: int, c):
return e + a + b - c


Expand All @@ -141,3 +148,7 @@ def f6(a: _Dummy) -> None:

def f7(*args: Any, **kwargs: int):
pass


def f8(a: list[list[int]]) -> None:
pass
78 changes: 61 additions & 17 deletions tests/utils/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
from __future__ import annotations

import builtins
import urllib # must keep for testing purpose
import urllib.request # must keep for testing purpose
from datetime import date, datetime, timedelta

from typing import Any, Callable, Dict, List, Union, get_type_hints
import pytest
import sys
import numpy as np
import pandas as pd
import tests.utils.convert_examples as ex
from pytest import raises

import tests.utils.convert_examples as ex
from tests.utils.convert_examples import BaseClass, Class2
from tests.utils.convert_examples import SubClass
from tests.utils.convert_examples import SubClass as SubClassSame
from triad.utils.convert import (
_parse_value_and_unit,
as_type,
get_caller_global_local_vars,
get_full_type_path,
str_to_instance,
str_to_object,
str_to_type,
to_bool,
to_datetime,
to_function,
to_instance,
to_size,
to_timedelta,
to_type,
)
from triad.utils.convert import (_parse_value_and_unit, as_type,
compare_annotations,
get_caller_global_local_vars,
get_full_type_path, str_to_instance,
str_to_object, str_to_type, to_bool,
to_datetime, to_function, to_instance,
to_size, to_timedelta, to_type)

_GLOBAL_DUMMY = 1

Expand Down Expand Up @@ -348,6 +345,53 @@ def f4():
f1()


@pytest.mark.skipif(sys.version_info < (3, 9), reason="python<3.9")
def test_compare_annotations():
def _assert(f, arg_a, arg_b, expected=True, **kwargs):
# get the argument type annoptation of name arg_a in function f
sig = get_type_hints(f)
a = sig.get(arg_a, Any)
b = sig.get(arg_b, Any)
assert compare_annotations(a, b, **kwargs) == expected

def f1(a: int, b: str, c, d: None, e: Any):
pass

_assert(f1, "a", "a")
_assert(f1, "a", "b", False)
_assert(f1, "a", "c", False)
_assert(f1, "c", "c")
_assert(f1, "c", "d", False)
_assert(f1, "c", "e")
_assert(f1, "e", "e")

def f2(a: List, b: Dict, c: Union[int, str], d: Callable):
pass

for o in [True, False]:
kwargs = dict(compare_origin=o)
_assert(f2, "a", "a", **kwargs)
_assert(f2, "a", "b", False, **kwargs)
_assert(f2, "c", "c", **kwargs)
_assert(f2, "c", "d", False, **kwargs)

def f3(a: List[Dict[str, Any]], b: list[dict[str, Any]], c: List):
pass

_assert(f3, "a", "a")
_assert(f3, "a", "b", True)
_assert(f3, "a", "b", False, compare_origin=False)
_assert(f3, "a", "c", False)

def f4(a: Callable[..., Dict[str, Any]], b: Callable[..., dict[str, Any]], c: callable):
pass

_assert(f4, "a", "a")
_assert(f4, "a", "b", True)
_assert(f4, "a", "b", False, compare_origin=False)
_assert(f3, "a", "c", False)


# This is for test_obj_to_function
def dummy_for_test():
pass
Expand Down
4 changes: 2 additions & 2 deletions triad/collections/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..exceptions import InvalidOperationError
from ..utils.assertion import assert_or_throw
from ..utils.convert import get_full_type_path
from ..utils.convert import compare_annotations, get_full_type_path
from ..utils.entry_points import load_entry_point
from ..utils.hash import to_uuid
from .dict import IndexedOrderedDict
Expand Down Expand Up @@ -165,7 +165,7 @@ def _func(tp: Type["AnnotatedParam"]) -> Type["AnnotatedParam"]:
anno = annotation

def _m(a: Any) -> bool:
return a == anno
return compare_annotations(a, anno, compare_origin=True)

_matcher = _m

Expand Down
37 changes: 26 additions & 11 deletions triad/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib
import inspect
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, get_args, get_origin

import numpy as np
import pandas as pd
Expand All @@ -16,7 +16,6 @@
_HAS_CISO8601 = False
from triad.utils.assertion import assert_or_throw


EMPTY_ARGS: List[Any] = []
EMPTY_KWARGS: Dict[str, Any] = {}

Expand Down Expand Up @@ -403,6 +402,7 @@ def to_timedelta(obj: Any) -> datetime.timedelta:
:param obj: object
:raises TypeError: if failed to convert
:return: timedelta value
"""
if obj is None:
Expand Down Expand Up @@ -449,15 +449,9 @@ def to_size(exp: Any) -> int:
default unit is byte if not provided. Unit can be `b`, `byte`,
`k`, `kb`, `m`, `mb`, `g`, `gb`, `t`, `tb`.
Args:
exp (Any): expression string or numerical value
Raises:
ValueError: for invalid expression
ValueError: for negative values
Returns:
int: size in byte
:param exp: expression string or numerical value
:raises ValueError: for invalid expression and negative values
:return: size in byte
"""
n, u = _parse_value_and_unit(exp)
assert n >= 0.0, "Size can't be negative"
Expand All @@ -474,6 +468,27 @@ def to_size(exp: Any) -> int:
raise ValueError(f"Invalid size expression {exp}")


def compare_annotations(a: Any, b: Any, compare_origin: bool = True) -> bool:
"""Compare two type annotations
:param a: first type annotation
:param b: second type annotation
:param compare_origin: whether to compare the origin of the type annotation
:return: whether the two type annotations are equal
"""
if compare_origin:
ta = get_origin(a) or a
tb = get_origin(b) or b
if ta != tb:
return False
aa = get_args(a)
ba = get_args(b)
if len(aa) != len(ba):
return False
return all(compare_annotations(x, y, compare_origin) for x, y in zip(aa, ba))
return a == b


def _parse_value_and_unit(exp: Any) -> Tuple[float, str]:
try:
assert exp is not None
Expand Down
2 changes: 1 addition & 1 deletion triad_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# flake8: noqa
__version__ = "0.9.6"
__version__ = "0.9.7"

0 comments on commit cdfd2bb

Please sign in to comment.