Skip to content

Commit

Permalink
General support for generators
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Sep 20, 2024
1 parent 53dafbd commit 389af29
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 32 deletions.
32 changes: 26 additions & 6 deletions executor/engine/job/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .utils import (
JobStatusAttr, InvalidStateError, JobStatusType,
ExecutorError, valid_job_statuses
ExecutorError, valid_job_statuses, GeneratorWrapper
)
from .condition import Condition, AfterOthers, AllSatisfied
from ..middle.capture import CaptureOut
Expand Down Expand Up @@ -275,7 +275,10 @@ async def wait_and_run(self):
self.status = "running"
try:
res = await self.run()
await self.on_done(res)
if not isinstance(res, GeneratorWrapper):
await self.on_done(res)
else:
self.future.set_result(res)
return res
except Exception as e:
await self.on_failed(e)
Expand All @@ -285,7 +288,22 @@ async def wait_and_run(self):

async def run(self):
"""Run the job."""
pass
if inspect.isgeneratorfunction(self.func) or inspect.isasyncgenfunction(self.func): # noqa: E501
return await self.run_generator()
else:
return await self.run_function()

async def run_function(self): # pragma: no cover
"""Run the job as a function."""
msg = f"{type(self).__name__} does not implement " \
"run_function method."
raise NotImplementedError(msg)

async def run_generator(self): # pragma: no cover
"""Run the job as a generator."""
msg = f"{type(self).__name__} does not implement " \
"run_generator method."
raise NotImplementedError(msg)

async def rerun(self, check_status: bool = True):
"""Rerun the job."""
Expand Down Expand Up @@ -351,9 +369,11 @@ def clear_context(self):

def result(self) -> T.Any:
"""Get the result of the job."""
if self.status != "done":
raise InvalidStateError(self, ['done'])
return self.future.result()
res = self.future.result()
if not isinstance(res, GeneratorWrapper):
if self.status != "done":
raise InvalidStateError(self, ['done'])
return res

def exception(self):
"""Get the exception of the job."""
Expand Down
2 changes: 1 addition & 1 deletion executor/engine/job/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def release_resource(self) -> bool:
True
)

async def run(self):
async def run_function(self):
"""Run job with Dask."""
client = self.engine.dask_client
func = functools.partial(self.func, **self.kwargs)
Expand Down
2 changes: 1 addition & 1 deletion executor/engine/job/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class LocalJob(Job):
async def run(self):
async def run_function(self):
"""Run job in local thread."""
res = self.func(*self.args, **self.kwargs)
return res
24 changes: 13 additions & 11 deletions executor/engine/job/process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import functools
import inspect

from loky.process_executor import ProcessPoolExecutor

Expand Down Expand Up @@ -43,18 +42,21 @@ def release_resource(self) -> bool:
True
)

async def run(self):
async def run_function(self):
"""Run job in process pool."""
func = functools.partial(self.func, *self.args, **self.kwargs)
if (inspect.isgeneratorfunction(self.func) or inspect.isasyncgenfunction(self.func)): # noqa: E501
self._executor = ProcessPoolExecutor(
1, initializer=_gen_initializer, initargs=(func,))
result = GeneratorWrapper(self)
else:
self._executor = ProcessPoolExecutor(1)
loop = asyncio.get_running_loop()
fut = loop.run_in_executor(self._executor, func)
result = await fut
self._executor = ProcessPoolExecutor(1)
loop = asyncio.get_running_loop()
fut = loop.run_in_executor(self._executor, func)
result = await fut
return result

async def run_generator(self):
"""Run job as a generator."""
func = functools.partial(self.func, *self.args, **self.kwargs)
self._executor = ProcessPoolExecutor(
1, initializer=_gen_initializer, initargs=(func,))
result = GeneratorWrapper(self)
return result

async def cancel(self):
Expand Down
2 changes: 1 addition & 1 deletion executor/engine/job/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def release_resource(self) -> bool:
True
)

async def run(self):
async def run_function(self):
"""Run job in thread pool."""
self._executor = ThreadPoolExecutor(1)
loop = asyncio.get_running_loop()
Expand Down
28 changes: 25 additions & 3 deletions executor/engine/job/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,33 @@ def __iter__(self):
return self

def __next__(self) -> _T:
return self._job._executor.submit(_gen_next).result()
try:
return self._job._executor.submit(_gen_next).result()
except Exception as e:
engine = self._job.engine
if engine is None:
loop = asyncio.get_event_loop()
else:
loop = engine.loop
if isinstance(e, StopIteration):
cor = self._job.on_done(self)
else:
cor = self._job.on_failed(e)
fut = asyncio.run_coroutine_threadsafe(cor, loop)
fut.result()
raise e

def __aiter__(self):
return self

async def __anext__(self) -> _T:
fut = self._job._executor.submit(_gen_anext)
return (await asyncio.wrap_future(fut))
try:
fut = self._job._executor.submit(_gen_anext)
res = await asyncio.wrap_future(fut)
return res
except Exception as e:
if isinstance(e, StopAsyncIteration):
await self._job.on_done(self)
else:
await self._job.on_failed(e)
raise e
4 changes: 4 additions & 0 deletions executor/engine/launcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def _fetch_result(job: 'Job') -> T.Any:
else:
return job.result()

def __call__(
self, *args: T.Any, **kwargs: T.Any) -> T.Any: # pragma: no cover
raise NotImplementedError("Subclasses must implement __call__")


class SyncLauncher(LauncherBase):

Expand Down
51 changes: 42 additions & 9 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,57 @@ def add(a, b):

@pytest.mark.asyncio
async def test_generator():
def gen():
for i in range(10):
yield i

async def gen_async(n):
for i in range(n):
yield i

with Engine() as engine:
def gen():
for i in range(10):
yield i

job = ProcessJob(gen)
await engine.submit_async(job)
await job.join()
assert list(job.result()) == list(range(10))
assert job.status == "running"
g = job.result()
assert list(g) == list(range(10))
assert job.status == "done"

async def gen_async(n):
for i in range(n):
yield i

job = ProcessJob(gen_async, (10,))
await engine.submit_async(job)
await job.join()
res = []
async for i in job.result():
assert job.status == "running"
res.append(i)
assert job.status == "done"
assert res == list(range(10))

def gen_error():
for i in range(2):
print(i)
yield i
raise ValueError("error")

job = ProcessJob(gen_error)
await engine.submit_async(job)
await job.join()
with pytest.raises(ValueError):
for i in job.result():
assert job.status == "running"
assert job.status == "failed"

async def gen_error():
for i in range(2):
print(i)
yield i
raise ValueError("error")

job = ProcessJob(gen_error)
await engine.submit_async(job)
await job.join()
with pytest.raises(ValueError):
async for i in job.result():
assert job.status == "running"
assert job.status == "failed"

0 comments on commit 389af29

Please sign in to comment.