Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make promise thread safety #70

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions promise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
promisify,
is_thenable,
async_instance,
async_lock,
get_default_scheduler,
set_default_scheduler,
)
Expand All @@ -32,6 +33,7 @@
"promisify",
"is_thenable",
"async_instance",
"async_lock",
"get_default_scheduler",
"set_default_scheduler",
"ImmediateScheduler",
Expand Down
4 changes: 2 additions & 2 deletions promise/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import Iterable, namedtuple
from functools import partial

from .promise import Promise, async_instance, get_default_scheduler
from .promise import Promise, async_lock, get_default_scheduler

if False:
from typing import (
Expand Down Expand Up @@ -225,7 +225,7 @@ def enqueue_post_promise_job(fn, scheduler):

def on_promise_resolve(v):
# type: (Any) -> None
async_instance.invoke(fn, scheduler)
async_lock.async_instance.invoke(fn, scheduler)

resolved_promise.then(on_promise_resolve) # type: Promise[None]

Expand Down
32 changes: 23 additions & 9 deletions promise/promise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import namedtuple
from functools import partial, wraps
from sys import version_info, exc_info
from threading import RLock
import threading
from types import TracebackType

from six import reraise # type: ignore
Expand Down Expand Up @@ -37,10 +37,24 @@
)


default_scheduler = ImmediateScheduler()
class AsyncThreadLocal(threading.local):
"""
The thread local class that make `async_instance` safe for threads.
"""
_async_instance = None

async_instance = Async()
@property
def async_instance(self):
# type: () -> Async
if not getattr(self, '_async_instance'):
self._async_instance = Async()

return self._async_instance

async_lock = AsyncThreadLocal()
async_instance = async_lock.async_instance

default_scheduler = ImmediateScheduler()

def get_default_scheduler():
# type: () -> ImmediateScheduler
Expand Down Expand Up @@ -237,7 +251,7 @@ def _fulfill(self, value):
if self._is_async_guaranteed:
self._settle_promises()
else:
async_instance.settle_promises(self)
async_lock.async_instance.settle_promises(self)

def _reject(self, reason, traceback=None):
# type: (Exception, Optional[TracebackType]) -> None
Expand All @@ -247,18 +261,18 @@ def _reject(self, reason, traceback=None):

if self._is_final:
assert self._length == 0
async_instance.fatal_error(reason, self.scheduler)
async_lock.async_instance.fatal_error(reason, self.scheduler)
return

if self._length > 0:
async_instance.settle_promises(self)
async_lock.async_instance.settle_promises(self)
else:
self._ensure_possible_rejection_handled()

if self._is_async_guaranteed:
self._settle_promises()
else:
async_instance.settle_promises(self)
async_lock.async_instance.settle_promises(self)

def _ensure_possible_rejection_handled(self):
# type: () -> None
Expand Down Expand Up @@ -497,7 +511,7 @@ def reject(reason, traceback=None):
@classmethod
def wait(cls, promise, timeout=None):
# type: (Promise, Optional[float]) -> None
async_instance.wait(promise, timeout)
async_lock.async_instance.wait(promise, timeout)

def _wait(self, timeout=None):
# type: (Optional[float]) -> None
Expand Down Expand Up @@ -583,7 +597,7 @@ def _then(
traceback = target._traceback
handler = did_reject # type: ignore
# target._rejection_is_unhandled = False
async_instance.invoke(
async_lock.async_instance.invoke(
partial(target._settle_promise, promise, handler, value, traceback),
promise.scheduler
# target._settle_promise instead?
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pytest import raises

from promise import Promise, async_instance
from promise import Promise, async_lock
from promise.dataloader import DataLoader


Expand Down Expand Up @@ -426,9 +426,9 @@ def do_resolve(x):

with raises(Exception):
a_loader.load("A1").get()
assert async_instance.have_drained_queues
assert async_lock.async_instance.have_drained_queues
with raises(Exception):
a_loader.load("A2").get()
assert async_instance.have_drained_queues
assert async_lock.async_instance.have_drained_queues

do().get()