diff --git a/README.md b/README.md index 0de3bc6..99e02e6 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,11 @@ from executor.engine import Engine, ProcessJob def add(a, b): return a + b +async def stream(): + for i in range(5): + await asyncio.sleep(0.5) + yield i + with Engine() as engine: # job1 and job2 will be executed in parallel job1 = ProcessJob(add, args=(1, 2)) @@ -86,6 +91,13 @@ with Engine() as engine: engine.submit(job1, job2, job3) engine.wait_job(job3) # wait for job3 done print(job3.result()) # 10 + + # generator + job4 = ProcessJob(stream) + # do not do engine.wait because the generator job's future is done only when StopIteration + await engine.submit_async(job4) + async for x in job3.result(): + print(x) ``` Async mode example: diff --git a/executor/engine/job/dask.py b/executor/engine/job/dask.py index 5312d68..e5ab88f 100644 --- a/executor/engine/job/dask.py +++ b/executor/engine/job/dask.py @@ -3,6 +3,7 @@ from dask.distributed import Client, LocalCluster from .base import Job +from .utils import GeneratorWrapper from ..utils import PortManager @@ -56,12 +57,21 @@ def release_resource(self) -> bool: async def run_function(self): """Run job with Dask.""" client = self.engine.dask_client - func = functools.partial(self.func, **self.kwargs) - fut = client.submit(func, *self.args) + func = functools.partial(self.func, *self.args, **self.kwargs) + fut = client.submit(func) self._executor = fut result = await fut return result + async def run_generator(self): + """Run job as a generator.""" + client = self.engine.dask_client + func = functools.partial(self.func, *self.args, **self.kwargs) + fut = client.submit(func) + self._executor = client.get_executor(pure=False) + result = GeneratorWrapper(self, fut) + return result + async def cancel(self): """Cancel job.""""" if self.status == "running": diff --git a/executor/engine/job/local.py b/executor/engine/job/local.py index d7818ce..d0b962c 100644 --- a/executor/engine/job/local.py +++ b/executor/engine/job/local.py @@ -6,3 +6,8 @@ async def run_function(self): """Run job in local thread.""" res = self.func(*self.args, **self.kwargs) return res + + async def run_generator(self): + """Run job as a generator.""" + res = self.func(*self.args, **self.kwargs) + return res diff --git a/executor/engine/job/thread.py b/executor/engine/job/thread.py index 14cce82..25cb1e2 100644 --- a/executor/engine/job/thread.py +++ b/executor/engine/job/thread.py @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from .base import Job +from .utils import _gen_initializer, GeneratorWrapper class ThreadJob(Job): @@ -42,13 +43,21 @@ def release_resource(self) -> bool: async def run_function(self): """Run job in thread pool.""" + func = functools.partial(self.func, *self.args, **self.kwargs) self._executor = ThreadPoolExecutor(1) loop = asyncio.get_running_loop() - func = functools.partial(self.func, **self.kwargs) - fut = loop.run_in_executor(self._executor, func, *self.args) + fut = loop.run_in_executor(self._executor, func) result = await fut return result + async def run_generator(self): + """Run job as a generator.""" + func = functools.partial(self.func, *self.args, **self.kwargs) + self._executor = ThreadPoolExecutor( + 1, initializer=_gen_initializer, initargs=(func,)) + result = GeneratorWrapper(self) + return result + async def cancel(self): """Cancel job.""" if self.status == "running": diff --git a/executor/engine/job/utils.py b/executor/engine/job/utils.py index 3f2ef79..94a730a 100644 --- a/executor/engine/job/utils.py +++ b/executor/engine/job/utils.py @@ -1,6 +1,8 @@ import typing as T import asyncio from datetime import datetime +from concurrent.futures import Future +import threading from ..utils import CheckAttrRange, ExecutorError @@ -8,7 +10,6 @@ if T.TYPE_CHECKING: from .base import Job - JobStatusType = T.Literal['pending', 'running', 'failed', 'done', 'cancelled'] valid_job_statuses: T.List[JobStatusType] = [ 'pending', 'running', 'failed', 'done', 'cancelled'] @@ -38,36 +39,48 @@ def __init__(self, job: "Job", valid_status: T.List[JobStatusType]): _T = T.TypeVar("_T") +_thread_locals = threading.local() def _gen_initializer(gen_func, args=tuple(), kwargs={}): # pragma: no cover - global _generator - _generator = gen_func(*args, **kwargs) + global _thread_locals + if "_thread_locals" not in globals(): + # avoid conflict for ThreadJob + _thread_locals = threading.local() + _thread_locals._generator = gen_func(*args, **kwargs) -def _gen_next(): # pragma: no cover - global _generator - return next(_generator) +def _gen_next(fut=None): # pragma: no cover + global _thread_locals + if fut is None: + return next(_thread_locals._generator) + else: + return next(fut) -def _gen_anext(): # pragma: no cover - global _generator - return asyncio.run(_generator.__anext__()) +def _gen_anext(fut=None): # pragma: no cover + global _thread_locals + if fut is None: + return asyncio.run(_thread_locals._generator.__anext__()) + else: + return asyncio.run(fut.__anext__()) class GeneratorWrapper(T.Generic[_T]): """ wrap a generator in executor pool """ - def __init__(self, job: "Job"): + + def __init__(self, job: "Job", fut: T.Optional[Future] = None): self._job = job + self._fut = fut def __iter__(self): return self def __next__(self) -> _T: try: - return self._job._executor.submit(_gen_next).result() + return self._job._executor.submit(_gen_next, self._fut).result() except Exception as e: engine = self._job.engine if engine is None: @@ -87,7 +100,7 @@ def __aiter__(self): async def __anext__(self) -> _T: try: - fut = self._job._executor.submit(_gen_anext) + fut = self._job._executor.submit(_gen_anext, self._fut) res = await asyncio.wrap_future(fut) return res except Exception as e: