Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change error handling strategy and add verbose errors #30

Merged
merged 4 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions cratedb_sqlparse_py/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ query = """
SELECT * FROM SYS.SHARDS;
INSERT INTO doc.tbl VALUES (1);
"""
statements = sqlparse(query)
statements = sqlparse(query, raise_exception=True)
surister marked this conversation as resolved.
Show resolved Hide resolved

print(len(statements))
# 2
Expand All @@ -43,17 +43,28 @@ print(select_query.type)
print(select_query.tree)
# (statement (query (queryNoWith (queryTerm (querySpec SELECT (selectItem *) FROM (relation (aliasedRelation (relationPrimary (table (qname (ident (unquotedIdent SYS)) . (ident (unquotedIdent (nonReserved SHARDS)))))))))))))

sqlparse('SUUULECT * FROM sys.shards')
# cratedb_sqlparse.parser.parser.ParsingException: line1:0 mismatched input 'SUUULECT' expecting {'SELECT', 'DEALLOCATE', ...}
sqlparse('SEEELECT * FROM sys.shards')
# cratedb_sqlparse.parser.parser.ParsingException: line1:0 mismatched input 'SEEELECT' expecting {'SELECT', 'DEALLOCATE', ...}
```


## Development
```shell
git clone https://github.com/crate/cratedb-sqlparse

