Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
This commit only adds type hints and comments and does not make any changes that
should affect runtime.

The type hints added here derive from work done for RDFLib#1418.
  • Loading branch information
aucampia committed Oct 16, 2021
1 parent 1729243 commit 885a115
Show file tree
Hide file tree
Showing 18 changed files with 350 additions and 144 deletions.
168 changes: 112 additions & 56 deletions rdflib/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from typing import Optional, Union, Type, cast, overload, Generator, Tuple
from typing import (
IO,
Any,
Iterable,
Optional,
Union,
Type,
cast,
overload,
Generator,
Tuple,
)
import logging
from warnings import warn
import random
Expand All @@ -21,7 +32,7 @@
import tempfile
import pathlib

from io import BytesIO, BufferedIOBase
from io import BytesIO
from urllib.parse import urlparse

assert Literal # avoid warning
Expand Down Expand Up @@ -313,15 +324,19 @@ class Graph(Node):
"""

def __init__(
self, store="default", identifier=None, namespace_manager=None, base=None
self,
store: Union[Store, str] = "default",
identifier: Optional[Union[Node, str]] = None,
namespace_manager: Optional[NamespaceManager] = None,
base: Optional[str] = None,
):
super(Graph, self).__init__()
self.base = base
self.__identifier = identifier or BNode()

self.__identifier: Node
self.__identifier = identifier or BNode() # type: ignore[assignment]
if not isinstance(self.__identifier, Node):
self.__identifier = URIRef(self.__identifier)

self.__identifier = URIRef(self.__identifier) # type: ignore[unreachable]
self.__store: Store
if not isinstance(store, Store):
# TODO: error handling
self.__store = store = plugin.get(store, Store)()
Expand Down Expand Up @@ -404,7 +419,7 @@ def close(self, commit_pending_transaction=False):
"""
return self.__store.close(commit_pending_transaction=commit_pending_transaction)

def add(self, triple):
def add(self, triple: Tuple[Node, Node, Node]):
"""Add a triple with self as context"""
s, p, o = triple
assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,)
Expand All @@ -413,7 +428,7 @@ def add(self, triple):
self.__store.add((s, p, o), self, quoted=False)
return self

def addN(self, quads):
def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]):
"""Add a sequence of triple with context"""

