diff --git a/rdflib/graph.py b/rdflib/graph.py index d662bf998..e934aa754 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -351,15 +351,15 @@ def __init__( self.default_union = False @property - def store(self): + def store(self) -> Store: # read-only attr return self.__store @property - def identifier(self): + def identifier(self) -> Node: # read-only attr return self.__identifier @property - def namespace_manager(self): + def namespace_manager(self) -> NamespaceManager: """ this graph's namespace-manager """ @@ -368,8 +368,9 @@ def namespace_manager(self): return self.__namespace_manager @namespace_manager.setter - def namespace_manager(self, nm): - self.__namespace_manager = nm + def namespace_manager(self, value: NamespaceManager): + """this graph's namespace-manager""" + self.__namespace_manager = value def __repr__(self): return "" % (self.identifier, type(self)) @@ -1096,18 +1097,37 @@ def serialize( encoding: Optional[str] = None, **args: Any, ) -> Union[bytes, str, "Graph"]: - """Serialize the Graph to destination - - If destination is None serialize method returns the serialization as - bytes or string. - - If encoding is None and destination is None, returns a string - If encoding is set, and Destination is None, returns bytes - - Format defaults to turtle. - - Format support can be extended with plugins, - but "xml", "n3", "turtle", "nt", "pretty-xml", "trix", "trig" and "nquads" are built in. + """ + Serialize the graph. + + :param destination: + The destination to serialize the graph to. This can be a path as a + :class:`str` or :class:`~pathlib.PurePath` object, or it can be a + :class:`~typing.IO[bytes]` like object. If this parameter is not + supplied the serialized graph will be returned. + :type destination: Optional[Union[str, typing.IO[bytes], pathlib.PurePath]] + :param format: + The format that the output should be written in. This value + references a :class:`~rdflib.serializer.Serializer` plugin. Format + support can be extended with plugins, but `"xml"`, `"n3"`, + `"turtle"`, `"nt"`, `"pretty-xml"`, `"trix"`, `"trig"`, `"nquads"` + and `"json-ld"` are built in. Defaults to `"turtle"`. + :type format: str + :param base: + The base IRI for formats that support it. For the turtle format this + will be used as the `@base` directive. + :type base: Optional[str] + :param encoding: Encoding of output. + :type encoding: Optional[str] + :param **args: + Additional arguments to pass to the + :class:`~rdflib.serializer.Serializer` that will be used. + :type **args: Any + :return: The serialized graph if `destination` is `None`. + :rtype: :class:`bytes` if `destination` is `None` and `encoding` is not `None`. + :rtype: :class:`bytes` if `destination` is `None` and `encoding` is `None`. + :return: `self` (i.e. the :class:`~rdflib.graph.Graph` instance) if `destination` is not None. + :rtype: :class:`~rdflib.graph.Graph` if `destination` is not None. """ # if base is not given as attribute use the base set for the graph @@ -1298,7 +1318,7 @@ def query( if none are given, the namespaces from the graph's namespace manager are used. - :returntype: rdflib.query.Result + :returntype: :class:`~rdflib.query.Result` """ diff --git a/rdflib/plugins/serializers/n3.py b/rdflib/plugins/serializers/n3.py index 806f445ef..032c779f0 100644 --- a/rdflib/plugins/serializers/n3.py +++ b/rdflib/plugins/serializers/n3.py @@ -109,7 +109,7 @@ def p_clause(self, node, position): self.write("{") self.depth += 1 serializer = N3Serializer(node, parent=self) - serializer.serialize(self.stream) + serializer.serialize(self.stream.buffer) self.depth -= 1 self.write(self.indent() + "}") return True diff --git a/rdflib/plugins/serializers/nt.py b/rdflib/plugins/serializers/nt.py index 188db78fc..ea96176f8 100644 --- a/rdflib/plugins/serializers/nt.py +++ b/rdflib/plugins/serializers/nt.py @@ -12,6 +12,8 @@ import warnings import codecs +from rdflib.util import as_textio + __all__ = ["NTSerializer"] @@ -38,9 +40,15 @@ def serialize( f"Given encoding was: {encoding}" ) - for triple in self.store: - stream.write(_nt_row(triple).encode()) - stream.write("\n".encode()) + with as_textio( + stream, + encoding=encoding, # TODO: CHECK: self.encoding set removed, why? + errors="_rdflib_nt_escape", + write_through=True, + ) as text_stream: + for triple in self.store: + text_stream.write(_nt_row(triple)) + text_stream.write("\n") class NT11Serializer(NTSerializer): diff --git a/rdflib/plugins/serializers/rdfxml.py b/rdflib/plugins/serializers/rdfxml.py index 91765d2bc..ad7b90566 100644 --- a/rdflib/plugins/serializers/rdfxml.py +++ b/rdflib/plugins/serializers/rdfxml.py @@ -1,4 +1,4 @@ -from typing import IO, Dict, Optional, Set +from typing import IO, Dict, Optional, Set, cast from rdflib.plugins.serializers.xmlwriter import XMLWriter from rdflib.namespace import Namespace, RDF, RDFS # , split_uri @@ -173,6 +173,8 @@ def serialize( encoding: Optional[str] = None, **args, ): + # TODO FIXME: this should be Optional, but it's not because nothing + # treats it as such. self.__serialized: Dict[Identifier, int] = {} store = self.store # if base is given here, use that, if not and a base is set for the graph use that @@ -241,6 +243,7 @@ def subject(self, subject: IdentifiedNode, depth: int = 1): writer = self.writer if subject in self.forceRDFAbout: + subject = cast(URIRef, subject) writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) @@ -282,6 +285,7 @@ def subj_as_obj_more_than(ceil): elif subject in self.forceRDFAbout: # TODO FIXME?: this looks like a duplicate of first condition + subject = cast(URIRef, subject) writer.push(RDFVOC.Description) writer.attribute(RDFVOC.about, self.relativize(subject)) writer.pop(RDFVOC.Description) diff --git a/rdflib/plugins/serializers/trig.py b/rdflib/plugins/serializers/trig.py index e4b7e55a6..db1fa8cf4 100644 --- a/rdflib/plugins/serializers/trig.py +++ b/rdflib/plugins/serializers/trig.py @@ -62,53 +62,45 @@ def serialize( spacious: Optional[bool] = None, **args, ): - self.reset() - self.stream = stream - # if base is given here, use that, if not and a base is set for the graph use that - if base is not None: - self.base = base - elif self.store.base is not None: - self.base = self.store.base - - if spacious is not None: - self._spacious = spacious - - self.preprocess() - - self.startDocument() - - firstTime = True - for store, (ordered_subjects, subjects, ref) in self._contexts.items(): - if not ordered_subjects: - continue - - self._references = ref - self._serialized = {} - self.store = store - self._subjects = subjects - - if self.default_context and store.identifier == self.default_context: - self.write(self.indent() + "\n{") - else: - iri: Optional[str] - if isinstance(store.identifier, BNode): - iri = store.identifier.n3() - else: - iri = self.getQName(store.identifier) - if iri is None: - iri = store.identifier.n3() - self.write(self.indent() + "\n%s {" % iri) + self._serialize_init(stream, base, encoding, spacious) + try: + self.preprocess() - self.depth += 1 - for subject in ordered_subjects: - if self.isDone(subject): + self.startDocument() + + firstTime = True + for store, (ordered_subjects, subjects, ref) in self._contexts.items(): + if not ordered_subjects: continue - if firstTime: - firstTime = False - if self.statement(subject) and not firstTime: - self.write("\n") - self.depth -= 1 - self.write("}\n") - - self.endDocument() - stream.write("\n".encode("latin-1")) + + self._references = ref + self._serialized = {} + self.store = store + self._subjects = subjects + + if self.default_context and store.identifier == self.default_context: + self.write(self.indent() + "\n{") + else: + if isinstance(store.identifier, BNode): + iri = store.identifier.n3() + else: + iri = self.getQName(store.identifier) + if iri is None: + iri = store.identifier.n3() + self.write(self.indent() + "\n%s {" % iri) + + self.depth += 1 + for subject in ordered_subjects: + if self.isDone(subject): + continue + if firstTime: + firstTime = False + if self.statement(subject) and not firstTime: + self.write("\n") + self.depth -= 1 + self.write("}\n") + + self.endDocument() + self.write("\n") + finally: + self._serialize_end() diff --git a/rdflib/plugins/serializers/turtle.py b/rdflib/plugins/serializers/turtle.py index a75f15243..ebf10f075 100644 --- a/rdflib/plugins/serializers/turtle.py +++ b/rdflib/plugins/serializers/turtle.py @@ -6,10 +6,13 @@ from collections import defaultdict from functools import cmp_to_key +from rdflib.graph import Graph from rdflib.term import BNode, Literal, URIRef from rdflib.exceptions import Error from rdflib.serializer import Serializer from rdflib.namespace import RDF, RDFS +from io import TextIOWrapper +from typing import IO, Dict, Optional __all__ = ["RecursiveSerializer", "TurtleSerializer"] @@ -44,10 +47,13 @@ class RecursiveSerializer(Serializer): indentString = " " roundtrip_prefixes = () - def __init__(self, store): + def __init__(self, store: Graph): super(RecursiveSerializer, self).__init__(store) - self.stream = None + # TODO FIXME: Ideally stream should be optional, but nothing treats it + # as such, so least weird solution is to just type it as not optional + # even thoug it can sometimes be null. + self.stream: IO[str] = None # type: ignore[assignment] self.reset() def addNamespace(self, prefix, uri): @@ -166,9 +172,9 @@ def indent(self, modifier=0): """Returns indent string multiplied by the depth""" return (self.depth + modifier) * self.indentString - def write(self, text): - """Write text in given encoding.""" - self.stream.write(text.encode(self.encoding, "replace")) + def write(self, text: str): + """Write text""" + self.stream.write(text) SUBJECT = 0 @@ -184,15 +190,15 @@ class TurtleSerializer(RecursiveSerializer): short_name = "turtle" indentString = " " - def __init__(self, store): - self._ns_rewrite = {} + def __init__(self, store: Graph): + self._ns_rewrite: Dict[str, str] = {} super(TurtleSerializer, self).__init__(store) self.keywords = {RDF.type: "a"} self.reset() - self.stream = None + self.stream: TextIOWrapper = None # type: ignore[assignment] self._spacious = _SPACIOUS_OUTPUT - def addNamespace(self, prefix, namespace): + def addNamespace(self, prefix: str, namespace: str): # Turtle does not support prefix that start with _ # if they occur in the graph, rewrite to p_blah # this is more complicated since we need to make sure p_blah @@ -223,36 +229,60 @@ def reset(self): self._started = False self._ns_rewrite = {} - def serialize(self, stream, base=None, encoding=None, spacious=None, **args): + def _serialize_init( + self, + stream: IO[bytes], + base: Optional[str], + encoding: Optional[str], + spacious: Optional[bool], + ) -> None: self.reset() - self.stream = stream + if encoding is not None: + self.encoding = encoding + self.stream = TextIOWrapper( + stream, self.encoding, errors="replace", write_through=True + ) # if base is given here, use that, if not and a base is set for the graph use that if base is not None: self.base = base elif self.store.base is not None: self.base = self.store.base - if spacious is not None: self._spacious = spacious - self.preprocess() - subjects_list = self.orderSubjects() - - self.startDocument() - - firstTime = True - for subject in subjects_list: - if self.isDone(subject): - continue - if firstTime: - firstTime = False - if self.statement(subject) and not firstTime: - self.write("\n") - - self.endDocument() - stream.write("\n".encode("latin-1")) - - self.base = None + def _serialize_end(self) -> None: + self.stream.flush() + self.stream.detach() + self.stream = None # type: ignore[assignment] + + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + spacious: Optional[bool] = None, + **args, + ): + self._serialize_init(stream, base, encoding, spacious) + try: + self.preprocess() + subjects_list = self.orderSubjects() + + self.startDocument() + + firstTime = True + for subject in subjects_list: + if self.isDone(subject): + continue + if firstTime: + firstTime = False + if self.statement(subject) and not firstTime: + self.write("\n") + + self.endDocument() + self.stream.write("\n") + finally: + self._serialize_end() def preprocessTriple(self, triple): super(TurtleSerializer, self).preprocessTriple(triple) diff --git a/rdflib/plugins/sparql/results/csvresults.py b/rdflib/plugins/sparql/results/csvresults.py index 11a0b3816..aba7ac058 100644 --- a/rdflib/plugins/sparql/results/csvresults.py +++ b/rdflib/plugins/sparql/results/csvresults.py @@ -9,12 +9,14 @@ import codecs import csv -from typing import IO +from typing import IO, TYPE_CHECKING, Optional, TextIO, Union from rdflib import Variable, BNode, URIRef, Literal from rdflib.query import Result, ResultSerializer, ResultParser +from rdflib.util import as_textio + class CSVResultParser(ResultParser): def __init__(self): @@ -62,24 +64,24 @@ def __init__(self, result): if result.type != "SELECT": raise Exception("CSVSerializer can only serialize select query results") - def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): + def serialize( + self, stream: Union[IO[bytes], TextIO], encoding: Optional[str] = None, **kwargs + ): # the serialiser writes bytes in the given encoding # in py3 csv.writer is unicode aware and writes STRINGS, # so we encode afterwards - import codecs - - stream = codecs.getwriter(encoding)(stream) # type: ignore[assignment] - - out = csv.writer(stream, delimiter=self.delim) - - vs = [self.serializeTerm(v, encoding) for v in self.result.vars] # type: ignore[union-attr] - out.writerow(vs) - for row in self.result.bindings: - out.writerow( - [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] # type: ignore[union-attr] - ) + with as_textio(stream, encoding=encoding) as stream: + out = csv.writer(stream, delimiter=self.delim) + if TYPE_CHECKING: + assert self.result.vars is not None + vs = [self.serializeTerm(v, encoding) for v in self.result.vars] + out.writerow(vs) + for row in self.result.bindings: + out.writerow( + [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] + ) def serializeTerm(self, term, encoding): if term is None: diff --git a/rdflib/plugins/sparql/results/jsonresults.py b/rdflib/plugins/sparql/results/jsonresults.py index 562f0ec07..5e933c1c7 100644 --- a/rdflib/plugins/sparql/results/jsonresults.py +++ b/rdflib/plugins/sparql/results/jsonresults.py @@ -3,6 +3,7 @@ from rdflib.query import Result, ResultException, ResultSerializer, ResultParser from rdflib import Literal, URIRef, BNode, Variable +from rdflib.util import as_textio """A Serializer for SPARQL results in JSON: @@ -29,8 +30,9 @@ class JSONResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream: IO, encoding: str = None): # type: ignore[override] - + def serialize( + self, stream: Union[IO[bytes], TextIO], encoding: Optional[str] = None, **kwargs + ): res: Dict[str, Any] = {} if self.result.type == "ASK": res["head"] = {} @@ -45,9 +47,7 @@ def serialize(self, stream: IO, encoding: str = None): # type: ignore[override] ] r = json.dumps(res, allow_nan=False, ensure_ascii=False) - if encoding is not None: - stream.write(r.encode(encoding)) - else: + with as_textio(stream, encoding=encoding) as stream: stream.write(r) def _bindingToJSON(self, b): diff --git a/rdflib/plugins/sparql/results/txtresults.py b/rdflib/plugins/sparql/results/txtresults.py index 3f41df942..8810c9e5d 100644 --- a/rdflib/plugins/sparql/results/txtresults.py +++ b/rdflib/plugins/sparql/results/txtresults.py @@ -1,7 +1,8 @@ -from typing import IO, List, Optional +from typing import IO, TYPE_CHECKING, List, Optional, TextIO, Union from rdflib import URIRef, BNode, Literal from rdflib.query import ResultSerializer from rdflib.namespace import NamespaceManager +from rdflib.util import as_textio from rdflib.term import Variable @@ -27,8 +28,8 @@ class TXTResultSerializer(ResultSerializer): # TODO FIXME: class specific args should be keyword only. def serialize( # type: ignore[override] self, - stream: IO, - encoding: str, + stream: Union[IO[bytes], TextIO], + encoding: Optional[str], namespace_manager: Optional[NamespaceManager] = None, ): """ @@ -62,9 +63,13 @@ def c(s, w): for i in range(len(keys)): maxlen[i] = max(maxlen[i], len(r[i])) - stream.write("|".join([c(k, maxlen[i]) for i, k in enumerate(keys)]) + "\n") - stream.write("-" * (len(maxlen) + sum(maxlen)) + "\n") - for r in sorted(b): + with as_textio(stream) as stream: stream.write( - "|".join([t + " " * (i - len(t)) for i, t in zip(maxlen, r)]) + "\n" + "|".join([c(k, maxlen[i]) for i, k in enumerate(keys)]) + "\n" ) + stream.write("-" * (len(maxlen) + sum(maxlen)) + "\n") + for r in sorted(b): + stream.write( + "|".join([t + " " * (i - len(t)) for i, t in zip(maxlen, r)]) + + "\n" + ) diff --git a/rdflib/plugins/sparql/results/xmlresults.py b/rdflib/plugins/sparql/results/xmlresults.py index 6f55ddfe7..d2e7eb5fd 100644 --- a/rdflib/plugins/sparql/results/xmlresults.py +++ b/rdflib/plugins/sparql/results/xmlresults.py @@ -1,5 +1,5 @@ import logging -from typing import IO, Optional +from typing import IO, Optional, TextIO, Union from xml.sax.saxutils import XMLGenerator from xml.dom import XML_NAMESPACE @@ -112,8 +112,11 @@ class XMLResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): - + def serialize( + self, stream: Union[IO[bytes], TextIO], encoding: Optional[str] = None, **kwargs + ): + if encoding is None: + encoding = "utf-8" writer = SPARQLXMLWriter(stream, encoding) if self.result.type == "ASK": writer.write_header([]) diff --git a/rdflib/query.py b/rdflib/query.py index 1312f180b..50b512356 100644 --- a/rdflib/query.py +++ b/rdflib/query.py @@ -4,7 +4,8 @@ import tempfile import warnings import types -from typing import IO, TYPE_CHECKING, List, Optional, Union, cast +import pathlib +from typing import IO, TYPE_CHECKING, List, Optional, TextIO, Union, cast, overload from io import BytesIO @@ -216,54 +217,176 @@ def parse( return parser.parse(source, content_type=content_type, **kwargs) + # None destination and non-None positional encoding + @overload def serialize( self, - destination: Optional[Union[str, IO]] = None, - encoding: str = "utf-8", - format: str = "xml", + destination: None, + encoding: str, + format: Optional[str] = ..., **args, - ) -> Optional[bytes]: - """ - Serialize the query result. + ) -> bytes: + ... + + # None destination and non-None keyword encoding + @overload + def serialize( + self, + *, + destination: None = ..., + encoding: str, + format: Optional[str] = ..., + **args, + ) -> bytes: + ... + + # None destination and None positional encoding + @overload + def serialize( + self, + destination: None, + encoding: None = None, + format: Optional[str] = ..., + **args, + ) -> str: + ... + + # None destination and None keyword encoding + @overload + def serialize( + self, + *, + destination: None = ..., + encoding: None = None, + format: Optional[str] = ..., + **args, + ) -> str: + ... + + # non-none binary destination + @overload + def serialize( + self, + destination: Union[str, pathlib.PurePath, IO[bytes]], + encoding: Optional[str] = ..., + format: Optional[str] = ..., + **args, + ) -> None: + ... + + # non-none text destination + @overload + def serialize( + self, + destination: TextIO, + encoding: None = ..., + format: Optional[str] = ..., + **args, + ) -> None: + ... - The :code:`format` argument determines the Serializer class to use. + # fallback + @overload + def serialize( + self, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes], TextIO]] = ..., + encoding: Optional[str] = ..., + format: Optional[str] = ..., + **args, + ) -> Union[bytes, str, None]: + ... - - csv: :class:`~rdflib.plugins.sparql.results.csvresults.CSVResultSerializer` - - json: :class:`~rdflib.plugins.sparql.results.jsonresults.JSONResultSerializer` - - txt: :class:`~rdflib.plugins.sparql.results.txtresults.TXTResultSerializer` - - xml: :class:`~rdflib.plugins.sparql.results.xmlresults.XMLResultSerializer` + # NOTE: Using TextIO as opposed to IO[str] because I want to be able to use buffer. + def serialize( + self, + destination: Optional[Union[str, pathlib.PurePath, IO[bytes], TextIO]] = None, + encoding: Optional[str] = None, + format: Optional[str] = None, + **args, + ) -> Union[bytes, str, None]: + """ + Serialize the query result. - :param destination: Path of file output or BufferedIOBase object to write the output to. + :param destination: + The destination to serialize the result to. This can be a path as a + :class:`str` or :class:`~pathlib.PurePath` object, or it can be a + :class:`~typing.IO[bytes]` or :class:`~typing.TextIO` like object. If this parameter is not + supplied the serialized result will be returned. + :type destination: Optional[Union[str, typing.IO[bytes], pathlib.PurePath]] :param encoding: Encoding of output. - :param format: One of ['csv', 'json', 'txt', xml'] - :param args: - :return: bytes + :type encoding: Optional[str] + :param format: + The format that the output should be written in. + + For tabular results, the value refers to a + :class:`rdflib.query.ResultSerializer` plugin.Support for the + following tabular formats are built in: + + - `"csv"`: :class:`~rdflib.plugins.sparql.results.csvresults.CSVResultSerializer` + - `"json"`: :class:`~rdflib.plugins.sparql.results.jsonresults.JSONResultSerializer` + - `"txt"`: :class:`~rdflib.plugins.sparql.results.txtresults.TXTResultSerializer` + - `"xml"`: :class:`~rdflib.plugins.sparql.results.xmlresults.XMLResultSerializer` + + For tabular results, the default format is `"txt"`. + + For graph results, the value refers to a + :class:`~rdflib.serializer.Serializer` plugin and is passed to + :func:`~rdflib.graph.Graph.serialize`. Graph format support can be + extended with plugins, but support for `"xml"`, `"n3"`, `"turtle"`, `"nt"`, + `"pretty-xml"`, `"trix"`, `"trig"`, `"nquads"` and `"json-ld"` are + built in. The default graph format is `"turtle"`. + :type format: str """ if self.type in ("CONSTRUCT", "DESCRIBE"): - return self.graph.serialize( # type: ignore[return-value] - destination, encoding=encoding, format=format, **args + if format is None: + format = "turtle" + if ( + destination is not None + and hasattr(destination, "encoding") + and hasattr(destination, "buffer") + ): + # rudimentary check for TextIO-like objects. + destination = cast(TextIO, destination).buffer + destination = cast( + Optional[Union[str, pathlib.PurePath, IO[bytes]]], destination + ) + result = self.graph.serialize( + destination=destination, format=format, encoding=encoding, **args ) + from rdflib.graph import Graph + + if isinstance(result, Graph): + return None + return result """stolen wholesale from graph.serialize""" from rdflib import plugin + if format is None: + format = "txt" serializer = plugin.get(format, ResultSerializer)(self) + stream: IO[bytes] if destination is None: - streamb: BytesIO = BytesIO() - stream2 = EncodeOnlyUnicode(streamb) - serializer.serialize(stream2, encoding=encoding, **args) # type: ignore - return streamb.getvalue() + stream = BytesIO() + if encoding is None: + serializer.serialize(stream, encoding="utf-8", **args) + return stream.getvalue().decode("utf-8") + else: + serializer.serialize(stream, encoding=encoding, **args) + return stream.getvalue() if hasattr(destination, "write"): stream = cast(IO[bytes], destination) serializer.serialize(stream, encoding=encoding, **args) else: - location = cast(str, destination) + if isinstance(destination, pathlib.PurePath): + location = str(destination) + else: + location = cast(str, destination) scheme, netloc, path, params, query, fragment = urlparse(location) if netloc != "": - print( - "WARNING: not saving as location" + "is not a local file reference" + raise ValueError( + f"destination {destination} is not a local file reference" ) - return None fd, name = tempfile.mkstemp() stream = os.fdopen(fd, "wb") serializer.serialize(stream, encoding=encoding, **args) @@ -351,6 +474,19 @@ class ResultSerializer(object): def __init__(self, result: Result): self.result = result - def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): + @overload + def serialize(self, stream: IO[bytes], encoding: Optional[str] = ..., **kwargs): + ... + + @overload + def serialize(self, stream: TextIO, encoding: None = ..., **kwargs): + ... + + def serialize( + self, + stream: Union[IO[bytes], TextIO], + encoding: Optional[str] = None, + **kwargs, + ): """return a string properly serialized""" pass # abstract diff --git a/rdflib/util.py b/rdflib/util.py index 7341a12fd..9824d8311 100644 --- a/rdflib/util.py +++ b/rdflib/util.py @@ -31,13 +31,14 @@ from calendar import timegm from time import altzone -from typing import Optional +from typing import IO, Generator, Optional, TextIO, Union, cast # from time import daylight from time import gmtime from time import localtime from time import time from time import timezone +from io import TextIOWrapper from os.path import splitext @@ -52,6 +53,7 @@ from rdflib.term import Literal from rdflib.term import URIRef from rdflib.compat import sign +from contextlib import contextmanager __all__ = [ "list2set", @@ -506,6 +508,27 @@ def get_tree( return (mapper(root), sorted(tree, key=sortkey)) +@contextmanager +def as_textio( + anyio: Union[IO[bytes], TextIO], + encoding: Optional[str] = None, + errors: Union[str, None] = None, + write_through: bool = False, +) -> Generator[TextIO, None, None]: + if hasattr(anyio, "encoding"): + yield cast(TextIO, anyio) + else: + textio_wrapper = TextIOWrapper( + cast(IO[bytes], anyio), + encoding=encoding, + errors=errors, + write_through=write_through, + ) + yield textio_wrapper + textio_wrapper.flush() + textio_wrapper.detach() + + def test(): import doctest diff --git a/test/test_conjunctivegraph/test_conjunctive_graph.py b/test/test_conjunctivegraph/test_conjunctive_graph.py index 0cbe00771..5662fee90 100644 --- a/test/test_conjunctivegraph/test_conjunctive_graph.py +++ b/test/test_conjunctivegraph/test_conjunctive_graph.py @@ -5,9 +5,11 @@ import pytest from rdflib import ConjunctiveGraph, Graph +from rdflib.namespace import Namespace from rdflib.term import Identifier, URIRef, BNode from rdflib.parser import StringInputSource -from os import path + +from .testutils import GraphHelper DATA = """ @@ -16,6 +18,17 @@ PUBLIC_ID = "http://example.org/record/1" +EG = Namespace("http://example.com/") + + +def test_add() -> None: + quad = (EG["subject"], EG["predicate"], EG["object"], EG["graph"]) + g = ConjunctiveGraph() + g.add(quad) + quad_set = GraphHelper.quad_set(g) + assert len(quad_set) == 1 + assert next(iter(quad_set)) == quad + def test_bnode_publicid(): diff --git a/test/test_issues/test_issue523.py b/test/test_issues/test_issue523.py index 2910cdd71..54210171e 100644 --- a/test/test_issues/test_issue523.py +++ b/test/test_issues/test_issue523.py @@ -9,9 +9,10 @@ def test_issue523(): "SELECT (<../baz> as ?test) WHERE {}", base=rdflib.URIRef("http://example.org/foo/bar"), ) - res = r.serialize(format="csv") + res = r.serialize(format="csv", encoding="utf-8") assert res == b"test\r\nhttp://example.org/baz\r\n", repr(res) - + res = r.serialize(format="csv") + assert res == "test\r\nhttp://example.org/baz\r\n", repr(res) # expected result: # test # http://example.org/baz diff --git a/test/test_serializers/test_serializer.py b/test/test_serializers/test_serializer.py index 5f99bb12f..6d90e6654 100644 --- a/test/test_serializers/test_serializer.py +++ b/test/test_serializers/test_serializer.py @@ -1,8 +1,198 @@ import logging +import enum +import inspect +import itertools +import sys import unittest from rdflib import RDF, Graph, Literal, Namespace, URIRef from tempfile import TemporaryDirectory +from contextlib import ExitStack +from io import IOBase from pathlib import Path, PurePath +from tempfile import TemporaryDirectory +from test.testutils import GraphHelper, get_unique_plugins +from typing import ( + IO, + Any, + Dict, + Iterable, + NamedTuple, + Optional, + Set, + TextIO, + Tuple, + Union, + cast, +) + +from rdflib import Graph +from rdflib.graph import ConjunctiveGraph +from rdflib.namespace import Namespace +from rdflib.plugin import PluginException +from rdflib.serializer import Serializer + +EG = Namespace("http://example.com/") + + +class DestinationType(str, enum.Enum): + PATH = enum.auto() + PURE_PATH = enum.auto() + PATH_STR = enum.auto() + IO_BYTES = enum.auto() + TEXT_IO = enum.auto() + + +class DestinationFactory: + _counter: int = 0 + + def __init__(self, tmpdir: Path) -> None: + self.tmpdir = tmpdir + + def make( + self, + type: DestinationType, + stack: Optional[ExitStack] = None, + ) -> Tuple[Union[str, Path, PurePath, IO[bytes], TextIO], Path]: + self._counter += 1 + count = self._counter + path = self.tmpdir / f"file-{type}-{count:05d}" + if type is DestinationType.PATH: + return (path, path) + if type is DestinationType.PURE_PATH: + return (PurePath(path), path) + if type is DestinationType.PATH_STR: + return (f"{path}", path) + if type is DestinationType.IO_BYTES: + return ( + path.open("wb") + if stack is None + else stack.enter_context(path.open("wb")), + path, + ) + if type is DestinationType.TEXT_IO: + return ( + path.open("w") + if stack is None + else stack.enter_context(path.open("w")), + path, + ) + raise ValueError(f"unsupported type {type}") + + +class GraphType(str, enum.Enum): + QUAD = enum.auto() + TRIPLE = enum.auto() + + +class FormatInfo(NamedTuple): + serializer_name: str + deserializer_name: str + graph_types: Set[GraphType] + encodings: Set[str] + + +class FormatInfos(Dict[str, FormatInfo]): + def add_format( + self, + serializer_name: str, + *, + deserializer_name: Optional[str] = None, + graph_types: Set[GraphType], + encodings: Set[str], + ) -> None: + self[serializer_name] = FormatInfo( + serializer_name, + serializer_name if deserializer_name is None else deserializer_name, + {GraphType.QUAD, GraphType.TRIPLE} if graph_types is None else graph_types, + encodings, + ) + + def select( + self, + *, + name: Optional[Set[str]] = None, + graph_type: Optional[Set[GraphType]] = None, + ) -> Iterable[FormatInfo]: + for format in self.values(): + if graph_type is not None and not graph_type.isdisjoint(format.graph_types): + yield format + if name is not None and format.serializer_name in name: + yield format + + @classmethod + def make_graph(self, format_info: FormatInfo) -> Graph: + if GraphType.QUAD in format_info.graph_types: + return ConjunctiveGraph() + else: + return Graph() + + @classmethod + def make(cls) -> "FormatInfos": + result = cls() + + flexible_formats = { + "trig", + } + for format in flexible_formats: + result.add_format( + format, + graph_types={GraphType.TRIPLE, GraphType.QUAD}, + encodings={"utf-8"}, + ) + + triple_only_formats = { + "turtle", + "nt11", + "xml", + "n3", + } + for format in triple_only_formats: + result.add_format( + format, graph_types={GraphType.TRIPLE}, encodings={"utf-8"} + ) + + quad_only_formats = { + "nquads", + "trix", + "json-ld", + } + for format in quad_only_formats: + result.add_format(format, graph_types={GraphType.QUAD}, encodings={"utf-8"}) + + result.add_format( + "pretty-xml", + deserializer_name="xml", + graph_types={GraphType.TRIPLE}, + encodings={"utf-8"}, + ) + result.add_format( + "ntriples", + graph_types={GraphType.TRIPLE}, + encodings={"ascii"}, + ) + + return result + + +format_infos = FormatInfos.make() + + +def assert_graphs_equal( + test_case: unittest.TestCase, lhs: Graph, rhs: Graph, check_context: bool = True +) -> None: + lhs_has_quads = hasattr(lhs, "quads") + rhs_has_quads = hasattr(rhs, "quads") + lhs_set: Set[Any] + rhs_set: Set[Any] + if lhs_has_quads and rhs_has_quads and check_context: + lhs = cast(ConjunctiveGraph, lhs) + rhs = cast(ConjunctiveGraph, rhs) + lhs_set, rhs_set = GraphHelper.quad_sets([lhs, rhs]) + else: + lhs_set, rhs_set = GraphHelper.triple_sets([lhs, rhs]) + test_case.assertEqual(lhs_set, rhs_set) + test_case.assertTrue(len(lhs_set) > 0) + test_case.assertTrue(len(rhs_set) > 0) from typing import Tuple, cast @@ -51,39 +241,49 @@ def test_rdf_type(format: str, tuple_index: int, is_keyword: bool) -> None: class TestSerialize(unittest.TestCase): def setUp(self) -> None: - - graph = Graph() - subject = URIRef("example:subject") - predicate = URIRef("example:predicate") - object = Literal("日本語の表記体系", lang="jpx") self.triple = ( - subject, - predicate, - object, + EG["subject"], + EG["predicate"], + Literal("日本語の表記体系", lang="jpx"), ) - graph.add(self.triple) - self.graph = graph - return super().setUp() + self.context = EG["graph"] + self.quad = (*self.triple, self.context) - def test_serialize_to_purepath(self): - with TemporaryDirectory() as td: - tfpath = PurePath(td) / "out.nt" - self.graph.serialize(destination=tfpath, format="nt", encoding="utf-8") - graph_check = Graph() - graph_check.parse(source=tfpath, format="nt") + conjunctive_graph = ConjunctiveGraph() + conjunctive_graph.add(self.quad) + self.graph = conjunctive_graph - self.assertEqual(self.triple, next(iter(graph_check))) + query = """ + CONSTRUCT { ?subject ?predicate ?object } WHERE { + ?subject ?predicate ?object + } ORDER BY ?object + """ + self.result = self.graph.query(query) + self.assertIsNotNone(self.result.graph) - def test_serialize_to_path(self): - with TemporaryDirectory() as td: - tfpath = Path(td) / "out.nt" - self.graph.serialize(destination=tfpath, format="nt", encoding="utf-8") - graph_check = Graph() - graph_check.parse(source=tfpath, format="nt") + self._tmpdir = TemporaryDirectory() + self.tmpdir = Path(self._tmpdir.name) - self.assertEqual(self.triple, next(iter(graph_check))) + return super().setUp() + + def tearDown(self) -> None: + self._tmpdir.cleanup() - def test_serialize_to_neturl(self): + def test_graph(self) -> None: + quad_set = GraphHelper.quad_set(self.graph) + self.assertEqual(quad_set, {self.quad}) + + def test_all_formats_specified(self) -> None: + plugins = get_unique_plugins(Serializer) + for plugin_refs in plugins.values(): + names = {plugin_ref.name for plugin_ref in plugin_refs} + self.assertNotEqual( + names.intersection(format_infos.keys()), + set(), + f"serializers does not include any of {names}", + ) + + def test_serialize_to_neturl(self) -> None: with self.assertRaises(ValueError) as raised: self.graph.serialize( destination="http://example.com/", format="nt", encoding="utf-8" @@ -101,3 +301,232 @@ def test_serialize_to_fileurl(self): graph_check = Graph() graph_check.parse(source=tfpath, format="nt") self.assertEqual(self.triple, next(iter(graph_check))) + + def test_serialize_badformat(self) -> None: + with self.assertRaises(PluginException) as raised: + self.graph.serialize(destination="http://example.com/", format="badformat") + self.assertIn("badformat", f"{raised.exception}") + + def test_str(self) -> None: + """ + This function tests serialization of graphs to strings, either directly + or from query results. + + This function also checks that the various string serialization + overloads are correct. + """ + for format in format_infos.keys(): + format_info = format_infos[format] + + def check(data: str, check_context: bool = True) -> None: + with self.subTest(format=format, caller=inspect.stack()[1]): + self.assertIsInstance(data, str) + + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse(data=data, format=format_info.deserializer_name) + assert_graphs_equal(self, self.graph, graph_check, check_context) + + if format == "turtle": + check(self.graph.serialize()) + check(self.graph.serialize(None)) + check(self.graph.serialize(None, format)) + check(self.graph.serialize(None, format, encoding=None)) + check(self.graph.serialize(None, format, None, None)) + check(self.graph.serialize(None, format=format)) + check(self.graph.serialize(None, format=format, encoding=None)) + + if GraphType.TRIPLE not in format_info.graph_types: + # tests below are only for formats that can work with context-less graphs. + continue + + if format == "turtle": + check(self.result.serialize(), False) + check(self.result.serialize(None), False) + check(self.result.serialize(None, format=format), False) + check(self.result.serialize(None, None, format), False) + check(self.result.serialize(None, None, format=format), False) + check(self.result.serialize(None, encodin=None, format=format), False) + check( + self.result.serialize(destination=None, encoding=None, format=format), + False, + ) + + def test_bytes(self) -> None: + """ + This function tests serialization of graphs to bytes, either directly or + from query results. + + This function also checks that the various bytes serialization overloads + are correct. + """ + for (format, encoding) in itertools.chain( + *( + itertools.product({format_info.serializer_name}, format_info.encodings) + for format_info in format_infos.values() + ) + ): + format_info = format_infos[format] + + def check(data: bytes, check_context: bool = True) -> None: + with self.subTest( + format=format, encoding=encoding, caller=inspect.stack()[1] + ): + # self.check_data_bytes(data, format=format, encoding=encoding) + self.assertIsInstance(data, bytes) + + # double check that encoding is right + data_str = data.decode(encoding) + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + data=data_str, format=format_info.deserializer_name + ) + assert_graphs_equal(self, self.graph, graph_check, check_context) + + # actual check + # TODO FIXME : handle other encodings + if encoding == "utf-8": + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + data=data, format=format_info.deserializer_name + ) + assert_graphs_equal( + self, self.graph, graph_check, check_context + ) + + if format == "turtle": + check(self.graph.serialize(encoding=encoding)) + check(self.graph.serialize(None, format, encoding=encoding)) + check(self.graph.serialize(None, format, None, encoding=encoding)) + check(self.graph.serialize(None, format, encoding=encoding)) + check(self.graph.serialize(None, format=format, encoding=encoding)) + + if GraphType.TRIPLE not in format_info.graph_types: + # tests below are only for formats that can work with context-less graphs. + continue + + if format == "turtle": + check(self.result.serialize(encoding=encoding), False) + check(self.result.serialize(None, encoding), False) + check(self.result.serialize(encoding=encoding, format=format), False) + check(self.result.serialize(None, encoding, format), False) + check(self.result.serialize(None, encoding=encoding, format=format), False) + check( + self.result.serialize( + destination=None, encoding=encoding, format=format + ), + False, + ) + + def test_file(self) -> None: + """ + This function tests serialization of graphs to destinations, either directly or + from query results. + + This function also checks that the various bytes serialization overloads + are correct. + """ + dest_factory = DestinationFactory(self.tmpdir) + + for (format, encoding, dest_type) in itertools.chain( + *( + itertools.product( + {format_info.serializer_name}, + format_info.encodings, + set(DestinationType).difference({DestinationType.TEXT_IO}), + ) + for format_info in format_infos.values() + ) + ): + format_info = format_infos[format] + with ExitStack() as stack: + dest_path: Path + _dest: Union[str, Path, PurePath, IO[bytes]] + + def dest() -> Union[str, Path, PurePath, IO[bytes]]: + nonlocal dest_path + nonlocal _dest + _dest, dest_path = cast( + Tuple[Union[str, Path, PurePath, IO[bytes]], Path], + dest_factory.make(dest_type, stack), + ) + return _dest + + def _check(check_context: bool = True) -> None: + with self.subTest( + format=format, + encoding=encoding, + dest_type=dest_type, + caller=inspect.stack()[2], + ): + if isinstance(_dest, IOBase): # type: ignore[unreachable] + _dest.flush() + + source = Path(dest_path) + + # double check that encoding is right + data_str = source.read_text(encoding=encoding) + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + data=data_str, format=format_info.deserializer_name + ) + assert_graphs_equal( + self, self.graph, graph_check, check_context + ) + + self.assertTrue(source.exists()) + # actual check + # TODO FIXME : This should work for all encodings, not just utf-8 + if encoding == "utf-8": + graph_check = FormatInfos.make_graph(format_info) + graph_check.parse( + source=source, format=format_info.deserializer_name + ) + assert_graphs_equal( + self, self.graph, graph_check, check_context + ) + + dest_path.unlink() + + def check_a(graph: Graph) -> None: + _check() + + if (format, encoding) == ("turtle", "utf-8"): + check_a(self.graph.serialize(dest())) + check_a(self.graph.serialize(dest(), encoding=None)) + if format == "turtle": + check_a(self.graph.serialize(dest(), encoding=encoding)) + if encoding == sys.getdefaultencoding(): + check_a(self.graph.serialize(dest(), format)) + check_a(self.graph.serialize(dest(), format, None)) + check_a(self.graph.serialize(dest(), format, None, None)) + + check_a(self.graph.serialize(dest(), format, encoding=encoding)) + check_a(self.graph.serialize(dest(), format, None, encoding)) + + if GraphType.TRIPLE not in format_info.graph_types: + # tests below are only for formats that can work with context-less graphs. + continue + + def check_b(none: None) -> None: + _check(False) + + if format == "turtle": + check_b(self.result.serialize(dest(), encoding)) + check_b( + self.result.serialize( + destination=cast(str, dest()), + encoding=encoding, + ) + ) + check_b(self.result.serialize(dest(), encoding=encoding, format=format)) + check_b( + self.result.serialize( + destination=dest(), encoding=encoding, format=format + ) + ) + check_b( + self.result.serialize( + destination=dest(), encoding=None, format=format + ) + ) + check_b(self.result.serialize(destination=dest(), format=format)) diff --git a/test/test_sparql/test_sparql.py b/test/test_sparql/test_sparql.py index 7fe410ae5..16d8078b6 100644 --- a/test/test_sparql/test_sparql.py +++ b/test/test_sparql/test_sparql.py @@ -240,8 +240,15 @@ def test_txtresult(): assert result.type == "SELECT" assert len(result) == 1 assert result.vars == vars - txtresult = result.serialize(format="txt") - lines = txtresult.decode().splitlines() + + bytesresult = result.serialize(format="txt", encoding="utf-8") + lines = bytesresult.decode().splitlines() + assert len(lines) == 3 + vars_check = [Variable(var.strip()) for var in lines[0].split("|")] + assert vars_check == vars + + strresult = result.serialize(format="txt") + lines = strresult.splitlines() assert len(lines) == 3 vars_check = [Variable(var.strip()) for var in lines[0].split("|")] assert vars_check == vars diff --git a/test/test_sparql_result_serialize.py b/test/test_sparql_result_serialize.py new file mode 100644 index 000000000..26567d928 --- /dev/null +++ b/test/test_sparql_result_serialize.py @@ -0,0 +1,281 @@ +from contextlib import ExitStack +import itertools +from typing import ( + IO, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Set, + TextIO, + Union, + cast, +) +from rdflib.query import Result, ResultRow +from test.test_serializer import DestinationFactory, DestinationType +from test.testutils import GraphHelper +from rdflib.term import Node +import unittest +from rdflib import Graph, Namespace +from tempfile import TemporaryDirectory +from pathlib import Path, PurePath +from io import BytesIO, IOBase, StringIO +import inspect + +EG = Namespace("http://example.com/") + + +class FormatInfo(NamedTuple): + serializer_name: str + deserializer_name: str + encodings: Set[str] + + +class FormatInfos(Dict[str, FormatInfo]): + def add_format( + self, + serializer_name: str, + deserializer_name: str, + *, + encodings: Set[str], + ) -> None: + self[serializer_name] = FormatInfo( + serializer_name, + deserializer_name, + encodings, + ) + + def select( + self, + *, + name: Optional[Set[str]] = None, + ) -> Iterable[FormatInfo]: + for format in self.values(): + if name is not None and format.serializer_name in name: + yield format + + @classmethod + def make(cls) -> "FormatInfos": + result = cls() + result.add_format("csv", "csv", encodings={"utf-8"}) + result.add_format("json", "json", encodings={"utf-8"}) + result.add_format("xml", "xml", encodings={"utf-8"}) + result.add_format("txt", "txt", encodings={"utf-8"}) + + return result + + +format_infos = FormatInfos.make() + + +class ResultHelper: + @classmethod + def to_list(cls, result: Result) -> List[Dict[str, Node]]: + output: List[Dict[str, Node]] = [] + row: ResultRow + for row in result: + output.append(row.asdict()) + return output + + +def check_txt(test_case: unittest.TestCase, result: Result, data: str) -> None: + """ + This does somewhat of a smoke tests that data is the txt serialization of the + given result. This is by no means perfect but better than nothing. + """ + txt_lines = data.splitlines() + test_case.assertEqual(len(txt_lines) - 2, len(result)) + test_case.assertRegex(txt_lines[1], r"^[-]+$") + header = txt_lines[0] + test_case.assertIsNotNone(result.vars) + assert result.vars is not None + for var in result.vars: + test_case.assertIn(var, header) + for row_index, row in enumerate(result): + txt_row = txt_lines[row_index + 2] + value: Node + for key, value in row.asdict().items(): + test_case.assertIn(f"{value}", txt_row) + + +class TestSerializeSelect(unittest.TestCase): + def setUp(self) -> None: + graph = Graph() + triples = [ + (EG["e0"], EG["a0"], EG["e1"]), + (EG["e0"], EG["a0"], EG["e2"]), + (EG["e0"], EG["a0"], EG["e3"]), + (EG["e1"], EG["a1"], EG["e2"]), + (EG["e1"], EG["a1"], EG["e3"]), + (EG["e2"], EG["a2"], EG["e3"]), + ] + GraphHelper.add_triples(graph, triples) + + query = """ + PREFIX eg: + SELECT ?subject ?predicate ?object WHERE { + VALUES ?predicate { eg:a1 } + ?subject ?predicate ?object + } ORDER BY ?object + """ + self.result = graph.query(query) + self.result_table = [ + ["subject", "predicate", "object"], + ["http://example.com/e1", "http://example.com/a1", "http://example.com/e2"], + ["http://example.com/e1", "http://example.com/a1", "http://example.com/e3"], + ] + + self._tmpdir = TemporaryDirectory() + self.tmpdir = Path(self._tmpdir.name) + + return super().setUp() + + def tearDown(self) -> None: + self._tmpdir.cleanup() + + def test_str(self) -> None: + for format in format_infos.keys(): + + def check(data: str) -> None: + with self.subTest(format=format, caller=inspect.stack()[1]): + self.assertIsInstance(data, str) + format_info = format_infos[format] + if format_info.deserializer_name == "txt": + check_txt(self, self.result, data) + else: + result_check = Result.parse( + StringIO(data), format=format_info.deserializer_name + ) + self.assertEqual(self.result, result_check) + + if format == "txt": + check(self.result.serialize()) + check(self.result.serialize(None, None, None)) + check(self.result.serialize(None, None, format)) + check(self.result.serialize(format=format)) + check(self.result.serialize(destination=None, format=format)) + check(self.result.serialize(destination=None, encoding=None, format=format)) + + def test_bytes(self) -> None: + for (format, encoding) in itertools.chain( + *( + itertools.product({format_info.serializer_name}, format_info.encodings) + for format_info in format_infos.values() + ) + ): + + def check(data: bytes) -> None: + with self.subTest(format=format, caller=inspect.stack()[1]): + self.assertIsInstance(data, bytes) + format_info = format_infos[format] + if format_info.deserializer_name == "txt": + check_txt(self, self.result, data.decode(encoding)) + else: + result_check = Result.parse( + BytesIO(data), format=format_info.deserializer_name + ) + self.assertEqual(self.result, result_check) + + if format == "txt": + check(self.result.serialize(encoding=encoding)) + check(self.result.serialize(None, encoding, None)) + check(self.result.serialize(None, encoding)) + check(self.result.serialize(None, encoding, format)) + check(self.result.serialize(format=format, encoding=encoding)) + check( + self.result.serialize( + destination=None, format=format, encoding=encoding + ) + ) + check( + self.result.serialize( + destination=None, encoding=encoding, format=format + ) + ) + + def test_file(self) -> None: + + dest_factory = DestinationFactory(self.tmpdir) + + for (format, encoding, dest_type) in itertools.chain( + *( + itertools.product( + {format_info.serializer_name}, + format_info.encodings, + set(DestinationType), + ) + for format_info in format_infos.values() + ) + ): + with ExitStack() as stack: + dest_path: Path + _dest: Union[str, Path, PurePath, IO[bytes], TextIO] + + def dest() -> Union[str, Path, PurePath, IO[bytes], TextIO]: + nonlocal dest_path + nonlocal _dest + _dest, dest_path = dest_factory.make(dest_type, stack) + return _dest + + def check(none: None) -> None: + with self.subTest( + format=format, + encoding=encoding, + dest_type=dest_type, + caller=inspect.stack()[1], + ): + if isinstance(_dest, IOBase): # type: ignore[unreachable] + _dest.flush() + format_info = format_infos[format] + data_str = dest_path.read_text(encoding=encoding) + if format_info.deserializer_name == "txt": + check_txt(self, self.result, data_str) + else: + result_check = Result.parse( + StringIO(data_str), format=format_info.deserializer_name + ) + self.assertEqual(self.result, result_check) + dest_path.unlink() + + if dest_type == DestinationType.IO_BYTES: + check( + self.result.serialize( + cast(IO[bytes], dest()), + encoding, + format, + ) + ) + check( + self.result.serialize( + cast(IO[bytes], dest()), + encoding, + format=format, + ) + ) + check( + self.result.serialize( + cast(IO[bytes], dest()), + encoding=encoding, + format=format, + ) + ) + check( + self.result.serialize( + destination=cast(IO[bytes], dest()), + encoding=encoding, + format=format, + ) + ) + check( + self.result.serialize( + destination=dest(), encoding=None, format=format + ) + ) + check(self.result.serialize(destination=dest(), format=format)) + check(self.result.serialize(dest(), format=format)) + check(self.result.serialize(dest(), None, format)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_trig.py b/test/test_trig.py index 71f5ec308..2685303d3 100644 --- a/test/test_trig.py +++ b/test/test_trig.py @@ -1,7 +1,12 @@ -import re + import unittest +from unittest.case import expectedFailure +import pytest import rdflib +import re +from rdflib import Namespace +from .testutils import GraphHelper TRIPLE = ( rdflib.URIRef("http://example.com/s"), @@ -16,6 +21,55 @@ def test_empty(): assert s is not None +EG = Namespace("http://example.com/") + + +def test_single_quad(self) -> None: + graph = rdflib.ConjunctiveGraph() + quad = (EG["subject"], EG["predicate"], EG["object"], EG["graph"]) + graph.add(quad) + check_graph = rdflib.ConjunctiveGraph() + data_str = graph.serialize(format="trig") + check_graph.parse(data=data_str, format="trig") + quad_set, check_quad_set = GraphHelper.quad_sets([graph, check_graph]) + assert quad_set == check_quad_set + + +@pytest.mark.xfail +def test_default_identifier(self) -> None: + """ + This should pass, but for some reason when the default identifier is + set, trig serializes quads inside this default indentifier to an + anonymous graph. + + So in this test, data_str is: + + @base . + @prefix ns1: . + + { + ns1:subject ns1:predicate ns1:object . + } + + instead of: + @base . + @prefix ns1: . + + ns1:graph { + ns1:subject ns1:predicate ns1:object . + } + """ + graph_id = EG["graph"] + graph = rdflib.ConjunctiveGraph(identifier=EG["graph"]) + quad = (EG["subject"], EG["predicate"], EG["object"], graph_id) + graph.add(quad) + check_graph = rdflib.ConjunctiveGraph() + data_str = graph.serialize(format="trig") + check_graph.parse(data=data_str, format="trig") + quad_set, check_quad_set = GraphHelper.quad_sets([graph, check_graph]) + self.assertEqual(quad_set, check_quad_set) + + def test_repeat_triples(): g = rdflib.ConjunctiveGraph() g.get_context("urn:a").add( diff --git a/test/test_util.py b/test/test_util.py index 76b6c51a0..aa9663adf 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +from io import BufferedIOBase, RawIOBase, TextIOBase +from typing import BinaryIO, TextIO import unittest import time from unittest.case import expectedFailure @@ -15,6 +17,9 @@ from rdflib.exceptions import PredicateTypeError from rdflib.exceptions import ObjectTypeError from rdflib.exceptions import ContextTypeError +from pathlib import Path +from tempfile import TemporaryDirectory +from rdflib.util import as_textio n3source = """\ @prefix : . @@ -44,7 +49,7 @@ :sister a rdf:Property. -:sister rdfs:domain :Person; +:sister rdfs:domain :Person; rdfs:range :Woman. :Woman = foo:FemaleAdult . @@ -391,5 +396,48 @@ def test_util_check_pattern(self): self.assertTrue(res == None) +class TestIO(unittest.TestCase): + def setUp(self) -> None: + self._tmpdir = TemporaryDirectory() + self.tmpdir = Path(self._tmpdir.name) + + return super().setUp() + + def tearDown(self) -> None: + self._tmpdir.cleanup() + + def test_as_textio_text(self) -> None: + tmp_file = self.tmpdir / "file" + with tmp_file.open("w") as text_stream: + text_io: TextIO = text_stream + assert text_io is text_stream + with as_textio(text_stream) as text_io: + assert text_io is text_stream + text_io.write("Test") + text_stream.flush() + self.assertEqual(tmp_file.read_text(), "Test") + self.assertIsInstance(text_io, TextIOBase) + + def test_as_textio_buffered_stream(self) -> None: + tmp_file = self.tmpdir / "file" + with tmp_file.open("wb") as buffered_stream: + binary_io: BinaryIO = buffered_stream + assert binary_io is buffered_stream + with as_textio(buffered_stream) as text_io: + text_io.write("Test") + self.assertEqual(tmp_file.read_text(), "Test") + self.assertIsInstance(buffered_stream, BufferedIOBase) + + def test_as_textio_raw_stream(self) -> None: + tmp_file = self.tmpdir / "file" + with tmp_file.open("wb", buffering=0) as raw_stream: + binary_io: BinaryIO = raw_stream + assert binary_io is raw_stream + with as_textio(raw_stream) as text_io: + text_io.write("Test") + self.assertEqual(tmp_file.read_text(), "Test") + self.assertIsInstance(binary_io, RawIOBase) + + if __name__ == "__main__": unittest.main() diff --git a/test/testutils.py b/test/testutils.py index 2d4500700..44efc09f2 100644 --- a/test/testutils.py +++ b/test/testutils.py @@ -1,4 +1,6 @@ from __future__ import print_function +from rdflib.graph import Dataset +from rdflib.plugin import Plugin import os import sys @@ -10,6 +12,7 @@ from contextlib import AbstractContextManager, contextmanager from typing import ( Callable, + Generic, Iterable, List, Optional, @@ -40,11 +43,44 @@ from pathlib import PurePath, PureWindowsPath from nturl2path import url2pathname as nt_url2pathname import rdflib.compare +import rdflib.plugin + if TYPE_CHECKING: import typing_extensions as te +# TODO: make an introspective version (like this one) of +# rdflib.graphutils.isomorphic and use instead. +from test import TEST_DIR + +PluginT = TypeVar("PluginT") + + +class PluginWithNames(Generic[PluginT]): + def __init__(self, plugin: Plugin[PluginT], names: Set[str]) -> None: + self.plugin = plugin + self.names = names + + +def get_unique_plugins( + type: Type[PluginT], +) -> Dict[Type[PluginT], Set[Plugin[PluginT]]]: + result: Dict[Type[PluginT], Set[Plugin[PluginT]]] = {} + for plugin in rdflib.plugin.plugins(None, type): + cls = plugin.getClass() + plugins = result.setdefault(cls, set()) + plugins.add(plugin) + return result + + +def get_unique_plugin_names(type: Type[PluginT]) -> Set[str]: + result: Set[str] = set() + unique_plugins = get_unique_plugins(type) + for type, plugin_set in unique_plugins.items(): + result.add(next(iter(plugin_set)).name) + return result + def get_random_ip(parts: List[str] = None) -> str: if parts is None: parts = ["127"] @@ -77,6 +113,14 @@ class GraphHelper: """ Provides methods which are useful for working with graphs. """ + @classmethod + def add_triples( + cls, graph: Graph, triples: Iterable[Tuple[Node, Node, Node]] + ) -> Graph: + for triple in triples: + graph.add(triple) + return graph + @classmethod def identifier(self, node: Node) -> Identifier: @@ -235,7 +279,6 @@ def strip_literal_datatypes(cls, graph: Graph, datatypes: Set[URIRef]) -> None: if object.datatype in datatypes: object._datatype = None - GenericT = TypeVar("GenericT", bound=Any)