Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry logic on all calls #206

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions cimsparql/async_sparql_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from http import HTTPStatus
from collections.abc import Callable, Coroutine
from typing import Any

import httpx
import tenacity
from SPARQLWrapper import JSON, SPARQLWrapper
from SPARQLWrapper.SPARQLExceptions import (
EndPointInternalError,
Expand All @@ -13,6 +14,8 @@
URITooLong,
)

from cimsparql.sparql_result_json import SparqlResultJson

exceptions = {
400: QueryBadFormed,
401: Unauthorized,
Expand All @@ -21,13 +24,39 @@
500: EndPointInternalError,
}

http_task = Coroutine[Any, Any, httpx.Response]


async def retry_task(
task_generator: Callable[[], http_task], num_retries: int, max_delay_seconds: int
) -> http_task.Response:
async for attempt in tenacity.AsyncRetrying(
stop=tenacity.stop_after_attempt(num_retries + 1),
wait=tenacity.wait_exponential(max=max_delay_seconds),
):
with attempt:
resp = await task_generator()
resp.raise_for_status()
return resp


class AsyncSparqlWrapper(SPARQLWrapper):
def __init__(self, *args: Any, **kwargs: dict[str, str | None]) -> None:
self.ca_bundle: str | None = kwargs.pop("ca_bundle", None)
def __init__(
self,
*args: Any,
ca_bundle: str | None = None,
num_retries: int = 0,
max_delay_seconds: int = 60,
validate: bool = False,
**kwargs: dict[str, str | None],
) -> None:
super().__init__(*args, **kwargs)
self.ca_bundle = ca_bundle
self.num_retries = num_retries
self.max_delay_seconds = max_delay_seconds
self.validate = validate

async def queryAndConvert(self) -> dict: # noqa N802
async def query_and_convert(self) -> SparqlResultJson:
if self.returnFormat != JSON:
raise NotImplementedError("Async client only support JSON return format")

Expand All @@ -37,12 +66,12 @@ async def queryAndConvert(self) -> dict: # noqa N802

