Skip to content

Commit

Permalink
chore(rls): Remove passing global username (#20344)
Browse files Browse the repository at this point in the history
* chore(rls): Remove passing global username

* Update manager.py

* Update manager.py

* Update manager.py

* Update manager.py

Co-authored-by: John Bodley <[email protected]>
  • Loading branch information
john-bodley and John Bodley authored Jul 5, 2022
1 parent 92bf1b8 commit ad308fb
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 24 deletions.
5 changes: 1 addition & 4 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,22 +1169,19 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool:
def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
username: Optional[str] = None,
) -> List[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
Flask global namespace.
:param template_processor: The template processor to apply to the filters.
:param username: Optional username if there's no user in the Flask global
namespace.
:returns: A list of SQL clauses to be ANDed together.
"""
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self, username):
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
Expand Down
19 changes: 5 additions & 14 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,25 +1147,16 @@ def get_guest_rls_filters(
]
return []

def get_rls_filters(
self,
table: "BaseDatasource",
username: Optional[str] = None,
) -> List[SqlaQuery]:
def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]:
"""
Retrieves the appropriate row level security filters for the current user and
the passed table.
:param BaseDatasource table: The table to check against.
:param Optional[str] username: Optional username if there's no user in the Flask
global namespace.
:param table: The table to check against
:returns: A list of filters
"""
if hasattr(g, "user"):
user = g.user
elif username:
user = self.find_user(username=username)
else:

if not (hasattr(g, "user") and g.user is not None):
return []

# pylint: disable=import-outside-toplevel
Expand All @@ -1175,7 +1166,7 @@ def get_rls_filters(
RowLevelSecurityFilter,
)

user_roles = [role.id for role in self.get_user_roles(user)]
user_roles = [role.id for role in self.get_user_roles(g.user)]
regular_filter_roles = (
self.get_session()
.query(RLSFilterRoles.c.rls_filter_id)
Expand Down
1 change: 0 additions & 1 deletion superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem
parsed_query._parsed[0], # pylint: disable=protected-access
database.id,
query.schema,
username=get_username(),
)
)
)
Expand Down
6 changes: 2 additions & 4 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
username: Optional[str] = None,
) -> Optional[TokenList]:
"""
Given a table name, return any associated RLS predicates.
Expand Down Expand Up @@ -586,7 +585,7 @@ def get_rls_for_table(
template_processor = dataset.get_template_processor()
predicate = " AND ".join(
str(filter_)
for filter_ in dataset.get_sqla_row_level_filters(template_processor, username)
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
)
if not predicate:
return None
Expand All @@ -601,7 +600,6 @@ def insert_rls(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
username: Optional[str] = None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
Expand All @@ -623,7 +621,7 @@ def insert_rls(
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
rls = get_rls_for_table(token, database_id, default_schema, username)
rls = get_rls_for_table(token, database_id, default_schema)
if rls:
state = InsertRLSState.FOUND_TABLE

Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,6 @@ def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: str,
username: Optional[str] = None,
) -> Optional[TokenList]:
"""
Return the RLS ``condition`` if ``candidate`` matches ``table``.
Expand Down

0 comments on commit ad308fb

Please sign in to comment.