Skip to content

Commit

Permalink
Fix pickle error when attempting to pickle an application which conta…
Browse files Browse the repository at this point in the history
…ins websocket routes. (#1853)

Moves the websocket_handler subfunction out to a class-level method, which can be more easily pickled by the built-in python Pickler.
Also includes a similar fix for the add_task deferred task scheduler subfunction.

Co-authored-by: Adam Hopkins <[email protected]>
  • Loading branch information
ashleysommer and ahopkins authored Jun 28, 2020
1 parent 83511a0 commit 761eef7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 58 deletions.
118 changes: 61 additions & 57 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,12 @@ def add_task(self, task):
:param task: future, couroutine or awaitable
"""
try:
if callable(task):
try:
self.loop.create_task(task(self))
except TypeError:
self.loop.create_task(task())
else:
self.loop.create_task(task)
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:

@self.listener("before_server_start")
def run(app, loop):
if callable(task):
try:
loop.create_task(task(self))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)
self.listener("before_server_start")(
partial(self._loop_add_task, task)
)

# Decorator
def listener(self, event):
Expand Down Expand Up @@ -499,42 +487,12 @@ def response(handler):
routes, handler = handler
else:
routes = []

async def websocket_handler(request, *args, **kwargs):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "")
+ handler.__name__
)

pass

if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self

ws = await protocol.websocket_handshake(
request, subprotocols
)

# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()

websocket_handler = partial(
self._websocket_handler, handler, subprotocols=subprotocols
)
websocket_handler.__name__ = (
"websocket_handler_" + handler.__name__
)
routes.extend(
self.router.add(
uri=uri,
Expand Down Expand Up @@ -598,10 +556,7 @@ def enable_websocket(self, enable=True):
if not self.websocket_enabled:
# if the server is stopped, we want to cancel any ongoing
# websocket tasks, to allow the server to exit promptly
@self.listener("before_server_stop")
def cancel_websocket_tasks(app, loop):
for task in self.websocket_tasks:
task.cancel()
self.listener("before_server_stop")(self._cancel_websocket_tasks)

self.websocket_enabled = enable

Expand Down Expand Up @@ -1425,6 +1380,55 @@ def _build_endpoint_name(self, *parts):
parts = [self.name, *parts]
return ".".join(parts)

@classmethod
def _loop_add_task(cls, task, app, loop):
if callable(task):
try:
loop.create_task(task(app))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)

@classmethod
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()

async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "") + handler.__name__
)

pass

if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self

ws = await protocol.websocket_handshake(request, subprotocols)

# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()

# -------------------------------------------------------------------- #
# ASGI
# -------------------------------------------------------------------- #
Expand Down
2 changes: 1 addition & 1 deletion sanic/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def sig_handler(signal, frame):

signal_func(SIGINT, lambda s, f: sig_handler(s, f))
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
mp = multiprocessing.get_context("fork")
mp = multiprocessing.get_context("spawn")

for _ in range(workers):
process = mp.Process(target=serve, kwargs=server_settings)
Expand Down

0 comments on commit 761eef7

Please sign in to comment.