Skip to content

Commit

Permalink
Working on aiidateam#258. Added functionality to selectable descentan…
Browse files Browse the repository at this point in the history
…ts_beta and Declarative DbPathBeta to behave as DbPath does,

with queryable properties being descendant_id (instead of child_id), ancestor_id (parent_id) and depth.
Possibility to add traversed path as an array, though this will only work in Postgresql, so commented out for now.
To initialize the mapper correctly, I added imports in the __init__ of backends.sqlalchemy.model
This might be also of greater convenience when importing!
Added headers to all files in the querybuild module.
  • Loading branch information
lekah committed Sep 17, 2016
1 parent ab79236 commit bd1b52c
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 29 deletions.
28 changes: 15 additions & 13 deletions aiida/backends/querybuild/dummy_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# -*- coding: utf-8 -*-

__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
__license__ = "MIT license, see LICENSE.txt file."
__authors__ = "The AiiDA team."
__version__ = "0.7.0"


"""
The dummy model encodes the model defined by django in backends.djsite
using SQLAlchemy.
Expand Down Expand Up @@ -38,10 +44,6 @@



__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
__license__ = "MIT license, see LICENSE.txt file."
__authors__ = "The AiiDA team."
__version__ = "0.7.0"

Base = declarative_base()

Expand Down Expand Up @@ -386,22 +388,25 @@ def get_aiida_class(self):



# Recursive query that replaces DbPath
# Note: Does not work with sqlite
node_aliased = aliased(DbNode)

walk = select([
DbNode.id.label('start'),
DbNode.id.label('end'),
DbNode.id.label('ancestor_id'),
DbNode.id.label('descendant_id'),
cast(-1, Integer).label('depth'),
array([DbNode.id]).label('path')
# array([DbNode.id]).label('path') Arrays can only be used with postgres, so leave it out for now
]).select_from(DbNode).cte(recursive=True) #, name="incl_aliased3")


descendants_beta = walk.union_all(
select([
walk.c.start,
walk.c.ancestor_id,
node_aliased.id,
walk.c.depth + cast(1, Integer),
(walk.c.path+array([node_aliased.id])).label('path'),
# This is the way to reconstruct the path (the sequence of nodes traversed)
# (walk.c.path+array([node_aliased.id])).label('path'), As above, but if arrays are supported
]).select_from(
join(
node_aliased,
Expand All @@ -410,7 +415,7 @@ def get_aiida_class(self):
)
).where(
and_(
DbLink.input_id == walk.c.end,
DbLink.input_id == walk.c.descendant_id,
)
)
)
Expand All @@ -423,11 +428,8 @@ def __init__(self, start, end, depth):
self.out = end
self.depth = depth



mapper(DbPathBeta, descendants_beta)


#~ DbAttribute.value_str = column_property(
#~ case([
#~ (DbAttribute.datatype == 'txt', DbAttribute.tval),
Expand Down
84 changes: 69 additions & 15 deletions aiida/backends/querybuild/querybuilder_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# -*- coding: utf-8 -*-

__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
__license__ = "MIT license, see LICENSE.txt file."
__authors__ = "The AiiDA team."
__version__ = "0.7.0"


"""
The general functionalities that all querybuilders need to have
Expand Down Expand Up @@ -531,6 +536,9 @@ def append(self, cls=None, type=None, tag=None,
aliased_edge = aliased(self.Link)
elif joining_keyword in ('ancestor_of', 'descendant_of'):
aliased_edge = aliased(self.Path)
elif joining_keyword in ('ancestor_of_beta', 'descendant_of_beta'):
#~ aliased_edge = aliased(self.PathBeta)
aliased_edge = self.PathBeta

if aliased_edge is not None:

Expand Down Expand Up @@ -1261,10 +1269,10 @@ def _join_outputs(self, joined_entity, entity_to_join, aliased_edge):
self._query = self._query.join(
aliased_edge,
aliased_edge.input_id == joined_entity.id
).join(
).join(
entity_to_join,
aliased_edge.output_id == entity_to_join.id
)
)

