Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: more type hints for rdflib.plugins.sparql #2268

Merged
merged 1 commit into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions rdflib/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
Node,
RDFLibGenid,
URIRef,
Variable,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1501,7 +1500,7 @@ def query(
processor: Union[str, query.Processor] = "sparql",
result: Union[str, Type[query.Result]] = "sparql",
initNs: Optional[Mapping[str, Any]] = None, # noqa: N803
initBindings: Optional[Mapping[Variable, Identifier]] = None,
initBindings: Optional[Mapping[str, Identifier]] = None,
use_store_provided: bool = True,
**kwargs: Any,
) -> query.Result:
Expand Down Expand Up @@ -1547,7 +1546,7 @@ def update(
update_object: Union[Update, str],
processor: Union[str, rdflib.query.UpdateProcessor] = "sparql",
initNs: Optional[Mapping[str, Any]] = None, # noqa: N803
initBindings: Optional[Mapping[Variable, Identifier]] = None,
initBindings: Optional[Mapping[str, Identifier]] = None,
use_store_provided: bool = True,
**kwargs: Any,
) -> None:
Expand Down
135 changes: 89 additions & 46 deletions rdflib/plugins/sparql/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
from decimal import Decimal
from __future__ import annotations

from rdflib import XSD, Literal
from decimal import Decimal
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)

from rdflib.namespace import XSD
from rdflib.plugins.sparql.datatypes import type_promotion
from rdflib.plugins.sparql.evalutils import NotBoundError, _eval, _val
from rdflib.plugins.sparql.evalutils import _eval, _val
from rdflib.plugins.sparql.operators import numeric
from rdflib.plugins.sparql.sparql import SPARQLTypeError
from rdflib.plugins.sparql.parserutils import CompValue
from rdflib.plugins.sparql.sparql import FrozenBindings, NotBoundError, SPARQLTypeError
from rdflib.term import BNode, Identifier, Literal, URIRef, Variable

