Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QB Api change for handling MOTs #5251

Open
wants to merge 4 commits into
base: production
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions specifyweb/specify/tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def get_search_filters(collection: spmodels.Collection, tree: str):
discipline_query |= Q(id=tree_at_discipline.id)
return discipline_query

def get_treedefs(collection: spmodels.Collection, tree_name: str) -> List[Tuple[int, int]]:
def get_treedefs(collection: spmodels.Collection, tree_name: str, treedef_id=None) -> List[Tuple[int, int]]:
# Get the appropriate TreeDef based on the Collection and tree_name

# Mimic the old behavior of limiting the query to the first item for trees other than taxon.
# Even though the queryconstruct can handle trees with multiple types.
_limit = lambda query: (query if tree_name.lower() == 'taxon' else query[:1])
search_filters = get_search_filters(collection, tree_name)
search_filters = get_search_filters(collection, tree_name) if treedef_id is None else Q(id=treedef_id)

lookup_tree = lookup(tree_name)
tree_table = datamodel.get_table_strict(lookup_tree)
Expand All @@ -45,4 +45,3 @@ def get_treedefs(collection: spmodels.Collection, tree_name: str) -> List[Tuple
assert len(result) > 0, "No definition to query on"

return result

54 changes: 38 additions & 16 deletions specifyweb/stored_queries/query_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ def _safe_filter(query):
return query.first()
raise Exception(f"Got more than one matching: {list(query)}")

class QueryConstruct(namedtuple('QueryConstruct', 'collection objectformatter query join_cache tree_rank_count internal_filters')):

class QueryConstruct(
namedtuple(
"QueryConstruct",
"collection objectformatter query join_cache tree_rank_count internal_filters",
)
):

def __new__(cls, *args, **kwargs):
kwargs['join_cache'] = dict()
Expand All @@ -27,7 +33,7 @@ def __new__(cls, *args, **kwargs):
kwargs['internal_filters'] = []
return super(QueryConstruct, cls).__new__(cls, *args, **kwargs)

def handle_tree_field(self, node, table, tree_rank, tree_field):
def handle_tree_field(self, node, table, tree_rank, tree_field, tree_def_id=None):
query = self
if query.collection is None: raise AssertionError( # Not sure it makes sense to query across collections
f"No Collection found in Query for {table}",
Expand All @@ -42,19 +48,18 @@ def handle_tree_field(self, node, table, tree_rank, tree_field):
logger.debug("using join cache for %r tree ranks.", table)
ancestors, treedefs = query.join_cache[(table, 'TreeRanks')]
else:
treedefs = get_treedefs(query.collection, table.name)

treedefs = get_treedefs(query.collection, table.name, tree_def_id)

# We need to take the max here. Otherwise, it is possible that the same rank
# name may not occur at the same level across tree defs.
max_depth = max(depth for _, depth in treedefs)

ancestors = [node]
for _ in range(max_depth-1):
ancestor = orm.aliased(node)
query = query.outerjoin(ancestor, ancestors[-1].ParentID == getattr(ancestor, ancestor._id))
ancestors.append(ancestor)


logger.debug("adding to join cache for %r tree ranks.", table)
query = query._replace(join_cache=query.join_cache.copy())
Expand All @@ -63,10 +68,21 @@ def handle_tree_field(self, node, table, tree_rank, tree_field):
item_model = getattr(spmodels, table.django_name + "treedefitem")

# TODO: optimize out the ranks that appear? cache them
treedefs_with_ranks: List[Tuple[int, int]] = [tup for tup in [
(treedef_id, _safe_filter(item_model.objects.filter(treedef_id=treedef_id, name=tree_rank).values_list('id', flat=True)))
for treedef_id, _ in treedefs
] if tup[1] is not None]
treedefs_with_ranks: List[Tuple[int, int]] = [
tup
for tup in [
(
treedef_id,
_safe_filter(
item_model.objects.filter(
treedef_id=treedef_id, name=tree_rank
).values_list("id", flat=True)
),
)
for treedef_id, _ in treedefs
]
if tup[1] is not None
]

assert len(treedefs_with_ranks) >= 1, "Didn't find the tree rank across any tree"

Expand All @@ -76,16 +92,23 @@ def handle_tree_field(self, node, table, tree_rank, tree_field):

def _predicates_for_node(_node):
return [
# TEST: consider taking the treedef_id comparison just to the first node, if it speeds things up (matching for higher is redundant..)
(sql.and_(getattr(_node, treedef_column)==treedef_id, getattr(_node, treedefitem_column)==treedefitem_id), getattr(_node, column_name))
# TEST: consider taking the treedef_id comparison just to the first node,
# if it speeds things up (matching for higher is redundant..)
(
sql.and_(
getattr(_node, treedef_column) == treedef_id,
getattr(_node, treedefitem_column) == treedefitem_id,
),
getattr(_node, column_name),
)
for (treedef_id, treedefitem_id) in treedefs_with_ranks
]

cases_per_ancestor = [
_predicates_for_node(ancestor)
for ancestor in ancestors
]
]

column = sql.case([case for per_ancestor in cases_per_ancestor for case in per_ancestor])

defs_to_filter_on = [def_id for (def_id, _) in treedefs_with_ranks]
Expand Down Expand Up @@ -135,7 +158,6 @@ def build_join(self, table, model, join_path):
table, model = next_table, aliased
return query, model, table, field


# To make things "simpler", it doesn't apply any filters, but returns a single predicate
# @model is an input parameter, because cannot guess if it is aliased or not (callers are supposed to know that)
def get_internal_filters(self):
Expand Down
56 changes: 46 additions & 10 deletions specifyweb/stored_queries/queryfieldspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
# Pull out author or groupnumber field from taxon query fields.
TAXON_FIELD_RE = re.compile(r'(.*) ((Author)|(groupNumber))$')

# MOTs tree query for a specify taxon tree.
# Schema: <table_id>.<table_name>.<treedef_id>,<rank>,<field>
# ex. 4.taxon.1,Kingdom,Author
TAXON_MOT_FIELD_RE = re.compile(r'^(\d*),([^,]*),(.*)$')

# Pull out geographyCode field from geography query fields.
GEOGRAPHY_FIELD_RE = re.compile(r'(.*) ((geographyCode))$')

Expand Down Expand Up @@ -61,7 +66,12 @@ def make_stringid(fs, table_list):
return table_list, fs.table.name.lower(), field_name


class QueryFieldSpec(namedtuple("QueryFieldSpec", "root_table root_sql_table join_path table date_part tree_rank tree_field")):
class QueryFieldSpec(
namedtuple(
"QueryFieldSpec",
"root_table root_sql_table join_path table date_part tree_rank tree_field treedef_id",
)
):
@classmethod
def from_path(cls, path_in, add_id=False):
path = deque(path_in)
Expand All @@ -88,8 +98,8 @@ def from_path(cls, path_in, add_id=False):
table=node,
date_part='Full Date' if (join_path and join_path[-1].is_temporal()) else None,
tree_rank=None,
tree_field=None)

