diff --git a/streamz/core.py b/streamz/core.py index f381b632..6e9f4209 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -119,6 +119,22 @@ def __str__(self): class APIRegisterMixin(object): + def _new_node(self, cls, args, kwargs): + """ Constructor for downstream nodes. + + Examples + -------- + To provide inheritance through nodes : + + >>> class MyStream(Stream): + >>> + >>> def _new_node(self, cls, args, kwargs): + >>> if not issubclass(cls, MyStream): + >>> cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__)) + >>> return cls(*args, **kwargs) + """ + return cls(*args, **kwargs) + @classmethod def register_api(cls, modifier=identity, attribute_name=None): """ Add callable to Stream API @@ -158,6 +174,10 @@ def register_api(cls, modifier=identity, attribute_name=None): def _(func): @functools.wraps(func) def wrapped(*args, **kwargs): + if identity is not staticmethod and args: + self = args[0] + if isinstance(self, APIRegisterMixin): + return self._new_node(func, args, kwargs) return func(*args, **kwargs) name = attribute_name if attribute_name else func.__name__ setattr(cls, name, modifier(wrapped)) diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 56a661d7..336d8500 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -1367,6 +1367,35 @@ class foo(NewStream): assert not hasattr(Stream(), 'foo') +def test_subclass_node(): + + def add(x) : return x + 1 + + class MyStream(Stream): + def _new_node(self, cls, args, kwargs): + if not issubclass(cls, MyStream): + cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__)) + return cls(*args, **kwargs) + + @MyStream.register_api() + class foo(sz.sinks.sink): + pass + + stream = MyStream() + lst = list() + + node = stream.map(add) + assert isinstance(node, sz.core.map) + assert isinstance(node, MyStream) + + node = node.foo(lst.append) + assert isinstance(node, sz.sinks.sink) + assert isinstance(node, MyStream) + + stream.emit(100) + assert lst == [ 101 ] + + @gen_test() def test_latest(): source = Stream(asynchronous=True)