self.__store.addN(
Expand All @@ -434,7 +449,9 @@ def remove(self, triple):
self.__store.remove(triple, context=self)
return self

def triples(self, triple):
def triples(
self, triple: Tuple[Optional[Node], Union[None, Path, Node], Optional[Node]]
):
"""Generator over the triple store
Returns triples that match the given triple pattern. If triple pattern
Expand Down Expand Up @@ -652,17 +669,17 @@ def set(self, triple):
self.add((subject, predicate, object_))
return self

def subjects(self, predicate=None, object=None):
def subjects(self, predicate=None, object=None) -> Iterable[Node]:
"""A generator of subjects with the given predicate and object"""
for s, p, o in self.triples((None, predicate, object)):
yield s

def predicates(self, subject=None, object=None):
def predicates(self, subject=None, object=None) -> Iterable[Node]:
"""A generator of predicates with the given subject and object"""
for s, p, o in self.triples((subject, None, object)):
yield p

def objects(self, subject=None, predicate=None):
def objects(self, subject=None, predicate=None) -> Iterable[Node]:
"""A generator of objects with the given subject and predicate"""
for s, p, o in self.triples((subject, predicate, None)):
yield o
Expand Down Expand Up @@ -1011,53 +1028,45 @@ def absolutize(self, uri, defrag=1):
# no destination and non-None positional encoding
@overload
def serialize(
self, destination: None, format: str, base: Optional[str], encoding: str, **args
self,
destination: None,
format: str,
base: Optional[str],
encoding: str,
**args,
) -> bytes:
...

# no destination and non-None keyword encoding
@overload
def serialize(
self,
*,
destination: None = ...,
format: str = ...,
base: Optional[str] = ...,
*,
encoding: str,
**args,
) -> bytes:
...

# no destination and None positional encoding
# no destination and None encoding
@overload
def serialize(
self,
destination: None,
format: str,
base: Optional[str],
encoding: None,
**args,
) -> str:
...

# no destination and None keyword encoding
@overload
def serialize(
self,
*,
destination: None = ...,
format: str = ...,
base: Optional[str] = ...,
encoding: None = None,
encoding: None = ...,
**args,
) -> str:
...

# non-none destination
# non-None destination
@overload
def serialize(
self,
destination: Union[str, BufferedIOBase, pathlib.PurePath],
destination: Union[str, pathlib.PurePath, IO[bytes]],
format: str = ...,
base: Optional[str] = ...,
encoding: Optional[str] = ...,
Expand All @@ -1069,21 +1078,21 @@ def serialize(
@overload
def serialize(
self,
destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None,
format: str = "turtle",
base: Optional[str] = None,
encoding: Optional[str] = None,
destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = ...,
format: str = ...,
base: Optional[str] = ...,
encoding: Optional[str] = ...,
**args,
) -> Union[bytes, str, "Graph"]:
...

def serialize(
self,
destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None,
destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = None,
format: str = "turtle",
base: Optional[str] = None,
encoding: Optional[str] = None,
**args,
**args: Any,
) -> Union[bytes, str, "Graph"]:
"""Serialize the Graph to destination
Expand All @@ -1104,7 +1113,7 @@ def serialize(
base = self.base

serializer = plugin.get(format, Serializer)(self)
stream: BufferedIOBase
stream: IO[bytes]
if destination is None:
stream = BytesIO()
if encoding is None:
Expand All @@ -1114,7 +1123,7 @@ def serialize(
serializer.serialize(stream, base=base, encoding=encoding, **args)
return stream.getvalue()
if hasattr(destination, "write"):
stream = cast(BufferedIOBase, destination)
stream = cast(IO[bytes], destination)
serializer.serialize(stream, base=base, encoding=encoding, **args)
else:
if isinstance(destination, pathlib.PurePath):
Expand Down Expand Up @@ -1149,10 +1158,10 @@ def parse(
self,
source=None,
publicID=None,
format=None,
format: Optional[str] = None,
location=None,
file=None,
data=None,
data: Optional[Union[str, bytes, bytearray]] = None,
**args,
):
"""
Expand Down Expand Up @@ -1537,7 +1546,12 @@ class ConjunctiveGraph(Graph):
All queries are carried out against the union of all graphs.
"""

def __init__(self, store="default", identifier=None, default_graph_base=None):
def __init__(
self,
store: Union[Store, str] = "default",
identifier: Optional[Union[Node, str]] = None,
default_graph_base: Optional[str] = None,
):
super(ConjunctiveGraph, self).__init__(store, identifier=identifier)
assert self.store.context_aware, (
"ConjunctiveGraph must be backed by" " a context aware store."
Expand All @@ -1555,7 +1569,31 @@ def __str__(self):
)
return pattern % self.store.__class__.__name__

def _spoc(self, triple_or_quad, default=False):
@overload
def _spoc(
self,
triple_or_quad: Union[
Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]
],
default: bool = False,
) -> Tuple[Node, Node, Node, Optional[Graph]]:
...

@overload
def _spoc(
self,
triple_or_quad: None,
default: bool = False,
) -> Tuple[None, None, None, Optional[Graph]]:
...

def _spoc(
self,
triple_or_quad: Optional[
Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]]
],
default: bool = False,
) -> Tuple[Optional[Node], Optional[Node], Optional[Node], Optional[Graph]]:
"""
helper method for having methods that support
either triples or quads
Expand All @@ -1564,9 +1602,9 @@ def _spoc(self, triple_or_quad, default=False):
return (None, None, None, self.default_context if default else None)
if len(triple_or_quad) == 3:
c = self.default_context if default else None
(s, p, o) = triple_or_quad
(s, p, o) = triple_or_quad # type: ignore[misc]
elif len(triple_or_quad) == 4:
(s, p, o, c) = triple_or_quad
(s, p, o, c) = triple_or_quad # type: ignore[misc]
c = self._graph(c)
return s, p, o, c

Expand All @@ -1577,7 +1615,7 @@ def __contains__(self, triple_or_quad):
return True
return False

def add(self, triple_or_quad):
def add(self, triple_or_quad: Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]]) -> "ConjunctiveGraph": # type: ignore[override]
"""
Add a triple or quad to the store.
Expand All @@ -1591,15 +1629,23 @@ def add(self, triple_or_quad):
self.store.add((s, p, o), context=c, quoted=False)
return self

