Skip to content

Commit

Permalink
Fix pickle error when attempting to pickle an application which conta…
Browse files Browse the repository at this point in the history
…ins websocket routes. (sanic-org#1853)

Moves the websocket_handler subfunction out to a class-level method, which can be more easily pickled by the built-in python Pickler.
Also includes a similar fix for the add_task deferred task scheduler subfunction.

Co-authored-by: Adam Hopkins <[email protected]>
  • Loading branch information
ashleysommer and ahopkins authored Jun 28, 2020
1 parent 83511a0 commit 761eef7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 58 deletions.
118 changes: 61 additions & 57 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,12 @@ def add_task(self, task):
:param task: future, couroutine or awaitable
"""
try:
if callable(task):
try:
self.loop.create_task(task(self))
except TypeError:
self.loop.create_task(task())
else:
self.loop.create_task(task)
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:

@self.listener("before_server_start")
def run(app, loop):
if callable(task):
try:
loop.create_task(task(self))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)
self.listener("before_server_start")(
partial(self._loop_add_task, task)
)

# Decorator
def listener(self, event):
Expand Down Expand Up @@ -499,42 +487,12 @@ def response(handler):
routes, handler = handler
else:
routes = []

async def websocket_handler(request, *args, **kwargs):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "")
+ handler.__name__
)

pass

if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self

ws = await protocol.websocket_handshake(
request, subprotocols
)

# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()

websocket_handler = partial(
self._websocket_handler, handler, subprotocols=subprotocols
)
websocket_handler.__name__ = (
"websocket_handler_" + handler.__name__
)
routes.extend(
self.router.add(
uri=uri,
Expand Down Expand Up @@ -598,10 +556,7 @@ def enable_websocket(self, enable=True):
if not self.websocket_enabled:
# if the server is stopped, we want to cancel any ongoing
# websocket tasks, to allow the server to exit promptly
@self.listener("before_server_stop")
def cancel_websocket_tasks(app, loop):
for task in self.websocket_tasks:
task.cancel()
self.listener("before_server_stop")(self._cancel_websocket_tasks)

self.websocket_enabled = enable

Expand Down Expand Up @@ -1425,6 +1380,55 @@ def _build_endpoint_name(self, *parts):
parts = [self.name, *parts]
return ".".join(parts)

@classmethod
def _loop_add_task(cls, task, app, loop):
if callable(task):
try:
loop.create_task(task(app))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)

@classmethod
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()

async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "") + handler.__name__
)

pass

if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self

ws = await protocol.websocket_handshake(request, subprotocols)

# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()

# -------------------------------------------------------------------- #
# ASGI
# -------------------------------------------------------------------- #
Expand Down
2 changes: 1 addition & 1 deletion sanic/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def sig_handler(signal, frame):

signal_func(SIGINT, lambda s, f: sig_handler(s, f))
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
mp = multiprocessing.get_context("fork")
mp = multiprocessing.get_context("spawn")

for _ in range(workers):
process = mp.Process(target=serve, kwargs=server_settings)
Expand Down

0 comments on commit 761eef7

Please sign in to comment.