Skip to content

Commit

Permalink
Improve conditional dispatcher (#115)
Browse files Browse the repository at this point in the history
* Remove enforce_type

* update

* Improve conditional_dispatcher
  • Loading branch information
goodwanghan authored Oct 22, 2023
1 parent b8615e2 commit dc6e704
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions triad/utils/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,17 @@ def get_int_len(obj:int) -> int:
:param entry_point: the entry point to preload dispatchers, defaults to None
"""

def _run(_func: Callable) -> "ConditionalDispatcher":
class _Dispatcher(ConditionalDispatcher):
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.run_top(*args, **kwds)

return _Dispatcher(_func, entry_point=entry_point)

return _run if default_func is None else _run(default_func) # type:ignore
return (
( # type: ignore
lambda func: ConditionalDispatcher(
func, is_broadcast=False, entry_point=entry_point
)
)
if default_func is None
else ConditionalDispatcher(
default_func, is_broadcast=False, entry_point=entry_point
)
)


def conditional_broadcaster(
Expand Down Expand Up @@ -191,14 +194,17 @@ def myprintb(obj:str) -> None:
:param entry_point: the entry point to preload dispatchers, defaults to None
"""

def _run(_func: Callable) -> "ConditionalDispatcher":
class _Dispatcher(ConditionalDispatcher):
def __call__(self, *args: Any, **kwds: Any) -> None:
list(self.run(*args, **kwds))

return _Dispatcher(_func, entry_point=entry_point)

return _run if default_func is None else _run(default_func) # type:ignore
return (
( # type: ignore
lambda func: ConditionalDispatcher(
func, is_broadcast=True, entry_point=entry_point
)
)
if default_func is None
else ConditionalDispatcher(
default_func, is_broadcast=True, entry_point=entry_point
)
)


class ConditionalDispatcher:
Expand All @@ -219,20 +225,24 @@ class ConditionalDispatcher:
"""

def __init__(
self, default_func: Callable[..., Any], entry_point: Optional[str] = None
self,
default_func: Callable[..., Any],
is_broadcast: bool,
entry_point: Optional[str] = None,
):
self._func = default_func
self._funcs: List[
Tuple[float, int, Callable[..., bool], Callable[..., Any]]
] = []
self._entry_point = entry_point
self._is_broadcast = is_broadcast
update_wrapper(self, default_func)

def __getstate__(self) -> Dict[str, Any]:
return {
k: v
for k, v in self.__dict__.items()
if k in ["_func", "_funcs", "_entry_point"]
if k in ["_func", "_funcs", "_entry_point", "_is_broadcast"]
}

def __setstate__(self, data: Dict[str, Any]) -> None:
Expand All @@ -241,7 +251,9 @@ def __setstate__(self, data: Dict[str, Any]) -> None:

def __call__(self, *args: Any, **kwds: Any) -> Any:
"""The abstract method to mimic the function call"""
raise NotImplementedError # pragma: no cover
if self._is_broadcast:
return list(self.run(*args, **kwds))
return self.run_top(*args, **kwds)

def run(self, *args: Any, **kwargs: Any) -> Iterable[Any]:
"""Execute all matching children functions as a generator.
Expand Down

0 comments on commit dc6e704

Please sign in to comment.