Skip to content

Commit

Permalink
SQLA v2 API: Replace use of lists in select()
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Sep 16, 2021
1 parent c48d617 commit 70a7018
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def upgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, uuid in nodes:
Expand All @@ -64,7 +64,7 @@ def downgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, _ in nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def export_workflow_data(connection):
DbWorkflowData = table('db_dbworkflowdata')
DbWorkflowStep = table('db_dbworkflowstep')

count_workflow = connection.execute(select([func.count()]).select_from(DbWorkflow)).scalar()
count_workflow_data = connection.execute(select([func.count()]).select_from(DbWorkflowData)).scalar()
count_workflow_step = connection.execute(select([func.count()]).select_from(DbWorkflowStep)).scalar()
count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar()
count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar()
count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar()

# Nothing to do if all tables are empty
if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0:
Expand All @@ -78,9 +78,9 @@ def export_workflow_data(connection):
delete_on_close = configuration.PROFILE.is_test_profile

data = {
'workflow': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflow))],
'workflow_data': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowData))],
'workflow_step': [dict(row) for row in connection.execute(select(['*']).select_from(DbWorkflowStep))],
'workflow': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflow))],
'workflow_data': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowData))],
'workflow_step': [dict(row) for row in connection.execute(select('*').select_from(DbWorkflowStep))],
}

with NamedTemporaryFile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def upgrade():
)

profile = get_profile()
node_count = connection.execute(select([func.count()]).select_from(DbNode)).scalar()
node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar()
missing_repo_folder = []
shard_count = 256

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def migrate_infer_calculation_entry_point(connection):
column('process_type', String)
)

query_set = connection.execute(select([DbNode.c.type]).where(DbNode.c.type.like('calculation.%'))).fetchall()
query_set = connection.execute(select(DbNode.c.type).where(DbNode.c.type.like('calculation.%'))).fetchall()
type_strings = set(entry[0] for entry in query_set)
mapping_node_type_to_entry_point = infer_calculation_entry_point(type_strings=type_strings)

Expand All @@ -54,7 +54,7 @@ def migrate_infer_calculation_entry_point(connection):
# All affected entries should be logged to file that the user can consult.
if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string:
query_set = connection.execute(
select([DbNode.c.uuid]).where(DbNode.c.type == op.inline_literal(type_string))
select(DbNode.c.uuid).where(DbNode.c.type == op.inline_literal(type_string))
).fetchall()

uuids = [str(entry.uuid) for entry in query_set]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def upgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, uuid in nodes:
Expand All @@ -61,11 +61,11 @@ def downgrade():
column('attributes', JSONB))

nodes = connection.execute(
select([DbNode.c.id, DbNode.c.uuid]).where(
select(DbNode.c.id, DbNode.c.uuid).where(
DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))).fetchall()

for pk, uuid in nodes:
attributes = connection.execute(select([DbNode.c.attributes]).where(DbNode.c.id == pk)).fetchone()
attributes = connection.execute(select(DbNode.c.attributes).where(DbNode.c.id == pk)).fetchone()
symbols = numpy.array(attributes['symbols'])
utils.store_numpy_array_in_repository(uuid, 'symbols', symbols)
key = op.inline_literal('{"array|symbols"}')
Expand Down
30 changes: 14 additions & 16 deletions aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _join_descendants_recursive(
if expand_path:
selection_walk_list.append(array((link1.input_id, link1.output_id)).label('path'))

walk = select(selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where(
walk = select(*selection_walk_list).select_from(join(node1, link1, link1.input_id == node1.id)).where(
and_(
in_recursive_filters, # I apply filters for speed here
link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)) # I follow input and create links
Expand All @@ -214,13 +214,12 @@ def _join_descendants_recursive(

descendants_recursive = aliased(
aliased_walk.union_all(
select(selection_union_list).select_from(
join(
aliased_walk,
link2,
link2.input_id == aliased_walk.c.descendant_id,
)
).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
select(*selection_union_list
).select_from(join(
aliased_walk,
link2,
link2.input_id == aliased_walk.c.descendant_id,
)).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
)
) # .alias()

Expand Down Expand Up @@ -259,7 +258,7 @@ def _join_ancestors_recursive(
if expand_path:
selection_walk_list.append(array((link1.output_id, link1.input_id)).label('path'))

walk = select(selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where(
walk = select(*selection_walk_list).select_from(join(node1, link1, link1.output_id == node1.id)).where(
and_(in_recursive_filters, link1.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
).cte(recursive=True)

Expand All @@ -275,13 +274,12 @@ def _join_ancestors_recursive(

ancestors_recursive = aliased(
aliased_walk.union_all(
select(selection_union_list).select_from(
join(
aliased_walk,
link2,
link2.output_id == aliased_walk.c.ancestor_id,
)
).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
select(*selection_union_list
).select_from(join(
aliased_walk,
link2,
link2.output_id == aliased_walk.c.ancestor_id,
)).where(link2.type.in_((LinkType.CREATE.value, LinkType.INPUT_CALC.value)))
# I can't follow RETURN or CALL links
)
)
Expand Down

0 comments on commit 70a7018

Please sign in to comment.