Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Add middleware requirements, metrics path instead of instrumentator f…
Browse files Browse the repository at this point in the history
…or prometheus middleware
  • Loading branch information
mike0sv committed Mar 21, 2023
1 parent a39f74d commit a01c051
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 16 deletions.
35 changes: 31 additions & 4 deletions mlem/contrib/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,39 @@
Exposes /metrics endpoint
"""
from typing import ClassVar, Optional
from typing import ClassVar, List, Optional

from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator

from mlem.contrib.fastapi import FastAPIMiddleware
from mlem.utils.importing import import_string_with_local
from mlem.utils.module import get_object_requirements


class PrometheusFastAPIMiddleware(FastAPIMiddleware):
"""Middleware for FastAPI server that exposes /metrics endpoint to be scraped by Prometheus"""

type: ClassVar = "prometheus_fastapi"

instrumentator: Optional[Instrumentator]
metrics: List[str] = []
"""Instrumentator instance to use. If not provided, a new one will be created"""
instrumentator_cache: Optional[Instrumentator] = None

class Config:
arbitrary_types_allowed = True
exclude = {"instrumentator"}
exclude = {"instrumentator_cache"}

@property
def instrumentator(self):
if self.instrumentator_cache is None:
self.instrumentator_cache = self.get_instrumentator()
return self.instrumentator_cache

def on_app_init(self, app: FastAPI):
@app.on_event("startup")
async def _startup():
(self.instrumentator or Instrumentator()).expose(app)
self.instrumentator.expose(app)

def on_init(self):
pass
Expand All @@ -36,3 +45,21 @@ def on_request(self, request):

def on_response(self, request, response):
return response

def get_instrumentator(self):
instrumentator = Instrumentator()
for metric in self._iter_metric_objects():
# todo: check object type
instrumentator.add(metric)
return instrumentator

def _iter_metric_objects(self):
for metric in self.metrics:
# todo: meaningful error on import error
yield import_string_with_local(metric)

def get_requirements(self):
reqs = super().get_requirements()
for metric in self._iter_metric_objects():
reqs += get_object_requirements(metric)
return reqs
11 changes: 2 additions & 9 deletions mlem/core/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import shlex
import sys
from collections import defaultdict
from inspect import isabstract
from typing import (
Expand All @@ -22,7 +21,7 @@

from mlem.core.errors import ExtensionRequirementError, UnknownImplementation
from mlem.polydantic import PolyModel
from mlem.utils.importing import import_string
from mlem.utils.importing import import_string_with_local
from mlem.utils.path import make_posix


Expand Down Expand Up @@ -64,18 +63,12 @@ def load_impl_ext(

if type_name is not None and "." in type_name:
try:
# this is needed because if run from cli curdir is not checked for
# modules to import
sys.path.append(".")

obj = import_string(type_name)
obj = import_string_with_local(type_name)
if not issubclass(obj, MlemABC):
raise ValueError(f"{obj} is not subclass of MlemABC")
return obj
except ImportError:
pass
finally:
sys.path.remove(".")

eps = load_entrypoints()
for ep in eps.values():
Expand Down
9 changes: 8 additions & 1 deletion mlem/runtime/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from pydantic import BaseModel

from mlem.core.base import MlemABC
from mlem.core.requirements import Requirements, WithRequirements


class Middleware(MlemABC):
class Middleware(MlemABC, WithRequirements):
abs_name: ClassVar = "middleware"

class Config:
Expand Down Expand Up @@ -42,3 +43,9 @@ def on_response(self, request, response):
for middleware in reversed(self.__root__):
response = middleware.on_response(request, response)
return response

def get_requirements(self) -> Requirements:
reqs = Requirements.new()
for m in self.__root__:
reqs += m.get_requirements()
return reqs
12 changes: 10 additions & 2 deletions mlem/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,16 @@ def _get_serializers(
return arg_serializers, returns

def get_requirements(self) -> Requirements:
return super().get_requirements() + get_object_requirements(
[self.request_serializer, self.response_serializer, self.methods]
return (
super().get_requirements()
+ get_object_requirements(
[
self.request_serializer,
self.response_serializer,
self.methods,
]
)
+ self.middlewares.get_requirements()
)

def get_ports(self) -> List[int]:
Expand Down
10 changes: 10 additions & 0 deletions mlem/utils/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def module_imported(module_name):
return sys.modules.get(module_name) is not None


def import_string_with_local(path):
try:
# this is needed because if run from cli curdir is not checked for
# modules to import
sys.path.append(".")
return import_string(path)
finally:
sys.path.remove(".")


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down

0 comments on commit a01c051

Please sign in to comment.