Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ml endpoints clients #20

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
29 changes: 29 additions & 0 deletions cmsdials/clients/ml_bad_lumisection/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from ...utils.api_client import BaseAuthorizedAPIClient
from .models import MLBadLumisection, MLBadLumisectionFilters, PaginatedMLBadLumisectionList


class MLBadLumisectionClient(BaseAuthorizedAPIClient):
data_model = MLBadLumisection
pagination_model = PaginatedMLBadLumisectionList
filter_class = MLBadLumisectionFilters
lookup_url = "ml-bad-lumisection/"

def get(self, model_id: int, dataset_id: int, run_number: int, ls_number: int, me_id: int, **kwargs):
edp = f"{model_id}/{dataset_id}/{run_number}/{ls_number}/{me_id}/"
return super().get(edp, **kwargs)

def cert_json(self, model_id__in: list[int], dataset_id__in: list[int], run_number__in: list[int], **kwargs):
edp = "cert-json/"
midin = ",".join(str(v) for v in model_id__in)
didin = ",".join(str(v) for v in dataset_id__in)
ridin = ",".join(str(v) for v in run_number__in)
params = {"model_id__in": midin, "dataset_id__in": didin, "run_number__in": ridin}
return super().get(edp, params=params, return_raw_json=True, **kwargs)

def golden_json(self, model_id__in: list[int], dataset_id__in: list[int], run_number__in: list[int], **kwargs):
edp = "golden-json/"
midin = ",".join(str(v) for v in model_id__in)
didin = ",".join(str(v) for v in dataset_id__in)
ridin = ",".join(str(v) for v in run_number__in)
params = {"model_id__in": midin, "dataset_id__in": didin, "run_number__in": ridin}
return super().get(edp, params=params, return_raw_json=True, **kwargs)
38 changes: 38 additions & 0 deletions cmsdials/clients/ml_bad_lumisection/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional

from pydantic import AnyUrl, BaseModel

from ...utils.base_model import OBaseModel, PaginatedBaseModel


class MLBadLumisection(BaseModel):
model_id: int
dataset_id: int
file_id: int
run_number: int
ls_number: int
me_id: int
mse: float


class PaginatedMLBadLumisectionList(PaginatedBaseModel):
next: Optional[AnyUrl]
previous: Optional[AnyUrl]
results: list[MLBadLumisection]


class MLBadLumisectionFilters(OBaseModel):
next_token: Optional[str] = None
page_size: Optional[int] = None
model_id: Optional[int] = None
model_id__in: Optional[list[int]] = None
dataset_id: Optional[int] = None
dataset_id__in: Optional[list[int]] = None
dataset: Optional[str] = None
dataset__regex: Optional[str] = None
me_id: Optional[int] = None
me: Optional[str] = None
me__regex: Optional[str] = None
run_number: Optional[int] = None
run_number__in: Optional[list[int]] = None
ls_number: Optional[int] = None
Empty file.
13 changes: 13 additions & 0 deletions cmsdials/clients/ml_models_index/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ...utils.api_client import BaseAuthorizedAPIClient
from .models import MLModelsIndex, MLModelsIndexFilters, PaginatedMLModelsIndexList


class MLModelsIndexClient(BaseAuthorizedAPIClient):
data_model = MLModelsIndex
pagination_model = PaginatedMLModelsIndexList
filter_class = MLModelsIndexFilters
lookup_url = "ml-models-index/"

def get(self, model_id: int, **kwargs):
edp = f"{model_id}/"
return super().get(edp, **kwargs)
30 changes: 30 additions & 0 deletions cmsdials/clients/ml_models_index/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

from pydantic import AnyUrl, BaseModel

from ...utils.base_model import OBaseModel, PaginatedBaseModel


class MLModelsIndex(BaseModel):
model_id: int
filename: str
target_me: str
active: bool


class PaginatedMLModelsIndexList(PaginatedBaseModel):
next: Optional[AnyUrl]
previous: Optional[AnyUrl]
results: list[MLModelsIndex]


class MLModelsIndexFilters(OBaseModel):
next_token: Optional[str] = None
page_size: Optional[int] = None
model_id: Optional[int] = None
model_id__in: Optional[list[int]] = None
filename: Optional[str] = None
filename__regex: Optional[str] = None
target_me: Optional[str] = None
target_me__regex: Optional[str] = None
active: Optional[bool] = None
4 changes: 4 additions & 0 deletions cmsdials/cmsdials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .clients.h2d.client import LumisectionHistogram2DClient
from .clients.lumisection.client import LumisectionClient
from .clients.mes.client import MonitoringElementClient
from .clients.ml_bad_lumisection.client import MLBadLumisectionClient
from .clients.ml_models_index.client import MLModelsIndexClient
from .clients.oms_proxy.client import OMSProxyClient
from .clients.run.client import RunClient

Expand All @@ -19,3 +21,5 @@ def __init__(self, creds: BaseCredentials, workspace: Optional[str] = None, *arg
self.run = RunClient(creds, workspace, *args, **kwargs)
self.mes = MonitoringElementClient(creds, workspace, *args, **kwargs)
self.oms = OMSProxyClient(creds, *args, **kwargs)
self.ml_bad_lumis = MLBadLumisectionClient(creds, workspace, *args, **kwargs)
self.ml_models_index = MLModelsIndexClient(creds, workspace, *args, **kwargs)
4 changes: 4 additions & 0 deletions cmsdials/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .clients.h2d.models import LumisectionHistogram2DFilters
from .clients.lumisection.models import LumisectionFilters
from .clients.mes.models import MEFilters
from .clients.ml_bad_lumisection.models import MLBadLumisectionFilters
from .clients.ml_models_index.models import MLModelsIndexFilters
from .clients.oms_proxy.models import OMSFilter, OMSPage
from .clients.run.models import RunFilters

Expand All @@ -14,6 +16,8 @@
"LumisectionFilters",
"RunFilters",
"MEFilters",
"MLBadLumisectionFilters",
"MLModelsIndexFilters",
"OMSFilter",
"OMSPage",
]
5 changes: 4 additions & 1 deletion cmsdials/utils/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ def _build_headers(self) -> dict:
self.creds.before_request(base)
return base

def get(self, edp: str, retries=DEFAULT_RETRIES): # noqa: A002
def get(self, edp: str, retries=DEFAULT_RETRIES, params: Optional[dict] = None, return_raw_json: bool = False):
endpoint_url = self.api_url + self.lookup_url + edp
headers = self._build_headers()
response = self._requests_get_retriable(
endpoint_url,
headers=headers,
params=params,
timeout=self.default_timeout,
retries=retries,
)
Expand All @@ -98,6 +99,8 @@ def get(self, edp: str, retries=DEFAULT_RETRIES): # noqa: A002
raise err

response = response.json()
if return_raw_json:
return response
return self.data_model(**response)

def list(self, filters=None, retries=DEFAULT_RETRIES):
Expand Down
6 changes: 5 additions & 1 deletion cmsdials/utils/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

class OBaseModel(BaseModel):
def cleandict(self):
return {key: value for key, value in self.dict().items() if value is not None}
return {
key: ",".join(str(v) for v in value) if isinstance(value, (list, tuple)) else value
for key, value in self.dict().items()
if value is not None
}


class PaginatedBaseModel(BaseModel):
Expand Down