Skip to content

Commit

Permalink
Make show toolbar callback function async/sync compatible.
Browse files Browse the repository at this point in the history
This checks if the SHOW_TOOLBAR_CALLBACK is a coroutine
if we're in async mode and the reverse if it's not. It
will automatically wrap the function with sync_to_async
or async_to_sync when necessary.
  • Loading branch information
tim-schilling committed Jan 30, 2025
1 parent f8fed2c commit 2a227c0
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 53 deletions.
2 changes: 1 addition & 1 deletion debug_toolbar/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def require_show_toolbar(view):
def inner(request, *args, **kwargs):
from debug_toolbar.middleware import get_show_toolbar

show_toolbar = get_show_toolbar()
show_toolbar = get_show_toolbar(async_mode=False)
if not show_toolbar(request):
raise Http404

Expand Down
35 changes: 29 additions & 6 deletions debug_toolbar/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import socket
from functools import cache

from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async, async_to_sync
from django.conf import settings
from django.utils.module_loading import import_string

Expand Down Expand Up @@ -45,9 +45,13 @@ def show_toolbar(request):
# No test passed
return False


@cache
def get_show_toolbar():
def show_toolbar_func_or_path():
"""
Fetch the show toolbar callback from settings
Cached to avoid importing multiple times.
"""
# If SHOW_TOOLBAR_CALLBACK is a string, which is the recommended
# setup, resolve it to the corresponding callable.
func_or_path = dt_settings.get_config()["SHOW_TOOLBAR_CALLBACK"]
Expand All @@ -57,6 +61,23 @@ def get_show_toolbar():
return func_or_path


def get_show_toolbar(async_mode):
"""
Get the callback function to show the toolbar.
Will wrap the function with sync_to_async or
async_to_sync depending on the status of async_mode
and whether the underlying function is a coroutine.
"""
show_toolbar = show_toolbar_func_or_path()
is_coroutine = iscoroutinefunction(show_toolbar)
if is_coroutine and not async_mode:
show_toolbar = async_to_sync(show_toolbar)
elif not is_coroutine and async_mode:
show_toolbar = sync_to_async(show_toolbar)
return show_toolbar


class DebugToolbarMiddleware:
"""
Middleware to set up Debug Toolbar on incoming request and render toolbar
Expand All @@ -82,7 +103,8 @@ def __call__(self, request):
if self.async_mode:
return self.__acall__(request)
# Decide whether the toolbar is active for this request.
show_toolbar = get_show_toolbar()
show_toolbar = get_show_toolbar(async_mode=self.async_mode)

if not show_toolbar(request) or DebugToolbar.is_toolbar_request(request):
return self.get_response(request)
toolbar = DebugToolbar(request, self.get_response)
Expand All @@ -103,8 +125,9 @@ def __call__(self, request):

async def __acall__(self, request):
# Decide whether the toolbar is active for this request.
show_toolbar = get_show_toolbar()
if not show_toolbar(request) or DebugToolbar.is_toolbar_request(request):
show_toolbar = get_show_toolbar(async_mode=self.async_mode)

if not await show_toolbar(request) or DebugToolbar.is_toolbar_request(request):
response = await self.get_response(request)
return response

Expand Down
85 changes: 85 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
from unittest.mock import patch

from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import AsyncRequestFactory, RequestFactory, TestCase, override_settings

from debug_toolbar.middleware import DebugToolbarMiddleware


def show_toolbar_if_staff(request):
# Hit the database, but always return True
return User.objects.exists() or True


async def ashow_toolbar_if_staff(request):
# Hit the database, but always return True
has_users = await User.objects.afirst()
return has_users or True


class MiddlewareSyncAsyncCompatibilityTestCase(TestCase):
def setUp(self):
self.factory = RequestFactory()
self.async_factory = AsyncRequestFactory()

@override_settings(DEBUG=True)
def test_sync_mode(self):
"""
test middleware switches to sync (__call__) based on get_response type
"""

request = self.factory.get("/")
middleware = DebugToolbarMiddleware(
lambda x: HttpResponse("<html><body>Test app</body></html>")
)

self.assertFalse(asyncio.iscoroutinefunction(middleware))

response = middleware(request)
self.assertEqual(response.status_code, 200)
self.assertIn(b"djdt", response.content)

@override_settings(DEBUG=True)
async def test_async_mode(self):
"""
test middleware switches to async (__acall__) based on get_response type
and returns a coroutine
"""

async def get_response(request):
return HttpResponse("<html><body>Test app</body></html>")

middleware = DebugToolbarMiddleware(get_response)
request = self.async_factory.get("/")

self.assertTrue(asyncio.iscoroutinefunction(middleware))

response = await middleware(request)
self.assertEqual(response.status_code, 200)
self.assertIn(b"djdt", response.content)

@override_settings(DEBUG=True)
@patch("debug_toolbar.middleware.show_toolbar_func_or_path", return_value=ashow_toolbar_if_staff)
def test_async_show_toolbar_callback_sync_middleware(self, mocked_show):
def get_response(request):
return HttpResponse("<html><body>Hello world</body></html>")
middleware = DebugToolbarMiddleware(get_response)

request = self.factory.get("/")
response = middleware(request)
self.assertEqual(response.status_code, 200)
self.assertIn(b"djdt", response.content)

@override_settings(DEBUG=True)
@patch("debug_toolbar.middleware.show_toolbar_func_or_path", return_value=show_toolbar_if_staff)
async def test_sync_show_toolbar_callback_async_middleware(self, mocked_show):
async def get_response(request):
return HttpResponse("<html><body>Hello world</body></html>")
middleware = DebugToolbarMiddleware(get_response)

request = self.async_factory.get("/")
response = await middleware(request)
self.assertEqual(response.status_code, 200)
self.assertIn(b"djdt", response.content)
46 changes: 0 additions & 46 deletions tests/test_middleware_compatibility.py

This file was deleted.

0 comments on commit 2a227c0

Please sign in to comment.