Skip to content

Commit

Permalink
feat: more type hints for rdflib.plugins.sparql (#2268)
Browse files Browse the repository at this point in the history
A bit of a roundabout reason why this matters now, but basically:

I want to add examples for securing RDFLib with `sys.addaudithook`
and `urllib.request.install_opener`. I also want to be sure examples
are actually valid, and runnable, so I was adding static analysis
and simple execution of examples to our CI.

During this, I noticed that examples use `initBindings` with
`Dict[str,...]`, which was not valid according to mypy, but then after
some investigation I realized the type hints in some places were too
strict.

So the main impetus for this is actually to relax the type hints in
`rdflib.graph`, but to ensure this is valid I'm adding a bunch of type
hints I had saved up to `rdflib.plugins.sparql`.

Even though this PR looks big, it has no runtime changes.
  • Loading branch information
aucampia authored Mar 12, 2023
1 parent a44bd99 commit 7a7cc1f
Show file tree
Hide file tree
Showing 14 changed files with 453 additions and 182 deletions.
5 changes: 2 additions & 3 deletions rdflib/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
Node,
RDFLibGenid,
URIRef,
Variable,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1501,7 +1500,7 @@ def query(
processor: Union[str, query.Processor] = "sparql",
result: Union[str, Type[query.Result]] = "sparql",
initNs: Optional[Mapping[str, Any]] = None, # noqa: N803
initBindings: Optional[Mapping[Variable, Identifier]] = None,
initBindings: Optional[Mapping[str, Identifier]] = None,
use_store_provided: bool = True,
**kwargs: Any,
) -> query.Result:
Expand Down Expand Up @@ -1547,7 +1546,7 @@ def update(
update_object: Union[Update, str],
processor: Union[str, rdflib.query.UpdateProcessor] = "sparql",
initNs: Optional[Mapping[str, Any]] = None, # noqa: N803
initBindings: Optional[Mapping[Variable, Identifier]] = None,
initBindings: Optional[Mapping[str, Identifier]] = None,
use_store_provided: bool = True,
**kwargs: Any,
) -> None:
Expand Down
135 changes: 89 additions & 46 deletions rdflib/plugins/sparql/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
from decimal import Decimal
from __future__ import annotations

from rdflib import XSD, Literal
from decimal import Decimal
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)

from rdflib.namespace import XSD
from rdflib.plugins.sparql.datatypes import type_promotion
from rdflib.plugins.sparql.evalutils import NotBoundError, _eval, _val
from rdflib.plugins.sparql.evalutils import _eval, _val
from rdflib.plugins.sparql.operators import numeric
from rdflib.plugins.sparql.sparql import SPARQLTypeError
from rdflib.plugins.sparql.parserutils import CompValue
from rdflib.plugins.sparql.sparql import FrozenBindings, NotBoundError, SPARQLTypeError
from rdflib.term import BNode, Identifier, Literal, URIRef, Variable

