-
Notifications
You must be signed in to change notification settings - Fork 87
/
Copy pathmiddleware.py
243 lines (205 loc) · 8.46 KB
/
middleware.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from __future__ import annotations
import asyncio
import re
from http import HTTPStatus
from timeit import default_timer
from typing import Awaitable, Callable, Optional, Sequence, Tuple, Union
from fastapi import FastAPI
from prometheus_client import REGISTRY, CollectorRegistry, Gauge
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send
from prometheus_fastapi_instrumentator import metrics, routing
class PrometheusInstrumentatorMiddleware:
def __init__(
self,
app: FastAPI,
*,
should_group_status_codes: bool = True,
should_ignore_untemplated: bool = False,
should_group_untemplated: bool = True,
should_round_latency_decimals: bool = False,
should_respect_env_var: bool = False,
should_instrument_requests_inprogress: bool = False,
excluded_handlers: Sequence[str] = (),
body_handlers: Sequence[str] = (),
round_latency_decimals: int = 4,
env_var_name: str = "ENABLE_METRICS",
inprogress_name: str = "http_requests_inprogress",
inprogress_labels: bool = False,
instrumentations: Sequence[Callable[[metrics.Info], None]] = (),
async_instrumentations: Sequence[Callable[[metrics.Info], Awaitable[None]]] = (),
metric_namespace: str = "",
metric_subsystem: str = "",
should_only_respect_2xx_for_highr: bool = False,
latency_highr_buckets: Sequence[Union[float, str]] = (
0.01,
0.025,
0.05,
0.075,
0.1,
0.25,
0.5,
0.75,
1,
1.5,
2,
2.5,
3,
3.5,
4,
4.5,
5,
7.5,
10,
30,
60,
),
latency_lowr_buckets: Sequence[Union[float, str]] = (0.1, 0.5, 1),
registry: CollectorRegistry = REGISTRY,
) -> None:
self.app = app
self.should_group_status_codes = should_group_status_codes
self.should_ignore_untemplated = should_ignore_untemplated
self.should_group_untemplated = should_group_untemplated
self.should_round_latency_decimals = should_round_latency_decimals
self.should_respect_env_var = should_respect_env_var
self.should_instrument_requests_inprogress = should_instrument_requests_inprogress
self.round_latency_decimals = round_latency_decimals
self.env_var_name = env_var_name
self.inprogress_name = inprogress_name
self.inprogress_labels = inprogress_labels
self.registry = registry
self.excluded_handlers = [re.compile(path) for path in excluded_handlers]
self.body_handlers = [re.compile(path) for path in body_handlers]
if instrumentations:
self.instrumentations = instrumentations
else:
default_instrumentation = metrics.default(
metric_namespace=metric_namespace,
metric_subsystem=metric_subsystem,
should_only_respect_2xx_for_highr=should_only_respect_2xx_for_highr,
latency_highr_buckets=latency_highr_buckets,
latency_lowr_buckets=latency_lowr_buckets,
registry=self.registry,
)
if default_instrumentation:
self.instrumentations = [default_instrumentation]
else:
self.instrumentations = []
self.async_instrumentations = async_instrumentations
self.inprogress: Optional[Gauge] = None
if self.should_instrument_requests_inprogress:
labels = (
(
"method",
"handler",
)
if self.inprogress_labels
else ()
)
self.inprogress = Gauge(
name=self.inprogress_name,
documentation="Number of HTTP requests in progress.",
labelnames=labels,
multiprocess_mode="livesum",
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope)
start_time = default_timer()
handler, is_templated = self._get_handler(request)
is_excluded = self._is_handler_excluded(handler, is_templated)
handler = (
"none" if not is_templated and self.should_group_untemplated else handler
)
if not is_excluded and self.inprogress:
if self.inprogress_labels:
inprogress = self.inprogress.labels(request.method, handler)
else:
inprogress = self.inprogress
inprogress.inc()
status_code = 500
headers = []
body = b""
# Message body collected for handlers matching body_handlers patterns.
if any(pattern.search(handler) for pattern in self.body_handlers):
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
nonlocal status_code, headers
headers = message["headers"]
status_code = message["status"]
elif message["type"] == "http.response.body" and message["body"]:
nonlocal body
body += message["body"]
await send(message)
else:
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
nonlocal status_code, headers
headers = message["headers"]
status_code = message["status"]
await send(message)
try:
await self.app(scope, receive, send_wrapper)
except Exception as exc:
raise exc
finally:
status = (
str(status_code.value)
if isinstance(status_code, HTTPStatus)
else str(status_code)
)
if not is_excluded:
duration = max(default_timer() - start_time, 0)
if self.should_instrument_requests_inprogress:
inprogress.dec()
if self.should_round_latency_decimals:
duration = round(duration, self.round_latency_decimals)
if self.should_group_status_codes:
status = status[0] + "xx"
response = Response(
content=body, headers=Headers(raw=headers), status_code=status_code
)
info = metrics.Info(
request=request,
response=response,
method=request.method,
modified_handler=handler,
modified_status=status,
modified_duration=duration,
)
for instrumentation in self.instrumentations:
instrumentation(info)
await asyncio.gather(
*[
instrumentation(info)
for instrumentation in self.async_instrumentations
]
)
def _get_handler(self, request: Request) -> Tuple[str, bool]:
"""Extracts either template or (if no template) path.
Args:
request (Request): Python Requests request object.
Returns:
Tuple[str, bool]: Tuple with two elements. First element is either
template or if no template the path. Second element tells you
if the path is templated or not.
"""
route_name = routing.get_route_name(request)
return route_name or request.url.path, True if route_name else False
def _is_handler_excluded(self, handler: str, is_templated: bool) -> bool:
"""Determines if the handler should be ignored.
Args:
handler (str): Handler that handles the request.
is_templated (bool): Shows if the request is templated.
Returns:
bool: `True` if excluded, `False` if not.
"""
if not is_templated and self.should_ignore_untemplated:
return True
if any(pattern.search(handler) for pattern in self.excluded_handlers):
return True
return False