diff --git a/rdflib/graph.py b/rdflib/graph.py index 805bb7c64c..4d2fa47bf3 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -1,4 +1,15 @@ -from typing import Optional, Union, Type, cast, overload, Generator, Tuple +from typing import ( + IO, + Any, + Iterable, + Optional, + Union, + Type, + cast, + overload, + Generator, + Tuple, +) import logging from warnings import warn import random @@ -21,7 +32,7 @@ import tempfile import pathlib -from io import BytesIO, BufferedIOBase +from io import BytesIO from urllib.parse import urlparse assert Literal # avoid warning @@ -313,15 +324,19 @@ class Graph(Node): """ def __init__( - self, store="default", identifier=None, namespace_manager=None, base=None + self, + store: Union[Store, str] = "default", + identifier: Optional[Union[Node, str]] = None, + namespace_manager: Optional[NamespaceManager] = None, + base: Optional[str] = None, ): super(Graph, self).__init__() self.base = base - self.__identifier = identifier or BNode() - + self.__identifier: Node + self.__identifier = identifier or BNode() # type: ignore[assignment] if not isinstance(self.__identifier, Node): - self.__identifier = URIRef(self.__identifier) - + self.__identifier = URIRef(self.__identifier) # type: ignore[unreachable] + self.__store: Store if not isinstance(store, Store): # TODO: error handling self.__store = store = plugin.get(store, Store)() @@ -404,7 +419,7 @@ def close(self, commit_pending_transaction=False): """ return self.__store.close(commit_pending_transaction=commit_pending_transaction) - def add(self, triple): + def add(self, triple: Tuple[Node, Node, Node]): """Add a triple with self as context""" s, p, o = triple assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,) @@ -413,7 +428,7 @@ def add(self, triple): self.__store.add((s, p, o), self, quoted=False) return self - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): """Add a sequence of triple with context""" self.__store.addN( @@ -434,7 +449,9 @@ def remove(self, triple): self.__store.remove(triple, context=self) return self - def triples(self, triple): + def triples( + self, triple: Tuple[Optional[Node], Union[None, Path, Node], Optional[Node]] + ): """Generator over the triple store Returns triples that match the given triple pattern. If triple pattern @@ -652,17 +669,17 @@ def set(self, triple): self.add((subject, predicate, object_)) return self - def subjects(self, predicate=None, object=None): + def subjects(self, predicate=None, object=None) -> Iterable[Node]: """A generator of subjects with the given predicate and object""" for s, p, o in self.triples((None, predicate, object)): yield s - def predicates(self, subject=None, object=None): + def predicates(self, subject=None, object=None) -> Iterable[Node]: """A generator of predicates with the given subject and object""" for s, p, o in self.triples((subject, None, object)): yield p - def objects(self, subject=None, predicate=None): + def objects(self, subject=None, predicate=None) -> Iterable[Node]: """A generator of objects with the given subject and predicate""" for s, p, o in self.triples((subject, predicate, None)): yield o @@ -1011,7 +1028,12 @@ def absolutize(self, uri, defrag=1): # no destination and non-None positional encoding @overload def serialize( - self, destination: None, format: str, base: Optional[str], encoding: str, **args + self, + destination: None, + format: str, + base: Optional[str], + encoding: str, + **args, ) -> bytes: ... @@ -1019,45 +1041,32 @@ def serialize( @overload def serialize( self, - *, destination: None = ..., format: str = ..., base: Optional[str] = ..., + *, encoding: str, **args, ) -> bytes: ... - # no destination and None positional encoding + # no destination and None encoding @overload def serialize( self, - destination: None, - format: str, - base: Optional[str], - encoding: None, - **args, - ) -> str: - ... - - # no destination and None keyword encoding - @overload - def serialize( - self, - *, destination: None = ..., format: str = ..., base: Optional[str] = ..., - encoding: None = None, + encoding: None = ..., **args, ) -> str: ... - # non-none destination + # non-None destination @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath], + destination: Union[str, pathlib.PurePath, IO[bytes]], format: str = ..., base: Optional[str] = ..., encoding: Optional[str] = ..., @@ -1069,21 +1078,21 @@ def serialize( @overload def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None, - format: str = "turtle", - base: Optional[str] = None, - encoding: Optional[str] = None, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = ..., + format: str = ..., + base: Optional[str] = ..., + encoding: Optional[str] = ..., **args, ) -> Union[bytes, str, "Graph"]: ... def serialize( self, - destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = None, format: str = "turtle", base: Optional[str] = None, encoding: Optional[str] = None, - **args, + **args: Any, ) -> Union[bytes, str, "Graph"]: """Serialize the Graph to destination @@ -1104,7 +1113,7 @@ def serialize( base = self.base serializer = plugin.get(format, Serializer)(self) - stream: BufferedIOBase + stream: IO[bytes] if destination is None: stream = BytesIO() if encoding is None: @@ -1114,7 +1123,7 @@ def serialize( serializer.serialize(stream, base=base, encoding=encoding, **args) return stream.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, base=base, encoding=encoding, **args) else: if isinstance(destination, pathlib.PurePath): @@ -1149,10 +1158,10 @@ def parse( self, source=None, publicID=None, - format=None, + format: Optional[str] = None, location=None, file=None, - data=None, + data: Optional[Union[str, bytes, bytearray]] = None, **args, ): """ @@ -1537,7 +1546,12 @@ class ConjunctiveGraph(Graph): All queries are carried out against the union of all graphs. """ - def __init__(self, store="default", identifier=None, default_graph_base=None): + def __init__( + self, + store: Union[Store, str] = "default", + identifier: Optional[Union[Node, str]] = None, + default_graph_base: Optional[str] = None, + ): super(ConjunctiveGraph, self).__init__(store, identifier=identifier) assert self.store.context_aware, ( "ConjunctiveGraph must be backed by" " a context aware store." @@ -1555,7 +1569,31 @@ def __str__(self): ) return pattern % self.store.__class__.__name__ - def _spoc(self, triple_or_quad, default=False): + @overload + def _spoc( + self, + triple_or_quad: Union[ + Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node] + ], + default: bool = False, + ) -> Tuple[Node, Node, Node, Optional[Graph]]: + ... + + @overload + def _spoc( + self, + triple_or_quad: None, + default: bool = False, + ) -> Tuple[None, None, None, Optional[Graph]]: + ... + + def _spoc( + self, + triple_or_quad: Optional[ + Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]] + ], + default: bool = False, + ) -> Tuple[Optional[Node], Optional[Node], Optional[Node], Optional[Graph]]: """ helper method for having methods that support either triples or quads @@ -1564,9 +1602,9 @@ def _spoc(self, triple_or_quad, default=False): return (None, None, None, self.default_context if default else None) if len(triple_or_quad) == 3: c = self.default_context if default else None - (s, p, o) = triple_or_quad + (s, p, o) = triple_or_quad # type: ignore[misc] elif len(triple_or_quad) == 4: - (s, p, o, c) = triple_or_quad + (s, p, o, c) = triple_or_quad # type: ignore[misc] c = self._graph(c) return s, p, o, c @@ -1577,7 +1615,7 @@ def __contains__(self, triple_or_quad): return True return False - def add(self, triple_or_quad): + def add(self, triple_or_quad: Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]]) -> "ConjunctiveGraph": # type: ignore[override] """ Add a triple or quad to the store. @@ -1591,7 +1629,15 @@ def add(self, triple_or_quad): self.store.add((s, p, o), context=c, quoted=False) return self - def _graph(self, c): + @overload + def _graph(self, c: Union[Graph, Node, str]) -> Graph: + ... + + @overload + def _graph(self, c: None) -> None: + ... + + def _graph(self, c: Optional[Union[Graph, Node, str]]) -> Optional[Graph]: if c is None: return None if not isinstance(c, Graph): @@ -1599,7 +1645,7 @@ def _graph(self, c): else: return c - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): """Add a sequence of triples with context""" self.store.addN( @@ -1689,13 +1735,19 @@ def contexts(self, triple=None): else: yield self.get_context(context) - def get_context(self, identifier, quoted=False, base=None): + def get_context( + self, + identifier: Optional[Union[Node, str]], + quoted: bool = False, + base: Optional[str] = None, + ) -> Graph: """Return a context graph for the given identifier identifier must be a URIRef or BNode. """ + # TODO: FIXME - why is ConjunctiveGraph passed as namespace_manager? return Graph( - store=self.store, identifier=identifier, namespace_manager=self, base=base + store=self.store, identifier=identifier, namespace_manager=self, base=base # type: ignore[arg-type] ) def remove_context(self, context): @@ -1747,6 +1799,7 @@ def parse( context = Graph(store=self.store, identifier=g_id) context.remove((None, None, None)) # hmm ? context.parse(source, publicID=publicID, format=format, **args) + # TODO: FIXME: This should not return context, but self. return context def __reduce__(self): @@ -1977,7 +2030,7 @@ class QuotedGraph(Graph): def __init__(self, store, identifier): super(QuotedGraph, self).__init__(store, identifier) - def add(self, triple): + def add(self, triple: Tuple[Node, Node, Node]): """Add a triple with self as context""" s, p, o = triple assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,) @@ -1987,7 +2040,7 @@ def add(self, triple): self.store.add((s, p, o), self, quoted=True) return self - def addN(self, quads): + def addN(self, quads: Tuple[Node, Node, Node, Any]) -> "QuotedGraph": # type: ignore[override] """Add a sequence of triple with context""" self.store.addN( @@ -2261,7 +2314,7 @@ class BatchAddGraph(object): """ - def __init__(self, graph, batch_size=1000, batch_addn=False): + def __init__(self, graph: Graph, batch_size: int = 1000, batch_addn: bool = False): if not batch_size or batch_size < 2: raise ValueError("batch_size must be a positive number") self.graph = graph @@ -2278,7 +2331,10 @@ def reset(self): self.count = 0 return self - def add(self, triple_or_quad): + def add( + self, + triple_or_quad: Union[Tuple[Node, Node, Node], Tuple[Node, Node, Node, Any]], + ) -> "BatchAddGraph": """ Add a triple to the buffer @@ -2294,7 +2350,7 @@ def add(self, triple_or_quad): self.batch.append(triple_or_quad) return self - def addN(self, quads): + def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]): if self.__batch_addn: for q in quads: self.add(q) diff --git a/rdflib/parser.py b/rdflib/parser.py index f0014150f6..1f8a490cde 100644 --- a/rdflib/parser.py +++ b/rdflib/parser.py @@ -16,6 +16,7 @@ import sys from io import BytesIO, TextIOBase, TextIOWrapper, StringIO, BufferedIOBase +from typing import Optional, Union from urllib.request import Request from urllib.request import url2pathname @@ -44,7 +45,7 @@ class Parser(object): def __init__(self): pass - def parse(self, source, sink): + def parse(self, source, sink, **args): pass @@ -214,7 +215,12 @@ def __repr__(self): def create_input_source( - source=None, publicID=None, location=None, file=None, data=None, format=None + source=None, + publicID=None, + location=None, + file=None, + data: Optional[Union[str, bytes, bytearray]] = None, + format=None, ): """ Return an appropriate InputSource instance for the given diff --git a/rdflib/plugin.py b/rdflib/plugin.py index 719c7eaf55..ac3a7fbd06 100644 --- a/rdflib/plugin.py +++ b/rdflib/plugin.py @@ -36,7 +36,21 @@ UpdateProcessor, ) from rdflib.exceptions import Error -from typing import Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterator, + Optional, + Tuple, + Type, + TypeVar, + overload, +) + +if TYPE_CHECKING: + from pkg_resources import EntryPoint __all__ = ["register", "get", "plugins", "PluginException", "Plugin", "PKGPlugin"] @@ -51,42 +65,47 @@ "rdf.plugins.updateprocessor": UpdateProcessor, } -_plugins = {} +_plugins: Dict[Tuple[str, Type[Any]], "Plugin"] = {} class PluginException(Error): pass -class Plugin(object): - def __init__(self, name, kind, module_path, class_name): +PluginT = TypeVar("PluginT") + + +class Plugin(Generic[PluginT]): + def __init__( + self, name: str, kind: Type[PluginT], module_path: str, class_name: str + ): self.name = name self.kind = kind self.module_path = module_path self.class_name = class_name - self._class = None + self._class: Optional[Type[PluginT]] = None - def getClass(self): + def getClass(self) -> Type[PluginT]: if self._class is None: module = __import__(self.module_path, globals(), locals(), [""]) self._class = getattr(module, self.class_name) return self._class -class PKGPlugin(Plugin): - def __init__(self, name, kind, ep): +class PKGPlugin(Plugin[PluginT]): + def __init__(self, name: str, kind: Type[PluginT], ep: "EntryPoint"): self.name = name self.kind = kind self.ep = ep - self._class = None + self._class: Optional[Type[PluginT]] = None - def getClass(self): + def getClass(self) -> Type[PluginT]: if self._class is None: self._class = self.ep.load() return self._class -def register(name: str, kind, module_path, class_name): +def register(name: str, kind: Type[Any], module_path, class_name): """ Register the plugin for (name, kind). The module_path and class_name should be the path to a plugin class. @@ -95,16 +114,13 @@ def register(name: str, kind, module_path, class_name): _plugins[(name, kind)] = p -PluginT = TypeVar("PluginT") - - def get(name: str, kind: Type[PluginT]) -> Type[PluginT]: """ Return the class for the specified (name, kind). Raises a PluginException if unable to do so. """ try: - p = _plugins[(name, kind)] + p: Plugin[PluginT] = _plugins[(name, kind)] except KeyError: raise PluginException("No plugin registered for (%s, %s)" % (name, kind)) return p.getClass() @@ -121,7 +137,21 @@ def get(name: str, kind: Type[PluginT]) -> Type[PluginT]: _plugins[(ep.name, kind)] = PKGPlugin(ep.name, kind, ep) -def plugins(name=None, kind=None): +@overload +def plugins( + name: Optional[str] = ..., kind: Type[PluginT] = ... +) -> Iterator[Plugin[PluginT]]: + ... + + +@overload +def plugins(name: Optional[str] = ..., kind: None = ...) -> Iterator[Plugin]: + ... + + +def plugins( + name: Optional[str] = None, kind: Optional[Type[PluginT]] = None +) -> Iterator[Plugin]: """ A generator of the plugins. diff --git a/rdflib/plugins/serializers/jsonld.py b/rdflib/plugins/serializers/jsonld.py index 67f3b86232..f5067e2873 100644 --- a/rdflib/plugins/serializers/jsonld.py +++ b/rdflib/plugins/serializers/jsonld.py @@ -41,6 +41,7 @@ from rdflib.graph import Graph from rdflib.term import URIRef, Literal, BNode from rdflib.namespace import RDF, XSD +from typing import IO, Optional from ..shared.jsonld.context import Context, UNDEF from ..shared.jsonld.util import json @@ -53,10 +54,16 @@ class JsonLDSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(JsonLDSerializer, self).__init__(store) - def serialize(self, stream, base=None, encoding=None, **kwargs): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **kwargs + ): # TODO: docstring w. args and return value encoding = encoding or "utf-8" if encoding not in ("utf-8", "utf-16"): diff --git a/rdflib/plugins/serializers/n3.py b/rdflib/plugins/serializers/n3.py index 6c4e2ec46d..806f445ef8 100644 --- a/rdflib/plugins/serializers/n3.py +++ b/rdflib/plugins/serializers/n3.py @@ -14,7 +14,7 @@ class N3Serializer(TurtleSerializer): short_name = "n3" - def __init__(self, store, parent=None): + def __init__(self, store: Graph, parent=None): super(N3Serializer, self).__init__(store) self.keywords.update({OWL.sameAs: "=", SWAP_LOG.implies: "=>"}) self.parent = parent diff --git a/rdflib/plugins/serializers/nquads.py b/rdflib/plugins/serializers/nquads.py index 54ee42ba12..e76c747d49 100644 --- a/rdflib/plugins/serializers/nquads.py +++ b/rdflib/plugins/serializers/nquads.py @@ -1,5 +1,7 @@ +from typing import IO, Optional import warnings +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.term import Literal from rdflib.serializer import Serializer @@ -9,15 +11,22 @@ class NQuadsSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): if not store.context_aware: raise Exception( "NQuads serialization only makes " "sense for context-aware stores!" ) super(NQuadsSerializer, self).__init__(store) - - def serialize(self, stream, base=None, encoding=None, **args): + self.store: ConjunctiveGraph + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): if base is not None: warnings.warn("NQuadsSerializer does not support base.") if encoding is not None and encoding.lower() != self.encoding.lower(): diff --git a/rdflib/plugins/serializers/nt.py b/rdflib/plugins/serializers/nt.py index bc265ee5f4..467de46134 100644 --- a/rdflib/plugins/serializers/nt.py +++ b/rdflib/plugins/serializers/nt.py @@ -3,6 +3,9 @@ See for details about the format. """ +from typing import IO, Optional + +from rdflib.graph import Graph from rdflib.term import Literal from rdflib.serializer import Serializer @@ -17,11 +20,17 @@ class NTSerializer(Serializer): Serializes RDF graphs to NTriples format. """ - def __init__(self, store): + def __init__(self, store: Graph): Serializer.__init__(self, store) self.encoding = "ascii" # n-triples are ascii encoded - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): if base is not None: warnings.warn("NTSerializer does not support base.") if encoding is not None and encoding.lower() != self.encoding.lower(): @@ -39,7 +48,7 @@ class NT11Serializer(NTSerializer): Exactly like nt - only utf8 encoded. """ - def __init__(self, store): + def __init__(self, store: Graph): Serializer.__init__(self, store) # default to utf-8 diff --git a/rdflib/plugins/serializers/rdfxml.py b/rdflib/plugins/serializers/rdfxml.py index 72648afbac..901d911d91 100644 --- a/rdflib/plugins/serializers/rdfxml.py +++ b/rdflib/plugins/serializers/rdfxml.py @@ -1,9 +1,11 @@ +from typing import IO, Dict, Optional, Set, cast from rdflib.plugins.serializers.xmlwriter import XMLWriter from rdflib.namespace import Namespace, RDF, RDFS # , split_uri from rdflib.plugins.parsers.RDFVOC import RDFVOC -from rdflib.term import URIRef, Literal, BNode +from rdflib.graph import Graph +from rdflib.term import Identifier, URIRef, Literal, BNode from rdflib.util import first, more_than from rdflib.collection import Collection from rdflib.serializer import Serializer @@ -17,7 +19,7 @@ class XMLSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(XMLSerializer, self).__init__(store) def __bindings(self): @@ -39,14 +41,20 @@ def __bindings(self): for prefix, namespace in bindings.items(): yield prefix, namespace - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): # if base is given here, use that, if not and a base is set for the graph use that if base is not None: self.base = base elif self.store.base is not None: self.base = self.store.base self.__stream = stream - self.__serialized = {} + self.__serialized: Dict[Identifier, int] = {} encoding = self.encoding self.write = write = lambda uni: stream.write(uni.encode(encoding, "replace")) @@ -154,12 +162,18 @@ def fix(val): class PrettyXMLSerializer(Serializer): - def __init__(self, store, max_depth=3): + def __init__(self, store: Graph, max_depth=3): super(PrettyXMLSerializer, self).__init__(store) - self.forceRDFAbout = set() - - def serialize(self, stream, base=None, encoding=None, **args): - self.__serialized = {} + self.forceRDFAbout: Set[URIRef] = set() + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): + self.__serialized: Dict[Identifier, int] = {} store = self.store # if base is given here, use that, if not and a base is set for the graph use that if base is not None: @@ -190,8 +204,9 @@ def serialize(self, stream, base=None, encoding=None, **args): writer.namespaces(namespaces.items()) + subject: Identifier # Write out subjects that can not be inline - for subject in store.subjects(): + for subject in store.subjects(): # type: ignore[assignment] if (None, None, subject) in store: if (subject, None, subject) in store: self.subject(subject, 1) @@ -202,7 +217,7 @@ def serialize(self, stream, base=None, encoding=None, **args): # write out BNodes last (to ensure they can be inlined where possible) bnodes = set() - for subject in store.subjects(): + for subject in store.subjects(): # type: ignore[assignment] if isinstance(subject, BNode): bnodes.add(subject) continue @@ -217,9 +232,9 @@ def serialize(self, stream, base=None, encoding=None, **args): stream.write("\n".encode("latin-1")) # Set to None so that the memory can get garbage collected. - self.__serialized = None + self.__serialized = None # type: ignore[assignment] - def subject(self, subject, depth=1): + def subject(self, subject: Identifier, depth: int = 1): store = self.store writer = self.writer @@ -227,7 +242,7 @@ def subject(self, subject, depth=1): writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) - self.forceRDFAbout.remove(subject) + self.forceRDFAbout.remove(subject) # type: ignore[arg-type] elif subject not in self.__serialized: self.__serialized[subject] = 1 @@ -264,10 +279,11 @@ def subj_as_obj_more_than(ceil): writer.pop(element) elif subject in self.forceRDFAbout: + # TODO FIXME?: this looks like a duplicate of first condition writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) - self.forceRDFAbout.remove(subject) + self.forceRDFAbout.remove(subject) # type: ignore[arg-type] def predicate(self, predicate, object, depth=1): writer = self.writer diff --git a/rdflib/plugins/serializers/trig.py b/rdflib/plugins/serializers/trig.py index cdaedd4892..5a606e401c 100644 --- a/rdflib/plugins/serializers/trig.py +++ b/rdflib/plugins/serializers/trig.py @@ -4,9 +4,12 @@ """ from collections import defaultdict +from typing import IO, TYPE_CHECKING, Optional, Union +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.plugins.serializers.turtle import TurtleSerializer -from rdflib.term import BNode +from rdflib.term import BNode, Node + __all__ = ["TrigSerializer"] @@ -16,8 +19,11 @@ class TrigSerializer(TurtleSerializer): short_name = "trig" indentString = 4 * " " - def __init__(self, store): + def __init__(self, store: Union[Graph, ConjunctiveGraph]): + self.default_context: Optional[Node] if store.context_aware: + if TYPE_CHECKING: + assert isinstance(store, ConjunctiveGraph) self.contexts = list(store.contexts()) self.default_context = store.default_context.identifier if store.default_context: @@ -48,7 +54,14 @@ def reset(self): super(TrigSerializer, self).reset() self._contexts = {} - def serialize(self, stream, base=None, encoding=None, spacious=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + spacious: Optional[bool] = None, + **args + ): self.reset() self.stream = stream # if base is given here, use that, if not and a base is set for the graph use that diff --git a/rdflib/plugins/serializers/trix.py b/rdflib/plugins/serializers/trix.py index 05b6f528f3..1612d815cc 100644 --- a/rdflib/plugins/serializers/trix.py +++ b/rdflib/plugins/serializers/trix.py @@ -1,3 +1,4 @@ +from typing import IO, Optional from rdflib.serializer import Serializer from rdflib.plugins.serializers.xmlwriter import XMLWriter @@ -15,14 +16,20 @@ class TriXSerializer(Serializer): - def __init__(self, store): + def __init__(self, store: Graph): super(TriXSerializer, self).__init__(store) if not store.context_aware: raise Exception( "TriX serialization only makes sense for context-aware stores" ) - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ): nm = self.store.namespace_manager diff --git a/rdflib/plugins/sparql/results/csvresults.py b/rdflib/plugins/sparql/results/csvresults.py index c87b6ea760..11a0b38165 100644 --- a/rdflib/plugins/sparql/results/csvresults.py +++ b/rdflib/plugins/sparql/results/csvresults.py @@ -9,6 +9,7 @@ import codecs import csv +from typing import IO from rdflib import Variable, BNode, URIRef, Literal @@ -61,7 +62,7 @@ def __init__(self, result): if result.type != "SELECT": raise Exception("CSVSerializer can only serialize select query results") - def serialize(self, stream, encoding="utf-8", **kwargs): + def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): # the serialiser writes bytes in the given encoding # in py3 csv.writer is unicode aware and writes STRINGS, @@ -69,15 +70,15 @@ def serialize(self, stream, encoding="utf-8", **kwargs): import codecs - stream = codecs.getwriter(encoding)(stream) + stream = codecs.getwriter(encoding)(stream) # type: ignore[assignment] out = csv.writer(stream, delimiter=self.delim) - vs = [self.serializeTerm(v, encoding) for v in self.result.vars] + vs = [self.serializeTerm(v, encoding) for v in self.result.vars] # type: ignore[union-attr] out.writerow(vs) for row in self.result.bindings: out.writerow( - [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] + [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] # type: ignore[union-attr] ) def serializeTerm(self, term, encoding): diff --git a/rdflib/plugins/sparql/results/jsonresults.py b/rdflib/plugins/sparql/results/jsonresults.py index 13a8da5eff..8ae67786a4 100644 --- a/rdflib/plugins/sparql/results/jsonresults.py +++ b/rdflib/plugins/sparql/results/jsonresults.py @@ -1,4 +1,5 @@ import json +from typing import IO, Any, Dict, Optional, TextIO, Union from rdflib.query import Result, ResultException, ResultSerializer, ResultParser from rdflib import Literal, URIRef, BNode, Variable @@ -28,9 +29,9 @@ class JSONResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream, encoding=None): + def serialize(self, stream: IO, encoding: str = None, **kwargs): - res = {} + res: Dict[str, Any] = {} if self.result.type == "ASK": res["head"] = {} res["boolean"] = self.result.askAnswer diff --git a/rdflib/plugins/sparql/results/txtresults.py b/rdflib/plugins/sparql/results/txtresults.py index baa5316b48..3f41df9429 100644 --- a/rdflib/plugins/sparql/results/txtresults.py +++ b/rdflib/plugins/sparql/results/txtresults.py @@ -1,8 +1,11 @@ +from typing import IO, List, Optional from rdflib import URIRef, BNode, Literal from rdflib.query import ResultSerializer +from rdflib.namespace import NamespaceManager +from rdflib.term import Variable -def _termString(t, namespace_manager): +def _termString(t, namespace_manager: Optional[NamespaceManager]): if t is None: return "-" if namespace_manager: @@ -21,7 +24,13 @@ class TXTResultSerializer(ResultSerializer): A write only QueryResult serializer for text/ascii tables """ - def serialize(self, stream, encoding, namespace_manager=None): + # TODO FIXME: class specific args should be keyword only. + def serialize( # type: ignore[override] + self, + stream: IO, + encoding: str, + namespace_manager: Optional[NamespaceManager] = None, + ): """ return a text table of query results """ @@ -43,7 +52,7 @@ def c(s, w): return "(no results)\n" else: - keys = self.result.vars + keys: List[Variable] = self.result.vars # type: ignore[assignment] maxlen = [0] * len(keys) b = [ [_termString(r[k], namespace_manager) for k in keys] diff --git a/rdflib/plugins/sparql/results/xmlresults.py b/rdflib/plugins/sparql/results/xmlresults.py index 8c77b50ad1..3869bc9e24 100644 --- a/rdflib/plugins/sparql/results/xmlresults.py +++ b/rdflib/plugins/sparql/results/xmlresults.py @@ -1,4 +1,5 @@ import logging +from typing import IO, Optional from xml.sax.saxutils import XMLGenerator from xml.dom import XML_NAMESPACE @@ -28,15 +29,17 @@ class XMLResultParser(ResultParser): - def parse(self, source, content_type=None): + # TODO FIXME: content_type should be a keyword only arg. + def parse(self, source, content_type: Optional[str] = None): # type: ignore[override] return XMLResult(source) class XMLResult(Result): - def __init__(self, source, content_type=None): + def __init__(self, source, content_type: Optional[str] = None): try: - parser = etree.XMLParser(huge_tree=True) + # try use as if etree is from lxml, and if not use it as normal. + parser = etree.XMLParser(huge_tree=True) # type: ignore[call-arg] tree = etree.parse(source, parser) except TypeError: tree = etree.parse(source) @@ -55,7 +58,7 @@ def __init__(self, source, content_type=None): if type_ == "SELECT": self.bindings = [] - for result in results: + for result in results: # type: ignore[union-attr] r = {} for binding in result: r[Variable(binding.get("name"))] = parseTerm(binding[0]) @@ -69,7 +72,7 @@ def __init__(self, source, content_type=None): ] else: - self.askAnswer = boolean.text.lower().strip() == "true" + self.askAnswer = boolean.text.lower().strip() == "true" # type: ignore[union-attr] def parseTerm(element): @@ -101,7 +104,7 @@ class XMLResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream, encoding="utf-8"): + def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): writer = SPARQLXMLWriter(stream, encoding) if self.result.type == "ASK": diff --git a/rdflib/query.py b/rdflib/query.py index da174cd1f2..65ee141581 100644 --- a/rdflib/query.py +++ b/rdflib/query.py @@ -5,6 +5,7 @@ import warnings import types from typing import Optional, Union, cast +from typing import IO, TYPE_CHECKING, List, Optional, TextIO, Union, cast, overload from io import BytesIO, BufferedIOBase @@ -12,6 +13,10 @@ __all__ = ["Processor", "Result", "ResultParser", "ResultSerializer", "ResultException"] +if TYPE_CHECKING: + from .graph import Graph + from .term import Variable + class Processor(object): """ @@ -161,17 +166,17 @@ class Result(object): """ - def __init__(self, type_): + def __init__(self, type_: str): if type_ not in ("CONSTRUCT", "DESCRIBE", "SELECT", "ASK"): raise ResultException("Unknown Result type: %s" % type_) self.type = type_ - self.vars = None + self.vars: Optional[List[Variable]] = None self._bindings = None self._genbindings = None - self.askAnswer = None - self.graph = None + self.askAnswer: bool = None # type: ignore[assignment] + self.graph: "Graph" = None # type: ignore[assignment] def _get_bindings(self): if self._genbindings: @@ -192,7 +197,12 @@ def _set_bindings(self, b): ) @staticmethod - def parse(source=None, format=None, content_type=None, **kwargs): + def parse( + source=None, + format: Optional[str] = None, + content_type: Optional[str] = None, + **kwargs, + ): from rdflib import plugin if format: @@ -208,7 +218,7 @@ def parse(source=None, format=None, content_type=None, **kwargs): def serialize( self, - destination: Optional[Union[str, BufferedIOBase]] = None, + destination: Optional[Union[str, IO]] = None, encoding: str = "utf-8", format: str = "xml", **args, @@ -230,7 +240,7 @@ def serialize( :return: bytes """ if self.type in ("CONSTRUCT", "DESCRIBE"): - return self.graph.serialize( + return self.graph.serialize( # type: ignore[return-value] destination, encoding=encoding, format=format, **args ) @@ -241,10 +251,10 @@ def serialize( if destination is None: streamb: BytesIO = BytesIO() stream2 = EncodeOnlyUnicode(streamb) - serializer.serialize(stream2, encoding=encoding, **args) + serializer.serialize(stream2, encoding=encoding, **args) # type: ignore return streamb.getvalue() if hasattr(destination, "write"): - stream = cast(BufferedIOBase, destination) + stream = cast(IO[bytes], destination) serializer.serialize(stream, encoding=encoding, **args) else: location = cast(str, destination) @@ -339,9 +349,14 @@ def parse(self, source, **kwargs): class ResultSerializer(object): - def __init__(self, result): + def __init__(self, result: Result): self.result = result - def serialize(self, stream, encoding="utf-8", **kwargs): + def serialize( + self, + stream: IO, + encoding: str = "utf-8", + **kwargs, + ): """return a string properly serialized""" pass # abstract diff --git a/rdflib/serializer.py b/rdflib/serializer.py index ecb8da0a2b..16a47d55cd 100644 --- a/rdflib/serializer.py +++ b/rdflib/serializer.py @@ -10,21 +10,31 @@ """ +from typing import IO, TYPE_CHECKING, Optional from rdflib.term import URIRef +if TYPE_CHECKING: + from rdflib.graph import Graph + __all__ = ["Serializer"] -class Serializer(object): - def __init__(self, store): - self.store = store - self.encoding = "UTF-8" - self.base = None +class Serializer: + def __init__(self, store: "Graph"): + self.store: "Graph" = store + self.encoding: str = "UTF-8" + self.base: Optional[str] = None - def serialize(self, stream, base=None, encoding=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + **args + ) -> None: """Abstract method""" - def relativize(self, uri): + def relativize(self, uri: str): base = self.base if base is not None and uri.startswith(base): uri = URIRef(uri.replace(base, "", 1)) diff --git a/rdflib/store.py b/rdflib/store.py index a7aa8d0b09..d96f569eaf 100644 --- a/rdflib/store.py +++ b/rdflib/store.py @@ -1,6 +1,11 @@ from io import BytesIO import pickle from rdflib.events import Dispatcher, Event +from typing import Tuple, TYPE_CHECKING, Iterable, Optional + +if TYPE_CHECKING: + from .term import Node + from .graph import Graph """ ============ @@ -172,7 +177,7 @@ def __get_node_pickler(self): def create(self, configuration): self.dispatcher.dispatch(StoreCreatedEvent(configuration=configuration)) - def open(self, configuration, create=False): + def open(self, configuration, create: bool = False): """ Opens the store specified by the configuration string. If create is True a store will be created if it does not already @@ -204,7 +209,12 @@ def gc(self): pass # RDF APIs - def add(self, triple, context, quoted=False): + def add( + self, + triple: Tuple["Node", "Node", "Node"], + context: Optional["Graph"], + quoted: bool = False, + ): """ Adds the given statement to a specific context or to the model. The quoted argument is interpreted by formula-aware stores to indicate @@ -215,7 +225,7 @@ def add(self, triple, context, quoted=False): """ self.dispatcher.dispatch(TripleAddedEvent(triple=triple, context=context)) - def addN(self, quads): + def addN(self, quads: Iterable[Tuple["Node", "Node", "Node", "Graph"]]): """ Adds each item in the list of statements to a specific context. The quoted argument is interpreted by formula-aware stores to indicate this @@ -283,7 +293,11 @@ def triples_choices(self, triple, context=None): for (s1, p1, o1), cg in self.triples((subject, None, object_), context): yield (s1, p1, o1), cg - def triples(self, triple_pattern, context=None): + def triples( + self, + triple_pattern: Tuple[Optional["Node"], Optional["Node"], Optional["Node"]], + context=None, + ): """ A generator over all the triples matching the pattern. Pattern can include any objects for used for comparing against nodes in the store, diff --git a/rdflib/term.py b/rdflib/term.py index eb1f2cb6ca..837e4a726e 100644 --- a/rdflib/term.py +++ b/rdflib/term.py @@ -64,7 +64,7 @@ from urllib.parse import urlparse from decimal import Decimal -from typing import TYPE_CHECKING, Dict, Callable, Union, Type +from typing import TYPE_CHECKING, Dict, Callable, Optional, Union, Type if TYPE_CHECKING: from .paths import AlternativePath, InvPath, NegatedPath, SequencePath, Path @@ -231,10 +231,10 @@ class URIRef(Identifier): __neg__: Callable[["URIRef"], "NegatedPath"] __truediv__: Callable[["URIRef", Union["URIRef", "Path"]], "SequencePath"] - def __new__(cls, value, base=None): + def __new__(cls, value: str, base: Optional[str] = None): if base is not None: ends_in_hash = value.endswith("#") - value = urljoin(base, value, allow_fragments=1) + value = urljoin(base, value, allow_fragments=True) if ends_in_hash: if not value.endswith("#"): value += "#" @@ -248,7 +248,7 @@ def __new__(cls, value, base=None): try: rt = str.__new__(cls, value) except UnicodeDecodeError: - rt = str.__new__(cls, value, "utf-8") + rt = str.__new__(cls, value, "utf-8") # type: ignore[call-overload] return rt def toPython(self):