Skip to content

Commit

Permalink
SQLA v2 API: Do not build empty query filters
Browse files Browse the repository at this point in the history
Fixes #2475
  • Loading branch information
chrisjsewell committed Sep 16, 2021
1 parent 70a7018 commit 87c96a0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 25 deletions.
28 changes: 18 additions & 10 deletions aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class SqlaJoiner:
"""A class containing the logic for SQLAlchemy entities joining entities."""

def __init__(
self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType], BooleanClauseList]
self, entity_mapper: _EntityMapper, filter_builder: Callable[[AliasedClass, FilterType],
Optional[BooleanClauseList]]
):
"""Initialise the class"""
self._entities = entity_mapper
Expand Down Expand Up @@ -185,7 +186,13 @@ def _join_descendants_recursive(
link1 = aliased(self._entities.Link)
link2 = aliased(self._entities.Link)
node1 = aliased(self._entities.Node)

link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links
in_recursive_filters = self._build_filters(node1, filter_dict)
if in_recursive_filters is None:
filters = link_filters
else:
filters = and_(in_recursive_filters, link_filters)

selection_walk_list = [
link1.input_id.label('ancestor_id'),
Expand All @@ -195,12 +202,8 @@ def _join_descendants_recursive(
if expand_path:
selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path'))

walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where(
and_(
in_recursive_filters, # I apply filters for speed here
link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links
)
).cte(recursive=True)
walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)
).where(filters).cte(recursive=True)

aliased_walk = aliased(walk)

Expand Down Expand Up @@ -248,7 +251,13 @@ def _join_ancestors_recursive(
link1 = aliased(self._entities.Link)
link2 = aliased(self._entities.Link)
node1 = aliased(self._entities.Node)

link_filters = link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # follow input / create links
in_recursive_filters = self._build_filters(node1, filter_dict)
if in_recursive_filters is None:
filters = link_filters
else:
filters = and_(in_recursive_filters, link_filters)

selection_walk_list = [
link1.input_id.label('ancestor_id'),
Expand All @@ -258,9 +267,8 @@ def _join_ancestors_recursive(
if expand_path:
selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path'))

walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where(
and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
).cte(recursive=True)
walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)
).where(filters).cte(recursive=True)

aliased_walk = aliased(walk)

Expand Down
33 changes: 19 additions & 14 deletions aiida/orm/implementation/sqlalchemy/querybuilder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def _build(self) -> Query:
alias = self._get_tag_alias(tag)
except KeyError:
raise ValueError(f'Unknown tag {tag!r} in filters, known: {list(self._tag_to_alias)}')
self._query = self._query.filter(self.build_filters(alias, filter_specs))
filters = self.build_filters(alias, filter_specs)
if filters is not None:
self._query = self._query.filter(filters)

# PROJECTIONS ##########################

Expand Down Expand Up @@ -601,7 +603,7 @@ def get_column(colname: str, alias: AliasedClass) -> InstrumentedAttribute:
'{}'.format(colname, alias, '\n'.join(alias._sa_class_manager.mapper.c.keys())) # pylint: disable=protected-access
) from exc

def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> BooleanClauseList:
def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Optional[BooleanClauseList]: # pylint: disable=too-many-branches
"""Recurse through the filter specification and apply filter operations.
:param alias: The alias of the ORM class the filter will be applied on
Expand All @@ -612,17 +614,20 @@ def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Boo
expressions: List[Any] = []
for path_spec, filter_operation_dict in filter_spec.items():
if path_spec in ('and', 'or', '~or', '~and', '!and', '!or'):
subexpressions = [
self.build_filters(alias, sub_filter_spec) for sub_filter_spec in filter_operation_dict
]
if path_spec == 'and':
expressions.append(and_(*subexpressions))
elif path_spec == 'or':
expressions.append(or_(*subexpressions))
elif path_spec in ('~and', '!and'):
expressions.append(not_(and_(*subexpressions)))
elif path_spec in ('~or', '!or'):
expressions.append(not_(or_(*subexpressions)))
subexpressions = []
for sub_filter_spec in filter_operation_dict:
filters = self.build_filters(alias, sub_filter_spec)
if filters is not None:
subexpressions.append(filters)
if subexpressions:
if path_spec == 'and':
expressions.append(and_(*subexpressions))
elif path_spec == 'or':
expressions.append(or_(*subexpressions))
elif path_spec in ('~and', '!and'):
expressions.append(not_(and_(*subexpressions)))
elif path_spec in ('~or', '!or'):
expressions.append(not_(or_(*subexpressions)))
else:
column_name = path_spec.split('.')[0]

Expand Down Expand Up @@ -650,7 +655,7 @@ def build_filters(self, alias: AliasedClass, filter_spec: Dict[str, Any]) -> Boo
alias=alias
)
)
return and_(*expressions)
return and_(*expressions) if expressions else None

def modify_expansions(self, alias: AliasedClass, expansions: List[str]) -> List[str]:
"""Modify names of projections if `**` was specified.
Expand Down
10 changes: 9 additions & 1 deletion tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ class TestQueryBuilderCornerCases:
In this class corner cases of QueryBuilder are added.
"""

def test_computer_json(self): # pylint: disable=no-self-use
def test_computer_json(self):
"""
In this test we check the correct behavior of QueryBuilder when
retrieving the _metadata with no content.
Expand All @@ -818,6 +818,14 @@ def test_computer_json(self): # pylint: disable=no-self-use
qb.append(orm.Computer, project=['id', 'metadata'], outerjoin=True, with_node='calc')
qb.all()

def test_empty_filters(self):
"""Test that an empty filter is correctly handled."""
orm.Data().store()
qb = orm.QueryBuilder().append(orm.Data, filters={})
assert qb.count() == 1
qb = orm.QueryBuilder().append(orm.Data, filters={'or': [{}, {}]})
assert qb.count() == 1


@pytest.mark.usefixtures('clear_database_before_test')
class TestAttributes:
Expand Down

0 comments on commit 87c96a0

Please sign in to comment.