def _graph(self, c):
@overload
def _graph(self, c: Union[Graph, Node, str]) -> Graph:
...

@overload
def _graph(self, c: None) -> None:
...

def _graph(self, c: Optional[Union[Graph, Node, str]]) -> Optional[Graph]:
if c is None:
return None
if not isinstance(c, Graph):
return self.get_context(c)
else:
return c

def addN(self, quads):
def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]):
"""Add a sequence of triples with context"""

self.store.addN(
Expand Down Expand Up @@ -1689,13 +1735,19 @@ def contexts(self, triple=None):
else:
yield self.get_context(context)

def get_context(self, identifier, quoted=False, base=None):
def get_context(
self,
identifier: Optional[Union[Node, str]],
quoted: bool = False,
base: Optional[str] = None,
) -> Graph:
"""Return a context graph for the given identifier
identifier must be a URIRef or BNode.
"""
# TODO: FIXME - why is ConjunctiveGraph passed as namespace_manager?
return Graph(
store=self.store, identifier=identifier, namespace_manager=self, base=base
store=self.store, identifier=identifier, namespace_manager=self, base=base # type: ignore[arg-type]
)

def remove_context(self, context):
Expand Down Expand Up @@ -1747,6 +1799,7 @@ def parse(
context = Graph(store=self.store, identifier=g_id)
context.remove((None, None, None)) # hmm ?
context.parse(source, publicID=publicID, format=format, **args)
# TODO: FIXME: This should not return context, but self.
return context

def __reduce__(self):
Expand Down Expand Up @@ -1977,7 +2030,7 @@ class QuotedGraph(Graph):
def __init__(self, store, identifier):
super(QuotedGraph, self).__init__(store, identifier)

def add(self, triple):
def add(self, triple: Tuple[Node, Node, Node]):
"""Add a triple with self as context"""
s, p, o = triple
assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,)
Expand All @@ -1987,7 +2040,7 @@ def add(self, triple):
self.store.add((s, p, o), self, quoted=True)
return self

def addN(self, quads):
def addN(self, quads: Tuple[Node, Node, Node, Any]) -> "QuotedGraph": # type: ignore[override]
"""Add a sequence of triple with context"""

self.store.addN(
Expand Down Expand Up @@ -2261,7 +2314,7 @@ class BatchAddGraph(object):
"""

def __init__(self, graph, batch_size=1000, batch_addn=False):
def __init__(self, graph: Graph, batch_size: int = 1000, batch_addn: bool = False):
if not batch_size or batch_size < 2:
raise ValueError("batch_size must be a positive number")
self.graph = graph
Expand All @@ -2278,7 +2331,10 @@ def reset(self):
self.count = 0
return self

def add(self, triple_or_quad):
def add(
self,
triple_or_quad: Union[Tuple[Node, Node, Node], Tuple[Node, Node, Node, Any]],
) -> "BatchAddGraph":
"""
Add a triple to the buffer
Expand All @@ -2294,7 +2350,7 @@ def add(self, triple_or_quad):
self.batch.append(triple_or_quad)
return self

def addN(self, quads):
def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]):
if self.__batch_addn:
for q in quads:
self.add(q)
Expand Down
Loading

0 comments on commit 885a115

Please sign in to comment.