kwargs = {"verify": self.ca_bundle} if self.ca_bundle else {}
async with httpx.AsyncClient(timeout=self.timeout, **kwargs) as client:
response = await client.request(
method, url, headers=request.headers, content=request.data
response = await retry_task(
lambda: client.request(method, url, headers=request.headers, content=request.data),
self.num_retries,
self.max_delay_seconds,
)

status = response.status_code
if status != HTTPStatus.OK:
raise exceptions.get(status, Exception)(response.content)

return response.json()
result = SparqlResultJson(**response.json())
if self.validate:
result.validate_column_consistency()
return result
45 changes: 34 additions & 11 deletions cimsparql/graphdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

import httpx
import pandas as pd
import tenacity
from SPARQLWrapper import JSON, POST, SPARQLWrapper
from strenum import StrEnum

from cimsparql.async_sparql_wrapper import AsyncSparqlWrapper
from cimsparql.sparql_result_json import SparqlResultJson
from cimsparql.url import service, service_blazegraph


Expand Down Expand Up @@ -76,6 +78,9 @@ class ServiceConfig:
limit: int | None = None
offset: int | None = None
timeout: int | None = None
num_retries: int = 0
max_delay_seconds: int = 60
validate: bool = False

def __post_init__(self) -> None:
if self.rest_api not in RestApi:
Expand Down Expand Up @@ -154,15 +159,13 @@ class GraphDBClient:
Where row is the output of graphdb.data_row
"""

sparql_wrapper = SPARQLWrapper

def __init__(
self,
service_cfg: ServiceConfig | None = None,
custom_headers: dict[str, str] | None = None,
) -> None:
self.service_cfg = service_cfg or ServiceConfig()
self.sparql = self.sparql_wrapper(self.service_cfg.url, self.service_cfg.ca_bundle)
self.sparql = self.create_sparql_wrapper()
self.sparql.setReturnFormat(JSON)
self.sparql.setMethod(POST)
self.sparql.setCredentials(self.service_cfg.user, self.service_cfg.passwd)
Expand All @@ -174,6 +177,9 @@ def __init__(
self.sparql.addCustomHttpHeader(name, value)
self._prefixes = None

def create_sparql_wrapper(self) -> SPARQLWrapper:
return SPARQLWrapper(self.service_cfg.url)

def _update_sparql_parameters(self) -> None:
for key, value in self.service_cfg.parameters.items():
if value is not None:
Expand Down Expand Up @@ -206,16 +212,27 @@ def _prep_query(self, query: str) -> None:
self.sparql.setQuery(query)
self._update_sparql_parameters()

def _process_result(self, results: dict) -> dict:
cols = results["head"]["vars"]
data = results["results"]["bindings"]
out = [{c: row.get(c, {}).get("value") for c in cols} for row in data]
@staticmethod
def _process_result(results: SparqlResultJson) -> dict:
cols = results.head.variables
data = results.results.bindings
out = [{c: row[c].value if c in row else None for c in cols} for row in data]
return {"out": out, "cols": cols, "data": data}

def _exec_query(self, query: str) -> SparqlResult:
self._prep_query(query)
results = self.sparql.queryAndConvert()
return self._process_result(results)

for attempt in tenacity.Retrying(
stop=tenacity.stop_after_attempt(self.service_cfg.num_retries + 1),
wait=tenacity.wait_exponential(max=self.service_cfg.max_delay_seconds),
):
with attempt:
results = self.sparql.queryAndConvert()

sparql_result = SparqlResultJson(**results)
if self.service_cfg.validate:
sparql_result.validate_column_consistency()
return self._process_result(sparql_result)

def exec_query(self, query: str) -> list[dict[str, str]]:
out = self._exec_query(query)
Expand Down Expand Up @@ -458,11 +475,17 @@ class AsyncGraphDBClient(GraphDBClient):
Where row is the output of graphdb.data_row
"""

sparql_wrapper = AsyncSparqlWrapper
def create_sparql_wrapper(self) -> AsyncSparqlWrapper:
return AsyncSparqlWrapper(
self.service_cfg.url,
ca_bundle=self.service_cfg.ca_bundle,
num_retries=self.service_cfg.num_retries,
max_delay_seconds=self.service_cfg.max_delay_seconds,
)

async def _exec_query(self, query: str) -> SparqlResult:
self._prep_query(query)
results = await self.sparql.queryAndConvert()
results = await self.sparql.query_and_convert()
return self._process_result(results)

async def exec_query(self, query: str) -> list[dict[str, str]]:
Expand Down
2 changes: 1 addition & 1 deletion cimsparql/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def client(self, default_client: GraphDBClient) -> None:
@staticmethod
def _col_map(data_row: dict[str, str]) -> dict[str, str]:
return {
column: data.get("datatype", data.get("type", None))
column: data.datatype if data.datatype else data.value_type
for column, data in data_row.items()
}

Expand Down
71 changes: 71 additions & 0 deletions cimsparql/sparql_result_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from polyfactory.decorators import post_generated
from polyfactory.factories import pydantic_factory
from pydantic import BaseModel, Field


class SparqlResultHead(BaseModel):
link: list[str] = Field(default_factory=list)
variables: list[str] = Field(default_factory=list, alias="vars")


class SparqlResultValue(BaseModel):
value_type: str = Field(alias="type")
value: str
datatype: str = ""


class SparqlData(BaseModel):
bindings: list[dict[str, SparqlResultValue]]


class SparqlResultJson(BaseModel):
"""
Data model for rest api resonse of MIME type

application/sparql-result+json

https://www.w3.org/TR/sparql11-results-json/
"""

head: SparqlResultHead
results: SparqlData

def validate_column_consistency(self) -> None:
"""
This is an quite expensive validation since it iterates over the entire result.
Therefore, it is not implemented as a validator, but it must be explicitly called
when it is desired to perform the validation
"""
column_set = set(self.head.variables)
for item in self.results.bindings:
if set(item.keys()) != column_set:
raise ValueError(f"Missing variables for {item}. Expected {column_set}")
return self


class SparqlResultValueFactory(pydantic_factory.ModelFactory):
__model__ = SparqlResultValue


def build_sparql_result(variables: list[str]) -> SparqlData:
return SparqlData(
bindings=[
{variable: SparqlResultValueFactory.build() for variable in variables}
for _ in range(10)
]
)


class SparqlResultJsonFactory(pydantic_factory.ModelFactory):
__model__ = SparqlResultJson

@post_generated
@classmethod
def results(cls, head: SparqlResultHead) -> SparqlData:
return build_sparql_result(head.variables)

@classmethod
def build(cls) -> SparqlResultJson:
result: SparqlResultJson = super().build()
result.validate_column_consistency()
return result
Loading
Loading