Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Add middlewares and one specific to expose /metrics endpoint for prometheus #629

Merged
merged 9 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions mlem/contrib/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FastAPIServer implementation
"""
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from types import ModuleType
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type
Expand All @@ -24,6 +25,7 @@
InterfaceArgument,
InterfaceMethod,
)
from mlem.runtime.middleware import Middleware, Middlewares
from mlem.runtime.server import Server
from mlem.ui import EMOJI_NAILS, echo
from mlem.utils.module import get_object_requirements
Expand All @@ -48,6 +50,12 @@ def _create_schema_route(app: FastAPI, interface: Interface):
app.add_api_route("/interface.json", lambda: schema, tags=["schema"])


class FastAPIMiddleware(Middleware, ABC):
@abstractmethod
def on_app_init(self, app: FastAPI):
raise NotImplementedError


class FastAPIServer(Server, LibRequirementsMixin):
"""Serves model with http"""

Expand All @@ -70,6 +78,7 @@ def _create_handler_executor(
arg_serializers: Dict[str, DataTypeSerializer],
executor: Callable,
response_serializer: DataTypeSerializer,
middlewares: Middlewares,
):
deserialized_model = create_model(
"Model", **{a: (Any, ...) for a in args}
Expand Down Expand Up @@ -99,7 +108,9 @@ def serializer_validator(_, values):

def bin_handler(model: schema_model): # type: ignore[valid-type]
values = {a: getattr(model, a) for a in args}
values = middlewares.on_request(values)
result = executor(**values)
result = middlewares.on_response(values, result)
with response_serializer.dump(result) as buffer:
return StreamingResponse(
buffer, media_type="application/octet-stream"
Expand All @@ -113,7 +124,9 @@ def bin_handler(model: schema_model): # type: ignore[valid-type]

def handler(model: schema_model): # type: ignore[valid-type]
values = {a: getattr(model, a) for a in args}
values = middlewares.on_request(values)
result = executor(**values)
result = middlewares.on_response(values, result)
response = response_serializer.serialize(result)
return parse_obj_as(response_model, response)

Expand All @@ -127,12 +140,15 @@ def _create_handler_executor_binary(
arg_name: str,
executor: Callable,
response_serializer: DataTypeSerializer,
middlewares: Middlewares,
):
if response_serializer.serializer.is_binary:

def bin_handler(file: UploadFile):
arg = serializer.deserialize(_SpooledFileIOWrapper(file.file))
arg = middlewares.on_request(arg)
result = executor(**{arg_name: arg})
result = middlewares.on_response(arg, result)
with response_serializer.dump(result) as buffer:
return StreamingResponse(
buffer, media_type="application/octet-stream"
Expand All @@ -146,15 +162,20 @@ def bin_handler(file: UploadFile):

def handler(file: UploadFile):
arg = serializer.deserialize(file.file)
arg = middlewares.on_request(arg)
result = executor(**{arg_name: arg})

result = middlewares.on_response(arg, result)
response = response_serializer.serialize(result)
return parse_obj_as(response_model, response)

return handler, response_model, None

def _create_handler(
self, method_name: str, signature: InterfaceMethod, executor: Callable
self,
method_name: str,
signature: InterfaceMethod,
executor: Callable,
middlewares: Middlewares,
) -> Tuple[Optional[Callable], Optional[Type], Optional[Response]]:
serializers, response_serializer = self._get_serializers(signature)
echo(EMOJI_NAILS + f"Adding route for /{method_name}")
Expand All @@ -170,13 +191,15 @@ def _create_handler(
arg_name,
executor,
response_serializer,
middlewares,
)
return self._create_handler_executor(
method_name,
{a.name: a for a in signature.args},
serializers,
executor,
response_serializer,
middlewares,
)

def app_init(self, interface: Interface):
Expand All @@ -185,11 +208,15 @@ def app_init(self, interface: Interface):
app.add_api_route(
"/", lambda: RedirectResponse("/docs"), include_in_schema=False
)
for mid in self.middlewares.__root__:
mid.on_init()
if isinstance(mid, FastAPIMiddleware):
mid.on_app_init(app)

for method, signature in interface.iter_methods():
executor = interface.get_method_executor(method)
handler, response_model, response_class = self._create_handler(
method, signature, executor
method, signature, executor, self.middlewares
)

app.add_api_route(
Expand Down
65 changes: 65 additions & 0 deletions mlem/contrib/prometheus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Instrumenting FastAPI app to expose metrics for prometheus
Extension type: middleware

Exposes /metrics endpoint
"""
from typing import ClassVar, List, Optional

from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator

from mlem.contrib.fastapi import FastAPIMiddleware
from mlem.utils.importing import import_string_with_local
from mlem.utils.module import get_object_requirements


class PrometheusFastAPIMiddleware(FastAPIMiddleware):
"""Middleware for FastAPI server that exposes /metrics endpoint to be scraped by Prometheus"""

type: ClassVar = "prometheus_fastapi"

metrics: List[str] = []
"""Instrumentator instance to use. If not provided, a new one will be created"""
instrumentator_cache: Optional[Instrumentator] = None

class Config:
arbitrary_types_allowed = True
exclude = {"instrumentator_cache"}

@property
def instrumentator(self):
if self.instrumentator_cache is None:
self.instrumentator_cache = self.get_instrumentator()
return self.instrumentator_cache

def on_app_init(self, app: FastAPI):
@app.on_event("startup")
async def _startup():
self.instrumentator.expose(app)

def on_init(self):
pass

def on_request(self, request):
return request

def on_response(self, request, response):
return response

def get_instrumentator(self):
instrumentator = Instrumentator()
for metric in self._iter_metric_objects():
# todo: check object type
instrumentator.add(metric)
return instrumentator

