Skip to content

Commit

Permalink
make thread safe
Browse files Browse the repository at this point in the history
  • Loading branch information
jnak committed Dec 14, 2019
1 parent d80d791 commit b07260e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 8 deletions.
3 changes: 2 additions & 1 deletion promise/async_.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Based on https://github.com/petkaantonov/bluebird/blob/master/src/promise.js
from collections import deque
from threading import local

if False:
from .promise import Promise
from typing import Any, Callable, Optional, Union # flake8: noqa


class Async(object):
class Async(local):
def __init__(self, trampoline_enabled=True):
self.is_tick_used = False
self.late_queue = deque() # type: ignore
Expand Down
14 changes: 7 additions & 7 deletions promise/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import Iterable, namedtuple
from functools import partial
from threading import local

from .promise import Promise, async_instance, get_default_scheduler

Expand Down Expand Up @@ -29,7 +30,7 @@ def get_chunks(iterable_obj, chunk_size=1):
Loader = namedtuple("Loader", "key,resolve,reject")


class DataLoader(object):
class DataLoader(local):

batch = True
max_batch_size = None # type: int
Expand Down Expand Up @@ -212,22 +213,21 @@ def prime(self, key, value):
# ensuring that it always occurs after "PromiseJobs" ends.

# Private: cached resolved Promise instance
resolved_promise = None # type: Optional[Promise[None]]

cache = local()

def enqueue_post_promise_job(fn, scheduler):
# type: (Callable, Any) -> None
global resolved_promise
if not resolved_promise:
resolved_promise = Promise.resolve(None)
global cache
if not hasattr(cache, 'resolved_promise'):
cache.resolved_promise = Promise.resolve(None)
if not scheduler:
scheduler = get_default_scheduler()

def on_promise_resolve(v):
# type: (Any) -> None
async_instance.invoke(fn, scheduler)

resolved_promise.then(on_promise_resolve) # type: Promise[None]
cache.resolved_promise.then(on_promise_resolve)


def dispatch_queue(loader):
Expand Down
115 changes: 115 additions & 0 deletions tests/test_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from promise import Promise
from promise.dataloader import DataLoader
import threading



def test_promise_thread_safety():
"""
Promise tasks should never be executed in a different thread from the one they are scheduled from,
unless the ThreadPoolExecutor is used.
Here we assert that the pending promise tasks on thread 1 are not executed on thread 2 as thread 2
resolves its own promise tasks.
"""
event_1 = threading.Event()
event_2 = threading.Event()

assert_object = {'is_same_thread': True}

def task_1():
thread_name = threading.current_thread().getName()

def then_1(value):
# Enqueue tasks to run later.
# This relies on the fact that `then` does not execute the function synchronously when called from
# within another `then` callback function.
promise = Promise.resolve(None).then(then_2)
assert promise.is_pending
event_1.set() # Unblock main thread
event_2.wait() # Wait for thread 2

def then_2(value):
assert_object['is_same_thread'] = (thread_name == threading.current_thread().getName())

promise = Promise.resolve(None).then(then_1)

def task_2():
promise = Promise.resolve(None).then(lambda v: None)
promise.get() # Drain task queue
event_2.set() # Unblock thread 1

thread_1 = threading.Thread(target=task_1)
thread_1.start()

event_1.wait() # Wait for Thread 1 to enqueue promise tasks

thread_2 = threading.Thread(target=task_2)
thread_2.start()

for thread in (thread_1, thread_2):
thread.join()

assert assert_object['is_same_thread']


def test_dataloader_thread_safety():
"""
Dataloader should only batch `load` calls that happened on the same thread.
Here we assert that `load` calls on thread 2 are not batched on thread 1 as
thread 1 batches its own `load` calls.
"""
def load_many(keys):
thead_name = threading.current_thread().getName()
return Promise.resolve([thead_name for key in keys])

thread_name_loader = DataLoader(load_many)

event_1 = threading.Event()
event_2 = threading.Event()
event_3 = threading.Event()

assert_object = {
'is_same_thread_1': True,
'is_same_thread_2': True,
}

def task_1():
@Promise.safe
def do():
promise = thread_name_loader.load(1)
event_1.set()
event_2.wait() # Wait for thread 2 to call `load`
assert_object['is_same_thread_1'] = (
promise.get() == threading.current_thread().getName()
)
event_3.set() # Unblock thread 2

do().get()

def task_2():
@Promise.safe
def do():
promise = thread_name_loader.load(2)
event_2.set()
event_3.wait() # Wait for thread 1 to run `dispatch_queue_batch`
assert_object['is_same_thread_2'] = (
promise.get() == threading.current_thread().getName()
)

do().get()

thread_1 = threading.Thread(target=task_1)
thread_1.start()

event_1.wait() # Wait for thread 1 to call `load`

thread_2 = threading.Thread(target=task_2)
thread_2.start()

for thread in (thread_1, thread_2):
thread.join()

assert assert_object['is_same_thread_1']
assert assert_object['is_same_thread_2']

0 comments on commit b07260e

Please sign in to comment.