Skip to content

Commit

Permalink
GIT-37: fix blueprint middleware application (#1690)
Browse files Browse the repository at this point in the history
* GIT-37: fix blueprint middleware application

1. If you register a middleware via `@blueprint.middleware` then it will apply only to the routes defined by the blueprint.
2. If you register a middleware via `@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group.
3. If you define a middleware via `@app.middleware` then it will be applied on all available routes

Fixes #37

Signed-off-by: Harsha Narayana <[email protected]>

* GIT-37: add changelog

Signed-off-by: Harsha Narayana <[email protected]>
  • Loading branch information
harshanarayana authored and sjsadowski committed Dec 20, 2019
1 parent 179a079 commit a6077a1
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 34 deletions.
11 changes: 11 additions & 0 deletions changelogs/37.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Fix blueprint middleware application

Currently, any blueprint middleware registered, irrespective of which blueprint was used to do so, was
being applied to all of the routes created by the :code:`@app` and :code:`@blueprint` alike.

As part of this change, the blueprint based middleware application is enforced based on where they are
registered.

- If you register a middleware via :code:`@blueprint.middleware` then it will apply only to the routes defined by the blueprint.
- If you register a middleware via :code:`@blueprint_group.middleware` then it will apply to all blueprint based routes that are part of the group.
- If you define a middleware via :code:`@app.middleware` then it will be applied on all available routes
67 changes: 49 additions & 18 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def __init__(
self.is_request_stream = False
self.websocket_enabled = False
self.websocket_tasks = set()

self.named_request_middleware = {}
self.named_response_middleware = {}
# Register alternative method names
self.go_fast = self.run

Expand Down Expand Up @@ -178,7 +179,7 @@ def route(
:param stream:
:param version:
:param name: user defined route name for url_for
:return: decorated function
:return: tuple of routes, decorated function
"""

# Fix case where the user did not prefix the URL with a /
Expand All @@ -204,7 +205,7 @@ def response(handler):
if stream:
handler.is_stream = stream

self.router.add(
routes = self.router.add(
uri=uri,
methods=methods,
handler=handler,
Expand All @@ -213,7 +214,7 @@ def response(handler):
version=version,
name=name,
)
return handler
return routes, handler

return response

Expand Down Expand Up @@ -462,7 +463,7 @@ def websocket(
:param subprotocols: optional list of str with supported subprotocols
:param name: A unique name assigned to the URL so that it can
be used with :func:`url_for`
:return: decorated function
:return: tuple of routes, decorated function
"""
self.enable_websocket()

Expand Down Expand Up @@ -515,15 +516,15 @@ async def websocket_handler(request, *args, **kwargs):
self.websocket_tasks.remove(fut)
await ws.close()

self.router.add(
routes = self.router.add(
uri=uri,
handler=websocket_handler,
methods=frozenset({"GET"}),
host=host,
strict_slashes=strict_slashes,
name=name,
)
return handler
return routes, handler

return response

Expand All @@ -544,6 +545,7 @@ def add_websocket_route(
:param host: Host IP or FQDN details
:param uri: URL path that will be mapped to the websocket
handler
handler
:param strict_slashes: If the API endpoint needs to terminate
with a "/" or not
:param subprotocols: Subprotocols to be used with websocket
Expand Down Expand Up @@ -645,6 +647,22 @@ def register_middleware(self, middleware, attach_to="request"):
self.response_middleware.appendleft(middleware)
return middleware

def register_named_middleware(
self, middleware, route_names, attach_to="request"
):
if attach_to == "request":
for _rn in route_names:
if _rn not in self.named_request_middleware:
self.named_request_middleware[_rn] = deque()
if middleware not in self.named_request_middleware[_rn]:
self.named_request_middleware[_rn].append(middleware)
if attach_to == "response":
for _rn in route_names:
if _rn not in self.named_response_middleware:
self.named_response_middleware[_rn] = deque()
if middleware not in self.named_response_middleware[_rn]:
self.named_response_middleware[_rn].append(middleware)

# Decorator
def middleware(self, middleware_or_request):
"""
Expand Down Expand Up @@ -916,20 +934,23 @@ async def handle_request(self, request, write_callback, stream_callback):
# allocation before assignment below.
response = None
cancelled = False
name = None
try:
# Fetch handler from router
handler, args, kwargs, uri, name = self.router.get(request)

# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(request)
response = await self._run_request_middleware(
request, request_name=name
)
# No middleware results
if not response:
# -------------------------------------------- #
# Execute Handler
# -------------------------------------------- #

# Fetch handler from router
handler, args, kwargs, uri = self.router.get(request)

request.uri_template = uri
if handler is None:
raise ServerError(
Expand Down Expand Up @@ -993,7 +1014,7 @@ async def handle_request(self, request, write_callback, stream_callback):
if response is not None:
try:
response = await self._run_response_middleware(
request, response
request, response, request_name=name
)
except CancelledError:
# Response middleware can timeout too, as above.
Expand Down Expand Up @@ -1265,20 +1286,30 @@ async def trigger_events(self, events, loop):
if isawaitable(result):
await result

async def _run_request_middleware(self, request):
async def _run_request_middleware(self, request, request_name=None):
# The if improves speed. I don't know why
if self.request_middleware:
for middleware in self.request_middleware:
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
response = middleware(request)
if isawaitable(response):
response = await response
if response:
return response
return None

async def _run_response_middleware(self, request, response):
if self.response_middleware:
for middleware in self.response_middleware:
async def _run_response_middleware(
self, request, response, request_name=None
):
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
Expand Down
20 changes: 15 additions & 5 deletions sanic/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def register(self, app, options):

url_prefix = options.get("url_prefix", self.url_prefix)

routes = []

# Routes
for future in self.routes:
# attach the blueprint name to the handler so that it can be
Expand All @@ -114,7 +116,7 @@ def register(self, app, options):

version = future.version or self.version

app.route(
_routes, _ = app.route(
uri=uri[1:] if uri.startswith("//") else uri,
methods=future.methods,
host=future.host or self.host,
Expand All @@ -123,28 +125,36 @@ def register(self, app, options):
version=version,
name=future.name,
)(future.handler)
if _routes:
routes += _routes

for future in self.websocket_routes:
# attach the blueprint name to the handler so that it can be
# prefixed properly in the router
future.handler.__blueprintname__ = self.name
# Prepend the blueprint URI prefix if available
uri = url_prefix + future.uri if url_prefix else future.uri
app.websocket(
_routes, _ = app.websocket(
uri=uri,
host=future.host or self.host,
strict_slashes=future.strict_slashes,
name=future.name,
)(future.handler)
if _routes:
routes += _routes

route_names = [route.name for route in routes]
# Middleware
for future in self.middlewares:
if future.args or future.kwargs:
app.register_middleware(
future.middleware, *future.args, **future.kwargs
app.register_named_middleware(
future.middleware,
route_names,
*future.args,
**future.kwargs
)
else:
app.register_middleware(future.middleware)
app.register_named_middleware(future.middleware, route_names)

# Exceptions
for future in self.exceptions:
Expand Down
17 changes: 11 additions & 6 deletions sanic/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,22 @@ def add(
docs for further details.
:return: Nothing
"""
routes = []
if version is not None:
version = re.escape(str(version).strip("/").lstrip("v"))
uri = "/".join(["/v{}".format(version), uri.lstrip("/")])
# add regular version
self._add(uri, methods, handler, host, name)
routes.append(self._add(uri, methods, handler, host, name))

if strict_slashes:
return
return routes

if not isinstance(host, str) and host is not None:
# we have gotten back to the top of the recursion tree where the
# host was originally a list. By now, we've processed the strict
# slashes logic on the leaf nodes (the individual host strings in
# the list of host)
return
return routes

# Add versions with and without trailing /
slashed_methods = self.routes_all.get(uri + "/", frozenset({}))
Expand All @@ -176,10 +177,12 @@ def add(
)
# add version with trailing slash
if slash_is_missing:
self._add(uri + "/", methods, handler, host, name)
routes.append(self._add(uri + "/", methods, handler, host, name))
# add version without trailing slash
elif without_slash_is_missing:
self._add(uri[:-1], methods, handler, host, name)
routes.append(self._add(uri[:-1], methods, handler, host, name))

return routes

def _add(self, uri, methods, handler, host=None, name=None):
"""Add a handler to the route list
Expand Down Expand Up @@ -328,6 +331,7 @@ def merge_route(route, methods, handler):
self.routes_dynamic[url_hash(uri)].append(route)
else:
self.routes_static[uri] = route
return route

@staticmethod
def check_dynamic_route_exists(pattern, routes_to_check, parameters):
Expand Down Expand Up @@ -442,6 +446,7 @@ def _get(self, url, method, host):
method=method,
allowed_methods=self.get_supported_methods(url),
)

if route:
if route.methods and method not in route.methods:
raise method_not_supported
Expand Down Expand Up @@ -476,7 +481,7 @@ def _get(self, url, method, host):
route_handler = route.handler
if hasattr(route_handler, "handlers"):
route_handler = route_handler.handlers[method]
return route_handler, [], kwargs, route.uri
return route_handler, [], kwargs, route.uri, route.name

def is_stream_handler(self, request):
""" Handler for request is stream or not.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def handler():

def test_app_handle_request_handler_is_none(app, monkeypatch):
def mockreturn(*args, **kwargs):
return None, [], {}, ""
return None, [], {}, "", ""

# Not sure how to make app.router.get() return None, so use mock here.
monkeypatch.setattr(app.router, "get", mockreturn)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_blueprint_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def enhance_response_middleware(request: Request, response: HTTPResponse):
_, response = app.test_client.patch("/api/bp2/route/bp2", headers=header)
assert response.text == "PATCH_bp2"

_, response = app.test_client.get("/v2/api/bp1/request_path")
_, response = app.test_client.put("/v2/api/bp1/request_path")
assert response.status == 401


Expand Down Expand Up @@ -141,8 +141,8 @@ def app_default_route(request):
_, response = app.test_client.get("/api/bp3")
assert response.text == "BP3_OK"

assert MIDDLEWARE_INVOKE_COUNTER["response"] == 4
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4
assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3
assert MIDDLEWARE_INVOKE_COUNTER["request"] == 2


def test_bp_group_list_operations(app: Sanic):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ async def handler(request):
request, response = app.test_client.get("/")

assert response.status == 200
assert response.text == "OK"
assert response.text == "FAIL"


def test_bp_exception_handler(app):
Expand Down

0 comments on commit a6077a1

Please sign in to comment.