Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unexpected "Vary" #735

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion django_structlog/middlewares/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from django.core.exceptions import PermissionDenied
from django.core.signals import got_request_exception
from django.http import Http404, StreamingHttpResponse
from django.utils.functional import SimpleLazyObject

from .. import signals
from ..app_settings import app_settings
Expand All @@ -39,6 +40,7 @@
if TYPE_CHECKING: # pragma: no cover
from types import TracebackType

from django.contrib.auth.base_user import AbstractBaseUser
from django.http import HttpRequest, HttpResponse

logger = structlog.getLogger(__name__)
Expand Down Expand Up @@ -207,13 +209,30 @@ def format_request(request: "HttpRequest") -> str:
@staticmethod
def bind_user_id(request: "HttpRequest") -> None:
user_id_field = app_settings.USER_ID_FIELD
if hasattr(request, "user") and request.user is not None and user_id_field:
if not user_id_field:
return

session_was_accessed = (
request.session.accessed if hasattr(request, "session") else None
)

if hasattr(request, "user") and request.user is not None:
user_id = None
if hasattr(request.user, user_id_field):
user_id = getattr(request.user, user_id_field)
if isinstance(user_id, uuid.UUID):
user_id = str(user_id)
structlog.contextvars.bind_contextvars(user_id=user_id)
if session_was_accessed is not None and not session_was_accessed:
"""using SessionMiddleware but user was never accessed, must reset accessed state"""
user = request.user

def get_user() -> Any:
request.session.accessed = True
return user

request.user = cast("AbstractBaseUser", SimpleLazyObject(get_user))
request.session.accessed = False
jrobichaud marked this conversation as resolved.
Show resolved Hide resolved

def process_got_request_exception(
self, sender: Type[Any], request: "HttpRequest", **kwargs: Any
Expand Down
93 changes: 93 additions & 0 deletions test_app/tests/middlewares/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from unittest.mock import AsyncMock, Mock, patch

import structlog
from django.contrib.auth.middleware import AuthenticationMiddleware
from django.contrib.auth.models import AnonymousUser, User
from django.contrib.sessions.middleware import SessionMiddleware
from django.contrib.sites.models import Site
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import PermissionDenied
Expand Down Expand Up @@ -236,6 +238,97 @@ class SimpleUser:
self.assertIn("user_id", record.msg)
self.assertIsNone(record.msg["user_id"])

@override_settings(
SECRET_KEY="00000000000000000000000000000000",
)
def test_process_request_session_middleware_without_vary(self) -> None:
def get_response(_request: HttpRequest) -> HttpResponse:
with self.assertLogs(__name__, logging.INFO) as log_results:
self.logger.info("hello")
self.log_results = log_results
return HttpResponse()

request = self.factory.get("/foo")

# simulate SessionMiddleware, AuthenticationMiddleware, and RequestMiddleware called in that order
request_middleware = RequestMiddleware(get_response)
authentication_middleware = AuthenticationMiddleware(
cast(
Any,
lambda r: request_middleware(r),
)
)
session_middleware = SessionMiddleware(
cast(Any, lambda r: authentication_middleware(r))
)
response = session_middleware(request)

self.assertEqual(1, len(self.log_results.records))
record = self.log_results.records[0]
self.assertIsNone(cast(HttpResponse, response).headers.get("Vary"))

self.assertEqual("INFO", record.levelname)

self.assertIn("user_id", record.msg)
self.assertIsNone(record.msg["user_id"])

@override_settings(
SECRET_KEY="00000000000000000000000000000000",
)
def test_process_request_session_middleware_with_vary(self) -> None:
def get_response(_request: HttpRequest) -> HttpResponse:
assert isinstance(
request.user, AnonymousUser
) # force evaluate user to trigger session middleware
with self.assertLogs(__name__, logging.INFO) as log_results:
self.logger.info("hello")
self.log_results = log_results
return HttpResponse()

request = self.factory.get("/foo")

# simulate SessionMiddleware, AuthenticationMiddleware, and RequestMiddleware called in that order
request_middleware = RequestMiddleware(get_response)
authentication_middleware = AuthenticationMiddleware(
cast(Any, lambda r: request_middleware(r))
)
session_middleware = SessionMiddleware(
cast(Any, lambda r: authentication_middleware(r))
)
response = session_middleware(request)

self.assertEqual(1, len(self.log_results.records))
record = self.log_results.records[0]
self.assertIsNotNone(cast(HttpResponse, response).headers.get("Vary"))

self.assertEqual("INFO", record.levelname)

self.assertIn("user_id", record.msg)
self.assertIsNone(record.msg["user_id"])

@override_settings(
DJANGO_STRUCTLOG_USER_ID_FIELD=None,
)
def test_process_request_no_user_id_field(self) -> None:
def get_response(_request: HttpRequest) -> HttpResponse:
with self.assertLogs(__name__, logging.INFO) as log_results:
self.logger.info("hello")
self.log_results = log_results
return HttpResponse()

request = self.factory.get("/foo")

middleware = RequestMiddleware(get_response)
response = middleware(request)
self.assertEqual(200, cast(HttpResponse, response).status_code)

self.assertEqual(1, len(self.log_results.records))
record = self.log_results.records[0]

self.assertEqual("INFO", record.levelname)

self.assertNotIn("user_id", record.msg)

def test_log_user_in_request_finished(self) -> None:
mock_response = Mock()
mock_response.status_code = 200
Expand Down
Loading