def _iter_metric_objects(self):
for metric in self.metrics:
# todo: meaningful error on import error
yield import_string_with_local(metric)

def get_requirements(self):
reqs = super().get_requirements()
for metric in self._iter_metric_objects():
reqs += get_object_requirements(metric)
return reqs
1 change: 1 addition & 0 deletions mlem/contrib/sagemaker/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def app_init(self, interface: Interface):
"invocations",
interface.get_method_signature(self.method),
interface.get_method_executor(self.method),
self.middlewares,
)
app.add_api_route(
"/invocations",
Expand Down
11 changes: 2 additions & 9 deletions mlem/core/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import shlex
import sys
from collections import defaultdict
from inspect import isabstract
from typing import (
Expand All @@ -22,7 +21,7 @@

from mlem.core.errors import ExtensionRequirementError, UnknownImplementation
from mlem.polydantic import PolyModel
from mlem.utils.importing import import_string
from mlem.utils.importing import import_string_with_local
from mlem.utils.path import make_posix


Expand Down Expand Up @@ -64,18 +63,12 @@ def load_impl_ext(

if type_name is not None and "." in type_name:
try:
# this is needed because if run from cli curdir is not checked for
# modules to import
sys.path.append(".")

obj = import_string(type_name)
obj = import_string_with_local(type_name)
if not issubclass(obj, MlemABC):
raise ValueError(f"{obj} is not subclass of MlemABC")
return obj
except ImportError:
pass
finally:
sys.path.remove(".")

eps = load_entrypoints()
for ep in eps.values():
Expand Down
5 changes: 5 additions & 0 deletions mlem/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class ExtensionLoader:
Extension("mlem.contrib.xgboost", ["xgboost"], False),
Extension("mlem.contrib.docker", ["docker"], False),
Extension("mlem.contrib.fastapi", ["fastapi", "uvicorn"], False),
Extension(
"mlem.contrib.prometheus",
["prometheus-fastapi-instrumentator"],
False,
),
Extension("mlem.contrib.callable", [], True),
Extension("mlem.contrib.rabbitmq", ["pika"], False, extra="rmq"),
Extension("mlem.contrib.github", [], True),
Expand Down
51 changes: 51 additions & 0 deletions mlem/runtime/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import abstractmethod
from typing import ClassVar, List

from pydantic import BaseModel

from mlem.core.base import MlemABC
from mlem.core.requirements import Requirements, WithRequirements


class Middleware(MlemABC, WithRequirements):
abs_name: ClassVar = "middleware"

class Config:
type_root = True

@abstractmethod
def on_init(self):
raise NotImplementedError

@abstractmethod
def on_request(self, request):
raise NotImplementedError

@abstractmethod
def on_response(self, request, response):
raise NotImplementedError


class Middlewares(BaseModel):
__root__: List[Middleware] = []
"""Middlewares to add to server"""

def on_init(self):
for middleware in self.__root__:
middleware.on_init()

def on_request(self, request):
for middleware in self.__root__:
request = middleware.on_request(request)
return request

def on_response(self, request, response):
for middleware in reversed(self.__root__):
response = middleware.on_response(request, response)
return response

def get_requirements(self) -> Requirements:
reqs = Requirements.new()
for m in self.__root__:
reqs += m.get_requirements()
return reqs
16 changes: 14 additions & 2 deletions mlem/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
InterfaceDescriptor,
InterfaceMethod,
)
from mlem.runtime.middleware import Middlewares
from mlem.utils.module import get_object_requirements

MethodMapping = Dict[str, str]
Expand Down Expand Up @@ -120,6 +121,9 @@ class Config:
additional_source_files: ClassVar[Optional[List[str]]] = None
port_field: ClassVar[Optional[str]] = None

middlewares: Middlewares = Middlewares()
"""Middlewares to add to server"""

# @validator("interface")
# @classmethod
# def validate_interface(cls, value):
Expand Down Expand Up @@ -155,8 +159,16 @@ def _get_serializers(
return arg_serializers, returns

def get_requirements(self) -> Requirements:
return super().get_requirements() + get_object_requirements(
[self.request_serializer, self.response_serializer, self.methods]
return (
super().get_requirements()
+ get_object_requirements(
[
self.request_serializer,
self.response_serializer,
self.methods,
]
)
+ self.middlewares.get_requirements()
)

def get_ports(self) -> List[int]:
Expand Down
10 changes: 10 additions & 0 deletions mlem/utils/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def module_imported(module_name):
return sys.modules.get(module_name) is not None


def import_string_with_local(path):
try:
# this is needed because if run from cli curdir is not checked for
# modules to import
sys.path.append(".")
return import_string(path)
finally:
sys.path.remove(".")


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"xgboost": ["xgboost"],
"lightgbm": ["lightgbm"],
"fastapi": ["uvicorn", "fastapi"],
"prometheus": ["prometheus-fastapi-instrumentator"],
"streamlit": ["uvicorn", "fastapi", "streamlit", "streamlit_pydantic"],
"sagemaker": ["docker", "boto3", "sagemaker"],
"torch": ["torch"],
Expand Down Expand Up @@ -214,6 +215,7 @@
"serializer.pil_numpy = mlem.contrib.pil:PILImageSerializer",
"builder.pip = mlem.contrib.pip.base:PipBuilder",
"builder.whl = mlem.contrib.pip.base:WhlBuilder",
"middleware.prometheus_fastapi = mlem.contrib.prometheus:PrometheusFastAPIMiddleware",
"client.rmq = mlem.contrib.rabbitmq:RabbitMQClient",
"server.rmq = mlem.contrib.rabbitmq:RabbitMQServer",
"builder.requirements = mlem.contrib.requirements:RequirementsBuilder",
Expand Down
Loading