"""
Aggregation functions
Expand All @@ -14,38 +33,43 @@
class Accumulator(object):
"""abstract base class for different aggregation functions"""

def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
self.get_value: Callable[[], Optional[Literal]]
self.update: Callable[[FrozenBindings, "Aggregator"], None]
self.var = aggregation.res
self.expr = aggregation.vars
if not aggregation.distinct:
self.use_row = self.dont_care
# type error: Cannot assign to a method
self.use_row = self.dont_care # type: ignore[assignment]
self.distinct = False
else:
self.distinct = aggregation.distinct
self.seen = set()
self.seen: Set[Any] = set()

def dont_care(self, row):
def dont_care(self, row: FrozenBindings) -> bool:
"""skips distinct test"""
return True

def use_row(self, row):
def use_row(self, row: FrozenBindings) -> bool:
"""tests distinct with set"""
return _eval(self.expr, row) not in self.seen

def set_value(self, bindings):
def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None:
"""sets final value in bindings"""
bindings[self.var] = self.get_value()
# type error: Incompatible types in assignment (expression has type "Optional[Literal]", target has type "Identifier")
bindings[self.var] = self.get_value() # type: ignore[assignment]


class Counter(Accumulator):
def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
super(Counter, self).__init__(aggregation)
self.value = 0
if self.expr == "*":
# cannot eval "*" => always use the full row
self.eval_row = self.eval_full_row
# type error: Cannot assign to a method
self.eval_row = self.eval_full_row # type: ignore[assignment]

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
val = self.eval_row(row)
except NotBoundError:
Expand All @@ -55,41 +79,54 @@ def update(self, row, aggregator):
if self.distinct:
self.seen.add(val)

def get_value(self):
def get_value(self) -> Literal:
return Literal(self.value)

def eval_row(self, row):
def eval_row(self, row: FrozenBindings) -> Identifier:
return _eval(self.expr, row)

def eval_full_row(self, row):
def eval_full_row(self, row: FrozenBindings) -> FrozenBindings:
return row

def use_row(self, row):
def use_row(self, row: FrozenBindings) -> bool:
return self.eval_row(row) not in self.seen


def type_safe_numbers(*args):
@overload
def type_safe_numbers(*args: int) -> Tuple[int]:
...


@overload
def type_safe_numbers(*args: Union[Decimal, float, int]) -> Tuple[Union[float, int]]:
...


def type_safe_numbers(*args: Union[Decimal, float, int]) -> Iterable[Union[float, int]]:
if any(isinstance(arg, float) for arg in args) and any(
isinstance(arg, Decimal) for arg in args
):
return map(float, args)
return args
# type error: Incompatible return value type (got "Tuple[Union[Decimal, float, int], ...]", expected "Iterable[Union[float, int]]")
# NOTE on type error: if args contains a Decimal it will nopt get here.
return args # type: ignore[return-value]


class Sum(Accumulator):
def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
super(Sum, self).__init__(aggregation)
self.value = 0
self.datatype = None
self.datatype: Optional[str] = None

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
dt = self.datatype
if dt is None:
dt = value.datatype
else:
dt = type_promotion(dt, value.datatype)
# type error: Argument 1 to "type_promotion" has incompatible type "str"; expected "URIRef"
dt = type_promotion(dt, value.datatype) # type: ignore[arg-type]
self.datatype = dt
self.value = sum(type_safe_numbers(self.value, numeric(value)))
if self.distinct:
Expand All @@ -98,26 +135,27 @@ def update(self, row, aggregator):
# skip UNDEF
pass

def get_value(self):
def get_value(self) -> Literal:
return Literal(self.value, datatype=self.datatype)


class Average(Accumulator):
def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
super(Average, self).__init__(aggregation)
self.counter = 0
self.sum = 0
self.datatype = None
self.datatype: Optional[str] = None

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
dt = self.datatype
self.sum = sum(type_safe_numbers(self.sum, numeric(value)))
if dt is None:
dt = value.datatype
else:
dt = type_promotion(dt, value.datatype)
# type error: Argument 1 to "type_promotion" has incompatible type "str"; expected "URIRef"
dt = type_promotion(dt, value.datatype) # type: ignore[arg-type]
self.datatype = dt
if self.distinct:
self.seen.add(value)
Expand All @@ -128,7 +166,7 @@ def update(self, row, aggregator):
except SPARQLTypeError:
pass

def get_value(self):
def get_value(self) -> Literal:
if self.counter == 0:
return Literal(0)
if self.datatype in (XSD.float, XSD.double):
Expand All @@ -140,18 +178,20 @@ def get_value(self):
class Extremum(Accumulator):
"""abstract base class for Minimum and Maximum"""

def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
self.compare: Callable[[Any, Any], Any]
super(Extremum, self).__init__(aggregation)
self.value = None
self.value: Any = None
# DISTINCT would not change the value for MIN or MAX
self.use_row = self.dont_care
# type error: Cannot assign to a method
self.use_row = self.dont_care # type: ignore[assignment]

def set_value(self, bindings):
def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None:
if self.value is not None:
# simply do not set if self.value is still None
bindings[self.var] = Literal(self.value)

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
if self.value is None:
self.value = _eval(self.expr, row)
Expand All @@ -165,13 +205,16 @@ def update(self, row, aggregator):
pass


_ValueT = TypeVar("_ValueT", Variable, BNode, URIRef, Literal)


class Minimum(Extremum):
def compare(self, val1, val2):
def compare(self, val1: _ValueT, val2: _ValueT) -> _ValueT:
return min(val1, val2, key=_val)


class Maximum(Extremum):
def compare(self, val1, val2):
def compare(self, val1: _ValueT, val2: _ValueT) -> _ValueT:
return max(val1, val2, key=_val)


Expand All @@ -183,7 +226,7 @@ def __init__(self, aggregation):
# DISTINCT would not change the value
self.use_row = self.dont_care

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
# set the value now
aggregator.bindings[self.var] = _eval(self.expr, row)
Expand All @@ -192,7 +235,7 @@ def update(self, row, aggregator):
except NotBoundError:
pass

def get_value(self):
def get_value(self) -> None:
# set None if no value was set
return None

Expand All @@ -204,7 +247,7 @@ def __init__(self, aggregation):
self.value = []
self.separator = aggregation.separator or " "

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
# skip UNDEF
Expand All @@ -221,7 +264,7 @@ def update(self, row, aggregator):
except NotBoundError:
pass

def get_value(self):
def get_value(self) -> Literal:
return Literal(self.separator.join(str(v) for v in self.value))


Expand All @@ -238,16 +281,16 @@ class Aggregator(object):
"Aggregate_GroupConcat": GroupConcat,
}

def __init__(self, aggregations):
self.bindings = {}
self.accumulators = {}
def __init__(self, aggregations: List[CompValue]):
self.bindings: Dict[Variable, Identifier] = {}
self.accumulators: Dict[str, Accumulator] = {}
for a in aggregations:
accumulator_class = self.accumulator_classes.get(a.name)
if accumulator_class is None:
raise Exception("Unknown aggregate function " + a.name)
self.accumulators[a.res] = accumulator_class(a)

def update(self, row):
def update(self, row: FrozenBindings) -> None:
"""update all own accumulators"""
# SAMPLE accumulators may delete themselves
# => iterate over list not generator
Expand All @@ -256,7 +299,7 @@ def update(self, row):
if acc.use_row(row):
acc.update(row, self)

def get_bindings(self):
def get_bindings(self) -> Mapping[Variable, Identifier]:
"""calculate and set last values"""
for acc in self.accumulators.values():
acc.set_value(self.bindings)
Expand Down
22 changes: 16 additions & 6 deletions rdflib/plugins/sparql/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Optional, Set

"""
Utility functions for supporting the XML Schema Datatypes hierarchy
"""

from rdflib import XSD
from rdflib.namespace import XSD

if TYPE_CHECKING:
from rdflib.term import URIRef

XSD_DTs = set(
XSD_DTs: Set[URIRef] = set(
(
XSD.integer,
XSD.decimal,
Expand Down Expand Up @@ -35,7 +42,7 @@

XSD_Duration_DTs = set((XSD.duration, XSD.dayTimeDuration, XSD.yearMonthDuration))

_sub_types = {
_sub_types: Dict[URIRef, List[URIRef]] = {
XSD.integer: [
XSD.nonPositiveInteger,
XSD.negativeInteger,
Expand All @@ -52,13 +59,13 @@
],
}

_super_types = {}
_super_types: Dict[URIRef, URIRef] = {}
for superdt in XSD_DTs:
for subdt in _sub_types.get(superdt, []):
_super_types[subdt] = superdt

# we only care about float, double, integer, decimal
_typePromotionMap = {
_typePromotionMap: Dict[URIRef, Dict[URIRef, URIRef]] = {
XSD.float: {XSD.integer: XSD.float, XSD.decimal: XSD.float, XSD.double: XSD.double},
XSD.double: {
XSD.integer: XSD.double,
Expand All @@ -78,14 +85,17 @@
}


def type_promotion(t1, t2):
def type_promotion(t1: URIRef, t2: Optional[URIRef]) -> URIRef:
if t2 is None:
return t1
t1 = _super_types.get(t1, t1)
t2 = _super_types.get(t2, t2)
if t1 == t2:
return t1 # matching super-types
try:
if TYPE_CHECKING:
# type assert because mypy is confused and thinks t2 can be None
assert t2 is not None
return _typePromotionMap[t1][t2]
except KeyError:
raise TypeError("Operators cannot combine datatypes %s and %s" % (t1, t2))
Loading

0 comments on commit 7a7cc1f

Please sign in to comment.