Skip to content

Commit

Permalink
Fix race conditional in Channel for many-to-many case
Browse files Browse the repository at this point in the history
  • Loading branch information
gleero committed Apr 19, 2023
1 parent 8e52b04 commit ff6a24a
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions tsasync/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Any, Awaitable, Deque, Generic, Union
from typing import Any, Awaitable, Deque, Generic, Optional, Union

from ._event import Event
from ._utils import T, in_loop, wrapret
Expand All @@ -17,6 +17,7 @@ class ChannelState(Enum):
class OperationContext:
event: Event
value: Any
destination: Optional[Event] = None


class Channel(Generic[T]):
Expand Down Expand Up @@ -53,6 +54,7 @@ def send(self, value: T) -> Union[Awaitable, None]:
# Unlock for waiter in needed
try:
event = self._waiters.popleft()
ctx.destination = event
event.set()
except IndexError:
pass
Expand All @@ -74,7 +76,7 @@ def receive(self) -> Union[T, Awaitable[T]]:
"""
with self._ts_lock:
# Queue has pending objects, return it immediately
if len(self._queue) > 0:
if len(self._queue) > 0 and len(self._waiters) == 0:
return wrapret(self._next_item())

# Create event and wait for data
Expand All @@ -88,7 +90,7 @@ def receive(self) -> Union[T, Awaitable[T]]:
# Get item from the queue
event.wait()
with self._ts_lock:
return self._next_item()
return self._next_item(event)

async def _areceive(self, event: Event) -> T:
"""
Expand All @@ -99,12 +101,14 @@ async def _areceive(self, event: Event) -> T:
if inspect.isawaitable(waiter):
await waiter
with self._ts_lock:
return self._next_item()
return self._next_item(event)

def _next_item(self) -> T:
def _next_item(self, event: Optional[Event] = None) -> T:
"""
Get next item from the queue
"""
ctx = self._queue.popleft()
if ctx.destination is not None:
assert ctx.destination == event, f"{ctx.destination} != {event}"
ctx.event.set()
return ctx.value

0 comments on commit ff6a24a

Please sign in to comment.