Skip to content

Commit

Permalink
Merge pull request #14 from Klavionik/make-typed
Browse files Browse the repository at this point in the history
Make pymitter type-safe
  • Loading branch information
riga authored Jan 4, 2025
2 parents 65d4ffd + 60639a8 commit 4bd3f77
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
max-line-length = 101

# codes of errors to ignore
ignore = E128, E306, E402, E722, E731, W504
ignore = E128, E306, E402, E722, E731, W504, E704

# enforce double quotes
inline-quotes = double
1 change: 1 addition & 0 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
Expand Down
125 changes: 74 additions & 51 deletions pymitter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
import time
import fnmatch
import asyncio
from typing import Callable, Awaitable, Sequence, Any, Generator
from collections.abc import Iterator, Awaitable
from typing import Any, Callable, List, Optional, TypeVar, overload, Dict, Tuple

F = TypeVar("F", bound=Callable[..., Any])
T = TypeVar("T")

class EventEmitter(object):

class EventEmitter:
"""
The EventEmitter class, ported from Node.js EventEmitter 2.
Expand All @@ -44,8 +48,6 @@ def __init__(
new_listener: bool = False,
max_listeners: int = -1,
) -> None:
super().__init__()

# store attributes
self.new_listener = new_listener
self.max_listeners = max_listeners
Expand All @@ -54,24 +56,30 @@ def __init__(
self._event_tree = Tree(wildcard=wildcard, delimiter=delimiter)

# flat list of listeners triggerd on "any" event
self._any_listeners: list[Listener] = []
self._any_listeners: List[Listener] = []

@property
def num_listeners(self) -> int:
return self._event_tree.num_listeners() + len(self._any_listeners)

@overload
def on(self, event: str, func: F, *, ttl: int = -1) -> F: ...
@overload
def on(self, event: str, *, ttl: int = -1) -> Callable[[F], F]: ...

def on(
self,
event: str,
func: Callable | None = None,
func: Optional[F] = None,
*,
ttl: int = -1,
) -> Callable:
):
"""
Registers a function to an event. *ttl* defines the times to listen with negative values
meaning infinity. When *func* is *None*, decorator usage is assumed. Returns the wrapped
function.
"""
def on(func: Callable) -> Callable:
def on(func: F) -> F:
# do not register the function when the maximum would be exceeded
if 0 <= self.max_listeners <= self.num_listeners:
return func
Expand All @@ -84,22 +92,32 @@ def on(func: Callable) -> Callable:

return func

return on(func) if func else on # type: ignore[return-value]
return on(func) if func else on

def once(self, event: str, func: Callable | None = None) -> Callable:
@overload
def once(self, event: str, func: F) -> F: ...
@overload
def once(self, event: str) -> Callable[[F], F]: ...

def once(self, event: str, func: Optional[F] = None):
"""
Registers a function to an event that is called once. When *func* is *None*, decorator usage
is assumed. Returns the wrapped function.
"""
return self.on(event, func=func, ttl=1)
return self.on(event, func, ttl=1) if func else self.on(event, ttl=1)

@overload
def on_any(self, func: F, *, ttl: int = -1) -> F: ...
@overload
def on_any(self, *, ttl: int = -1) -> Callable[[F], F]: ...

def on_any(self, func: Callable | None = None, ttl: int = -1) -> Callable:
def on_any(self, func: Optional[F] = None, *, ttl: int = -1):
"""
Registers a function that is called every time an event is emitted. *ttl* defines the times
to listen with negative values meaning infinity. When *func* is *None*, decorator usage is
assumed. Returns the wrapped function.
"""
def on_any(func: Callable) -> Callable:
def on_any(func: F) -> F:
# do not register the function when the maximum would be exceeded
if 0 <= self.max_listeners <= self.num_listeners:
return func
Expand All @@ -112,9 +130,14 @@ def on_any(func: Callable) -> Callable:

return func

return on_any(func) if func else on_any # type: ignore[return-value]
return on_any(func) if func else on_any

@overload
def off(self, event: str, func: F) -> F: ...
@overload
def off(self, event: str) -> Callable[[F], F]: ...

def off(self, event: str, func: Callable | None = None) -> Callable:
def off(self, event: str, func: Optional[F] = None):
"""
Removes a function that is registered to an event. When *func* is *None*, decorator usage is
assumed. Returns the wrapped function.
Expand All @@ -124,14 +147,19 @@ def off(func: Callable) -> Callable:

return func

return off(func) if func else off # type: ignore[return-value]
return off(func) if func else off