"""
Aggregation functions
Expand All @@ -14,38 +33,43 @@
class Accumulator(object):
"""abstract base class for different aggregation functions"""

def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
self.get_value: Callable[[], Optional[Literal]]
self.update: Callable[[FrozenBindings, "Aggregator"], None]
self.var = aggregation.res
self.expr = aggregation.vars
if not aggregation.distinct:
self.use_row = self.dont_care
# type error: Cannot assign to a method
self.use_row = self.dont_care # type: ignore[assignment]
self.distinct = False
else:
self.distinct = aggregation.distinct
self.seen = set()
self.seen: Set[Any] = set()

def dont_care(self, row):
def dont_care(self, row: FrozenBindings) -> bool:
"""skips distinct test"""
return True

def use_row(self, row):
def use_row(self, row: FrozenBindings) -> bool:
"""tests distinct with set"""
return _eval(self.expr, row) not in self.seen

def set_value(self, bindings):
def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None:
"""sets final value in bindings"""
bindings[self.var] = self.get_value()
# type error: Incompatible types in assignment (expression has type "Optional[Literal]", target has type "Identifier")
bindings[self.var] = self.get_value() # type: ignore[assignment]


class Counter(Accumulator):
def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
super(Counter, self).__init__(aggregation)
self.value = 0
if self.expr == "*":
# cannot eval "*" => always use the full row
self.eval_row = self.eval_full_row
# type error: Cannot assign to a method
self.eval_row = self.eval_full_row # type: ignore[assignment]

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
val = self.eval_row(row)
except NotBoundError:
Expand All @@ -55,41 +79,54 @@ def update(self, row, aggregator):
if self.distinct:
self.seen.add(val)

def get_value(self):
def get_value(self) -> Literal:
return Literal(self.value)

def eval_row(self, row):
def eval_row(self, row: FrozenBindings) -> Identifier:
return _eval(self.expr, row)

def eval_full_row(self, row):
def eval_full_row(self, row: FrozenBindings) -> FrozenBindings:
return row

def use_row(self, row):
def use_row(self, row: FrozenBindings) -> bool:
return self.eval_row(row) not in self.seen


def type_safe_numbers(*args):
@overload
def type_safe_numbers(*args: int) -> Tuple[int]:
...


@overload
def type_safe_numbers(*args: Union[Decimal, float, int]) -> Tuple[Union[float, int]]:
...


def type_safe_numbers(*args: Union[Decimal, float, int]) -> Iterable[Union[float, int]]:
if any(isinstance(arg, float) for arg in args) and any(
isinstance(arg, Decimal) for arg in args
):
return map(float, args)
return args
# type error: Incompatible return value type (got "Tuple[Union[Decimal, float, int], ...]", expected "Iterable[Union[float, int]]")
# NOTE on type error: if args contains a Decimal it will nopt get here.
return args # type: ignore[return-value]


class Sum(Accumulator):
def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
super(Sum, self).__init__(aggregation)
self.value = 0
self.datatype = None
self.datatype: Optional[str] = None

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
dt = self.datatype
if dt is None:
dt = value.datatype
else:
dt = type_promotion(dt, value.datatype)
# type error: Argument 1 to "type_promotion" has incompatible type "str"; expected "URIRef"
dt = type_promotion(dt, value.datatype) # type: ignore[arg-type]
self.datatype = dt
self.value = sum(type_safe_numbers(self.value, numeric(value)))
if self.distinct:
Expand All @@ -98,26 +135,27 @@ def update(self, row, aggregator):
# skip UNDEF
pass

def get_value(self):
def get_value(self) -> Literal:
return Literal(self.value, datatype=self.datatype)


class Average(Accumulator):
def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
super(Average, self).__init__(aggregation)
self.counter = 0
self.sum = 0
self.datatype = None
self.datatype: Optional[str] = None

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
dt = self.datatype
self.sum = sum(type_safe_numbers(self.sum, numeric(value)))
if dt is None:
dt = value.datatype
else:
dt = type_promotion(dt, value.datatype)
# type error: Argument 1 to "type_promotion" has incompatible type "str"; expected "URIRef"
dt = type_promotion(dt, value.datatype) # type: ignore[arg-type]
self.datatype = dt
if self.distinct:
self.seen.add(value)
Expand All @@ -128,7 +166,7 @@ def update(self, row, aggregator):
except SPARQLTypeError:
pass

def get_value(self):
def get_value(self) -> Literal:
if self.counter == 0:
return Literal(0)
if self.datatype in (XSD.float, XSD.double):
Expand All @@ -140,18 +178,20 @@ def get_value(self):
class Extremum(Accumulator):
"""abstract base class for Minimum and Maximum"""

def __init__(self, aggregation):
def __init__(self, aggregation: CompValue):
self.compare: Callable[[Any, Any], Any]
super(Extremum, self).__init__(aggregation)
self.value = None
self.value: Any = None
# DISTINCT would not change the value for MIN or MAX
self.use_row = self.dont_care
# type error: Cannot assign to a method
self.use_row = self.dont_care # type: ignore[assignment]

def set_value(self, bindings):
def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None:
if self.value is not None:
# simply do not set if self.value is still None
bindings[self.var] = Literal(self.value)

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
if self.value is None:
self.value = _eval(self.expr, row)
Expand All @@ -165,13 +205,16 @@ def update(self, row, aggregator):
pass


_ValueT = TypeVar("_ValueT", Variable, BNode, URIRef, Literal)


class Minimum(Extremum):
def compare(self, val1, val2):
def compare(self, val1: _ValueT, val2: _ValueT) -> _ValueT:
return min(val1, val2, key=_val)


class Maximum(Extremum):
def compare(self, val1, val2):
def compare(self, val1: _ValueT, val2: _ValueT) -> _ValueT:
return max(val1, val2, key=_val)


Expand All @@ -183,7 +226,7 @@ def __init__(self, aggregation):
# DISTINCT would not change the value
self.use_row = self.dont_care

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
# set the value now
aggregator.bindings[self.var] = _eval(self.expr, row)
Expand All @@ -192,7 +235,7 @@ def update(self, row, aggregator):
except NotBoundError:
pass

def get_value(self):
def get_value(self) -> None:
# set None if no value was set
return None

Expand All @@ -204,7 +247,7 @@ def __init__(self, aggregation):
self.value = []
self.separator = aggregation.separator or " "

def update(self, row, aggregator):
def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
# skip UNDEF
Expand All @@ -221,7 +264,7 @@ def update(self, row, aggregator):
except NotBoundError:
pass

def get_value(self):
def get_value(self) -> Literal:
return Literal(self.separator.join(str(v) for v in self.value))


Expand All @@ -238,16 +281,16 @@ class Aggregator(object):
"Aggregate_GroupConcat": GroupConcat,
}

def __init__(self, aggregations):
self.bindings = {}
self.accumulators = {}
def __init__(self, aggregations: List[CompValue]):
self.bindings: Dict[Variable, Identifier] = {}
self.accumulators: Dict[str, Accumulator] = {}
for a in aggregations:
accumulator_class = self.accumulator_classes.get(a.name)
if accumulator_class is None:
raise Exception("Unknown aggregate function " + a.name)
self.accumulators[a.res] = accumulator_class(a)

def update(self, row):
def update(self, row: FrozenBindings) -> None:
"""update all own accumulators"""
# SAMPLE accumulators may delete themselves
# => iterate over list not generator
Expand All @@ -256,7 +299,7 @@ def update(self, row):
if acc.use_row(row):
acc.update(row, self)

def get_bindings(self):
def get_bindings(self) -> Mapping[Variable, Identifier]:
"""calculate and set last values"""
for acc in self.accumulators.values():
acc.set_value(self.bindings)
Expand Down
22 changes: 16 additions & 6 deletions rdflib/plugins/sparql/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Optional, Set

"""
Utility functions for supporting the XML Schema Datatypes hierarchy
"""

from rdflib import XSD
from rdflib.namespace import XSD

if TYPE_CHECKING:
from rdflib.term import URIRef

XSD_DTs = set(
XSD_DTs: Set[URIRef] = set(
(
XSD.integer,
XSD.decimal,
Expand Down Expand Up @@ -35,7 +42,7 @@

XSD_Duration_DTs = set((XSD.duration, XSD.dayTimeDuration, XSD.yearMonthDuration))

_sub_types = {
_sub_types: Dict[URIRef, List[URIRef]] = {
XSD.integer: [
XSD.nonPositiveInteger,
XSD.negativeInteger,
Expand All @@ -52,13 +59,13 @@
],
}

_super_types = {}
_super_types: Dict[URIRef, URIRef] = {}
for superdt in XSD_DTs:
for subdt in _sub_types.get(superdt, []):
_super_types[subdt] = superdt

# we only care about float, double, integer, decimal
_typePromotionMap = {
_typePromotionMap: Dict[URIRef, Dict[URIRef, URIRef]] = {
XSD.float: {XSD.integer: XSD.float, XSD.decimal: XSD.float, XSD.double: XSD.double},
XSD.double: {
XSD.integer: XSD.double,
Expand All @@ -78,14 +85,17 @@
}


def type_promotion(t1, t2):
def type_promotion(t1: URIRef, t2: Optional[URIRef]) -> URIRef:
if t2 is None:
return t1
t1 = _super_types.get(t1, t1)
t2 = _super_types.get(t2, t2)
if t1 == t2:
return t1 # matching super-types
try:
if TYPE_CHECKING:
# type assert because mypy is confused and thinks t2 can be None
assert t2 is not None
return _typePromotionMap[t1][t2]
except KeyError:
raise TypeError("Operators cannot combine datatypes %s and %s" % (t1, t2))
Loading