diff --git a/README.md b/README.md index 0de3bc6..3e59572 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,11 @@ engine = Engine() def add(a, b): return a + b +async def stream(): + for i in range(5): + await asyncio.sleep(0.5) + yield i + async def main(): job1 = ProcessJob(add, args=(1, 2)) job2 = ProcessJob(add, args=(job1.future, 4)) @@ -107,6 +112,12 @@ async def main(): print(job1.result()) # 3 print(job2.result()) # 7 + # generator + job3 = ProcessJob(stream) + await engine.submit_async(job3) # do not do engine.wait because the generator job's future is done only when StopIteration + async for x in job3.result(): + print(x) + asyncio.run(main()) # or just `await main()` in jupyter environment ``` diff --git a/executor/engine/job/local.py b/executor/engine/job/local.py index d7818ce..d0b962c 100644 --- a/executor/engine/job/local.py +++ b/executor/engine/job/local.py @@ -6,3 +6,8 @@ async def run_function(self): """Run job in local thread.""" res = self.func(*self.args, **self.kwargs) return res + + async def run_generator(self): + """Run job as a generator.""" + res = self.func(*self.args, **self.kwargs) + return res diff --git a/executor/engine/job/thread.py b/executor/engine/job/thread.py index 14cce82..25cb1e2 100644 --- a/executor/engine/job/thread.py +++ b/executor/engine/job/thread.py @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from .base import Job +from .utils import _gen_initializer, GeneratorWrapper class ThreadJob(Job): @@ -42,13 +43,21 @@ def release_resource(self) -> bool: async def run_function(self): """Run job in thread pool.""" + func = functools.partial(self.func, *self.args, **self.kwargs) self._executor = ThreadPoolExecutor(1) loop = asyncio.get_running_loop() - func = functools.partial(self.func, **self.kwargs) - fut = loop.run_in_executor(self._executor, func, *self.args) + fut = loop.run_in_executor(self._executor, func) result = await fut return result + async def run_generator(self): + """Run job as a generator.""" + func = functools.partial(self.func, *self.args, **self.kwargs) + self._executor = ThreadPoolExecutor( + 1, initializer=_gen_initializer, initargs=(func,)) + result = GeneratorWrapper(self) + return result + async def cancel(self): """Cancel job.""" if self.status == "running": diff --git a/executor/engine/job/utils.py b/executor/engine/job/utils.py index 3f2ef79..02dca99 100644 --- a/executor/engine/job/utils.py +++ b/executor/engine/job/utils.py @@ -1,6 +1,7 @@ import typing as T import asyncio from datetime import datetime +import threading from ..utils import CheckAttrRange, ExecutorError @@ -41,23 +42,25 @@ def __init__(self, job: "Job", valid_status: T.List[JobStatusType]): def _gen_initializer(gen_func, args=tuple(), kwargs={}): # pragma: no cover - global _generator - _generator = gen_func(*args, **kwargs) + global _thread_locals + if "_thread_locals" not in globals(): + _thread_locals = threading.local() # avoid conflict for LocalJob & ThreadJob + _thread_locals._generator = gen_func(*args, **kwargs) def _gen_next(): # pragma: no cover - global _generator - return next(_generator) + global _thread_locals + return next(_thread_locals._generator) def _gen_anext(): # pragma: no cover - global _generator - return asyncio.run(_generator.__anext__()) + global _thread_locals + return asyncio.run(_thread_locals._generator.__anext__()) -class GeneratorWrapper(T.Generic[_T]): +class GeneratorWrapper(T.Generic[_T]): # TODO: this may be extended to an Actor or ObjectProxy """ - wrap a generator in executor pool + wrap a generator in executor pool. """ def __init__(self, job: "Job"): self._job = job