diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index b76cb45..5398cd2 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -22,18 +22,21 @@ def replace_query_select_clause(query: str, repl: str) -> str: """Replace the SELECT clause of a query with repl.""" - if re.search(r"select\s.+", query, re.I) is None: + pattern: re.Pattern = re.compile( + r"select\s+.*?(?=\s+where)", flags=re.IGNORECASE | re.DOTALL + ) + + if re.search(pattern=pattern, string=query) is None: raise Exception("Unable to obtain SELECT clause.") - count_query = re.sub( - pattern=r"select\s.+", + modified_query = re.sub( + pattern=pattern, repl=repl, string=query, count=1, - flags=re.I, ) - return count_query + return modified_query def construct_count_query(query: str, model: type[_TModelInstance]) -> str: diff --git a/tests/unit/test_replace_query_select_clause.py b/tests/unit/test_replace_query_select_clause.py new file mode 100644 index 0000000..16130a0 --- /dev/null +++ b/tests/unit/test_replace_query_select_clause.py @@ -0,0 +1,75 @@ +"""Unit tests for rdfproxy.utils.sparql_utils.replace_query_select_clause.""" + +import re +from textwrap import dedent + +import pytest + +from rdfproxy.utils.sparql_utils import replace_query_select_clause +from tests.utils._types import QueryConstructionParameter + + +def _normalize_whitespace(string: str) -> str: + return re.sub(r"\s+", " ", string.strip()) + + +expected_simple_query = "select where { ?s ?p ?o . }" + +query_construction_parameters = [ + QueryConstructionParameter( + input_query=""" + select ?s ?p ?o + where { + ?s ?p ?o . + } + """, + expected_query=expected_simple_query, + ), + QueryConstructionParameter( + input_query=""" + select ?s + ?p ?o + where { + ?s ?p ?o . + } + """, + expected_query=expected_simple_query, + ), + QueryConstructionParameter( + input_query=""" + select ?s ?p + ?o + where { + ?s ?p ?o . + } + """, + expected_query=expected_simple_query, + ), + QueryConstructionParameter( + input_query=""" + select ?s ?p ?o where { + ?s ?p ?o . + } + """, + expected_query=expected_simple_query, + ), + QueryConstructionParameter( + input_query="select ?s ?p ?o where { ?s ?p ?o . }", + expected_query=expected_simple_query, + ), +] + + +@pytest.mark.parametrize( + ["input_query", "expected_query"], query_construction_parameters +) +def test_basic_replace_query_select_clause(input_query, expected_query): + _constructed_indent: str = replace_query_select_clause(input_query, "select ") + _constructed_dedent: str = replace_query_select_clause( + dedent(input_query), "select " + ) + + constructed_indent = _normalize_whitespace(_constructed_indent) + constructed_dedent = _normalize_whitespace(_constructed_dedent) + + assert constructed_dedent == constructed_indent == expected_query diff --git a/tests/utils/_types.py b/tests/utils/_types.py index d51993d..cf8b2da 100644 --- a/tests/utils/_types.py +++ b/tests/utils/_types.py @@ -19,3 +19,10 @@ class CountQueryParameter(NamedTuple): query: str model: type[BaseModel] expected: int + + +class QueryConstructionParameter(NamedTuple): + """Parameter type for query constructors.""" + + input_query: str + expected_query: str