def _join_inputs(self, joined_entity, entity_to_join, aliased_edge):
"""
Expand All @@ -1283,12 +1291,56 @@ def _join_inputs(self, joined_entity, entity_to_join, aliased_edge):
self._query = self._query.join(
aliased_edge,
aliased_edge.output_id == joined_entity.id
).join(
).join(
entity_to_join,
aliased_edge.input_id == entity_to_join.id
)
)

def _join_descendants_beta(self, joined_entity, entity_to_join, aliased_path):
"""
Beta version, joining descendants using the recursive functionality
"""
self._check_dbentities(
(joined_entity, self.Node),
(entity_to_join, self.Node),
'descendant_of_beta'
)

#~ def _join_descendants_beta(self, joined_entity, entity_to_join, aliased_path):
self._query = self._query.join(
aliased_path,
aliased_path.ancestor_id == joined_entity.id
).join(
entity_to_join,
aliased_path.descendant_id == entity_to_join.id
).filter(
# it is necessary to put this filter so that the
# the node does not include itself as a ancestor/descendant
aliased_path.depth > -1
)
def _join_ancestors_beta(self, joined_entity, entity_to_join, aliased_path):
"""
:param joined_entity: The (aliased) ORMclass that is a descendant
:param entity_to_join: The (aliased) ORMClass that is an ancestor.
:param aliased_path: An aliased instance of DbPath
"""
self._check_dbentities(
(joined_entity, self.Node),
(entity_to_join, self.Node),
'ancestor_of_beta'
)
#~ aliased_path = aliased(self.Path)
self._query = self._query.join(
aliased_path,
aliased_path.descendant_id == joined_entity.id
).join(
entity_to_join,
aliased_path.ancestor_id == entity_to_join.id
).filter(
# it is necessary to put this filter so that the
# the node does not include itself as a ancestor/descendant
aliased_path.depth > -1
)

def _join_descendants(self, joined_entity, entity_to_join, aliased_path):
"""
Expand All @@ -1310,10 +1362,10 @@ def _join_descendants(self, joined_entity, entity_to_join, aliased_path):
self._query = self._query.join(
aliased_path,
aliased_path.parent_id == joined_entity.id
).join(
).join(
entity_to_join,
aliased_path.child_id == entity_to_join.id
)
)

def _join_ancestors(self, joined_entity, entity_to_join, aliased_path):
"""
Expand All @@ -1335,10 +1387,10 @@ def _join_ancestors(self, joined_entity, entity_to_join, aliased_path):
self._query = self._query.join(
aliased_path,
aliased_path.child_id == joined_entity.id
).join(
).join(
entity_to_join,
aliased_path.parent_id == entity_to_join.id
)
)
def _join_group_members(self, joined_entity, entity_to_join):
"""
:param joined_entity:
Expand All @@ -1361,10 +1413,10 @@ def _join_group_members(self, joined_entity, entity_to_join):
self._query = self._query.join(
aliased_group_nodes,
aliased_group_nodes.c.dbgroup_id == joined_entity.id
).join(
).join(
entity_to_join,
entity_to_join.id == aliased_group_nodes.c.dbnode_id
)
)
def _join_groups(self, joined_entity, entity_to_join):
"""
:param joined_entity: The (aliased) node in the database
Expand All @@ -1384,10 +1436,10 @@ def _join_groups(self, joined_entity, entity_to_join):
self._query = self._query.join(
aliased_group_nodes,
aliased_group_nodes.c.dbnode_id == joined_entity.id
).join(
).join(
entity_to_join,
entity_to_join.id == aliased_group_nodes.c.dbgroup_id
)
)
def _join_creator_of(self, joined_entity, entity_to_join):
"""
:param joined_entity: the aliased node
Expand Down Expand Up @@ -1431,7 +1483,7 @@ def _join_to_computer_used(self, joined_entity, entity_to_join):
self._query = self._query.join(
entity_to_join,
entity_to_join.dbcomputer_id == joined_entity.id
)
)

