Skip to content

Commit

Permalink
fix: add more type-hinting for SPARQL plugin (#2265)
Browse files Browse the repository at this point in the history
Here, adding type-hints to some of the SPARQL parser plugin code.

Includes a couple of small consequent changes:

1. Minor refactor of `prettify_parsetree()`, separating the public-facing callable from the internal code that does not need to be public-facing.  That allows the public-facing callable to have more informative and restrictive type-hints for its arguments.
2. Added some test-coverage for `expandUnicodeEscapes()` - initially for my own understanding, but seems useful to leave it in place since I didn't see test-coverage for that function.

There should be no backwards-incompatible changes in this PR - at least, not intentionally.

---------

Co-authored-by: Iwan Aucamp <[email protected]>
  • Loading branch information
jclerman and aucampia authored Mar 12, 2023
1 parent 8c48549 commit a44bd99
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 24 deletions.
25 changes: 15 additions & 10 deletions rdflib/plugins/sparql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import re
import sys
from typing import Any, BinaryIO
from typing import Optional as OptionalType
from typing import TextIO, Tuple, Union

from pyparsing import CaselessKeyword as Keyword # watch out :)
from pyparsing import (
Expand Down Expand Up @@ -37,15 +40,15 @@
# ---------------- ACTIONS


def neg(literal):
def neg(literal) -> rdflib.Literal:
return rdflib.Literal(-literal, datatype=literal.datatype)


def setLanguage(terms):
def setLanguage(terms: Tuple[Any, OptionalType[str]]) -> rdflib.Literal:
return rdflib.Literal(terms[0], lang=terms[1])


def setDataType(terms):
def setDataType(terms: Tuple[Any, OptionalType[str]]) -> rdflib.Literal:
return rdflib.Literal(terms[0], datatype=terms[1])


Expand Down Expand Up @@ -1508,25 +1511,27 @@ def expandCollection(terms):
UpdateUnit.ignore("#" + restOfLine)


expandUnicodeEscapes_re = re.compile(r"\\u([0-9a-f]{4}(?:[0-9a-f]{4})?)", flags=re.I)
expandUnicodeEscapes_re: re.Pattern = re.compile(
r"\\u([0-9a-f]{4}(?:[0-9a-f]{4})?)", flags=re.I
)


def expandUnicodeEscapes(q):
def expandUnicodeEscapes(q: str) -> str:
r"""
The syntax of the SPARQL Query Language is expressed over code points in Unicode [UNICODE]. The encoding is always UTF-8 [RFC3629].
Unicode code points may also be expressed using an \ uXXXX (U+0 to U+FFFF) or \ UXXXXXXXX syntax (for U+10000 onwards) where X is a hexadecimal digit [0-9A-F]
"""

def expand(m):
def expand(m: re.Match) -> str:
try:
return chr(int(m.group(1), 16))
except: # noqa: E722
raise Exception("Invalid unicode code point: " + m)
except (ValueError, OverflowError) as e:
raise ValueError("Invalid unicode code point: " + m.group(1)) from e

return expandUnicodeEscapes_re.sub(expand, q)


def parseQuery(q):
def parseQuery(q: Union[str, bytes, TextIO, BinaryIO]) -> ParseResults:
if hasattr(q, "read"):
q = q.read()
if isinstance(q, bytes):
Expand All @@ -1536,7 +1541,7 @@ def parseQuery(q):
return Query.parseString(q, parseAll=True)


def parseUpdate(q):
def parseUpdate(q: Union[str, bytes, TextIO, BinaryIO]):
if hasattr(q, "read"):
q = q.read()

Expand Down
35 changes: 22 additions & 13 deletions rdflib/plugins/sparql/parserutils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections import OrderedDict
from types import MethodType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, List, Tuple, Union

from pyparsing import ParseResults, TokenConverter, originalTextFor

from rdflib import BNode, Variable
from rdflib.term import Identifier

if TYPE_CHECKING:
from rdflib.plugins.sparql.sparql import FrozenBindings
Expand Down Expand Up @@ -252,26 +253,34 @@ def setEvalFn(self, evalfn):
return self


def prettify_parsetree(t, indent="", depth=0):
out = []
if isinstance(t, ParseResults):
for e in t.asList():
out.append(prettify_parsetree(e, indent, depth + 1))
for k, v in sorted(t.items()):
out.append("%s%s- %s:\n" % (indent, " " * depth, k))
out.append(prettify_parsetree(v, indent, depth + 1))
elif isinstance(t, CompValue):
def prettify_parsetree(t: ParseResults, indent: str = "", depth: int = 0) -> str:
out: List[str] = []
for e in t.asList():
out.append(_prettify_sub_parsetree(e, indent, depth + 1))
for k, v in sorted(t.items()):
out.append("%s%s- %s:\n" % (indent, " " * depth, k))
out.append(_prettify_sub_parsetree(v, indent, depth + 1))
return "".join(out)


def _prettify_sub_parsetree(
t: Union[Identifier, CompValue, set, list, dict, Tuple, bool, None],
indent: str = "",
depth: int = 0,
) -> str:
out: List[str] = []
if isinstance(t, CompValue):
out.append("%s%s> %s:\n" % (indent, " " * depth, t.name))
for k, v in t.items():
out.append("%s%s- %s:\n" % (indent, " " * (depth + 1), k))
out.append(prettify_parsetree(v, indent, depth + 2))
out.append(_prettify_sub_parsetree(v, indent, depth + 2))
elif isinstance(t, dict):
for k, v in t.items():
out.append("%s%s- %s:\n" % (indent, " " * (depth + 1), k))
out.append(prettify_parsetree(v, indent, depth + 2))
out.append(_prettify_sub_parsetree(v, indent, depth + 2))
elif isinstance(t, list):
for e in t:
out.append(prettify_parsetree(e, indent, depth + 1))
out.append(_prettify_sub_parsetree(e, indent, depth + 1))
else:
out.append("%s%s- %r\n" % (indent, " " * depth, t))
return "".join(out)
Expand Down
24 changes: 23 additions & 1 deletion test/test_sparql/test_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from rdflib.plugins.sparql.algebra import translateQuery
from rdflib.plugins.sparql.evaluate import evalPart
from rdflib.plugins.sparql.evalutils import _eval
from rdflib.plugins.sparql.parser import parseQuery
from rdflib.plugins.sparql.parser import expandUnicodeEscapes, parseQuery
from rdflib.plugins.sparql.parserutils import prettify_parsetree
from rdflib.plugins.sparql.sparql import SPARQLError
from rdflib.query import Result, ResultRow
Expand Down Expand Up @@ -957,3 +957,25 @@ def test_sparql_describe(
subjects = {s for s in r.graph.subjects() if not isinstance(s, BNode)}
assert subjects == expected_subjects
assert len(r.graph) == expected_size


@pytest.mark.parametrize(
"arg, expected_result, expected_valid",
[
("abc", "abc", True),
("1234", "1234", True),
(r"1234\u0050", "1234P", True),
(r"1234\u00e3", "1234\u00e3", True),
(r"1234\u00e3\u00e5", "1234ãå", True),
(r"1234\u900000e5", "", False),
(r"1234\u010000e5", "", False),
(r"1234\u001000e5", "1234\U001000e5", True),
],
)
def test_expand_unicode_escapes(arg: str, expected_result: str, expected_valid: bool):
if expected_valid:
actual_result = expandUnicodeEscapes(arg)
assert actual_result == expected_result
else:
with pytest.raises(ValueError, match="Invalid unicode code point"):
_ = expandUnicodeEscapes(arg)

0 comments on commit a44bd99

Please sign in to comment.