Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pickle error when attempting to pickle an application which contains websocket routes #1853

Merged
merged 2 commits into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 61 additions & 57 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,12 @@ def add_task(self, task):
:param task: future, couroutine or awaitable
"""
try:
if callable(task):
try:
self.loop.create_task(task(self))
except TypeError:
self.loop.create_task(task())
else:
self.loop.create_task(task)
loop = self.loop # Will raise SanicError if loop is not started
self._loop_add_task(task, self, loop)
except SanicException:

@self.listener("before_server_start")
def run(app, loop):
if callable(task):
try:
loop.create_task(task(self))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)
self.listener("before_server_start")(
partial(self._loop_add_task, task)
)

# Decorator
def listener(self, event):
Expand Down Expand Up @@ -493,42 +481,12 @@ def response(handler):
routes, handler = handler
else:
routes = []

async def websocket_handler(request, *args, **kwargs):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "")
+ handler.__name__
)

pass

if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self

ws = await protocol.websocket_handshake(
request, subprotocols
)

# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()

websocket_handler = partial(
self._websocket_handler, handler, subprotocols=subprotocols
)
websocket_handler.__name__ = (
"websocket_handler_" + handler.__name__
)
routes.extend(
self.router.add(
uri=uri,
Expand Down Expand Up @@ -589,10 +547,7 @@ def enable_websocket(self, enable=True):
if not self.websocket_enabled:
# if the server is stopped, we want to cancel any ongoing
# websocket tasks, to allow the server to exit promptly
@self.listener("before_server_stop")
def cancel_websocket_tasks(app, loop):
for task in self.websocket_tasks:
task.cancel()
self.listener("before_server_stop")(self._cancel_websocket_tasks)

self.websocket_enabled = enable

Expand Down Expand Up @@ -1425,6 +1380,55 @@ def _build_endpoint_name(self, *parts):
parts = [self.name, *parts]
return ".".join(parts)

@classmethod
def _loop_add_task(cls, task, app, loop):
if callable(task):
try:
loop.create_task(task(app))
except TypeError:
loop.create_task(task())
else:
loop.create_task(task)

@classmethod
def _cancel_websocket_tasks(cls, app, loop):
for task in app.websocket_tasks:
task.cancel()

async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
):
request.app = self
if not getattr(handler, "__blueprintname__", False):
request.endpoint = handler.__name__
else:
request.endpoint = (
getattr(handler, "__blueprintname__", "") + handler.__name__
)

pass

if self.asgi:
ws = request.transport.get_websocket_connection()
else:
protocol = request.transport.get_protocol()
protocol.app = self

ws = await protocol.websocket_handshake(request, subprotocols)

# schedule the application handler
# its future is kept in self.websocket_tasks in case it
# needs to be cancelled due to the server being stopped
fut = ensure_future(handler(request, ws, *args, **kwargs))
self.websocket_tasks.add(fut)
try:
await fut
except (CancelledError, ConnectionClosed):
pass
finally:
self.websocket_tasks.remove(fut)
await ws.close()

# -------------------------------------------------------------------- #
# ASGI
# -------------------------------------------------------------------- #
Expand Down
2 changes: 1 addition & 1 deletion sanic/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def sig_handler(signal, frame):

signal_func(SIGINT, lambda s, f: sig_handler(s, f))
signal_func(SIGTERM, lambda s, f: sig_handler(s, f))
mp = multiprocessing.get_context("fork")
mp = multiprocessing.get_context("spawn")

for _ in range(workers):
process = mp.Process(target=serve, kwargs=server_settings)
Expand Down