cd cratedb-sqlparse/cratedb_sqlparse_py
python3 -m venv .venv
source .venv/bin/activate
pip install --editable='.[develop,generate,release,test]'
poe check
```

### Run only tests
```shell
poe test
```

surister marked this conversation as resolved.
Show resolved Hide resolved
### Run only one test
surister marked this conversation as resolved.
Show resolved Hide resolved
```shell
poe test -k test_sqlparse_collects_exceptions_2
```
amotl marked this conversation as resolved.
Show resolved Hide resolved
140 changes: 128 additions & 12 deletions cratedb_sqlparse_py/cratedb_sqlparse/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import List

from antlr4 import CommonTokenStream, InputStream, Token
from antlr4 import CommonTokenStream, InputStream, RecognitionException, Token
from antlr4.error.ErrorListener import ErrorListener

from cratedb_sqlparse.generated_parser.SqlBaseLexer import SqlBaseLexer
Expand Down Expand Up @@ -30,7 +31,51 @@ def END_DOLLAR_QUOTED_STRING_sempred(self, localctx, predIndex) -> bool:


class ParsingException(Exception):
pass
def __init__(self, *, query: str, msg: str, offending_token: Token, e: RecognitionException):
self.message = msg
self.offending_token = offending_token
self.e = e
self.query = query

@property
def error_message(self):
return f"{self!r}[line {self.line}:{self.column} {self.message}]"

@property
def original_query_with_error_marked(self):
query = self.offending_token.source[1].strdata
offending_token_text: str = query[self.offending_token.start : self.offending_token.stop + 1]
query_lines: list = query.split("\n")

offending_line: str = query_lines[self.line - 1]

# White spaces from the beginning of the offending line to the offending text, so the '^'
# chars are correctly placed below the offending token.
newline_offset = offending_line.index(offending_token_text)
newline = (
offending_line
+ "\n"
+ (" " * newline_offset + "^" * (self.offending_token.stop - self.offending_token.start + 1))
)

query_lines[self.line - 1] = newline

msg = "\n".join(query_lines)
return msg

@property
def column(self):
return self.offending_token.column

@property
def line(self):
return self.offending_token.line

def __repr__(self):
return f"{type(self.e).__qualname__}"

def __str__(self):
return repr(self)


class CaseInsensitiveStream(InputStream):
Expand All @@ -47,16 +92,44 @@ class ExceptionErrorListener(ErrorListener):
"""

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise ParsingException(f"line{line}:{column} {msg}")
error = ParsingException(
msg=msg,
offending_token=offendingSymbol,
e=e,
query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex),
)
raise error


class ExceptionCollectorListener(ErrorListener):
"""
Error listener that collects all errors into errors for further processing.

Based partially on https://github.com/antlr/antlr4/issues/396
"""

def __init__(self):
self.errors = []

def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
error = ParsingException(
msg=msg,
offending_token=offendingSymbol,
e=e,
query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex),
)

self.errors.append(error)


class Statement:
"""
Represents a CrateDB SQL statement.
"""

def __init__(self, ctx: SqlBaseParser.StatementContext):
def __init__(self, ctx: SqlBaseParser.StatementContext, exception: ParsingException = None):
self.ctx: SqlBaseParser.StatementContext = ctx
self.exception = exception

@property
def tree(self):
Expand All @@ -77,7 +150,7 @@ def query(self) -> str:
"""
Returns the query, comments and ';' are not included.
"""
return self.ctx.parser.getTokenStream().getText(start=self.ctx.start.tokenIndex, stop=self.ctx.stop.tokenIndex)
return self.ctx.parser.getTokenStream().getText(start=self.ctx.start, stop=self.ctx.stop)

@property
def type(self):
Expand All @@ -90,7 +163,20 @@ def __repr__(self):
return f'{self.__class__.__qualname__}<{self.query if len(self.query) < 15 else self.query[:15] + "..."}>'


def sqlparse(query: str) -> List[Statement]:
def find_suitable_error(statement, errors):
for error in errors[:]:
# We clean the error_query of ';' and spaces because ironically,
# we can get the full query in the error handler but not in the context.
error_query = error.query
if error_query.endswith(";"):
error_query = error_query[: len(error_query) - 1]

if error_query.lstrip().rstrip() == statement.query:
statement.exception = error
errors.pop(errors.index(error))


def sqlparse(query: str, raise_exception: bool = False) -> List[Statement]:
"""
Parses a string into SQL `Statement`.
"""
Expand All @@ -101,12 +187,42 @@ def sqlparse(query: str) -> List[Statement]:

parser = SqlBaseParser(stream)
parser.removeErrorListeners()
parser.addErrorListener(ExceptionErrorListener())
error_listener = ExceptionErrorListener() if raise_exception else ExceptionCollectorListener()
parser.addErrorListener(error_listener)

tree = parser.statements()

# At this point, all errors are already raised; it's seasonably safe to assume
# that the statements are valid.
statements = list(filter(lambda children: isinstance(children, SqlBaseParser.StatementContext), tree.children))

return [Statement(statement) for statement in statements]
statements_context: list[SqlBaseParser.StatementContext] = list(
filter(lambda children: isinstance(children, SqlBaseParser.StatementContext), tree.children)
)

statements = []
for statement_context in statements_context:
_stmt = Statement(statement_context)
find_suitable_error(_stmt, error_listener.errors)
statements.append(_stmt)

else:
# We might still have error(s) that we couldn't match with their origin statement,
# this happens when the query is composed of only one keyword, e.g. 'SELCT 1'
# the error.query will be 'SELCT' instead of 'SELCT 1'.
if len(error_listener.errors) == 1:
# This case has an edge case where we hypothetically assign the
# wrong error to a statement, for example:
# SELECT A FROM tbl1;
# SELEC 1;
# This would match both conditionals, this however is protected by
# by https://github.com/crate/cratedb-sqlparse/issues/28, but might
# change in the future.
error = error_listener.errors[0]
for _stmt in statements:
if _stmt.exception is None and error.query in _stmt.query:
_stmt.exception = error
break

if len(error_listener.errors) > 1:
logging.error(
"Could not match errors to queries, too much ambiguity, open an issue with this error and the query."
)

return statements
2 changes: 1 addition & 1 deletion cratedb_sqlparse_py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ check = [
format = [
{ cmd = "ruff format ." },
# Configure Ruff not to auto-fix (remove!) unused variables (F841) and `print` statements (T201).
{ cmd = "ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 ." },
{ cmd = "ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=E501 ." },
{ cmd = "pyproject-fmt --keep-full-version pyproject.toml" },
]

Expand Down
79 changes: 79 additions & 0 deletions cratedb_sqlparse_py/tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest


def test_exception_message():
from cratedb_sqlparse import sqlparse

r = sqlparse("""
SELEC 1;
surister marked this conversation as resolved.
Show resolved Hide resolved
SELECT A, B, C, D FROM tbl1;
SELECT D, A FROM tbl1 WHERE;
""")
expected_message = "InputMismatchException[line 2:9 mismatched input 'SELEC' expecting {'SELECT', 'DEALLOCATE', 'FETCH', 'END', 'WITH', 'CREATE', 'ALTER', 'KILL', 'CLOSE', 'BEGIN', 'START', 'COMMIT', 'ANALYZE', 'DISCARD', 'EXPLAIN', 'SHOW', 'OPTIMIZE', 'REFRESH', 'RESTORE', 'DROP', 'INSERT', 'VALUES', 'DELETE', 'UPDATE', 'SET', 'RESET', 'COPY', 'GRANT', 'DENY', 'REVOKE', 'DECLARE'}]" # noqa
surister marked this conversation as resolved.
Show resolved Hide resolved
expected_message_2 = "\n SELEC 1;\n ^^^^^\n SELECT A, B, C, D FROM tbl1;\n SELECT D, A FROM tbl1 WHERE;\n " # noqa
assert r[0].exception.error_message == expected_message
assert r[0].exception.original_query_with_error_marked == expected_message_2


def test_sqlparse_raises_exception():
from cratedb_sqlparse import ParsingException, sqlparse

query = "SELCT 2"

with pytest.raises(ParsingException):
sqlparse(query, raise_exception=True)


def test_sqlparse_collects_exception():
from cratedb_sqlparse import sqlparse

query = "SELCT 2"

statements = sqlparse(query)
assert statements[0]


def test_sqlparse_collects_exceptions():
from cratedb_sqlparse import sqlparse

r = sqlparse("""
SELECT A FROM tbl1 where ;
SELECT 1;
SELECT D, A FROM tbl1 WHERE;
""")

assert len(r) == 3

assert r[0].exception is not None
assert r[1].exception is None
assert r[2].exception is not None


def test_sqlparse_collects_exceptions_2():
from cratedb_sqlparse import sqlparse

# Different combination of the query to validate
r = sqlparse("""
SELEC 1;
SELECT A, B, C, D FROM tbl1;
SELECT D, A FROM tbl1 WHERE;
""")

assert r[0].exception is not None
assert r[1].exception is None
assert r[2].exception is not None


def test_sqlparse_collects_exceptions_3():
from cratedb_sqlparse import sqlparse

# Different combination of the query to validate
r = sqlparse("""
SELECT 1;
SELECT A, B, C, D FROM tbl1;
INSERT INTO doc.tbl VALUES (1,2, 'three', ['four']);
""")

assert r[0].exception is None
assert r[1].exception is None
assert r[2].exception is None
20 changes: 12 additions & 8 deletions cratedb_sqlparse_py/tests/test_lexer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import pytest


def test_sqlparser_one_statement(query=None):
from cratedb_sqlparse import sqlparse

Expand Down Expand Up @@ -44,13 +41,20 @@ def test_sqlparse_dollar_string():
assert r[0].query == query


def test_sqlparse_raises_exception():
from cratedb_sqlparse import ParsingException, sqlparse
def test_sqlparse_multiquery_edge_case():
# Test for https://github.com/crate/cratedb-sqlparse/issues/28,
# if this ends up parsing 3 statements, we can change this test,
# it's here so we can programmatically track if the behavior changes.
from cratedb_sqlparse import sqlparse

query = "SALUT MON AMIE"
query = """
SELECT A FROM tbl1 where ;
SELEC 1;
SELECT D, A FROM tbl1 WHERE;
"""

with pytest.raises(ParsingException):
sqlparse(query)
statements = sqlparse(query)
assert len(statements) == 1


def test_sqlparse_is_case_insensitive():
Expand Down