-
Notifications
You must be signed in to change notification settings - Fork 34
/
middleware.py
108 lines (81 loc) · 3.55 KB
/
middleware.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional
from uuid import UUID, uuid4
from starlette.datastructures import MutableHeaders
from asgi_correlation_id.context import correlation_id
from asgi_correlation_id.extensions.sentry import get_sentry_extension
if TYPE_CHECKING:
from starlette.types import ASGIApp, Message, Receive, Scope, Send
logger = logging.getLogger('asgi_correlation_id')
def is_valid_uuid4(uuid_: str) -> bool:
"""
Check whether a string is a valid v4 uuid.
"""
try:
return UUID(uuid_).version == 4
except ValueError:
return False
FAILED_VALIDATION_MESSAGE = 'Generated new request ID (%s), since request header value failed validation'
@dataclass
class CorrelationIdMiddleware:
app: 'ASGIApp'
header_name: str = 'X-Request-ID'
update_request_header: bool = True
# ID-generating callable
generator: Callable[[], str] = field(default=lambda: uuid4().hex)
# ID validator
validator: Optional[Callable[[str], bool]] = field(default=is_valid_uuid4)
# ID transformer - can be used to clean/mutate IDs
transformer: Optional[Callable[[str], str]] = field(default=lambda a: a)
async def __call__(self, scope: 'Scope', receive: 'Receive', send: 'Send') -> None:
"""
Load request ID from headers if present. Generate one otherwise.
"""
if scope['type'] not in ('http', 'websocket'):
await self.app(scope, receive, send)
return
# Try to load request ID from the request headers
headers = MutableHeaders(scope=scope)
header_value = headers.get(self.header_name.lower())
validation_failed = False
if not header_value:
# Generate request ID if none was found
id_value = self.generator()
elif self.validator and not self.validator(header_value):
# Also generate a request ID if one was found, but it was deemed invalid
validation_failed = True
id_value = self.generator()
else:
# Otherwise, use the found request ID
id_value = header_value
# Clean/change the ID if needed
if self.transformer:
id_value = self.transformer(id_value)
if validation_failed is True:
logger.warning(FAILED_VALIDATION_MESSAGE, id_value)
# Update the request headers if needed
if id_value != header_value and self.update_request_header is True:
headers[self.header_name] = id_value
correlation_id.set(id_value)
self.sentry_extension(id_value)
async def handle_outgoing_request(message: 'Message') -> None:
if message['type'] == 'http.response.start' and correlation_id.get():
headers = MutableHeaders(scope=message)
headers.append(self.header_name, correlation_id.get())
await send(message)
await self.app(scope, receive, handle_outgoing_request)
return
def __post_init__(self) -> None:
"""
Load extensions on initialization.
If Sentry is installed, propagate correlation IDs to Sentry events.
If Celery is installed, propagate correlation IDs to spawned worker processes.
"""
self.sentry_extension = get_sentry_extension()
try:
import celery # noqa: F401, TC002
from asgi_correlation_id.extensions.celery import load_correlation_ids
load_correlation_ids()
except ImportError: # pragma: no cover
pass