tree_field=None,
treedef_id=None)

@classmethod
def from_stringid(cls, stringid, is_relation):
Expand All @@ -114,15 +124,27 @@ def from_stringid(cls, stringid, is_relation):

extracted_fieldname, date_part = extract_date_part(field_name)
field = node.get_field(extracted_fieldname, strict=False)
tree_rank = tree_field = None
treedef_id = tree_rank = tree_field = None
if field is None:
tree_id_match = TREE_ID_FIELD_RE.match(extracted_fieldname)
if tree_id_match:
tree_rank = tree_id_match.group(1)
tree_field = 'ID'
else:
tree_field_match = TAXON_FIELD_RE.match(extracted_fieldname) if node is datamodel.get_table('Taxon') else GEOGRAPHY_FIELD_RE.match(extracted_fieldname) if node is datamodel.get_table('Geography') else None
if tree_field_match:
tree_mot_field_match = tree_field_match = None
if node is datamodel.get_table("Taxon"):
tree_mot_field_match = TAXON_MOT_FIELD_RE.match(extracted_fieldname)
tree_field_match = TAXON_FIELD_RE.match(extracted_fieldname)
elif node is datamodel.get_table("Geography"):
tree_field_match = GEOGRAPHY_FIELD_RE.match(extracted_fieldname)
else:
tree_field_match = None

if tree_mot_field_match:
treedef_id = tree_mot_field_match.group(1)
tree_rank = tree_mot_field_match.group(2)
tree_field = tree_mot_field_match.group(3)
elif tree_field_match:
tree_rank = tree_field_match.group(1)
tree_field = tree_field_match.group(2)
else:
Expand All @@ -138,7 +160,8 @@ def from_stringid(cls, stringid, is_relation):
table=node,
date_part=date_part,
tree_rank=tree_rank,
tree_field=tree_field)
tree_field=tree_field,
treedef_id=treedef_id)

logger.debug('parsed %s (is_relation %s) to %s. extracted_fieldname = %s',
stringid, is_relation, result, extracted_fieldname)
Expand Down Expand Up @@ -195,7 +218,12 @@ def is_auditlog_obj_format_field(self, formatauditobjs):
return self.get_field().name.lower() in ['oldvalue','newvalue']

def is_specify_username_end(self):
return len(self.join_path) > 2 and self.join_path[-1].name == 'name' and self.join_path[-2].is_relationship and self.join_path[-2].relatedModelName == 'SpecifyUser'
return (
len(self.join_path) > 2
and self.join_path[-1].name == "name"
and self.join_path[-2].is_relationship
and self.join_path[-2].relatedModelName == "SpecifyUser"
)

def apply_filter(self, query, orm_field, field, table, value=None, op_num=None, negate=False):
no_filter = op_num is None or (self.tree_rank is None and self.get_field() is None)
Expand Down Expand Up @@ -241,11 +269,19 @@ def add_spec_to_query(self, query, formatter=None, aggregator=None, cycle_detect
query, orm_field = query.objectformatter.objformat(query, orm_model, formatter, cycle_detector)
else:
query, orm_model, table, field = self.build_join(query, self.join_path[:-1])
orm_field = query.objectformatter.aggregate(query, self.get_field(), orm_model, aggregator or formatter, cycle_detector)
orm_field = query.objectformatter.aggregate(
query,
self.get_field(),
orm_model,
aggregator or formatter,
cycle_detector,
)
else:
query, orm_model, table, field = self.build_join(query, self.join_path)
if self.tree_rank is not None:
query, orm_field = query.handle_tree_field(orm_model, table, self.tree_rank, self.tree_field)
query, orm_field = query.handle_tree_field(
orm_model, table, self.tree_rank, self.tree_field, self.treedef_id
)
else:
orm_field = getattr(orm_model, self.get_field().name)

Expand Down
Loading