def _join_computer(self, joined_entity, entity_to_join):
"""
Expand All @@ -1448,7 +1500,7 @@ def _join_computer(self, joined_entity, entity_to_join):
self._query = self._query.join(
entity_to_join,
joined_entity.dbcomputer_id == entity_to_join.id
)
)

def _get_function_map(self):
d = {
Expand All @@ -1458,6 +1510,8 @@ def _get_function_map(self):
'master_of' : self._join_masters,# not implemented
'ancestor_of': self._join_ancestors,
'descendant_of': self._join_descendants,
'ancestor_of_beta': self._join_ancestors_beta, #not implemented
'descendant_of_beta': self._join_descendants_beta,
'direction' : None,
'group_of' : self._join_groups,
'member_of' : self._join_group_members,
Expand Down
8 changes: 8 additions & 0 deletions aiida/backends/querybuild/querybuilder_django.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# -*- coding: utf-8 -*-

__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
__license__ = "MIT license, see LICENSE.txt file."
__authors__ = "The AiiDA team."
__version__ = "0.7.0"


from datetime import datetime
from json import loads as json_loads

Expand All @@ -15,6 +21,7 @@
DbLink as DummyLink,
DbCalcState as DummyState,
DbPath as DummyPath,
DbPathBeta as DummyPathBeta,
DbUser as DummyUser,
DbComputer as DummyComputer,
DbGroup as DummyGroup,
Expand Down Expand Up @@ -50,6 +57,7 @@ def __init__(self, *args, **kwargs):

self.Link = DummyLink
self.Path = DummyPath
self.PathBeta = DummyPathBeta
self.Node = DummyNode
self.Computer = DummyComputer
self.User = DummyUser
Expand Down
8 changes: 8 additions & 0 deletions aiida/backends/querybuild/querybuilder_sa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# -*- coding: utf-8 -*-

__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
__license__ = "MIT license, see LICENSE.txt file."
__authors__ = "The AiiDA team."
__version__ = "0.7.0"


from datetime import datetime

__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
Expand All @@ -26,6 +32,7 @@
from sqlalchemy_utils.types.choice import Choice
from aiida.backends.sqlalchemy import session as sa_session
from aiida.backends.sqlalchemy.models.node import DbNode, DbLink, DbPath
from aiida.backends.sqlalchemy.models import DbPathBeta
from aiida.backends.sqlalchemy.models.computer import DbComputer
from aiida.backends.sqlalchemy.models.group import DbGroup, table_groups_nodes
from aiida.backends.sqlalchemy.models.user import DbUser
Expand All @@ -47,6 +54,7 @@ def __init__(self, *args, **kwargs):
from aiida.orm.implementation.sqlalchemy.user import User as AiidaUser
self.Link = DbLink
self.Path = DbPath
self.PathBeta = DbPathBeta
self.Node = DbNode
self.Computer = DbComputer
self.User = DbUser
Expand Down
7 changes: 7 additions & 0 deletions aiida/backends/querybuild/sa_init.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# -*- coding: utf-8 -*-

__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
__license__ = "MIT license, see LICENSE.txt file."
__authors__ = "The AiiDA team."
__version__ = "0.7.0"


"""
Imports used for the QueryBuilder.
See
Expand Down
60 changes: 60 additions & 0 deletions aiida/backends/sqlalchemy/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,63 @@
# is stored in the DbSetting table and the check is done in the
# load_dbenv() function).
SCHEMA_VERSION = 0.1



# This is convenience so that one can import all ORM classes from one module
# from aiida.backends.sqlalchemy.models import *
# Also, only by import
from comment import DbComment
from computer import DbComputer
from group import DbGroup
from lock import DbLock
from log import DbLog
from node import DbNode, DbLink, DbPath, DbCalcState
from settings import DbSetting
from user import DbUser
from workflow import DbWorkflow, DbWorkflowData, DbWorkflowStep

from sqlalchemy.orm import aliased, mapper
from sqlalchemy import select, func, join, and_
from sqlalchemy.sql.expression import cast
from sqlalchemy.types import Integer

node_aliased = aliased(DbNode)

walk = select([
DbNode.id.label('ancestor_id'),
DbNode.id.label('descendant_id'),
cast(-1, Integer).label('depth'),
# array([DbNode.id]).label('path') Arrays can only be used with postgres, so leave it out for now
]).select_from(DbNode).cte(recursive=True) #, name="incl_aliased3")


descendants_beta = walk.union_all(
select([
walk.c.ancestor_id,
node_aliased.id,
walk.c.depth + cast(1, Integer),
# (walk.c.path+array([node_aliased.id])).label('path'), As above, but if arrays are supported
# This is the way to reconstruct the path (the sequence of nodes traversed)
]).select_from(
join(
node_aliased,
DbLink,
DbLink.output_id==node_aliased.id,
)
).where(
and_(
DbLink.input_id == walk.c.descendant_id,
)
)
)


class DbPathBeta(object):

def __init__(self, start, end, depth):
self.start = start
self.out = end
self.depth = depth

mapper(DbPathBeta, descendants_beta)
4 changes: 3 additions & 1 deletion aiida/backends/sqlalchemy/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy import ForeignKey, select, func, join, and_
from sqlalchemy.orm import (
relationship, backref, Query, mapper,
foreign, column_property,
foreign, column_property, aliased
)
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.schema import Column, UniqueConstraint
Expand Down Expand Up @@ -382,3 +382,5 @@ class DbCalcState(Base):
select([recent_states.c.state]).
where(recent_states.c.dbnode_id == foreign(DbNode.id))
)


0 comments on commit bd1b52c

Please sign in to comment.