diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46a627af87..9538733728 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,6 +68,7 @@ repos: aiida/engine/.*py| aiida/manage/manager.py| aiida/manage/database/delete/nodes.py| + aiida/orm/querybuilder.py| aiida/orm/nodes/data/jsonable.py| aiida/orm/nodes/node.py| aiida/orm/nodes/process/.*py| diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index f0dfd50fe2..f2463d1b9d 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,6 +9,16 @@ ########################################################################### """Generic backend related objects""" import abc +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + from aiida.orm.implementation import ( + BackendAuthInfoCollection, BackendCommentCollection, BackendComputerCollection, BackendGroupCollection, + BackendLogCollection, BackendNodeCollection, BackendQueryBuilder, BackendUserCollection + ) + from aiida.backends.general.abstractqueries import AbstractQueryManager __all__ = ('Backend',) @@ -21,85 +31,40 @@ def migrate(self): """Migrate the database to the latest schema generation or version.""" @abc.abstractproperty - def authinfos(self): - """ - Return the collection of authorisation information objects - - :return: the authinfo collection - :rtype: :class:`aiida.orm.implementation.BackendAuthInfoCollection` - """ + def authinfos(self) -> 'BackendAuthInfoCollection': + """Return the collection of authorisation information objects""" @abc.abstractproperty - def comments(self): - """ - Return the collection of comments - - :return: the comment collection - :rtype: :class:`aiida.orm.implementation.BackendCommentCollection` - """ + def comments(self) -> 'BackendCommentCollection': + """Return the collection of comments""" @abc.abstractproperty - def computers(self): - """ - Return the collection of computers - - :return: the computers collection - :rtype: :class:`aiida.orm.implementation.BackendComputerCollection` - """ + def computers(self) -> 'BackendComputerCollection': + """Return the collection of computers""" @abc.abstractproperty - def groups(self): - """ - Return the collection of groups - - :return: the groups collection - :rtype: :class:`aiida.orm.implementation.BackendGroupCollection` - """ + def groups(self) -> 'BackendGroupCollection': + """Return the collection of groups""" @abc.abstractproperty - def logs(self): - """ - Return the collection of logs - - :return: the log collection - :rtype: :class:`aiida.orm.implementation.BackendLogCollection` - """ + def logs(self) -> 'BackendLogCollection': + """Return the collection of logs""" @abc.abstractproperty - def nodes(self): - """ - Return the collection of nodes - - :return: the nodes collection - :rtype: :class:`aiida.orm.implementation.BackendNodeCollection` - """ + def nodes(self) -> 'BackendNodeCollection': + """Return the collection of nodes""" @abc.abstractproperty - def query_manager(self): - """ - Return the query manager for the objects stored in the backend - - :return: The query manger - :rtype: :class:`aiida.backends.general.abstractqueries.AbstractQueryManager` - """ + def query_manager(self) -> 'AbstractQueryManager': + """Return the query manager for the objects stored in the backend""" @abc.abstractmethod - def query(self): - """ - Return an instance of a query builder implementation for this backend - - :return: a new query builder instance - :rtype: :class:`aiida.orm.implementation.BackendQueryBuilder` - """ + def query(self) -> 'BackendQueryBuilder': + """Return an instance of a query builder implementation for this backend""" @abc.abstractproperty - def users(self): - """ - Return the collection of users - - :return: the users collection - :rtype: :class:`aiida.orm.implementation.BackendUserCollection` - """ + def users(self) -> 'BackendUserCollection': + """Return the collection of users""" @abc.abstractmethod def transaction(self): @@ -112,7 +77,7 @@ def transaction(self): """ @abc.abstractmethod - def get_session(self): + def get_session(self) -> 'Session': """Return a database session that can be used by the `QueryBuilder` to perform its query. :return: an instance of :class:`sqlalchemy.orm.session.Session` diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py index acb1fa9af9..de266e524f 100644 --- a/aiida/orm/implementation/querybuilder.py +++ b/aiida/orm/implementation/querybuilder.py @@ -16,6 +16,7 @@ likely be moved to a `SqlAlchemyBasedQueryBuilder` class and restore this abstract class to being a pure agnostic one. """ import abc +from typing import TYPE_CHECKING import uuid # pylint: disable=no-name-in-module,import-error @@ -25,6 +26,9 @@ from aiida.common.lang import type_check +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session # pylint: disable=ungrouped-imports + __all__ = ('BackendQueryBuilder',) @@ -111,7 +115,7 @@ def AiidaNode(self): from aiida.orm import Node return Node - def get_session(self): + def get_session(self) -> 'Session': """ :returns: a valid session, an instance of :class:`sqlalchemy.orm.session.Session` """ diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 859180c508..b7376dccf8 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -22,6 +22,7 @@ from inspect import isclass as inspect_isclass import copy import logging +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING import warnings from sqlalchemy import and_, or_, not_, func as sa_func, select, join @@ -43,6 +44,10 @@ from . import entities from . import convert +if TYPE_CHECKING: + from sqlalchemy.orm import Query # pylint: disable=ungrouped-imports + from aiida.orm.implementation import Backend # pylint: disable=ungrouped-imports + __all__ = ('QueryBuilder',) _LOGGER = logging.getLogger(__name__) @@ -55,6 +60,28 @@ # subclassing for any entity type. This workaround should then be able to be removed. GROUP_ENTITY_TYPE_PREFIX = 'group.' +# re-usable type annotations +NodeClsType = Type[Any] # pylint: disable=invalid-name +ProjectType = Union[str, dict, Sequence[Union[str, dict]]] # pylint: disable=invalid-name +FilterType = Dict[str, Any] # pylint: disable=invalid-name +RowType = Any # pylint: disable=invalid-name + +try: + # new in python 3.8 + from typing import TypedDict # pylint: disable=ungrouped-imports + + class PathItemType(TypedDict): + """An item on the query path""" + + entity_type: Any + tag: str + joining_keyword: str + joining_value: str + outerjoin: bool + edge_tag: str +except ImportError: + PathItemType = Dict[str, Any] # type: ignore + def get_querybuilder_classifiers_from_cls(cls, query): # pylint: disable=invalid-name """ @@ -73,7 +100,7 @@ def get_querybuilder_classifiers_from_cls(cls, query): # pylint: disable=invali from aiida.engine import Process from aiida.orm.utils.node import is_valid_node_type_string - classifiers = {} + classifiers: Dict[str, Optional[str]] = {} classifiers['process_type_string'] = None @@ -166,12 +193,12 @@ def get_querybuilder_classifiers_from_type(ormclass_type_string, query): # pyli Same as get_querybuilder_classifiers_from_cls, but accepts a string instead of a class. """ from aiida.orm.utils.node import is_valid_node_type_string - classifiers = {} + classifiers: Dict[str, Optional[str]] = {} classifiers['process_type_string'] = None classifiers['ormclass_type_string'] = ormclass_type_string.lower() - if classifiers['ormclass_type_string'].startswith(GROUP_ENTITY_TYPE_PREFIX): + if ormclass_type_string.lower().startswith(GROUP_ENTITY_TYPE_PREFIX): classifiers['ormclass_type_string'] = 'group.core' ormclass = query.Group elif classifiers['ormclass_type_string'] == 'computer': @@ -190,7 +217,7 @@ def get_querybuilder_classifiers_from_type(ormclass_type_string, query): # pyli return ormclass, classifiers -def get_node_type_filter(classifiers, subclassing): +def get_node_type_filter(classifiers: dict, subclassing: bool) -> dict: """ Return filter dictionaries given a set of classifiers. @@ -215,7 +242,7 @@ def get_node_type_filter(classifiers, subclassing): return filters -def get_process_type_filter(classifiers, subclassing): +def get_process_type_filter(classifiers: dict, subclassing: bool) -> dict: """ Return filter dictionaries given a set of classifiers. @@ -282,7 +309,7 @@ def get_process_type_filter(classifiers, subclassing): return filters -def get_group_type_filter(classifiers, subclassing): +def get_group_type_filter(classifiers: dict, subclassing: bool) -> dict: """Return filter dictionaries for `Group.type_string` given a set of classifiers. :param classifiers: a dictionary with classifiers (note: does *not* support lists) @@ -332,17 +359,28 @@ class QueryBuilder: _EDGE_TAG_DELIM = '--' _VALID_PROJECTION_KEYS = ('func', 'cast') - def __init__(self, backend=None, **kwargs): + def __init__( + self, + backend: Optional['Backend'] = None, + *, + debug: bool = False, + path: Optional[Sequence[Union[str, Dict[str, Any], NodeClsType]]] = (), + filters: Optional[Dict[str, FilterType]] = None, + project: Optional[Dict[str, ProjectType]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + order_by: Optional[Any] = None, + ) -> None: """ Instantiates a QueryBuilder instance. Which backend is used decided here based on backend-settings (taken from the user profile). - This cannot be overriden so far by the user. + This cannot be overridden so far by the user. - :param bool debug: + :param debug: Turn on debug mode. This feature prints information on the screen about the stages of the QueryBuilder. Does not affect results. - :param list path: + :param path: A list of the vertices to traverse. Leave empty if you plan on using the method :func:`QueryBuilder.append`. :param filters: @@ -353,10 +391,10 @@ def __init__(self, backend=None, **kwargs): The projections to apply. You can specify the projections here, when appending to the query using :func:`QueryBuilder.append` or even later using :func:`QueryBuilder.add_projection`. Latter gives you API-details. - :param int limit: + :param limit: Limit the number of rows to this number. Check :func:`QueryBuilder.limit` for more information. - :param int offset: + :param offset: Set an offset for the results returned. Details in :func:`QueryBuilder.offset`. :param order_by: How to order the results. As the 2 above, can be set also at later stage, @@ -367,26 +405,23 @@ def __init__(self, backend=None, **kwargs): self._impl = backend.query() # A list storing the path being traversed by the query - self._path = [] - - # A list of unique aliases in same order as path - self._aliased_path = [] + self._path: List[PathItemType] = [] # A dictionary tag:alias of ormclass # redundant but makes life easier - self.tag_to_alias_map = {} - self.tag_to_projected_property_dict = {} + self.tag_to_alias_map: Dict[str, Any] = {} + self.tag_to_projected_property_dict: Dict[str, dict] = {} # A dictionary tag: filter specification for this alias - self._filters = {} + self._filters: Dict[str, FilterType] = {} # A dictionary tag: projections for this alias - self._projections = {} + self._projections: Dict[str, List[dict]] = {} self.nr_of_projections = 0 - self._attrkeys_as_in_sql_result = None + self._attrkeys_as_in_sql_result: Optional[dict] = None - self._query = None + self._query: 'Query' = None # A dictionary for classes passed to the tag given to them # Everything is specified with unique tags, which are strings. @@ -402,10 +437,10 @@ def __init__(self, backend=None, **kwargs): # {PwCalculation:'PwCalculation', StructureData:'StructureData'} # Keep in mind that it needs to be checked (and this is done) whether the class # is used twice. In that case, the user has to provide a tag! - self._cls_to_tag_map = {} + self._cls_to_tag_map: Dict[Any, str] = {} - # Hashing the the internal queryhelp allows me to avoid to build a query again - self._hash = None + # Hashing the internal queryhelp avoids rebuild a query + self._hash: Optional[str] = None # The hash being None implies that the query will be build (Check the code in .get_query # The user can inject a query, this keyword stores whether this was done. @@ -413,11 +448,10 @@ def __init__(self, backend=None, **kwargs): self._injected = False # Setting debug levels: - self.set_debug(kwargs.pop('debug', False)) + self.set_debug(debug) # One can apply the path as a keyword. Allows for jsons to be given to the QueryBuilder. - path = kwargs.pop('path', []) - if not isinstance(path, (tuple, list)): + if not isinstance(path, (list, tuple)): raise TypeError('Path needs to be a tuple or a list') # If the user specified a path, I use the append method to analyze, see QueryBuilder.append for path_spec in path: @@ -433,43 +467,31 @@ def __init__(self, backend=None, **kwargs): # Projections. The user provides a dictionary, but the specific checks is # left to QueryBuilder.add_project. - projection_dict = kwargs.pop('project', {}) + projection_dict = project or {} if not isinstance(projection_dict, dict): raise TypeError('You need to provide the projections as dictionary') for key, val in projection_dict.items(): self.add_projection(key, val) # For filters, I also expect a dictionary, and the checks are done lower. - filter_dict = kwargs.pop('filters', {}) + filter_dict = filters or {} if not isinstance(filter_dict, dict): raise TypeError('You need to provide the filters as dictionary') for key, val in filter_dict.items(): self.add_filter(key, val) # The limit is caps the number of results returned, and can also be set with QueryBuilder.limit - self.limit(kwargs.pop('limit', None)) + self.limit(limit) # The offset returns results after the offset - self.offset(kwargs.pop('offset', None)) + self.offset(offset) # The user can also specify the order. - self._order_by = {} - order_spec = kwargs.pop('order_by', None) - if order_spec: - self.order_by(order_spec) - - # I've gone through all the keywords, popping each item - # If kwargs is not empty, there is a problem: - if kwargs: - valid_keys = ('path', 'filters', 'project', 'limit', 'offset', 'order_by') - raise ValueError( - 'Received additional keywords: {}' - '\nwhich I cannot process' - '\nValid keywords are: {}' - ''.format(list(kwargs.keys()), valid_keys) - ) + self._order_by: List[dict] = [] + if order_by: + self.order_by(order_by) - def __str__(self): + def __str__(self) -> str: """ When somebody hits: print(QueryBuilder) or print(str(QueryBuilder)) I want to print the SQL-query. Because it looks cool... @@ -531,11 +553,11 @@ def _get_ormclass(self, cls, ormclass_type_string): return ormclass, classifiers - def _get_unique_tag(self, classifiers): + def _get_unique_tag(self, classifiers) -> str: """ Using the function get_tag_from_type, I get a tag. I increment an index that is appended to that tag until I have an unused tag. - This function is called in :func:`QueryBuilder.append` when autotag is set to True. + This function is called in :func:`QueryBuilder.append` when no tag is given. :param dict classifiers: Classifiers, containing the string that defines the type of the AiiDA ORM class. @@ -580,18 +602,20 @@ def get_tag_from_type(classifiers): def append( self, - cls=None, - entity_type=None, - tag=None, - filters=None, - project=None, - subclassing=True, - edge_tag=None, - edge_filters=None, - edge_project=None, - outerjoin=False, - **kwargs - ): + cls: Optional[Union[NodeClsType, Sequence[NodeClsType]]] = None, + entity_type: Optional[Union[str, Sequence[str]]] = None, + tag: Optional[str] = None, + filters: Optional[FilterType] = None, + project: Optional[ProjectType] = None, + subclassing: bool = True, + edge_tag: Optional[str] = None, + edge_filters: Optional[FilterType] = None, + edge_project: Optional[ProjectType] = None, + outerjoin: bool = False, + joining_keyword: Optional[str] = None, + joining_value: Optional[Any] = None, + **kwargs: Any + ) -> 'QueryBuilder': """ Any iterative procedure to build the path for a graph query needs to invoke this method to append to the path. @@ -609,9 +633,7 @@ def append( cls=(Group, Node) :param entity_type: The node type of the class, if cls is not given. Also here, a tuple or list is accepted. - :type type: str - :param bool autotag: Whether to find automatically a unique tag. If this is set to True (default False), - :param str tag: + :param tag: A unique tag. If none is given, I will create a unique tag myself. :param filters: Filters to apply for this vertex. @@ -619,21 +641,28 @@ def append( :param project: Projections to apply. See usage examples for details. More information also in :meth:`.add_projection`. - :param bool subclassing: - Whether to include subclasses of the given class - (default **True**). - E.g. Specifying a ProcessNode as cls will include CalcJobNode, WorkChainNode, CalcFunctionNode, etc.. - :param bool outerjoin: - If True, (default is False), will do a left outerjoin - instead of an inner join - :param str edge_tag: + :param subclassing: + Whether to include subclasses of the given class (default **True**). + E.g. Specifying a ProcessNode as cls will include CalcJobNode, WorkChainNode, CalcFunctionNode, etc.. + :param edge_tag: The tag that the edge will get. If nothing is specified (and there is a meaningful edge) the default is tag1--tag2 with tag1 being the entity joining from and tag2 being the entity joining to (this entity). - :param str edge_filters: + :param edge_filters: The filters to apply on the edge. Also here, details in :meth:`.add_filter`. - :param str edge_project: + :param edge_project: The project from the edges. API-details in :meth:`.add_projection`. + :param outerjoin: + If True, (default is False), will do a left outerjoin + instead of an inner join + + Joining can be specified in two ways: + + - Specifying the 'joining_keyword' and 'joining_value' arguments + - Specify a single keyword argument + + The joining keyword wil be ``with_*`` or ``direction``, depending on the joining entity type. + The joining value is the tag name or class of the entity to join to. A small usage example how this can be invoked:: @@ -648,14 +677,11 @@ def append( ) :return: self - :rtype: :class:`aiida.orm.QueryBuilder` """ # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements # INPUT CHECKS ########################## - # This function can be called by users, so I am checking the - # input now. - # First of all, let's make sure the specified - # the class or the type (not both) + # This function can be called by users, so I am checking the input now. + # First of all, let's make sure the specified the class or the type (not both) if cls is not None and entity_type is not None: raise ValueError(f'You cannot specify both a class ({cls}) and a entity_type ({entity_type})') @@ -665,21 +691,19 @@ def append( # Let's check if it is a valid class or type if cls: - if isinstance(cls, (tuple, list, set)): + if isinstance(cls, (list, tuple)): for sub_cls in cls: if not inspect_isclass(sub_cls): raise TypeError(f"{sub_cls} was passed with kw 'cls', but is not a class") - else: - if not inspect_isclass(cls): - raise TypeError(f"{cls} was passed with kw 'cls', but is not a class") + elif not inspect_isclass(cls): + raise TypeError(f"{cls} was passed with kw 'cls', but is not a class") elif entity_type is not None: - if isinstance(entity_type, (tuple, list, set)): + if isinstance(entity_type, (list, tuple)): for sub_type in entity_type: if not isinstance(sub_type, str): raise TypeError(f'{sub_type} was passed as entity_type, but is not a string') - else: - if not isinstance(entity_type, str): - raise TypeError(f'{entity_type} was passed as entity_type, but is not a string') + elif not isinstance(entity_type, str): + raise TypeError(f'{entity_type} was passed as entity_type, but is not a string') ormclass, classifiers = self._get_ormclass(cls, entity_type) @@ -711,7 +735,7 @@ def append( if isinstance(cls, (list, set)): tag_key = tuple(cls) else: - tag_key = cls + tag_key = cls # type: ignore[assignment] if tag_key in self._cls_to_tag_map.keys(): # In this case, this class already stands for another @@ -729,9 +753,7 @@ def append( try: self.tag_to_alias_map[tag] = aliased(ormclass) except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append, cleaning up') - print(' ', exception) + self.debug('Exception caught in append, cleaning up: %s', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) @@ -755,9 +777,7 @@ def append( if filters is not None: self.add_filter(tag, filters) except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append (part filters), cleaning up') - print(' ', exception) + self.debug('Exception caught in append, cleaning up: %s', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag) @@ -770,9 +790,7 @@ def append( if project is not None: self.add_projection(tag, project) except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append (part projections), cleaning up') - print(' ', exception) + self.debug('Exception caught in append, cleaning up: %s', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) @@ -784,57 +802,49 @@ def append( # pylint: disable=too-many-nested-blocks try: # Get the functions that are implemented: - spec_to_function_map = [] + # 'direction 'was an old implementation, which is now converted below to with_outgoing or with_incoming + spec_to_function_map = {'direction'} for secondary_dict in self._get_function_map().values(): - for key in secondary_dict.keys(): - if key not in spec_to_function_map: - spec_to_function_map.append(key) - joining_keyword = kwargs.pop('joining_keyword', None) - joining_value = kwargs.pop('joining_value', None) + spec_to_function_map.update(secondary_dict.keys()) for key, val in kwargs.items(): if key not in spec_to_function_map: raise ValueError( - '{} is not a valid keyword ' - 'for joining specification\n' - 'Valid keywords are: ' - '{}'.format( - key, spec_to_function_map + ['cls', 'type', 'tag', 'autotag', 'filters', 'project'] - ) + f"'{key}' is not a valid keyword for joining specification\n" + f'Valid keywords are: {spec_to_function_map!r}' ) - elif joining_keyword: + if joining_keyword: raise ValueError( 'You already specified joining specification {}\n' 'But you now also want to specify {}' ''.format(joining_keyword, key) ) + + joining_keyword = key + if joining_keyword == 'direction': + if not isinstance(val, int): + raise TypeError('direction=n expects n to be an integer') + try: + if val < 0: + joining_keyword = 'with_outgoing' + elif val > 0: + joining_keyword = 'with_incoming' + else: + raise ValueError('direction=0 is not valid') + joining_value = self._path[-abs(val)]['tag'] + except IndexError as exc: + raise ValueError( + f'You have specified a non-existent entity with\ndirection={joining_value}\n{exc}\n' + ) else: - joining_keyword = key - if joining_keyword == 'direction': - if not isinstance(val, int): - raise TypeError('direction=n expects n to be an integer') - try: - if val < 0: - joining_keyword = 'with_outgoing' - elif val > 0: - joining_keyword = 'with_incoming' - else: - raise ValueError('direction=0 is not valid') - joining_value = self._path[-abs(val)]['tag'] - except IndexError as exc: - raise ValueError( - f'You have specified a non-existent entity with\ndirection={joining_value}\n{exc}\n' - ) - else: - joining_value = self._get_tag_from_specification(val) + joining_value = self._get_tag_from_specification(val) + # the default is that this vertice is 'with_incoming' as the previous one if joining_keyword is None and len(self._path) > 0: joining_keyword = 'with_incoming' joining_value = self._path[-1]['tag'] except Exception as exception: - if self._debug: - print('DEBUG: Exception caught in append (part joining), cleaning up') - print(' ', exception) + self.debug('Exception caught in append (part filters), cleaning up: %s', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) @@ -846,16 +856,13 @@ def append( # EDGES ################################# if len(self._path) > 0: try: - if self._debug: - print('DEBUG: Choosing an edge_tag') if edge_tag is None: edge_destination_tag = self._get_tag_from_specification(joining_value) edge_tag = edge_destination_tag + self._EDGE_TAG_DELIM + tag else: if edge_tag in self.tag_to_alias_map.keys(): raise ValueError(f'The tag {edge_tag} is already in use') - if self._debug: - print('I have chosen', edge_tag) + self.debug('edge_tag chosen: %s', edge_tag) # My edge is None for now, since this is created on the FLY, # the _tag_to_alias_map will be updated later (in _build) @@ -873,11 +880,7 @@ def append( if edge_project is not None: self.add_projection(edge_tag, edge_project) except Exception as exception: - - if self._debug: - print('DEBUG: Exception caught in append (part joining), cleaning up') - import traceback - print(traceback.format_exc()) + self.debug('Exception caught in append (part joining), cleaning up %s', exception) if l_class_added_to_map: self._cls_to_tag_map.pop(cls) self.tag_to_alias_map.pop(tag, None) @@ -902,16 +905,19 @@ def append( dict( entity_type=path_type, tag=tag, - joining_keyword=joining_keyword, - joining_value=joining_value, + # for the first item joining_keyword/joining_value can be None, + # but after they always default to 'with_incoming' of the previous item + joining_keyword=joining_keyword, # type: ignore + joining_value=joining_value, # type: ignore + # same for edge_tag for which a default is applied + edge_tag=edge_tag, # type: ignore outerjoin=outerjoin, - edge_tag=edge_tag ) ) return self - def order_by(self, order_by): + def order_by(self, order_by: Union[dict, List[dict], Tuple[dict, ...]]) -> 'QueryBuilder': """ Set the entity to order by @@ -961,7 +967,7 @@ def order_by(self, order_by): '[columns to sort]' ''.format(order_spec) ) - _order_spec = {} + _order_spec: dict = {} for tagspec, items_to_order_by in order_spec.items(): if not isinstance(items_to_order_by, (tuple, list)): items_to_order_by = [items_to_order_by] @@ -1013,7 +1019,7 @@ def order_by(self, order_by): self._order_by.append(_order_spec) return self - def add_filter(self, tagspec, filter_spec): + def add_filter(self, tagspec: str, filter_spec: FilterType) -> None: """ Adding a filter to my filters. @@ -1036,7 +1042,7 @@ def add_filter(self, tagspec, filter_spec): self._filters[tag].update(filters) @staticmethod - def _process_filters(filters): + def _process_filters(filters: FilterType) -> dict: """Process filters.""" if not isinstance(filters, dict): raise TypeError('Filters have to be passed as dictionaries') @@ -1063,7 +1069,7 @@ def _add_node_type_filter(self, tagspec, classifiers, subclassing): """ if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - entity_type_filter = {'or': []} + entity_type_filter: dict = {'or': []} for classifier in classifiers: entity_type_filter['or'].append(get_node_type_filter(classifier, subclassing)) else: @@ -1083,7 +1089,7 @@ def _add_process_type_filter(self, tagspec, classifiers, subclassing): """ if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - process_type_filter = {'or': []} + process_type_filter: dict = {'or': []} for classifier in classifiers: if classifier['process_type_string'] is not None: process_type_filter['or'].append(get_process_type_filter(classifier, subclassing)) @@ -1106,7 +1112,7 @@ def _add_group_type_filter(self, tagspec, classifiers, subclassing): """ if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - type_string_filter = {'or': []} + type_string_filter: dict = {'or': []} for classifier in classifiers: type_string_filter['or'].append(get_group_type_filter(classifier, subclassing)) else: @@ -1114,7 +1120,7 @@ def _add_group_type_filter(self, tagspec, classifiers, subclassing): self.add_filter(tagspec, {'type_string': type_string_filter}) - def add_projection(self, tag_spec, projection_spec): + def add_projection(self, tag_spec: str, projection_spec: ProjectType) -> None: r""" Adds a projection @@ -1163,11 +1169,9 @@ def add_projection(self, tag_spec, projection_spec): """ tag = self._get_tag_from_specification(tag_spec) _projections = [] - if self._debug: - print('DEBUG: Adding projection of', tag_spec) - print(' projection', projection_spec) + self.debug('Adding projection of %s: %s', tag_spec, projection_spec) if not isinstance(projection_spec, (list, tuple)): - projection_spec = [projection_spec] + projection_spec = [projection_spec] # type: ignore for projection in projection_spec: if isinstance(projection, dict): _thisprojection = projection @@ -1187,8 +1191,7 @@ def add_projection(self, tag_spec, projection_spec): if not isinstance(val, str): raise TypeError(f'{val} has to be a string') _projections.append(_thisprojection) - if self._debug: - print(' projections have become:', _projections) + self.debug('projections have become: %s', _projections) self._projections[tag] = _projections def _get_projectable_entity(self, alias, column_name, attrpath, **entityspec): @@ -1241,8 +1244,7 @@ def _build_projections(self, tag, items_to_project=None): items_to_project = self._projections.get(tag, []) # Return here if there is nothing to project, reduces number of key in return dictionary - if self._debug: - print(tag, items_to_project) + self.debug('projection for %s: %s', tag, items_to_project) if not items_to_project: return @@ -1289,12 +1291,12 @@ def _get_tag_from_specification(self, specification): ) return tag - def set_debug(self, debug): + def set_debug(self, debug: bool) -> 'QueryBuilder': """ Run in debug mode. This does not affect functionality, but prints intermediate stages when creating a query on screen. - :param bool debug: Turn debug on or off + :param debug: Turn debug on or off """ if not isinstance(debug, bool): return TypeError('I expect a boolean') @@ -1302,19 +1304,26 @@ def set_debug(self, debug): return self - def limit(self, limit): + def debug(self, msg: str, *objects: Any) -> None: + """Log debug message. + + objects will passed to the format string, e.g. ``msg % objects`` """ - Set the limit (nr of rows to return) + if self._debug: + print(f'DEBUG: {msg}' % objects) - :param int limit: integers of number of rows of rows to return + def limit(self, limit: Optional[int]) -> 'QueryBuilder': """ + Set the limit (nr of rows to return) + :param limit: integers of number of rows of rows to return + """ if (limit is not None) and (not isinstance(limit, int)): raise TypeError('The limit has to be an integer, or None') self._limit = limit return self - def offset(self, offset): + def offset(self, offset: Optional[int]) -> 'QueryBuilder': """ Set the offset. If offset is set, that many rows are skipped before returning. *offset* = 0 is the same as omitting setting the offset. @@ -1322,7 +1331,7 @@ def offset(self, offset): then *offset* rows are skipped before starting to count the *limit* rows that are returned. - :param int offset: integers of nr of rows to skip + :param offset: integers of nr of rows to skip """ if (offset is not None) and (not isinstance(offset, int)): raise TypeError('offset has to be an integer, or None') @@ -1716,7 +1725,7 @@ def _join_comment_user(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Comment), (entity_to_join, self._impl.User), 'with_comment') self._query = self._query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) - def _get_function_map(self): + def _get_function_map(self) -> Dict[str, Dict[str, Callable[[Any, Any, bool], None]]]: """ Map relationship type keywords to functions The new mapping (since 1.0.0a5) is a two level dictionary. The first level defines the entity which has been @@ -1733,37 +1742,31 @@ def _get_function_map(self): 'with_computer': self._join_to_computer_used, 'with_user': self._join_created_by, 'with_group': self._join_group_members, - 'direction': None, }, 'computer': { 'with_node': self._join_computer, - 'direction': None, }, 'user': { 'with_comment': self._join_comment_user, 'with_node': self._join_creator_of, 'with_group': self._join_group_user, - 'direction': None, }, 'group': { 'with_node': self._join_groups, 'with_user': self._join_user_group, - 'direction': None, }, 'comment': { 'with_user': self._join_user_comment, 'with_node': self._join_node_comment, - 'direction': None }, 'log': { 'with_node': self._join_node_log, - 'direction': None } } - return mapping + return mapping # type: ignore - def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, **kwargs): + def _get_connecting_node(self, index: int, joining_keyword: str, joining_value: str, **_: Any): """ :param querydict: A dictionary specifying how the current node @@ -1783,32 +1786,22 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, else: calling_entity = entity_type - if joining_keyword == 'direction': - if joining_value > 0: - returnval = self._aliased_path[index - joining_value], self._join_outputs - elif joining_value < 0: - returnval = self._aliased_path[index + joining_value], self._join_inputs - else: - raise Exception('Direction 0 is not valid') - else: + try: + func = self._get_function_map()[calling_entity][joining_keyword] + except KeyError: + raise ValueError(f"'{joining_keyword}' is not a valid joining keyword for a '{calling_entity}' type entity") + + if isinstance(joining_value, str): try: - func = self._get_function_map()[calling_entity][joining_keyword] + return self.tag_to_alias_map[self._get_tag_from_specification(joining_value)], func except KeyError: raise ValueError( - f"'{joining_keyword}' is not a valid joining keyword for a '{calling_entity}' type entity" + 'Key {} is unknown to the types I know about:\n' + '{}'.format(self._get_tag_from_specification(joining_value), self.tag_to_alias_map.keys()) ) - - if isinstance(joining_value, int): - returnval = (self._aliased_path[joining_value], func) - elif isinstance(joining_value, str): - try: - returnval = self.tag_to_alias_map[self._get_tag_from_specification(joining_value)], func - except KeyError: - raise ValueError( - 'Key {} is unknown to the types I know about:\n' - '{}'.format(self._get_tag_from_specification(joining_value), self.tag_to_alias_map.keys()) - ) - return returnval + raise ValueError( + f'Key {self._get_tag_from_specification(joining_value)} value is not a string:\n{joining_value}' + ) @property def queryhelp(self): @@ -1917,11 +1910,7 @@ def _build(self): self.tag_to_projected_property_dict = {} self.nr_of_projections = 0 - if self._debug: - print('DEBUG:') - print(' Printing the content of self._projections') - print(' ', self._projections) - print() + self.debug('self._projections: %s', self._projections) if not any(self._projections.values()): # If user has not set projection, @@ -1936,14 +1925,12 @@ def _build(self): # LINK-PROJECTIONS ######################### for vertex in self._path[1:]: - edge_tag = vertex.get('edge_tag', None) - if self._debug: - print('DEBUG: Checking projections for edges:') - print( - ' This is edge {} from {}, {} of {}'.format( - edge_tag, vertex.get('tag'), vertex.get('joining_keyword'), vertex.get('joining_value') - ) - ) + edge_tag = vertex.get('edge_tag', None) # type: ignore + + self.debug( + 'Checking projections for edges: This is edge %s from %s, %s of %s', edge_tag, vertex.get('tag'), + vertex.get('joining_keyword'), vertex.get('joining_value') + ) if edge_tag is not None: self._build_projections(edge_tag) @@ -1990,13 +1977,7 @@ def _build(self): return self._query - def get_aliases(self): - """ - :returns: the list of aliases - """ - return self._aliased_path - - def get_alias(self, tag): + def get_alias(self, tag: str): """ In order to continue a query by the user, this utility function returns the aliased ormclasses. @@ -2007,12 +1988,12 @@ def get_alias(self, tag): tag = self._get_tag_from_specification(tag) return self.tag_to_alias_map[tag] - def get_used_tags(self, vertices=True, edges=True): + def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]: """ Returns a list of all the vertices that are being used. Some parameter allow to select only subsets. - :param bool vertices: Defaults to True. If True, adds the tags of vertices to the returned list - :param bool edges: Defaults to True. If True, adds the tags of edges to the returnend list. + :param vertices: Defaults to True. If True, adds the tags of vertices to the returned list + :param edges: Defaults to True. If True, adds the tags of edges to the returnend list. :returns: A list of all tags, including (if there is) also the tag give for the edges """ @@ -2068,7 +2049,7 @@ def get_query(self): return query @staticmethod - def get_aiida_entity_res(value): + def get_aiida_entity_res(value) -> RowType: """Convert a projected query result to front end class if it is an instance of a `BackendEntity`. Values that are not an `BackendEntity` instance will be returned unaltered @@ -2081,7 +2062,7 @@ def get_aiida_entity_res(value): except TypeError: return value - def inject_query(self, query): + def inject_query(self, query: 'Query') -> None: """ Manipulate the query an inject it back. This can be done to add custom filters using SQLA. @@ -2093,7 +2074,7 @@ def inject_query(self, query): self._query = query self._injected = True - def distinct(self): + def distinct(self) -> 'QueryBuilder': """ Asks for distinct rows, which is the same as asking the backend to remove duplicates. @@ -2114,7 +2095,7 @@ def distinct(self): self._query = self.get_query().distinct() return self - def first(self): + def first(self) -> Optional[List[RowType]]: """ Executes query asking for one instance. Use as follows:: @@ -2134,12 +2115,12 @@ def first(self): if not isinstance(result, (list, tuple)): result = [result] - if len(result) != len(self._attrkeys_as_in_sql_result): + if not self._attrkeys_as_in_sql_result or len(result) != len(self._attrkeys_as_in_sql_result): raise Exception('length of query result does not match the number of specified projections') return [self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) for colindex, rowitem in enumerate(result)] - def one(self): + def one(self) -> RowType: """ Executes the query asking for exactly one results. Will raise an exception if this is not the case :raises: MultipleObjectsError if more then one row can be returned @@ -2154,7 +2135,7 @@ def one(self): raise NotExistent('No result was found') return res[0] - def count(self): + def count(self) -> int: """ Counts the number of rows returned by the backend. @@ -2163,15 +2144,14 @@ def count(self): query = self.get_query() return self._impl.count(query) - def iterall(self, batch_size=100): + def iterall(self, batch_size: Optional[int] = 100) -> Iterator[List[RowType]]: """ Same as :meth:`.all`, but returns a generator. Be aware that this is only safe if no commit will take place during this transaction. You might also want to read the SQLAlchemy documentation on https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per - - :param int batch_size: + :param batch_size: The size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. @@ -2186,7 +2166,7 @@ def iterall(self, batch_size=100): yield item - def iterdict(self, batch_size=100): + def iterdict(self, batch_size: Optional[int] = 100) -> Iterable[Dict[str, RowType]]: """ Same as :meth:`.dict`, but returns a generator. Be aware that this is only safe if no commit will take place during this @@ -2194,7 +2174,7 @@ def iterdict(self, batch_size=100): https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per - :param int batch_size: + :param batch_size: The size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. @@ -2208,16 +2188,16 @@ def iterdict(self, batch_size=100): yield item - def all(self, batch_size=None, flat=False): + def all(self, batch_size: Optional[int] = None, flat: bool = False) -> Union[List[List[RowType]], List[RowType]]: """Executes the full query with the order of the rows as returned by the backend. The order inside each row is given by the order of the vertices in the path and the order of the projections for each vertex in the path. - :param int batch_size: the size of the batches to ask the backend to batch results in subcollections. You can + :param batch_size: the size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. Leave the default `None` if speed is not critical or if you don't know what you're doing. - :param bool flat: return the result as a flat list of projected entities without sub lists. + :param flat: return the result as a flat list of projected entities without sub lists. :returns: a list of lists of all projected entities. """ matches = list(self.iterall(batch_size=batch_size)) @@ -2227,13 +2207,13 @@ def all(self, batch_size=None, flat=False): return [projection for entry in matches for projection in entry] - def dict(self, batch_size=None): + def dict(self, batch_size: Optional[int] = None) -> List[Dict[str, RowType]]: """ Executes the full query with the order of the rows as returned by the backend. the order inside each row is given by the order of the vertices in the path and the order of the projections for each vertice in the path. - :param int batch_size: + :param batch_size: The size of the batches to ask the backend to batch results in subcollections. You can optimize the speed of the query by tuning this parameter. Leave the default (*None*) if speed is not critical or if you don't know @@ -2281,7 +2261,7 @@ def dict(self, batch_size=None): """ return list(self.iterdict(batch_size=batch_size)) - def inputs(self, **kwargs): + def inputs(self, **kwargs: Any) -> 'QueryBuilder': """ Join to inputs of previous vertice in path. @@ -2290,10 +2270,10 @@ def inputs(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_outgoing=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_outgoing=join_to, **kwargs) return self - def outputs(self, **kwargs): + def outputs(self, **kwargs: Any) -> 'QueryBuilder': """ Join to outputs of previous vertice in path. @@ -2302,10 +2282,10 @@ def outputs(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_incoming=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_incoming=join_to, **kwargs) return self - def children(self, **kwargs): + def children(self, **kwargs: Any) -> 'QueryBuilder': """ Join to children/descendants of previous vertice in path. @@ -2314,10 +2294,10 @@ def children(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_ancestors=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_ancestors=join_to, **kwargs) return self - def parents(self, **kwargs): + def parents(self, **kwargs: Any) -> 'QueryBuilder': """ Join to parents/ancestors of previous vertice in path. @@ -2326,5 +2306,5 @@ def parents(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_descendants=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_descendants=join_to, **kwargs) return self diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 76c37de928..77e228fb28 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -163,4 +163,7 @@ py:meth pgsu.PGSU.__init__ py:class jsonschema.exceptions._Error +py:class Session +py:class Query +py:class BackendQueryBuilder py:class importlib_metadata.EntryPoint