Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor families + add family participants to graphql #740

Merged
merged 15 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ jobs:
uses: codecov/codecov-action@v3
with:
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}


- name: "build web front-end"
run: |
Expand Down
59 changes: 42 additions & 17 deletions api/graphql/filters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Callable, Generic, TypeVar
from typing import Callable, Generic, TypeVar

import strawberry

from db.python.utils import GenericFilter
from db.python.utils import GenericFilter, GenericMetaFilter

T = TypeVar('T')
Y = TypeVar('Y')


@strawberry.input(description='Filter for GraphQL queries')
Expand Down Expand Up @@ -47,22 +48,8 @@ def all_values(self):

return v

def to_internal_filter(self, f: Callable[[T], Any] = None):
def to_internal_filter(self) -> GenericFilter[T]:
"""Convert from GraphQL to internal filter model"""

if f:
return GenericFilter(
eq=f(self.eq) if self.eq else None,
in_=list(map(f, self.in_)) if self.in_ else None,
nin=list(map(f, self.nin)) if self.nin else None,
gt=f(self.gt) if self.gt else None,
gte=f(self.gte) if self.gte else None,
lt=f(self.lt) if self.lt else None,
lte=f(self.lte) if self.lte else None,
contains=f(self.contains) if self.contains else None,
icontains=f(self.icontains) if self.icontains else None,
)

return GenericFilter(
eq=self.eq,
in_=self.in_,
Expand All @@ -75,5 +62,43 @@ def to_internal_filter(self, f: Callable[[T], Any] = None):
icontains=self.icontains,
)

def to_internal_filter_mapped(self, f: Callable[[T], Y]) -> GenericFilter[Y]:
"""
To internal filter, but apply a function to all values.
Separate this into a separate function to please linters and type checkers
"""
return GenericFilter(
eq=f(self.eq) if self.eq else None,
in_=list(map(f, self.in_)) if self.in_ else None,
nin=list(map(f, self.nin)) if self.nin else None,
gt=f(self.gt) if self.gt else None,
gte=f(self.gte) if self.gte else None,
lt=f(self.lt) if self.lt else None,
lte=f(self.lte) if self.lte else None,
contains=f(self.contains) if self.contains else None,
icontains=f(self.icontains) if self.icontains else None,
)


GraphQLMetaFilter = strawberry.scalars.JSON


def graphql_meta_filter_to_internal_filter(
f: GraphQLMetaFilter | None,
) -> GenericMetaFilter | None:
"""Convert from GraphQL to internal filter model

Args:
f (GraphQLMetaFilter | None): GraphQL filter

Returns:
GenericMetaFilter | None: internal filter
"""
if not f:
return None

d: GenericMetaFilter = {}
f_to_d: dict[str, GraphQLMetaFilter] = dict(f) # type: ignore
for k, v in f_to_d.items():
d[k] = GenericFilter(**v) if isinstance(v, dict) else GenericFilter(eq=v)
return d
88 changes: 61 additions & 27 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
)
from db.python.tables.analysis import AnalysisFilter
from db.python.tables.assay import AssayFilter
from db.python.tables.family import FamilyFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import GenericFilter, NotFoundError
from db.python.utils import GenericFilter, NotFoundError, get_hashable_value
from models.models import (
AnalysisInternal,
AssayInternal,
Expand All @@ -36,6 +37,7 @@
SequencingGroupInternal,
)
from models.models.audit_log import AuditLogInternal
from models.models.family import PedRowInternal


class LoaderKeys(enum.Enum):
Expand Down Expand Up @@ -65,6 +67,9 @@ class LoaderKeys(enum.Enum):
PARTICIPANTS_FOR_PROJECTS = 'participants_for_projects'

FAMILIES_FOR_PARTICIPANTS = 'families_for_participants'
FAMILY_PARTICIPANTS_FOR_FAMILIES = 'family_participants_for_families'
FAMILY_PARTICIPANTS_FOR_PARTICIPANTS = 'family_participants_for_participants'
FAMILIES_FOR_IDS = 'families_for_ids'

SEQUENCING_GROUPS_FOR_IDS = 'sequencing_groups_for_ids'
SEQUENCING_GROUPS_FOR_SAMPLES = 'sequencing_groups_for_samples'
Expand All @@ -91,31 +96,8 @@ async def wrapped(*args, **kwargs):
return connected_data_loader_caller


def _prepare_partial_value_for_hashing(value):
if value is None:
return None
if isinstance(value, (int, str, float, bool)):
return value
if isinstance(value, enum.Enum):
return value.value
if isinstance(value, list):
# let's see if later we need to prepare the values in the list
return tuple(value)
if isinstance(value, dict):
return tuple(
sorted(
((k, _prepare_partial_value_for_hashing(v)) for k, v in value.items()),
key=lambda x: x[0],
)
)

return hash(value)


def _get_connected_data_loader_partial_key(kwargs):
return _prepare_partial_value_for_hashing(
{k: v for k, v in kwargs.items() if k != 'id'}
)
def _get_connected_data_loader_partial_key(kwargs) -> tuple:
return get_hashable_value({k: v for k, v in kwargs.items() if k != 'id'}) # type: ignore


def connected_data_loader_with_params(
Expand Down Expand Up @@ -381,7 +363,10 @@ async def load_families_for_participants(
Get families of participants, noting a participant can be in multiple families
"""
flayer = FamilyLayer(connection)
fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids)

fam_map = await flayer.get_families_by_participants(
participant_ids=participant_ids, check_project_ids=False
)
return [fam_map.get(p, []) for p in participant_ids]


Expand Down Expand Up @@ -449,6 +434,55 @@ async def load_phenotypes_for_participants(
return [participant_phenotypes.get(pid, {}) for pid in participant_ids]


@connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS)
async def load_families_for_ids(
family_ids: list[int], connection
) -> list[FamilyInternal]:
"""
DataLoader: get_families_for_ids
"""
flayer = FamilyLayer(connection)
families = await flayer.query(FamilyFilter(id=GenericFilter(in_=family_ids)))
f_by_id = {f.id: f for f in families}
return [f_by_id[f] for f in family_ids]


@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES)
async def load_family_participants_for_families(
family_ids: list[int], connection
) -> list[list[PedRowInternal]]:
"""
DataLoader: get_family_participants_for_families
"""
flayer = FamilyLayer(connection)
fp_map = await flayer.get_family_participants_by_family_ids(family_ids)

return [fp_map.get(fid, []) for fid in family_ids]


@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS)
async def load_family_participants_for_participants(
participant_ids: list[int], connection
) -> list[list[PedRowInternal]]:
"""data loader for family participants for participants

Args:
participant_ids (list[int]): list of internal participant ids
connection (_type_): (this is automatically filled in by the loader decorator)

Returns:
list[list[PedRowInternal]]: list of family participants for each participant
(in order)
"""
flayer = FamilyLayer(connection)
family_participants = await flayer.get_family_participants_for_participants(
participant_ids
)
fp_map = group_by(family_participants, lambda fp: fp.individual_id)

return [fp_map.get(pid, []) for pid in participant_ids]


async def get_context(
request: Request, connection=get_projectless_db_connection
): # pylint: disable=unused-argument
Expand Down
Loading
Loading