def off_any(self, func: Callable | None = None) -> Callable:
@overload
def off_any(self, func: F) -> F: ...
@overload
def off_any(self) -> Callable[[F], F]: ...

def off_any(self, func: Optional[F] = None):
"""
Removes a function that was registered via :py:meth:`on_any`. When *func* is *None*,
decorator usage is assumed. Returns the wrapped function.
"""
def off_any(func: Callable) -> Callable:
def off_any(func: F) -> F:
self._any_listeners[:] = [
listener
for listener in self._any_listeners
Expand All @@ -140,7 +168,7 @@ def off_any(func: Callable) -> Callable:

return func

return off_any(func) if func else off_any # type: ignore[return-value]
return off_any(func) if func else off_any

def off_all(self) -> None:
"""
Expand All @@ -149,19 +177,19 @@ def off_all(self) -> None:
self._event_tree.clear()
del self._any_listeners[:]

def listeners(self, event: str) -> list[Callable]:
def listeners(self, event: str) -> List[Callable[..., Any]]:
"""
Returns all functions that are registered to an event.
"""
return [listener.func for listener in self._event_tree.find_listeners(event)]

def listeners_any(self) -> list[Callable]:
def listeners_any(self) -> List[Callable[..., Any]]:
"""
Returns all functions that were registered using :py:meth:`on_any`.
"""
return [listener.func for listener in self._any_listeners]

def listeners_all(self) -> list[Callable]:
def listeners_all(self) -> List[Callable[..., Any]]:
"""
Returns all registered functions, ordered by their registration time.
"""
Expand All @@ -177,7 +205,7 @@ def listeners_all(self) -> list[Callable]:

return [listener.func for listener in listeners]

def _emit(self, event: str, *args, **kwargs) -> list[Awaitable]:
def _emit(self, event: str, *args: Any, **kwargs: Any) -> List[Awaitable]:
listeners = self._event_tree.find_listeners(event)
if event != self.new_listener_event:
listeners.extend(self._any_listeners)
Expand All @@ -197,7 +225,7 @@ def _emit(self, event: str, *args, **kwargs) -> list[Awaitable]:

return awaitables

def emit(self, event: str, *args, **kwargs) -> None:
def emit(self, event: str, *args: Any, **kwargs: Any) -> None:
"""
Emits an *event*. All functions of events that match *event* are invoked with *args* and
*kwargs* in the exact order of their registration, with the exception of async functions
Expand All @@ -208,11 +236,11 @@ def emit(self, event: str, *args, **kwargs) -> None:

# handle awaitables
if awaitables:
async def start():
async def start() -> None:
await asyncio.gather(*awaitables)
asyncio.run(start())

async def emit_async(self, event: str, *args, **kwargs) -> None:
async def emit_async(self, event: str, *args: Any, **kwargs: Any) -> None:
"""
Awaitable version of :py:meth:`emit`. However, this method does not start a new event loop
but uses the existing one.
Expand All @@ -224,7 +252,7 @@ async def emit_async(self, event: str, *args, **kwargs) -> None:
if awaitables:
await asyncio.gather(*awaitables)

def emit_future(self, event: str, *args, **kwargs) -> None:
def emit_future(self, event: str, *args: Any, **kwargs: Any) -> None:
"""
Deferred version of :py:meth:`emit` with all awaitable events being places at the end of the
existing event loop (using :py:func:`asyncio.ensure_future`).
Expand All @@ -237,15 +265,12 @@ def emit_future(self, event: str, *args, **kwargs) -> None:
asyncio.ensure_future(asyncio.gather(*awaitables))


class BaseNode(object):

class BaseNode:
def __init__(self, wildcard: bool, delimiter: str) -> None:
super().__init__()

self.wildcard = wildcard
self.delimiter = delimiter
self.parent = None
self.nodes: dict[str, Node] = {}
self.parent: "Optional[BaseNode]" = None
self.nodes: Dict[str, "Node"] = {}

def clear(self) -> None:
self.nodes.clear()
Expand All @@ -259,11 +284,11 @@ def add_node(self, node: "Node") -> "Node":

# otherwise add it and set its parent
self.nodes[node.name] = node
node.parent = self # type: ignore[assignment]
node.parent = self

return node

def walk_nodes(self) -> Generator[tuple[str, tuple[str, ...], list[str]], None, None]:
def walk_nodes(self) -> Iterator[Tuple[str, Tuple[str, ...], List[str]]]:
queue = [
(name, [name], node)
for name, node in self.nodes.items()
Expand Down Expand Up @@ -294,11 +319,11 @@ class Node(BaseNode):
def str_is_pattern(cls, s: str) -> bool:
return "*" in s or "?" in s

def __init__(self, name: str, *args) -> None:
def __init__(self, name: str, *args: Any) -> None:
super().__init__(*args)

self.name = name
self.listeners: list[Listener] = []
self.listeners: List[Listener] = []

def num_listeners(self, recursive: bool = True) -> int:
n = len(self.listeners)
Expand All @@ -308,7 +333,7 @@ def num_listeners(self, recursive: bool = True) -> int:

return n

def remove_listeners_by_func(self, func: Callable) -> None:
def remove_listeners_by_func(self, func: Callable[..., Any]) -> None:
self.listeners[:] = [listener for listener in self.listeners if listener.func != func]

def add_listener(self, listener: Listener) -> None:
Expand All @@ -323,16 +348,16 @@ def check_name(self, pattern: str) -> bool:

return self.name == pattern

def find_nodes(self, event: str | Sequence[str]) -> list[Node]:
def find_nodes(self, event: str | List[str]) -> List[Node]:
# trivial case
if not event:
return []

# parse event
if isinstance(event, (list, tuple)):
pattern, sub_patterns = event[0], event[1:]
if isinstance(event, str):
pattern, *sub_patterns = event.split(self.delimiter)
else:
pattern, *sub_patterns = event.split(self.delimiter) # type: ignore[attr-defined]
pattern, sub_patterns = event[0], event[1:]

# first make sure that pattern matches _this_ name
if not self.check_name(pattern):
Expand All @@ -354,7 +379,7 @@ class Tree(BaseNode):
def num_listeners(self) -> int:
return sum(node.num_listeners(recursive=True) for node in self.nodes.values())

def find_nodes(self, *args, **kwargs) -> list[Node]:
def find_nodes(self, *args: Any, **kwargs: Any) -> List[Node]:
return sum((node.find_nodes(*args, **kwargs) for node in self.nodes.values()), [])

def add_listener(self, event: str, listener: Listener) -> None:
Expand All @@ -375,12 +400,12 @@ def add_listener(self, event: str, listener: Listener) -> None:
# add the listeners
node.add_listener(listener) # type: ignore[arg-type, call-arg]

def remove_listeners_by_func(self, event: str, func: Callable) -> None:
def remove_listeners_by_func(self, event: str, func: Callable[..., Any]) -> None:
for node in self.find_nodes(event):
node.remove_listeners_by_func(func)

def find_listeners(self, event: str, sort: bool = True) -> list[Listener]:
listeners = sum((node.listeners for node in self.find_nodes(event)), [])
def find_listeners(self, event: str, sort: bool = True) -> List[Listener]:
listeners: List[Listener] = sum((node.listeners for node in self.find_nodes(event)), [])

# sort by registration time
if sort:
Expand All @@ -389,15 +414,13 @@ def find_listeners(self, event: str, sort: bool = True) -> list[Listener]:
return listeners


class Listener(object):
class Listener:
"""
A simple event listener class that wraps a function *func* for a specific *event* and that keeps
track of the times to listen left.
"""

def __init__(self: Listener, func: Callable, event: str, ttl: int) -> None:
super().__init__()

def __init__(self, func: Callable[..., Any], event: str, ttl: int) -> None:
self.func = func
self.event = event
self.ttl = ttl
Expand All @@ -413,7 +436,7 @@ def is_coroutine(self) -> bool:
def is_async_callable(self) -> bool:
return asyncio.iscoroutinefunction(getattr(self.func, "__call__", None))

def __call__(self, *args, **kwargs) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Invokes the wrapped function when ttl is non-zero, decreases the ttl value when positive and
returns its return value.
Expand Down
Empty file added pymitter/py.typed
Empty file.
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ dependencies = {file = ["requirements.txt"]}
optional-dependencies = {dev = {file = ["requirements_dev.txt"]}}


[tool.setuptools]

include-package-data = false


[tool.setuptools.packages.find]

include = ["pymitter"]
Expand Down
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
mypy~=1.9.0;python_version>="3.8"
flake8~=7.0.0;python_version>="3.8"
flake8~=5.0.0;python_version<"3.8"
flake8-commas~=2.1.0
flake8-quotes~=3.3.2
types-docutils~=0.20.0
pytest-cov>=3.0
mypy>=1.4.1
typing-extensions>=4.7.1

0 comments on commit 4bd3f77

Please sign in to comment.