-
Notifications
You must be signed in to change notification settings - Fork 14.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: helper functions for RLS #19055
Changes from 4 commits
eca77d7
11c2bb3
6942b9e
746a919
5b32f9b
b5bfb94
b21a19f
e816857
00598fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,10 @@ | |
IdentifierList, | ||
Parenthesis, | ||
remove_quotes, | ||
Statement, | ||
Token, | ||
TokenList, | ||
Where, | ||
) | ||
from sqlparse.tokens import ( | ||
CTE, | ||
|
@@ -458,3 +460,178 @@ def validate_filter_clause(clause: str) -> None: | |
) | ||
if open_parens > 0: | ||
raise QueryClauseValidationException("Unclosed parenthesis in filter clause") | ||
|
||
|
||
def has_table_query(statement: Statement) -> bool: | ||
""" | ||
Return if a stament has a query reading from a table. | ||
|
||
>>> has_table_query(sqlparse.parse("COUNT(*)")[0]) | ||
False | ||
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0]) | ||
True | ||
|
||
Note that queries reading from constant values return false: | ||
|
||
>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0]) | ||
False | ||
|
||
""" | ||
seen_source = False | ||
tokens = statement.tokens[:] | ||
while tokens: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You likely could just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
>>> list(sqlparse.parse('SELECT * FROM my_table')[0].flatten())
[<DML 'SELECT' at 0x10FF019A0>, <Whitespace ' ' at 0x10FF01D00>, <Wildcard '*' at 0x10FF01D60>, <Whitespace ' ' at 0x10FF01DC0>, <Keyword 'FROM' at 0x10FF01E20>, <Whitespace ' ' at 0x10FF01E80>, <Name 'my_tab...' at 0x10FF01EE0>] Since I'm looking for identifiers after a |
||
token = tokens.pop(0) | ||
if isinstance(token, TokenList): | ||
tokens.extend(token.tokens) | ||
|
||
if token.ttype == Keyword and token.value.lower() in ("from", "join"): | ||
betodealmeida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
seen_source = True | ||
elif seen_source and ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The challenge here is there's no strong connection to ensure that the consecutive (or near consecutive) tokens are those which are being identified here. I guess the question is how robust do we want this logic. The proposed solution may well we suffice. The correct way of doing this is more of a tree traversal (as opposed to a flattened list) where one checks the next token (which could be a group) from the My sense is that can likely be addressed later. We probably need to cleanup the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @john-bodley I reimplemented it following the same logic as |
||
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword | ||
): | ||
return True | ||
elif seen_source and token.ttype not in (Whitespace, Punctuation): | ||
seen_source = False | ||
|
||
return False | ||
|
||
|
||
def add_table_name(rls: TokenList, table: str) -> None: | ||
betodealmeida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Modify a RLS expression ensuring columns are fully qualified. | ||
""" | ||
tokens = rls.tokens[:] | ||
while tokens: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You likely could use for token in list(rls.flatten()):
if imt(token, i=Identifier) and token.get_parent_name() is None:
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue, if we call |
||
token = tokens.pop(0) | ||
|
||
if isinstance(token, Identifier) and token.get_parent_name() is None: | ||
token.tokens = [ | ||
Token(Name, table), | ||
Token(Punctuation, "."), | ||
Token(Name, token.get_name()), | ||
] | ||
elif isinstance(token, TokenList): | ||
tokens.extend(token.tokens) | ||
|
||
|
||
class InsertRLSState(str, Enum): | ||
""" | ||
State machine that scans for WHERE and ON clauses referencing tables. | ||
""" | ||
|
||
SCANNING = "SCANNING" | ||
SEEN_SOURCE = "SEEN_SOURCE" | ||
FOUND_TABLE = "FOUND_TABLE" | ||
|
||
|
||
def matches_table_name(token: Token, table: str) -> bool: | ||
""" | ||
Return the name of a table. | ||
betodealmeida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
A table should be represented as an identifier, but due to sqlparse's aggressive list | ||
of keywords (spanning multiple dialects) often it gets classified as a keyword. | ||
""" | ||
candidate = token.value | ||
|
||
# match from right to left, splitting on the period, eg, schema.table == table | ||
for left, right in zip(candidate.split(".")[::-1], table.split(".")[::-1]): | ||
betodealmeida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if left != right: | ||
return False | ||
|
||
return True | ||
|
||
|
||
def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList: | ||
""" | ||
Update a statement inpalce applying an RLS associated with a given table. | ||
betodealmeida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
# make sure the identifier has the table name | ||
add_table_name(rls, table) | ||
|
||
state = InsertRLSState.SCANNING | ||
for token in token_list.tokens: | ||
|
||
# Recurse into child token list | ||
if isinstance(token, TokenList): | ||
i = token_list.tokens.index(token) | ||
token_list.tokens[i] = insert_rls(token, table, rls) | ||
|
||
# Found a source keyword (FROM/JOIN) | ||
if token.ttype == Keyword and token.value.lower() in ("from", "join"): | ||
state = InsertRLSState.SEEN_SOURCE | ||
|
||
# Found identifier/keyword after FROM/JOIN, test for table | ||
elif state == InsertRLSState.SEEN_SOURCE and ( | ||
isinstance(token, Identifier) or token.ttype == Keyword | ||
): | ||
if matches_table_name(token, table): | ||
state = InsertRLSState.FOUND_TABLE | ||
|
||
# found table at the end of the statement; append a WHERE clause | ||
betodealmeida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if token == token_list[-1]: | ||
token_list.tokens.extend( | ||
[ | ||
Token(Whitespace, " "), | ||
Where( | ||
[Token(Keyword, "WHERE"), Token(Whitespace, " "), rls] | ||
), | ||
] | ||
) | ||
return token_list | ||
|
||
# Found WHERE clause, insert RLS if not present | ||
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where): | ||
if str(rls) not in {str(t) for t in token.tokens}: | ||
token.tokens.extend( | ||
[ | ||
Token(Whitespace, " "), | ||
Token(Keyword, "AND"), | ||
Token(Whitespace, " "), | ||
] | ||
+ rls.tokens | ||
) | ||
state = InsertRLSState.SCANNING | ||
|
||
# Found ON clause, insert RLS if not present | ||
betodealmeida marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif ( | ||
state == InsertRLSState.FOUND_TABLE | ||
and token.ttype == Keyword | ||
and token.value.upper() == "ON" | ||
): | ||
i = token_list.tokens.index(token) | ||
token.parent.tokens[i + 1 : i + 1] = [ | ||
Token(Whitespace, " "), | ||
rls, | ||
Token(Whitespace, " "), | ||
Token(Keyword, "AND"), | ||
] | ||
state = InsertRLSState.SCANNING | ||
|
||
# Found table but no WHERE clause found, insert one | ||
elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace: | ||
i = token_list.tokens.index(token) | ||
|
||
# Left pad with space, if needed | ||
if i > 0 and token_list.tokens[i - 1].ttype != Whitespace: | ||
token_list.tokens.insert(i, Token(Whitespace, " ")) | ||
i += 1 | ||
|
||
# Insert predicate | ||
token_list.tokens.insert( | ||
i, Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]), | ||
) | ||
|
||
# Right pad with space, if needed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does sqlparse even tokenize whitespace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's because it makes it easier to convert the parse tree back to a string. Not sure. |
||
if ( | ||
i < len(token_list.tokens) - 2 | ||
and token_list.tokens[i + 2] != Whitespace | ||
): | ||
token_list.tokens.insert(i + 1, Token(Whitespace, " ")) | ||
|
||
state = InsertRLSState.SCANNING | ||
|
||
# Found nothing, leaving source | ||
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: | ||
state = InsertRLSState.SCANNING | ||
|
||
return token_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@betodealmeida there's also this example which has logic for identifying tables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember the details, but I've had issues with that example code before — I think it failed to identify table names when they were considered keywords (even though the example calls it out).