From 6cd483a3dbf675bed9069ba351dcdc71e1fef58e Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 9 Apr 2024 11:05:43 +1000 Subject: [PATCH 01/12] Initial commit --- api/graphql/loaders.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index c85699e4a..383d5f981 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -36,6 +36,7 @@ SequencingGroupInternal, ) from models.models.audit_log import AuditLogInternal +from models.models.family import PedRowInternal class LoaderKeys(enum.Enum): @@ -65,6 +66,8 @@ class LoaderKeys(enum.Enum): PARTICIPANTS_FOR_PROJECTS = 'participants_for_projects' FAMILIES_FOR_PARTICIPANTS = 'families_for_participants' + FAMILY_PARTICIPANTS_FOR_FAMILIES = 'family_participants_for_families' + FAMILIES_FOR_IDS = 'families_for_ids' SEQUENCING_GROUPS_FOR_IDS = 'sequencing_groups_for_ids' SEQUENCING_GROUPS_FOR_SAMPLES = 'sequencing_groups_for_samples' @@ -446,6 +449,31 @@ async def load_phenotypes_for_participants( return [participant_phenotypes.get(pid, {}) for pid in participant_ids] +@connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS) +async def load_families_for_ids( + family_ids: list[int], connection +) -> list[FamilyInternal]: + """ + DataLoader: get_families_for_ids + """ + flayer = FamilyLayer(connection) + families = await flayer.get_families_by_ids(family_ids) + f_by_id = {f.id: f for f in families} + return [f_by_id[f] for f in family_ids] + + +@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES) +async def load_family_participants_for_families( + family_ids: list[int], connection +) -> list[list[PedRowInternal]]: + flayer = FamilyLayer(connection) + family_participants = await flayer.get_family_participants_for_family_ids( + family_ids + ) + + return [family_participants.get(fid, []) for fid in family_ids] + + async def get_context( request: Request, connection=get_projectless_db_connection ): # pylint: disable=unused-argument From be87fd05acdf83a7045c10f531201cf85779169f Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 9 Apr 2024 16:10:54 +1000 Subject: [PATCH 02/12] Implement more family_participant + remove schema types --- api/graphql/filters.py | 20 +++- api/graphql/schema.py | 132 +++++++++++++++++++------ db/python/layers/family.py | 24 ++++- db/python/tables/family_participant.py | 44 +++++++++ db/python/utils.py | 2 +- models/models/analysis.py | 12 ++- test/test_graphql.py | 64 ++++++++++++ 7 files changed, 258 insertions(+), 40 deletions(-) diff --git a/api/graphql/filters.py b/api/graphql/filters.py index 398db3a00..392c94927 100644 --- a/api/graphql/filters.py +++ b/api/graphql/filters.py @@ -1,10 +1,11 @@ -from typing import Any, Callable, Generic, TypeVar +from typing import Callable, Generic, TypeVar import strawberry -from db.python.utils import GenericFilter +from db.python.utils import GenericFilter, GenericMetaFilter T = TypeVar('T') +Y = TypeVar('Y') @strawberry.input(description='Filter for GraphQL queries') @@ -41,7 +42,7 @@ def all_values(self): return v - def to_internal_filter(self, f: Callable[[T], Any] = None): + def to_internal_filter(self, f: Callable[[T], Y] | None = None) -> GenericFilter[Y]: """Convert from GraphQL to internal filter model""" if f: @@ -67,3 +68,16 @@ def to_internal_filter(self, f: Callable[[T], Any] = None): GraphQLMetaFilter = strawberry.scalars.JSON + + +def graphql_meta_filter_to_internal_filter( + f: GraphQLMetaFilter | None, +) -> GenericMetaFilter | None: + if not f: + return None + + d: GenericMetaFilter = {} + f_to_d: dict[str, GraphQLMetaFilter] = dict(f) # type: ignore + for k, v in f_to_d.items(): + d[k] = GenericFilter(**v) if isinstance(v, dict) else GenericFilter(eq=v) + return d diff --git a/api/graphql/schema.py b/api/graphql/schema.py index eec4b0ef9..569600e9e 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -1,5 +1,3 @@ -# type: ignore -# flake8: noqa # pylint: disable=no-value-for-parameter,redefined-builtin,missing-function-docstring,unused-argument """ Schema for GraphQL. @@ -15,7 +13,11 @@ from strawberry.fastapi import GraphQLRouter from strawberry.types import Info -from api.graphql.filters import GraphQLFilter, GraphQLMetaFilter +from api.graphql.filters import ( + GraphQLFilter, + GraphQLMetaFilter, + graphql_meta_filter_to_internal_filter, +) from api.graphql.loaders import LoaderKeys, get_context from db.python import enum_tables from db.python.layers import ( @@ -45,6 +47,8 @@ SequencingGroupInternal, ) from models.models.analysis_runner import AnalysisRunnerInternal +from models.models.family import PedRowInternal +from models.models.project import ProjectId from models.models.sample import sample_id_transform_to_raw from models.utils.sample_id_format import sample_id_format from models.utils.sequencing_group_id_format import ( @@ -72,6 +76,8 @@ async def m(info: Info) -> list[str]: GraphQLEnum = strawberry.type(type('GraphQLEnum', (object,), enum_methods)) +GraphQLAnalysisStatus = strawberry.enum(AnalysisStatus) + @strawberry.type class GraphQLProject: @@ -129,6 +135,12 @@ async def pedigree( connection = info.context['connection'] family_layer = FamilyLayer(connection) + if not root.id: + raise ValueError('Project must have an id') + + if not internal_family_ids: + return [] + pedigree_dicts = await family_layer.get_pedigree( project=root.id, family_ids=internal_family_ids, @@ -144,11 +156,11 @@ async def pedigree( async def families( self, info: Info, - root: 'Project', + root: 'GraphQLProject', ) -> list['GraphQLFamily']: connection = info.context['connection'] families = await FamilyLayer(connection).get_families(project=root.id) - return families + return [GraphQLFamily.from_internal(f) for f in families] @strawberry.field() async def participants( @@ -173,7 +185,7 @@ async def samples( type=type.to_internal_filter() if type else None, external_id=external_id.to_internal_filter() if external_id else None, id=id.to_internal_filter(sample_id_transform_to_raw) if id else None, - meta=meta, + meta=graphql_meta_filter_to_internal_filter(meta), ) samples = await loader.load({'id': root.id, 'filter': filter_}) return [GraphQLSample.from_internal(p) for p in samples] @@ -216,7 +228,7 @@ async def analyses( info: Info, root: 'Project', type: GraphQLFilter[str] | None = None, - status: GraphQLFilter[strawberry.enum(AnalysisStatus)] | None = None, + status: GraphQLFilter[GraphQLAnalysisStatus] | None = None, active: GraphQLFilter[bool] | None = None, meta: GraphQLMetaFilter | None = None, timestamp_completed: GraphQLFilter[datetime.datetime] | None = None, @@ -277,10 +289,12 @@ class GraphQLAnalysis: output: str | None timestamp_completed: datetime.datetime | None = None active: bool - meta: strawberry.scalars.JSON + meta: strawberry.scalars.JSON | None @staticmethod def from_internal(internal: AnalysisInternal) -> 'GraphQLAnalysis': + if not internal.id: + raise ValueError('Analysis must have an id') return GraphQLAnalysis( id=internal.id, type=internal.type, @@ -352,6 +366,52 @@ async def participants( ) return [GraphQLParticipant.from_internal(p) for p in participants] + @strawberry.field + async def family_participants( + self, info: Info, root: 'GraphQLFamily' + ) -> list['GraphQLFamilyParticipant']: + family_participants = await info.context[ + LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES + ].load(root.id) + return [ + GraphQLFamilyParticipant.from_internal(fp) for fp in family_participants + ] + + +@strawberry.type +class GraphQLFamilyParticipant: + + affected: int | None + notes: str | None + + participant_id: strawberry.Private[int] + family_id: strawberry.Private[int] + + @strawberry.field + async def participant( + self, info: Info, root: 'GraphQLFamilyParticipant' + ) -> 'GraphQLParticipant': + loader = info.context[LoaderKeys.PARTICIPANTS_FOR_IDS] + participant = await loader.load(root.participant_id) + return GraphQLParticipant.from_internal(participant) + + @strawberry.field + async def family( + self, info: Info, root: 'GraphQLFamilyParticipant' + ) -> GraphQLFamily: + loader = info.context[LoaderKeys.FAMILIES_FOR_IDS] + family = await loader.load(root.family_id) + return GraphQLFamily.from_internal(family) + + @staticmethod + def from_internal(internal: PedRowInternal) -> 'GraphQLFamilyParticipant': + return GraphQLFamilyParticipant( + affected=internal.affected, + notes=internal.notes, + participant_id=internal.participant_id, + family_id=internal.family_id, + ) + @strawberry.type class GraphQLParticipant: @@ -392,7 +452,7 @@ async def samples( ) -> list['GraphQLSample']: filter_ = SampleFilter( type=type.to_internal_filter() if type else None, - meta=meta.to_internal_filter() if meta else None, + meta=graphql_meta_filter_to_internal_filter(meta), active=active.to_internal_filter() if active else GenericFilter(eq=True), ) q = {'id': root.id, 'filter': filter_} @@ -413,6 +473,18 @@ async def families( ) -> list[GraphQLFamily]: fams = await info.context[LoaderKeys.FAMILIES_FOR_PARTICIPANTS].load(root.id) return [GraphQLFamily.from_internal(f) for f in fams] + + @strawberry.field + async def family_participants( + self, info: Info, root: 'GraphQLParticipant' + ) -> list[GraphQLFamilyParticipant]: + return [] + # family_participants = await info.context[ + # LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS + # ].load(root.id) + # return [ + # GraphQLFamilyParticipant.from_internal(fp) for fp in family_participants + # ] @strawberry.field async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProject: @@ -508,7 +580,7 @@ async def sequencing_groups( if id else None ), - meta=meta, + meta=graphql_meta_filter_to_internal_filter(meta), type=type.to_internal_filter() if type else None, technology=technology.to_internal_filter() if technology else None, platform=platform.to_internal_filter() if platform else None, @@ -540,6 +612,9 @@ class GraphQLSequencingGroup: @staticmethod def from_internal(internal: SequencingGroupInternal) -> 'GraphQLSequencingGroup': + if not internal.id: + raise ValueError('SequencingGroup must have an id') + return GraphQLSequencingGroup( id=sequencing_group_id_format(internal.id), type=internal.type, @@ -563,7 +638,7 @@ async def analyses( self, info: Info, root: 'GraphQLSequencingGroup', - status: GraphQLFilter[strawberry.enum(AnalysisStatus)] | None = None, + status: GraphQLFilter[GraphQLAnalysisStatus] | None = None, type: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, active: GraphQLFilter[bool] | None = None, @@ -571,14 +646,18 @@ async def analyses( ) -> list[GraphQLAnalysis]: connection = info.context['connection'] loader = info.context[LoaderKeys.ANALYSES_FOR_SEQUENCING_GROUPS] - project_id_map = {} + + _project_filter: GenericFilter[ProjectId] | None = None if project: ptable = ProjectPermissionsTable(connection) project_ids = project.all_values() projects = await ptable.get_and_check_access_to_projects_for_names( user=connection.author, project_names=project_ids, readonly=True ) - project_id_map = {p.name: p.id for p in projects} + project_id_map: dict[str, int] = { + p.name: p.id for p in projects if p.name and p.id + } + _project_filter = project.to_internal_filter(lambda p: project_id_map[p]) analyses = await loader.load( { @@ -586,17 +665,13 @@ async def analyses( 'filter_': AnalysisFilter( status=status.to_internal_filter() if status else None, type=type.to_internal_filter() if type else None, - meta=meta, + meta=graphql_meta_filter_to_internal_filter(meta), active=( active.to_internal_filter() if active else GenericFilter(eq=True) ), - project=( - project.to_internal_filter(lambda val: project_id_map[val]) - if project - else None - ), + project=_project_filter, ), } ) @@ -623,6 +698,9 @@ class GraphQLAssay: @staticmethod def from_internal(internal: AssayInternal) -> 'GraphQLAssay': + if not internal.id: + raise ValueError('Assay must have an id') + return GraphQLAssay( id=internal.id, type=internal.type, @@ -740,12 +818,12 @@ async def sample( projects = await ptable.get_and_check_access_to_projects_for_names( user=connection.author, project_names=project_names, readonly=True ) - project_name_map = {p.name: p.id for p in projects} + project_name_map = {p.name: p.id for p in projects if p.name and p.id} filter_ = SampleFilter( id=id.to_internal_filter(sample_id_transform_to_raw) if id else None, type=type.to_internal_filter() if type else None, - meta=meta, + meta=graphql_meta_filter_to_internal_filter(meta), external_id=external_id.to_internal_filter() if external_id else None, participant_id=( participant_id.to_internal_filter() if participant_id else None @@ -780,20 +858,18 @@ async def sequencing_groups( raise ValueError('Must filter by project, sample or id') # we list project names, but internally we want project ids - project_id_map = {} + _project_filter: GenericFilter[ProjectId] | None = None + if project: project_names = project.all_values() projects = await ptable.get_and_check_access_to_projects_for_names( user=connection.author, project_names=project_names, readonly=True ) - project_id_map = {p.name: p.id for p in projects} + project_id_map = {p.name: p.id for p in projects if p.name and p.id} + _project_filter = project.to_internal_filter(lambda p: project_id_map[p]) filter_ = SequencingGroupFilter( - project=( - project.to_internal_filter(lambda val: project_id_map[val]) - if project - else None - ), + project=_project_filter, sample_id=( sample_id.to_internal_filter(sample_id_transform_to_raw) if sample_id diff --git a/db/python/layers/family.py b/db/python/layers/family.py index 3d77576b0..4cf3943c6 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -570,9 +570,9 @@ async def import_pedigree( reported_sex=row.sex, ) ) - external_participant_ids_map[ - row.individual_id - ] = upserted_participant.id + external_participant_ids_map[row.individual_id] = ( + upserted_participant.id + ) for external_family_id in missing_external_family_ids: internal_family_id = await self.ftable.create_family( @@ -682,3 +682,21 @@ def select_columns(col1: Optional[int], col2: Optional[int] = None): coded_phenotypes=select_columns(phenotype_idx), ) return True + + async def get_family_participants_for_family_ids( + self, family_ids: list[int], check_project_ids: bool = True + ) -> dict[int, list[PedRowInternal]]: + """Get family participants for family IDs""" + projects, fps = await self.fptable.get_family_participants_by_family_ids( + family_ids + ) + + if not fps: + return {} + + if check_project_ids: + await self.ptable.check_access_to_project_ids( + self.connection.author, projects, readonly=True + ) + + return fps diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py index e782038ab..b1ab87635 100644 --- a/db/python/tables/family_participant.py +++ b/db/python/tables/family_participant.py @@ -257,3 +257,47 @@ async def delete_family_participant_row(self, family_id: int, participant_id: in ) return True + + async def get_family_participants_by_family_ids( + self, family_ids: list[int] + ) -> tuple[set[ProjectId], dict[int, list[PedRowInternal]]]: + """ + Get all participants in a list of families + """ + if not family_ids: + return set(), {} + + _query = """ + SELECT + p.project, + p.id, + fp.family_id, + fp.paternal_participant_id, + fp.maternal_participant_id, + fp.affected, + fp.notes + FROM + family_participant fp + INNER JOIN participant p ON p.id = fp.participant_id + WHERE fp.family_id IN :family_ids + """ + + rows = await self.connection.fetch_all(_query, {'family_ids': family_ids}) + + projects: set[ProjectId] = set() + by_family = defaultdict(list) + + for row in rows: + projects.add(row['project']) + by_family[row['family_id']].append( + PedRowInternal( + family_id=row['family_id'], + participant_id=row['id'], + paternal_id=row['paternal_participant_id'], + maternal_id=row['maternal_participant_id'], + affected=row['affected'], + notes=row['notes'], + ) + ) + + return projects, by_family diff --git a/db/python/utils.py b/db/python/utils.py index 0b0846dcb..487078d45 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -256,7 +256,7 @@ def __post_init__(self): setattr(self, field.name, GenericFilter(eq=value)) def to_sql( - self, field_overrides: dict[str, str] = None + self, field_overrides: dict[str, str] | None = None ) -> tuple[str, dict[str, Any]]: """Convert the model to SQL, and avoid SQL injection""" _foverrides = field_overrides or {} diff --git a/models/models/analysis.py b/models/models/analysis.py index fb6e3152d..dbe03ae55 100644 --- a/models/models/analysis.py +++ b/models/models/analysis.py @@ -19,11 +19,11 @@ class AnalysisInternal(SMBase): id: int | None = None type: str status: AnalysisStatus - output: str = None + active: bool + output: str | None = None sequencing_group_ids: list[int] = [] timestamp_completed: datetime | None = None project: int | None = None - active: bool | None = None meta: dict[str, Any] = {} author: str | None = None @@ -72,9 +72,11 @@ def to_external(self): self.sequencing_group_ids ), output=self.output, - timestamp_completed=self.timestamp_completed.isoformat() - if self.timestamp_completed - else None, + timestamp_completed=( + self.timestamp_completed.isoformat() + if self.timestamp_completed + else None + ), project=self.project, active=self.active, meta=self.meta, diff --git a/test/test_graphql.py b/test/test_graphql.py index 294e1ad56..2ba9048dc 100644 --- a/test/test_graphql.py +++ b/test/test_graphql.py @@ -4,6 +4,7 @@ import api.graphql.schema from db.python.layers import AnalysisLayer, ParticipantLayer +from db.python.layers.family import FamilyLayer from metamist.graphql import configure_sync_client, gql, validate from models.enums import AnalysisStatus from models.models import ( @@ -257,3 +258,66 @@ async def test_participant_phenotypes(self): self.assertIn('participant', resp) self.assertIn('phenotypes', resp['participant']) self.assertDictEqual(phenotypes, resp['participant']['phenotypes']) + + @run_as_sync + async def test_family_participants(self): + family_layer = FamilyLayer(self.connection) + + rows = [ + ["family1", "individual1", "paternal1", "maternal1", "m", "1", "note1"], + ["family1", "paternal1", None, None, "m", "0", "note2"], + ["family1", "maternal1", None, None, "f", "1", "note3"], + ] + + await family_layer.import_pedigree(None, rows, create_missing_participants=True) + + q = """ +query MyQuery($project: String!) { + project(name: $project) { + participants { + externalId + familyParticipants { + affected + notes + family { + externalId + } + } + } + families { + id + familyParticipants { + affected + notes + participant { + externalId + } + } + } + } +} +""" + + resp = await self.run_graphql_query_async(q, {'project': self.project_name}) + + { + "project": { + "participants": [ + {"externalId": "individual1", "familyParticipants": []}, + {"externalId": "maternal1", "familyParticipants": []}, + {"externalId": "paternal1", "familyParticipants": []}, + ], + "families": [ + { + "id": 1, + "familyParticipants": [ + {"affected": 0, "notes": "note2", "participant": {"id": 1}}, + {"affected": 1, "notes": "note3", "participant": {"id": 2}}, + {"affected": 1, "notes": "note1", "participant": {"id": 3}}, + ], + } + ], + } + } + + print(resp) From b930392b4cfafef5cdb9202fecd5b39a960299c1 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 23 Apr 2024 13:40:10 +1000 Subject: [PATCH 03/12] Genercise the querying for families --- api/routes/family.py | 13 +- db/python/layers/family.py | 470 ++++++------------------- db/python/tables/family_participant.py | 226 ++++-------- 3 files changed, 194 insertions(+), 515 deletions(-) diff --git a/api/routes/family.py b/api/routes/family.py index 38b5298f2..c7e8a90d5 100644 --- a/api/routes/family.py +++ b/api/routes/family.py @@ -18,6 +18,8 @@ from api.utils.export import ExportType from api.utils.extensions import guess_delimiter_by_upload_file_obj from db.python.layers.family import FamilyLayer, PedRow +from db.python.tables.family import FamilyFilter +from db.python.utils import GenericFilter from models.models.family import Family from models.utils.sample_id_format import sample_id_transform_to_raw_list @@ -146,8 +148,13 @@ async def get_families( family_layer = FamilyLayer(connection) sample_ids_raw = sample_id_transform_to_raw_list(sample_ids) if sample_ids else None - families = await family_layer.get_families( - participant_ids=participant_ids, sample_ids=sample_ids_raw + families = await family_layer.query( + FamilyFilter( + participant_id=( + GenericFilter(in_=participant_ids) if participant_ids else None + ), + sample_id=GenericFilter(in_=sample_ids_raw) if sample_ids_raw else None, + ) ) return [f.to_external() for f in families] @@ -174,7 +181,7 @@ async def update_family( async def import_families( file: UploadFile = File(...), has_header: bool = True, - delimiter: str = None, + delimiter: str | None = None, connection: Connection = get_project_write_connection, ): """Import a family csv""" diff --git a/db/python/layers/family.py b/db/python/layers/family.py index 4cf3943c6..344032d16 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -1,298 +1,22 @@ # pylint: disable=used-before-assignment -import logging -from typing import Dict, List, Optional, Union +from api.utils import group_by from db.python.connect import Connection from db.python.layers.base import BaseLayer from db.python.layers.participant import ParticipantLayer -from db.python.tables.family import FamilyTable -from db.python.tables.family_participant import FamilyParticipantTable +from db.python.tables.family import FamilyFilter, FamilyTable +from db.python.tables.family_participant import ( + FamilyParticipantFilter, + FamilyParticipantTable, +) from db.python.tables.participant import ParticipantTable -from db.python.tables.sample import SampleFilter, SampleTable -from db.python.utils import GenericFilter -from models.models.family import FamilyInternal, PedRowInternal +from db.python.tables.sample import SampleTable +from db.python.utils import GenericFilter, NotFoundError +from models.models.family import FamilyInternal, PedRow, PedRowInternal from models.models.participant import ParticipantUpsertInternal from models.models.project import ProjectId -class PedRow: - """Class for capturing a row in a pedigree""" - - ALLOWED_SEX_VALUES = [0, 1, 2] - ALLOWED_AFFECTED_VALUES = [-9, 0, 1, 2] - - PedRowKeys = { - # seqr individual template: - # Family ID, Individual ID, Paternal ID, Maternal ID, Sex, Affected, Status, Notes - 'family_id': {'familyid', 'family id', 'family', 'family_id'}, - 'individual_id': {'individualid', 'id', 'individual_id', 'individual id'}, - 'paternal_id': {'paternal_id', 'paternal id', 'paternalid', 'father'}, - 'maternal_id': {'maternal_id', 'maternal id', 'maternalid', 'mother'}, - 'sex': {'sex', 'gender'}, - 'affected': { - 'phenotype', - 'affected', - 'phenotypes', - 'affected status', - 'affection', - 'affection status', - }, - 'notes': {'notes'}, - } - - @staticmethod - def default_header(): - """Default header (corresponds to the __init__ keys)""" - return [ - 'family_id', - 'individual_id', - 'paternal_id', - 'maternal_id', - 'sex', - 'affected', - 'notes', - ] - - @staticmethod - def row_header(): - """Default RowHeader for output""" - return [ - '#Family ID', - 'Individual ID', - 'Paternal ID', - 'Maternal ID', - 'Sex', - 'Affected', - ] - - def __init__( - self, - family_id, - individual_id, - paternal_id, - maternal_id, - sex, - affected, - notes=None, - ): - self.family_id = family_id.strip() - self.individual_id = individual_id.strip() - self.paternal_id = None - self.maternal_id = None - self.paternal_id = self.check_linking_id(paternal_id, 'paternal_id') - self.maternal_id = self.check_linking_id(maternal_id, 'maternal_id') - self.sex = self.parse_sex(sex) - self.affected = self.parse_affected_status(affected) - self.notes = notes - - @staticmethod - def check_linking_id(linking_id, description: str, blank_values=('0', '')): - """Check that the ID is a valid value, or return None if it's a blank value""" - if linking_id is None: - return None - if isinstance(linking_id, int): - linking_id = str(linking_id).strip() - - if isinstance(linking_id, str): - if linking_id.strip().lower() in blank_values: - return None - return linking_id.strip() - - raise TypeError( - f'Unexpected type {type(linking_id)} ({linking_id}) ' - f'for {description}, expected "str"' - ) - - @staticmethod - def parse_sex(sex: Union[str, int]) -> int: - """ - Parse the pedigree SEX value: - 0: unknown - 1: male (also accepts 'm') - 2: female (also accepts 'f') - """ - if isinstance(sex, str) and sex.isdigit(): - sex = int(sex) - if isinstance(sex, int): - if sex in PedRow.ALLOWED_SEX_VALUES: - return sex - raise ValueError( - f'Sex value ({sex}) was not an expected value {PedRow.ALLOWED_SEX_VALUES}.' - ) - - sl = sex.lower() - if sl in ('m', 'male'): - return 1 - if sl in ('f', 'female'): - return 2 - if sl in ('u', 'unknown'): - return 0 - - if sl == 'sex': - raise ValueError( - f'Unknown sex {sex!r}, did you mean to call import_pedigree with has_headers=True?' - ) - raise ValueError( - f'Unknown sex {sex!r}, please ensure sex is in {PedRow.ALLOWED_SEX_VALUES}' - ) - - @staticmethod - def parse_affected_status(affected): - """ - Parse the pedigree "AFFECTED" value: - -9 / 0: unknown - 1: unaffected - 2: affected - """ - if isinstance(affected, str) and not affected.isdigit(): - affected = affected.lower().strip() - if affected in ['unknown']: - return 0 - if affected in ['n', 'no']: - return 1 - if affected in ['y', 'yes', 'affected']: - return 2 - - affected = int(affected) - if affected not in PedRow.ALLOWED_AFFECTED_VALUES: - raise ValueError( - f'Affected value {affected} was not in expected value: {PedRow.ALLOWED_AFFECTED_VALUES}' - ) - - return affected - - def __str__(self): - return f'PedRow: {self.individual_id} ({self.sex})' - - @staticmethod - def order(rows: List['PedRow']) -> List['PedRow']: - """ - Order a list of PedRows, but also validates: - - There are no circular dependencies - - All maternal / paternal IDs are found in the pedigree - """ - rows_to_order: List['PedRow'] = [*rows] - ordered = [] - seen_individuals = set() - remaining_iterations_in_round = len(rows_to_order) - - while len(rows_to_order) > 0: - row = rows_to_order.pop(0) - reqs = [row.paternal_id, row.maternal_id] - if all(r is None or r in seen_individuals for r in reqs): - remaining_iterations_in_round = len(rows_to_order) - ordered.append(row) - seen_individuals.add(row.individual_id) - else: - remaining_iterations_in_round -= 1 - rows_to_order.append(row) - - # makes more sense to keep this comparison separate: - # - If remaining iterations is or AND we still have rows - # - Then raise an Exception - # pylint: disable=chained-comparison - if remaining_iterations_in_round <= 0 and len(rows_to_order) > 0: - participant_ids = ', '.join( - f'{r.individual_id} ({r.paternal_id} | {r.maternal_id})' - for r in rows_to_order - ) - raise ValueError( - "There was an issue in the pedigree, either a parent wasn't " - 'found in the pedigree, or a circular dependency detected ' - "(eg: someone's child is an ancestor's parent). " - f"Can't resolve participants with parental IDs: {participant_ids}" - ) - - return ordered - - @staticmethod - def validate_sexes(rows: List['PedRow'], throws=True) -> bool: - """ - Validate that individuals listed as mothers and fathers - have either unknown sex, male if paternal, and female if maternal. - - Future note: The pedigree has a simplified view of sex, especially - how it relates families together. This function might not handle - more complex cases around intersex disorders within families. The - best advice is either to skip this check, or provide sex as 0 (unknown) - - :param throws: If True is provided (default), raise a ValueError, else just return False - """ - keyed: Dict[str, PedRow] = {r.individual_id: r for r in rows} - paternal_ids = [r.paternal_id for r in rows if r.paternal_id] - mismatched_pat_sex = [ - pid for pid in paternal_ids if keyed[pid].sex not in (0, 1) - ] - maternal_ids = [r.maternal_id for r in rows if r.maternal_id] - mismatched_mat_sex = [ - mid for mid in maternal_ids if keyed[mid].sex not in (0, 2) - ] - - messages = [] - if mismatched_pat_sex: - actual_values = ', '.join( - f'{pid} ({keyed[pid].sex})' for pid in mismatched_pat_sex - ) - messages.append('(0, 1) as they are listed as fathers: ' + actual_values) - if mismatched_mat_sex: - actual_values = ', '.join( - f'{pid} ({keyed[pid].sex})' for pid in mismatched_mat_sex - ) - messages.append('(0, 2) as they are listed as mothers: ' + actual_values) - - if messages: - message = 'Expected individuals have sex values:' + ''.join( - '\n\t' + m for m in messages - ) - if throws: - raise ValueError(message) - logging.warning(message) - return False - - return True - - @staticmethod - def parse_header_order(header: List[str]): - """ - Takes a list of unformatted headers, and returns a list of ordered init_keys - - >>> PedRow.parse_header_order(['family', 'mother', 'paternal id', 'phenotypes', 'gender']) - ['family_id', 'maternal_id', 'paternal_id', 'affected', 'sex'] - - >>> PedRow.parse_header_order(['#family id']) - ['family_id'] - - >>> PedRow.parse_header_order(['unexpected header']) - Traceback (most recent call last): - ValueError: Unable to identity header elements: "unexpected header" - """ - ordered_init_keys = [] - unmatched = [] - for item in header: - litem = item.lower().strip().strip('#') - found = False - for h, options in PedRow.PedRowKeys.items(): - for potential_key in options: - if potential_key == litem: - ordered_init_keys.append(h) - found = True - break - if found: - break - - if not found: - unmatched.append(item) - - if unmatched: - # repr casts to string and quotes if applicable - unmatched_headers_str = ', '.join(map(repr, unmatched)) - raise ValueError( - 'Unable to identity header elements: ' + unmatched_headers_str - ) - - return ordered_init_keys - - class FamilyLayer(BaseLayer): """Layer for import logic""" @@ -316,48 +40,52 @@ async def get_family_by_internal_id( self, family_id: int, check_project_id: bool = True ) -> FamilyInternal: """Get family by internal ID""" - project, family = await self.ftable.get_family_by_internal_id(family_id) + projects, families = await self.ftable.query( + FamilyFilter(id=GenericFilter(eq=family_id)) + ) + if not families: + raise NotFoundError(f'Family with ID {family_id} not found') + family = families[0] if check_project_id: await self.ptable.check_access_to_project_ids( - self.author, [project], readonly=True + self.author, projects, readonly=True ) return family async def get_family_by_external_id( - self, external_id: str, project: ProjectId = None + self, external_id: str, project: ProjectId | None = None ): """Get family by external ID, requires project scope""" - return await self.ftable.get_family_by_external_id(external_id, project=project) + # return await self.ftable.get_family_by_external_id(external_id, project=project) + families = await self.ftable.query( + FamilyFilter( + external_id=GenericFilter(eq=external_id), + project=GenericFilter(eq=project or self.connection.project), + ) + ) + if not families: + raise NotFoundError(f'Family with external ID {external_id} not found') + + return families[0] - async def get_families( + async def query( self, - project: int = None, - participant_ids: List[int] = None, - sample_ids: List[int] = None, - ): + filter_: FamilyFilter, + check_project_ids: bool = True, + ) -> list[FamilyInternal]: """Get all families for a project""" - project = project or self.connection.project - # Merge sample_id and participant_ids into a single list - all_participants = participant_ids if participant_ids else [] + # don't need a project check, as we're being provided an explicit filter - # Find the participants from the given samples - if sample_ids is not None and len(sample_ids) > 0: - _, samples = await self.stable.query( - SampleFilter( - project=GenericFilter(eq=project), id=GenericFilter(in_=sample_ids) - ) - ) + projects, families = await self.ftable.query(filter_) - all_participants += [ - int(s.participant_id) for s in samples if s.participant_id - ] - all_participants = list(set(all_participants)) + if check_project_ids: + await self.ptable.check_access_to_project_ids( + self.connection.author, projects, readonly=True + ) - return await self.ftable.get_families( - project=project, participant_ids=all_participants - ) + return families async def get_families_by_ids( self, @@ -366,8 +94,8 @@ async def get_families_by_ids( check_project_ids: bool = True, ) -> list[FamilyInternal]: """Get families by internal IDs""" - projects, families = await self.ftable.get_families_by_ids( - family_ids=family_ids + projects, families = await self.ftable.query( + FamilyFilter(id=GenericFilter(in_=family_ids)) ) if not families: return [] @@ -427,14 +155,14 @@ async def update_family( async def get_pedigree( self, project: ProjectId, - family_ids: List[int] = None, + family_ids: list[int] | None = None, # pylint: disable=invalid-name - replace_with_participant_external_ids=False, + replace_with_participant_external_ids: bool = False, # pylint: disable=invalid-name - replace_with_family_external_ids=False, - empty_participant_value=None, - include_participants_not_in_families=False, - ) -> List[Dict[str, str]]: + replace_with_family_external_ids: bool = False, + empty_participant_value: str | None = None, + include_participants_not_in_families: bool = False, + ) -> list[dict[str, str | int | None]]: """ Generate pedigree file for ALL families in project (unless internal_family_ids is specified). @@ -442,46 +170,49 @@ async def get_pedigree( Use internal IDs unless specific options are specified. """ - # this is important because a PED file MUST be ordered like this - - pid_fields = { - 'individual_id', - 'paternal_id', - 'maternal_id', - } - - rows = await self.fptable.get_rows( - project=project, - family_ids=family_ids, + _, rows = await self.fptable.query( + FamilyParticipantFilter( + project=GenericFilter(eq=project), + family_id=GenericFilter(in_=family_ids) if family_ids else None, + ), include_participants_not_in_families=include_participants_not_in_families, ) - pmap = {} + # participant_id to external_id + pmap: dict[int, str] = {} + # family_id to external_id + fmap: dict[int, str] = {} if replace_with_participant_external_ids: participant_ids = set( s for r in rows - for s in [r[pfield] for pfield in pid_fields] + for s in (r.individual_id, r.maternal_id, r.paternal_id) if s is not None ) ptable = ParticipantTable(connection=self.connection) pmap = await ptable.get_id_map_by_internal_ids(list(participant_ids)) - for r in rows: - for pfield in pid_fields: - r[pfield] = pmap.get(r[pfield], r[pfield]) or empty_participant_value - if replace_with_family_external_ids: - family_ids = list( - set(r['family_id'] for r in rows if r['family_id'] is not None) - ) + family_ids = list(set(r.family_id for r in rows if r.family_id is not None)) fmap = await self.ftable.get_id_map_by_internal_ids(list(family_ids)) - for r in rows: - r['family_id'] = fmap.get(r['family_id'], r['family_id']) - return rows + mapped_rows: list[dict[str, str | int | None]] = [] + for r in rows: + mapped_rows.append( + { + 'family_id': fmap.get(r.family_id, str(r.family_id)), + 'individual_id': pmap.get(r.individual_id, empty_participant_value), + 'paternal_id': pmap.get(r.paternal_id, empty_participant_value), + 'maternal_id': pmap.get(r.maternal_id, empty_participant_value), + 'sex': r.sex, + 'affected': r.affected, + 'notes': r.notes, + } + ) + + return mapped_rows async def get_participant_family_map( - self, participant_ids: List[int], check_project_ids=False + self, participant_ids: list[int], check_project_ids=False ): """Get participant family map""" @@ -497,8 +228,8 @@ async def get_participant_family_map( async def import_pedigree( self, - header: Optional[List[str]], - rows: List[List[str]], + header: list[str] | None, + rows: list[list[str]], create_missing_participants=False, perform_sex_check=True, ): @@ -522,7 +253,7 @@ async def import_pedigree( if len(_header) > max_row_length: _header = _header[:max_row_length] - pedrows: List[PedRow] = [ + pedrows: list[PedRow] = [ PedRow(**{_header[i]: r[i] for i in range(len(_header))}) for r in rows ] # this validates a lot of the pedigree too @@ -587,11 +318,12 @@ async def import_pedigree( insertable_rows = [ PedRowInternal( family_id=external_family_id_map[row.family_id], - participant_id=external_participant_ids_map[row.individual_id], + individual_id=external_participant_ids_map[row.individual_id], paternal_id=external_participant_ids_map.get(row.paternal_id), maternal_id=external_participant_ids_map.get(row.maternal_id), affected=row.affected, notes=row.notes, + sex=row.sex, ) for row in pedrows ] @@ -613,9 +345,7 @@ async def update_family_members(self, rows: list[PedRowInternal]): """Update family members""" await self.fptable.create_rows(rows) - async def import_families( - self, headers: Optional[List[str]], rows: List[List[str]] - ): + async def import_families(self, headers: list[str] | None, rows: list[list[str]]): """Import a family table""" ordered_headers = [ 'Family ID', @@ -637,7 +367,7 @@ async def import_families( }, } - def get_idx_for_header(header) -> Optional[int]: + def get_idx_for_header(header) -> int | None: return next( iter(idx for idx, key in enumerate(lheaders) if key in key_map[header]), None, @@ -653,11 +383,13 @@ def replace_empty_string_with_none(val): """Don't set as empty string, prefer to set as null""" return None if val == '' else val - rows = [[replace_empty_string_with_none(el) for el in r] for r in rows] + _fixed_rows = [[replace_empty_string_with_none(el) for el in r] for r in rows] - empty = [None] * len(rows) + empty: list[str | None] = [None] * len(_fixed_rows) - def select_columns(col1: Optional[int], col2: Optional[int] = None): + def select_columns( + col1: int | None, col2: int | None = None + ) -> list[str | None]: """ - If col1 and col2 is None, return [None] * len(rows) - if either col1 or col2 is not None, return that column @@ -668,13 +400,13 @@ def select_columns(col1: Optional[int], col2: Optional[int] = None): return empty if col1 is not None and col2 is None: # if only col1 is set - return [r[col1] for r in rows] + return [r[col1] for r in _fixed_rows] if col2 is not None and col1 is None: # if only col2 is set - return [r[col2] for r in rows] + return [r[col2] for r in _fixed_rows] # if col1 AND col2 are not None assert col1 is not None and col2 is not None - return [r[col1] if r[col1] is not None else r[col2] for r in rows] + return [r[col1] if r[col1] is not None else r[col2] for r in _fixed_rows] await self.ftable.insert_or_update_multiple_families( external_ids=select_columns(external_identifier_idx, display_name_idx), @@ -683,12 +415,12 @@ def select_columns(col1: Optional[int], col2: Optional[int] = None): ) return True - async def get_family_participants_for_family_ids( + async def get_family_participants_by_family_ids( self, family_ids: list[int], check_project_ids: bool = True ) -> dict[int, list[PedRowInternal]]: """Get family participants for family IDs""" - projects, fps = await self.fptable.get_family_participants_by_family_ids( - family_ids + projects, fps = await self.fptable.query( + FamilyParticipantFilter(family_id=GenericFilter(in_=family_ids)) ) if not fps: @@ -699,4 +431,22 @@ async def get_family_participants_for_family_ids( self.connection.author, projects, readonly=True ) + return group_by(fps, lambda r: r.family_id) + + async def get_family_participants_for_participants( + self, participant_ids: list[int], check_project_ids: bool = True + ) -> list[PedRowInternal]: + """Get family participants for participant IDs""" + projects, fps = await self.fptable.query( + FamilyParticipantFilter(participant_id=GenericFilter(in_=participant_ids)) + ) + + if not fps: + return [] + + if check_project_ids: + await self.ptable.check_access_to_project_ids( + self.connection.author, projects, readonly=True + ) + return fps diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py index b1ab87635..1808a2348 100644 --- a/db/python/tables/family_participant.py +++ b/db/python/tables/family_participant.py @@ -1,11 +1,22 @@ +import dataclasses from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any from db.python.tables.base import DbBase +from db.python.utils import GenericFilter, GenericFilterModel from models.models.family import PedRowInternal from models.models.project import ProjectId +@dataclasses.dataclass +class FamilyParticipantFilter(GenericFilterModel): + """Filter for family_participant table""" + + project: GenericFilter[ProjectId] | None = None + participant_id: GenericFilter[int] | None = None + family_id: GenericFilter[int] | None = None + + class FamilyParticipantTable(DbBase): """ Capture Analysis table operations and queries @@ -21,7 +32,7 @@ async def create_row( maternal_id: int, affected: int, notes: str | None = None, - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Create a new sample, and add it to database """ @@ -64,16 +75,17 @@ async def create_rows( """ ignore_keys_during_update = {'participant_id'} - remapped_ds_by_keys: Dict[Tuple, List[Dict]] = defaultdict(list) + remapped_ds_by_keys: dict[tuple, list[dict]] = defaultdict(list) # this now works when only a portion of the keys are specified for row in rows: d: dict[str, Any] = { 'family_id': row.family_id, - 'participant_id': row.participant_id, + 'participant_id': row.individual_id, 'paternal_participant_id': row.paternal_id, 'maternal_participant_id': row.maternal_id, 'affected': row.affected, 'notes': row.notes, + # sex is NOT inserted here 'audit_log_id': await self.audit_log_id(), } @@ -97,113 +109,51 @@ async def create_rows( return True - async def get_rows( + async def query( self, - project: ProjectId, - family_ids: Optional[List[int]] = None, - include_participants_not_in_families=False, - ): - """ - Get rows from database, return ALL rows unless family_ids is specified. - If family_ids is not specified, and `include_participants_not_in_families` is True, - Get all participants from project and include them in pedigree - """ - keys = [ - 'fp.family_id', - 'p.id as individual_id', - 'fp.paternal_participant_id as paternal_id', - 'fp.maternal_participant_id as maternal_id', - 'p.reported_sex as sex', - 'fp.affected', - ] - keys_str = ', '.join(keys) - - values: Dict[str, Any] = {'project': project or self.project} - wheres = ['p.project = :project'] - if family_ids: - wheres.append('family_id in :family_ids') - values['family_ids'] = family_ids - - if not include_participants_not_in_families: - wheres.append('f.project = :project') - - conditions = ' AND '.join(wheres) - - _query = f""" - SELECT {keys_str} FROM family_participant fp - INNER JOIN family f ON f.id = fp.family_id - INNER JOIN participant p on fp.participant_id = p.id - WHERE {conditions}""" - if not family_ids and include_participants_not_in_families: - # rewrite the query to LEFT join from participants - # to include all participants - _query = f""" - SELECT {keys_str} FROM participant p - LEFT JOIN family_participant fp ON fp.participant_id = p.id - LEFT JOIN family f ON f.id = fp.family_id - WHERE {conditions}""" - - rows = await self.connection.fetch_all(_query, values) - - ordered_keys = [ - 'family_id', - 'individual_id', - 'paternal_id', - 'maternal_id', - 'sex', - 'affected', - ] - ds = [{k: row[k] for k in ordered_keys} for row in rows] - - return ds - - async def get_row( - self, - family_id: int, - participant_id: int, - ) -> dict | None: - """Get a single row from the family_participant table""" - values: Dict[str, Any] = { - 'family_id': family_id, - 'participant_id': participant_id, - } - - _query = """ -SELECT - fp.family_id as family_id, - p.id as individual_id, - fp.paternal_participant_id as paternal_id, - fp.maternal_participant_id as maternal_id, - p.reported_sex as sex, - fp.affected -FROM family_participant fp -INNER JOIN family f ON f.id = fp.family_id -INNER JOIN participant p on fp.participant_id = p.id -WHERE f.id = :family_id AND p.id = :participant_id -""" - - row = await self.connection.fetch_one(_query, values) - if not row: - return None - - ordered_keys = [ - 'family_id', - 'individual_id', - 'paternal_id', - 'maternal_id', - 'sex', - 'affected', - ] - ds = {k: row[k] for k in ordered_keys} + filter_: FamilyParticipantFilter, + include_participants_not_in_families: bool = False, + ) -> tuple[set[ProjectId], list[PedRowInternal]]: + """ + Query the family_participant table + """ + + wheres, values = filter_.to_sql() + if not wheres: + raise ValueError('No filter provided') + + join_type = 'LEFT' if include_participants_not_in_families else 'INNER' + query = f""" + SELECT + fp.family_id, + p.id as individual_id, + fp.paternal_participant_id as paternal_id, + fp.maternal_participant_id as maternal_id, + p.reported_sex as sex, + fp.affected, + fp.notes as notes, + p.project + FROM participant p + {join_type} JOIN family_participant fp on fp.participant_id = p.id + WHERE {wheres} + """ + + rows = await self.connection.fetch_all(query, values) + projects: set[ProjectId] = set() + pedrows: list[PedRowInternal] = [] + for row in rows: + r = dict(row) + projects.add(r.pop('project')) + pedrows.append(PedRowInternal(**r)) - return ds + return projects, pedrows async def get_participant_family_map( - self, participant_ids: List[int] - ) -> Tuple[Set[int], Dict[int, int]]: + self, + participant_ids: list[int], + ) -> tuple[set[int], dict[int, int]]: """ Get {participant_id: family_id} map - w/ projects """ if len(participant_ids) == 0: @@ -218,10 +168,26 @@ async def get_participant_family_map( rows = await self.connection.fetch_all( _query, {'participant_ids': participant_ids} ) + projects = set(r['project'] for r in rows) - m = {r['id']: r['family_id'] for r in rows} + conflicts: dict[int, list[int]] = {} + pid_to_fid_map: dict[int, int] = {} + for r in rows: + r_id = r['id'] - return projects, m + if r_id in pid_to_fid_map: + if r_id not in conflicts: + conflicts[r_id] = [pid_to_fid_map[r_id]] + conflicts[r_id].append(r['family_id']) + + pid_to_fid_map[r_id] = r['family_id'] + + if conflicts: + raise ValueError( + f'Participants were found in more than one family ({{pid: [fids]}}): {conflicts}' + ) + + return projects, pid_to_fid_map async def delete_family_participant_row(self, family_id: int, participant_id: int): """ @@ -233,7 +199,7 @@ async def delete_family_participant_row(self, family_id: int, participant_id: in _update_before_delete = """ UPDATE family_participant - SET audit_log_id = :audit_log_id + set audit_log_id = :audit_log_id WHERE family_id = :family_id AND participant_id = :participant_id """ @@ -257,47 +223,3 @@ async def delete_family_participant_row(self, family_id: int, participant_id: in ) return True - - async def get_family_participants_by_family_ids( - self, family_ids: list[int] - ) -> tuple[set[ProjectId], dict[int, list[PedRowInternal]]]: - """ - Get all participants in a list of families - """ - if not family_ids: - return set(), {} - - _query = """ - SELECT - p.project, - p.id, - fp.family_id, - fp.paternal_participant_id, - fp.maternal_participant_id, - fp.affected, - fp.notes - FROM - family_participant fp - INNER JOIN participant p ON p.id = fp.participant_id - WHERE fp.family_id IN :family_ids - """ - - rows = await self.connection.fetch_all(_query, {'family_ids': family_ids}) - - projects: set[ProjectId] = set() - by_family = defaultdict(list) - - for row in rows: - projects.add(row['project']) - by_family[row['family_id']].append( - PedRowInternal( - family_id=row['family_id'], - participant_id=row['id'], - paternal_id=row['paternal_participant_id'], - maternal_id=row['maternal_participant_id'], - affected=row['affected'], - notes=row['notes'], - ) - ) - - return projects, by_family From 284786cdb0bc5267a49ca36c20fa960b2cdb12b9 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 23 Apr 2024 13:41:02 +1000 Subject: [PATCH 04/12] Cleanup for the family refactor --- api/graphql/loaders.py | 39 +++- api/graphql/schema.py | 62 ++++--- db/python/layers/participant.py | 43 ++++- db/python/tables/family.py | 110 +++++------- models/models/family.py | 306 +++++++++++++++++++++++++++++++- 5 files changed, 457 insertions(+), 103 deletions(-) diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 383d5f981..a3a512ecc 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -21,6 +21,7 @@ ) from db.python.tables.analysis import AnalysisFilter from db.python.tables.assay import AssayFilter +from db.python.tables.family import FamilyFilter from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter @@ -67,6 +68,7 @@ class LoaderKeys(enum.Enum): FAMILIES_FOR_PARTICIPANTS = 'families_for_participants' FAMILY_PARTICIPANTS_FOR_FAMILIES = 'family_participants_for_families' + FAMILY_PARTICIPANTS_FOR_PARTICIPANTS = 'family_participants_for_participants' FAMILIES_FOR_IDS = 'families_for_ids' SEQUENCING_GROUPS_FOR_IDS = 'sequencing_groups_for_ids' @@ -381,7 +383,10 @@ async def load_families_for_participants( Get families of participants, noting a participant can be in multiple families """ flayer = FamilyLayer(connection) - fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids) + + fam_map = await flayer.get_families_by_participants( + participant_ids=participant_ids, check_project_ids=False + ) return [fam_map.get(p, []) for p in participant_ids] @@ -457,7 +462,7 @@ async def load_families_for_ids( DataLoader: get_families_for_ids """ flayer = FamilyLayer(connection) - families = await flayer.get_families_by_ids(family_ids) + families = await flayer.query(FamilyFilter(id=GenericFilter(in_=family_ids))) f_by_id = {f.id: f for f in families} return [f_by_id[f] for f in family_ids] @@ -466,12 +471,36 @@ async def load_families_for_ids( async def load_family_participants_for_families( family_ids: list[int], connection ) -> list[list[PedRowInternal]]: + """ + DataLoader: get_family_participants_for_families + """ + flayer = FamilyLayer(connection) + fp_map = await flayer.get_family_participants_by_family_ids(family_ids) + + return [fp_map.get(fid, []) for fid in family_ids] + + +@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS) +async def load_family_participants_for_participants( + participant_ids: list[int], connection +) -> list[list[PedRowInternal]]: + """data loader for family participants for participants + + Args: + participant_ids (list[int]): list of internal participant ids + connection (_type_): (this is automatically filled in by the loader decorator) + + Returns: + list[list[PedRowInternal]]: list of family participants for each participant + (in order) + """ flayer = FamilyLayer(connection) - family_participants = await flayer.get_family_participants_for_family_ids( - family_ids + family_participants = await flayer.get_family_participants_for_participants( + participant_ids ) + fp_map = group_by(family_participants, lambda fp: fp.individual_id) - return [family_participants.get(fid, []) for fid in family_ids] + return [fp_map.get(pid, []) for pid in participant_ids] async def get_context( diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 569600e9e..ecfb1c1a4 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -31,6 +31,7 @@ from db.python.tables.analysis import AnalysisFilter from db.python.tables.analysis_runner import AnalysisRunnerFilter from db.python.tables.assay import AssayFilter +from db.python.tables.family import FamilyFilter from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter @@ -157,9 +158,17 @@ async def families( self, info: Info, root: 'GraphQLProject', + id: GraphQLFilter[int] | None = None, + external_id: GraphQLFilter[str] | None = None, ) -> list['GraphQLFamily']: connection = info.context['connection'] - families = await FamilyLayer(connection).get_families(project=root.id) + families = await FamilyLayer(connection).query( + FamilyFilter( + project=GenericFilter(eq=root.id), + id=id.to_internal_filter() if id else None, + external_id=external_id.to_internal_filter() if external_id else None, + ) + ) return [GraphQLFamily.from_internal(f) for f in families] @strawberry.field() @@ -184,7 +193,7 @@ async def samples( filter_ = SampleFilter( type=type.to_internal_filter() if type else None, external_id=external_id.to_internal_filter() if external_id else None, - id=id.to_internal_filter(sample_id_transform_to_raw) if id else None, + id=id.to_internal_filter_mapped(sample_id_transform_to_raw) if id else None, meta=graphql_meta_filter_to_internal_filter(meta), ) samples = await loader.load({'id': root.id, 'filter': filter_}) @@ -205,7 +214,7 @@ async def sequencing_groups( loader = info.context[LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS] filter_ = SequencingGroupFilter( id=( - id.to_internal_filter(sequencing_group_id_transform_to_raw) + id.to_internal_filter_mapped(sequencing_group_id_transform_to_raw) if id else None ), @@ -285,7 +294,7 @@ class GraphQLAnalysis: id: int type: str - status: strawberry.enum(AnalysisStatus) + status: strawberry.enum(AnalysisStatus) # type: ignore output: str | None timestamp_completed: datetime.datetime | None = None active: bool @@ -380,6 +389,10 @@ async def family_participants( @strawberry.type class GraphQLFamilyParticipant: + """ + A FamilyParticipant, an individual in a family, noting that a Family is bounded + by some 'affected' attribute + """ affected: int | None notes: str | None @@ -408,7 +421,7 @@ def from_internal(internal: PedRowInternal) -> 'GraphQLFamilyParticipant': return GraphQLFamilyParticipant( affected=internal.affected, notes=internal.notes, - participant_id=internal.participant_id, + participant_id=internal.individual_id, family_id=internal.family_id, ) @@ -473,18 +486,17 @@ async def families( ) -> list[GraphQLFamily]: fams = await info.context[LoaderKeys.FAMILIES_FOR_PARTICIPANTS].load(root.id) return [GraphQLFamily.from_internal(f) for f in fams] - + @strawberry.field async def family_participants( self, info: Info, root: 'GraphQLParticipant' ) -> list[GraphQLFamilyParticipant]: - return [] - # family_participants = await info.context[ - # LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS - # ].load(root.id) - # return [ - # GraphQLFamilyParticipant.from_internal(fp) for fp in family_participants - # ] + family_participants = await info.context[ + LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS + ].load(root.id) + return [ + GraphQLFamilyParticipant.from_internal(fp) for fp in family_participants + ] @strawberry.field async def project(self, info: Info, root: 'GraphQLParticipant') -> GraphQLProject: @@ -576,7 +588,7 @@ async def sequencing_groups( _filter = SequencingGroupFilter( id=( - id.to_internal_filter(sequencing_group_id_transform_to_raw) + id.to_internal_filter_mapped(sequencing_group_id_transform_to_raw) if id else None ), @@ -657,7 +669,9 @@ async def analyses( project_id_map: dict[str, int] = { p.name: p.id for p in projects if p.name and p.id } - _project_filter = project.to_internal_filter(lambda p: project_id_map[p]) + _project_filter = project.to_internal_filter_mapped( + lambda p: project_id_map[p] + ) analyses = await loader.load( { @@ -778,7 +792,7 @@ class Query: """GraphQL Queries""" @strawberry.field() - def enum(self, info: Info) -> GraphQLEnum: + def enum(self, info: Info) -> GraphQLEnum: # type: ignore return GraphQLEnum() @strawberry.field() @@ -821,7 +835,7 @@ async def sample( project_name_map = {p.name: p.id for p in projects if p.name and p.id} filter_ = SampleFilter( - id=id.to_internal_filter(sample_id_transform_to_raw) if id else None, + id=id.to_internal_filter_mapped(sample_id_transform_to_raw) if id else None, type=type.to_internal_filter() if type else None, meta=graphql_meta_filter_to_internal_filter(meta), external_id=external_id.to_internal_filter() if external_id else None, @@ -829,7 +843,7 @@ async def sample( participant_id.to_internal_filter() if participant_id else None ), project=( - project.to_internal_filter(lambda pname: project_name_map[pname]) + project.to_internal_filter_mapped(lambda pname: project_name_map[pname]) if project else None ), @@ -866,17 +880,19 @@ async def sequencing_groups( user=connection.author, project_names=project_names, readonly=True ) project_id_map = {p.name: p.id for p in projects if p.name and p.id} - _project_filter = project.to_internal_filter(lambda p: project_id_map[p]) + _project_filter = project.to_internal_filter_mapped( + lambda p: project_id_map[p] + ) filter_ = SequencingGroupFilter( project=_project_filter, sample_id=( - sample_id.to_internal_filter(sample_id_transform_to_raw) + sample_id.to_internal_filter_mapped(sample_id_transform_to_raw) if sample_id else None ), id=( - id.to_internal_filter(sequencing_group_id_transform_to_raw) + id.to_internal_filter_mapped(sequencing_group_id_transform_to_raw) if id else None ), @@ -941,4 +957,6 @@ async def analysis_runner( schema = strawberry.Schema( query=Query, mutation=None, extensions=[QueryDepthLimiter(max_depth=10)] ) -MetamistGraphQLRouter = GraphQLRouter(schema, graphiql=True, context_getter=get_context) +MetamistGraphQLRouter: GraphQLRouter = GraphQLRouter( + schema, graphiql=True, context_getter=get_context +) diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index 3698e0cc8..dacc464ec 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -7,11 +7,19 @@ from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer from db.python.tables.family import FamilyTable -from db.python.tables.family_participant import FamilyParticipantTable +from db.python.tables.family_participant import ( + FamilyParticipantFilter, + FamilyParticipantTable, +) from db.python.tables.participant import ParticipantTable from db.python.tables.participant_phenotype import ParticipantPhenotypeTable from db.python.tables.sample import SampleTable -from db.python.utils import NoOpAenter, NotFoundError, split_generic_terms +from db.python.utils import ( + GenericFilter, + NoOpAenter, + NotFoundError, + split_generic_terms, +) from models.models.family import PedRowInternal from models.models.participant import ParticipantInternal, ParticipantUpsertInternal from models.models.project import ProjectId @@ -489,11 +497,12 @@ async def generic_individual_metadata_importer( formed_rows = [ PedRowInternal( family_id=fmap_by_external[external_family_id], - participant_id=pid, + individual_id=pid, affected=0, maternal_id=None, paternal_id=None, notes=None, + sex=None, ) for external_family_id, pid in family_persons_to_insert ] @@ -766,11 +775,29 @@ async def get_seqr_individual_template( 'header_map': json_header_map, } - async def get_family_participant_data(self, family_id: int, participant_id: int): + async def get_family_participant_data( + self, family_id: int, participant_id: int, check_project_ids: bool = True + ) -> PedRowInternal: """Gets the family_participant row for a specific participant""" fptable = FamilyParticipantTable(self.connection) - return await fptable.get_row(family_id=family_id, participant_id=participant_id) + projects, rows = await fptable.query( + FamilyParticipantFilter( + family_id=GenericFilter(eq=family_id), + participant_id=GenericFilter(eq=participant_id), + ) + ) + if not rows: + raise NotFoundError( + f'Family participant row (family_id: {family_id}, ' + f'participant_id: {participant_id}) not found' + ) + if check_project_ids: + await self.ptable.check_access_to_project_ids( + self.author, projects, readonly=True + ) + + return rows[0] async def remove_participant_from_family(self, family_id: int, participant_id: int): """Deletes a participant from a family""" @@ -948,7 +975,7 @@ async def update_participant_family( return await self.add_participant_to_family( family_id=new_family_id, participant_id=participant_id, - paternal_id=fp_row['paternal_id'], - maternal_id=fp_row['maternal_id'], - affected=fp_row['affected'], + paternal_id=fp_row.paternal_id, + maternal_id=fp_row.maternal_id, + affected=fp_row.affected, ) diff --git a/db/python/tables/family.py b/db/python/tables/family.py index 18b13dd72..b936d458a 100644 --- a/db/python/tables/family.py +++ b/db/python/tables/family.py @@ -1,12 +1,29 @@ +import dataclasses from collections import defaultdict from typing import Any, Dict, List, Optional, Set from db.python.tables.base import DbBase -from db.python.utils import NotFoundError +from db.python.utils import GenericFilter, GenericFilterModel, NotFoundError from models.models.family import FamilyInternal from models.models.project import ProjectId +@dataclasses.dataclass +class FamilyFilter(GenericFilterModel): + """Filter mode for querying Families + + Args: + GenericFilterModel (_type_): _description_ + """ + + id: GenericFilter[int] | None = None + external_id: GenericFilter[str] | None = None + + project: GenericFilter[ProjectId] | None = None + participant_id: GenericFilter[int] | None = None + sample_id: GenericFilter[int] | None = None + + class FamilyTable(DbBase): """ Capture Analysis table operations and queries @@ -31,41 +48,53 @@ async def get_projects_by_family_ids(self, family_ids: List[int]) -> Set[Project ) return projects - async def get_families( - self, project: int = None, participant_ids: List[int] = None - ) -> List[FamilyInternal]: + async def query( + self, filter_: FamilyFilter + ) -> tuple[set[ProjectId], list[FamilyInternal]]: """Get all families for some project""" _query = """ - SELECT id, external_id, description, coded_phenotype, project - FROM family + SELECT f.id, f.external_id, f.description, f.coded_phenotype, f.project + FROM family f """ - values: Dict[str, Any] = {'project': project or self.project} - where: List[str] = [] + if not filter_.project and not filter_.id: + raise ValueError('Project or ID filter is required for family queries') + + field_overrides = {'id': 'f.id', 'external_id': 'f.external_id'} - if participant_ids: + if filter_.participant_id: + field_overrides['participant_id'] = 'fp.participant_id' + has_participant_join = True _query += """ - JOIN family_participant - ON family.id = family_participant.family_id + JOIN family_participant fp ON family.id = fp.family_id """ - where.append('participant_id IN :pids') - values['pids'] = participant_ids - if project or self.project: - where.append('project = :project') + if filter_.sample_id: + field_overrides['sample_id'] = 's.id' + if not has_participant_join: + _query += """ + JOIN family_participant fp ON family.id = fp.family_id + """ + + _query += """ + INNER JOIN sample s ON fp.participant_id = s.participant_id + """ - if where: - _query += 'WHERE ' + ' AND '.join(where) + wheres, values = filter_.to_sql(field_overrides) + if wheres: + _query += f'WHERE {wheres}' rows = await self.connection.fetch_all(_query, values) seen = set() families = [] + projects: set[ProjectId] = set() for r in rows: if r['id'] not in seen: + projects.add(r['project']) families.append(FamilyInternal.from_db(dict(r))) seen.add(r['id']) - return families + return projects, families async def get_families_by_participants( self, participant_ids: list[int] @@ -90,47 +119,6 @@ async def get_families_by_participants( return projects, ret_map - async def get_family_by_external_id( - self, external_id: str, project: ProjectId = None - ): - """Get single family by external ID (requires project)""" - _query = """ - SELECT id, external_id, description, coded_phenotype, project - FROM family - WHERE project = :project AND external_id = :external_id - """ - row = await self.connection.fetch_one( - _query, {'project': project or self.project, 'external_id': external_id} - ) - return FamilyInternal.from_db(row) - - async def get_family_by_internal_id( - self, family_id: int - ) -> tuple[ProjectId, FamilyInternal]: - """Get family (+ project) by internal ID""" - _query = """ - SELECT id, external_id, description, coded_phenotype, project - FROM family WHERE id = :fid - """ - row = await self.connection.fetch_one(_query, {'fid': family_id}) - if not row: - raise NotFoundError - project = row['project'] - return project, FamilyInternal.from_db(row) - - async def get_families_by_ids( - self, family_ids: list[int] - ) -> tuple[set[ProjectId], list[FamilyInternal]]: - """Get family (+ project) by internal ID""" - _query = """ - SELECT id, external_id, description, coded_phenotype, project - FROM family WHERE id IN :fids - """ - rows = list(await self.connection.fetch_all(_query, {'fids': family_ids})) - fams = [FamilyInternal.from_db(row) for row in rows] - project = set(f.project for f in fams) - return project, fams - async def search( self, query, project_ids: list[ProjectId], limit: int = 5 ) -> list[tuple[ProjectId, int, str]]: @@ -203,7 +191,7 @@ async def create_family( external_id: str, description: Optional[str], coded_phenotype: Optional[str], - project: ProjectId = None, + project: ProjectId | None = None, ) -> int: """ Create a new sample, and add it to database @@ -233,7 +221,7 @@ async def insert_or_update_multiple_families( external_ids: List[str], descriptions: List[str], coded_phenotypes: List[Optional[str]], - project: int = None, + project: ProjectId | None = None, ): """Upsert""" updater = [ diff --git a/models/models/family.py b/models/models/family.py index d7f18ae5b..6113f27d8 100644 --- a/models/models/family.py +++ b/models/models/family.py @@ -1,4 +1,4 @@ -from typing import Optional +import logging from pydantic import BaseModel @@ -23,8 +23,8 @@ class FamilyInternal(BaseModel): id: int external_id: str project: int - description: Optional[str] = None - coded_phenotype: Optional[str] = None + description: str | None = None + coded_phenotype: str | None = None @staticmethod def from_db(d): @@ -55,8 +55,8 @@ class Family(BaseModel): id: int | None external_id: str project: int - description: Optional[str] = None - coded_phenotype: Optional[str] = None + description: str | None = None + coded_phenotype: str | None = None def to_internal(self): """Convert to internal model""" @@ -75,15 +75,307 @@ class PedRowInternal: def __init__( self, family_id: int, - participant_id: int, + individual_id: int, paternal_id: int | None, maternal_id: int | None, + sex: int | None, affected: int | None, notes: str | None, ): self.family_id = family_id - self.participant_id = participant_id + self.individual_id = individual_id self.paternal_id = paternal_id self.maternal_id = maternal_id + self.sex = sex self.affected = affected self.notes = notes + + def to_dict(self): + """Convert to dictionary""" + return { + 'family_id': self.family_id, + 'individual_id': self.individual_id, + 'paternal_id': self.paternal_id, + 'maternal_id': self.maternal_id, + 'sex': self.sex, + 'affected': self.affected, + 'notes': self.notes, + } + + +class PedRow: + """Class for capturing a row in a pedigree""" + + ALLOWED_SEX_VALUES = [0, 1, 2] + ALLOWED_AFFECTED_VALUES = [-9, 0, 1, 2] + + PedRowKeys = { + # seqr individual template: + # Family ID, Individual ID, Paternal ID, Maternal ID, Sex, Affected, Status, Notes + 'family_id': {'familyid', 'family id', 'family', 'family_id'}, + 'individual_id': {'individualid', 'id', 'individual_id', 'individual id'}, + 'paternal_id': {'paternal_id', 'paternal id', 'paternalid', 'father'}, + 'maternal_id': {'maternal_id', 'maternal id', 'maternalid', 'mother'}, + 'sex': {'sex', 'gender'}, + 'affected': { + 'phenotype', + 'affected', + 'phenotypes', + 'affected status', + 'affection', + 'affection status', + }, + 'notes': {'notes'}, + } + + @staticmethod + def default_header(): + """Default header (corresponds to the __init__ keys)""" + return [ + 'family_id', + 'individual_id', + 'paternal_id', + 'maternal_id', + 'sex', + 'affected', + 'notes', + ] + + @staticmethod + def row_header(): + """Default RowHeader for output""" + return [ + '#Family ID', + 'Individual ID', + 'Paternal ID', + 'Maternal ID', + 'Sex', + 'Affected', + ] + + def __init__( + self, + family_id, + individual_id, + paternal_id, + maternal_id, + sex, + affected, + notes=None, + ): + self.family_id = family_id.strip() + self.individual_id = individual_id.strip() + self.paternal_id = None + self.maternal_id = None + self.paternal_id = self.check_linking_id(paternal_id, 'paternal_id') + self.maternal_id = self.check_linking_id(maternal_id, 'maternal_id') + self.sex = self.parse_sex(sex) + self.affected = self.parse_affected_status(affected) + self.notes = notes + + @staticmethod + def check_linking_id(linking_id, description: str, blank_values=('0', '')): + """Check that the ID is a valid value, or return None if it's a blank value""" + if linking_id is None: + return None + if isinstance(linking_id, int): + linking_id = str(linking_id).strip() + + if isinstance(linking_id, str): + if linking_id.strip().lower() in blank_values: + return None + return linking_id.strip() + + raise TypeError( + f'Unexpected type {type(linking_id)} ({linking_id}) ' + f'for {description}, expected "str"' + ) + + @staticmethod + def parse_sex(sex: str | int) -> int: + """ + Parse the pedigree SEX value: + 0: unknown + 1: male (also accepts 'm') + 2: female (also accepts 'f') + """ + if isinstance(sex, str) and sex.isdigit(): + sex = int(sex) + if isinstance(sex, int): + if sex in PedRow.ALLOWED_SEX_VALUES: + return sex + raise ValueError( + f'Sex value ({sex}) was not an expected value {PedRow.ALLOWED_SEX_VALUES}.' + ) + + sl = sex.lower() + if sl in ('m', 'male'): + return 1 + if sl in ('f', 'female'): + return 2 + if sl in ('u', 'unknown'): + return 0 + + if sl == 'sex': + raise ValueError( + f'Unknown sex {sex!r}, did you mean to call import_pedigree with has_headers=True?' + ) + raise ValueError( + f'Unknown sex {sex!r}, please ensure sex is in {PedRow.ALLOWED_SEX_VALUES}' + ) + + @staticmethod + def parse_affected_status(affected): + """ + Parse the pedigree "AFFECTED" value: + -9 / 0: unknown + 1: unaffected + 2: affected + """ + if isinstance(affected, str) and not affected.isdigit(): + affected = affected.lower().strip() + if affected in ['unknown']: + return 0 + if affected in ['n', 'no']: + return 1 + if affected in ['y', 'yes', 'affected']: + return 2 + + affected = int(affected) + if affected not in PedRow.ALLOWED_AFFECTED_VALUES: + raise ValueError( + f'Affected value {affected} was not in expected value: {PedRow.ALLOWED_AFFECTED_VALUES}' + ) + + return affected + + def __str__(self): + return f'PedRow: {self.individual_id} ({self.sex})' + + @staticmethod + def order(rows: list['PedRow']) -> list['PedRow']: + """ + Order a list of PedRows, but also validates: + - There are no circular dependencies + - All maternal / paternal IDs are found in the pedigree + """ + rows_to_order: list['PedRow'] = [*rows] + ordered = [] + seen_individuals = set() + remaining_iterations_in_round = len(rows_to_order) + + while len(rows_to_order) > 0: + row = rows_to_order.pop(0) + reqs = [row.paternal_id, row.maternal_id] + if all(r is None or r in seen_individuals for r in reqs): + remaining_iterations_in_round = len(rows_to_order) + ordered.append(row) + seen_individuals.add(row.individual_id) + else: + remaining_iterations_in_round -= 1 + rows_to_order.append(row) + + # makes more sense to keep this comparison separate: + # - If remaining iterations is or AND we still have rows + # - Then raise an Exception + # pylint: disable=chained-comparison + if remaining_iterations_in_round <= 0 and len(rows_to_order) > 0: + participant_ids = ', '.join( + f'{r.individual_id} ({r.paternal_id} | {r.maternal_id})' + for r in rows_to_order + ) + raise ValueError( + "There was an issue in the pedigree, either a parent wasn't " + 'found in the pedigree, or a circular dependency detected ' + "(eg: someone's child is an ancestor's parent). " + f"Can't resolve participants with parental IDs: {participant_ids}" + ) + + return ordered + + @staticmethod + def validate_sexes(rows: list['PedRow'], throws=True) -> bool: + """ + Validate that individuals listed as mothers and fathers + have either unknown sex, male if paternal, and female if maternal. + + Future note: The pedigree has a simplified view of sex, especially + how it relates families together. This function might not handle + more complex cases around intersex disorders within families. The + best advice is either to skip this check, or provide sex as 0 (unknown) + + :param throws: If True is provided (default), raise a ValueError, else just return False + """ + keyed: dict[str, PedRow] = {r.individual_id: r for r in rows} + paternal_ids = [r.paternal_id for r in rows if r.paternal_id] + mismatched_pat_sex = [ + pid for pid in paternal_ids if keyed[pid].sex not in (0, 1) + ] + maternal_ids = [r.maternal_id for r in rows if r.maternal_id] + mismatched_mat_sex = [ + mid for mid in maternal_ids if keyed[mid].sex not in (0, 2) + ] + + messages = [] + if mismatched_pat_sex: + actual_values = ', '.join( + f'{pid} ({keyed[pid].sex})' for pid in mismatched_pat_sex + ) + messages.append('(0, 1) as they are listed as fathers: ' + actual_values) + if mismatched_mat_sex: + actual_values = ', '.join( + f'{pid} ({keyed[pid].sex})' for pid in mismatched_mat_sex + ) + messages.append('(0, 2) as they are listed as mothers: ' + actual_values) + + if messages: + message = 'Expected individuals have sex values:' + ''.join( + '\n\t' + m for m in messages + ) + if throws: + raise ValueError(message) + logging.warning(message) + return False + + return True + + @staticmethod + def parse_header_order(header: list[str]): + """ + Takes a list of unformatted headers, and returns a list of ordered init_keys + + >>> PedRow.parse_header_order(['family', 'mother', 'paternal id', 'phenotypes', 'gender']) + ['family_id', 'maternal_id', 'paternal_id', 'affected', 'sex'] + + >>> PedRow.parse_header_order(['#family id']) + ['family_id'] + + >>> PedRow.parse_header_order(['unexpected header']) + Traceback (most recent call last): + ValueError: Unable to identity header elements: "unexpected header" + """ + ordered_init_keys = [] + unmatched = [] + for item in header: + litem = item.lower().strip().strip('#') + found = False + for h, options in PedRow.PedRowKeys.items(): + for potential_key in options: + if potential_key == litem: + ordered_init_keys.append(h) + found = True + break + if found: + break + + if not found: + unmatched.append(item) + + if unmatched: + # repr casts to string and quotes if applicable + unmatched_headers_str = ', '.join(map(repr, unmatched)) + raise ValueError( + 'Unable to identity header elements: ' + unmatched_headers_str + ) + + return ordered_init_keys From 9f500633be8e363598311543a1ccd07e144b5d73 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 23 Apr 2024 13:41:36 +1000 Subject: [PATCH 05/12] Fix tests + filters --- api/graphql/filters.py | 38 ++++++++----- db/python/utils.py | 22 +++++--- models/models/analysis.py | 2 +- test/test_graphql.py | 74 ++++++++++++++++++-------- test/test_search.py | 3 +- test/test_update_participant_family.py | 42 +++++++++------ 6 files changed, 123 insertions(+), 58 deletions(-) diff --git a/api/graphql/filters.py b/api/graphql/filters.py index 392c94927..bfb3e0ecc 100644 --- a/api/graphql/filters.py +++ b/api/graphql/filters.py @@ -42,20 +42,11 @@ def all_values(self): return v - def to_internal_filter(self, f: Callable[[T], Y] | None = None) -> GenericFilter[Y]: + def to_internal_filter( + self, + ) -> GenericFilter[T]: """Convert from GraphQL to internal filter model""" - if f: - return GenericFilter( - eq=f(self.eq) if self.eq else None, - in_=list(map(f, self.in_)) if self.in_ else None, - nin=list(map(f, self.nin)) if self.nin else None, - gt=f(self.gt) if self.gt else None, - gte=f(self.gte) if self.gte else None, - lt=f(self.lt) if self.lt else None, - lte=f(self.lte) if self.lte else None, - ) - return GenericFilter( eq=self.eq, in_=self.in_, @@ -66,6 +57,21 @@ def to_internal_filter(self, f: Callable[[T], Y] | None = None) -> GenericFilter lte=self.lte, ) + def to_internal_filter_mapped(self, f: Callable[[T], Y]) -> GenericFilter[Y]: + """ + To internal filter, but apply a function to all values. + Separate this into a separate function to please linters and type checkers + """ + return GenericFilter( + eq=f(self.eq) if self.eq else None, + in_=list(map(f, self.in_)) if self.in_ else None, + nin=list(map(f, self.nin)) if self.nin else None, + gt=f(self.gt) if self.gt else None, + gte=f(self.gte) if self.gte else None, + lt=f(self.lt) if self.lt else None, + lte=f(self.lte) if self.lte else None, + ) + GraphQLMetaFilter = strawberry.scalars.JSON @@ -73,6 +79,14 @@ def to_internal_filter(self, f: Callable[[T], Y] | None = None) -> GenericFilter def graphql_meta_filter_to_internal_filter( f: GraphQLMetaFilter | None, ) -> GenericMetaFilter | None: + """Convert from GraphQL to internal filter model + + Args: + f (GraphQLMetaFilter | None): GraphQL filter + + Returns: + GenericMetaFilter | None: internal filter + """ if not f: return None diff --git a/db/python/utils.py b/db/python/utils.py index 487078d45..1e42d016a 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -85,8 +85,8 @@ class GenericFilter(Generic[T]): """ eq: T | None = None - in_: list[T] | None = None - nin: list[T] | None = None + in_: Sequence[T] | None = None + nin: Sequence[T] | None = None gt: T | None = None gte: T | None = None lt: T | None = None @@ -96,8 +96,8 @@ def __init__( self, *, eq: T | None = None, - in_: list[T] | None = None, - nin: list[T] | None = None, + in_: Sequence[T] | None = None, + nin: Sequence[T] | None = None, gt: T | None = None, gte: T | None = None, lt: T | None = None, @@ -147,9 +147,19 @@ def generate_field_name(name): return NONFIELD_CHARS_REGEX.sub('_', name) def to_sql( - self, column: str, column_name: str = None + self, column: str, column_name: str | None = None ) -> tuple[str, dict[str, T | list[T]]]: - """Convert to SQL, and avoid SQL injection""" + """Convert to SQL, and avoid SQL injection + + Args: + column (str): The expression, or column name that derives the values + column_name (str, optional): A column name to use in the field_override. + We'll replace any non-alphanumeric characters with an _. + (Defaults to None) + + Returns: + tuple[str, dict[str, T | list[T]]]: (condition, prepared_values) + """ conditionals = [] values: dict[str, T | list[T]] = {} _column_name = column_name or column diff --git a/models/models/analysis.py b/models/models/analysis.py index dbe03ae55..fceea7705 100644 --- a/models/models/analysis.py +++ b/models/models/analysis.py @@ -19,7 +19,7 @@ class AnalysisInternal(SMBase): id: int | None = None type: str status: AnalysisStatus - active: bool + active: bool | None = None output: str | None = None sequencing_group_ids: list[int] = [] timestamp_completed: datetime | None = None diff --git a/test/test_graphql.py b/test/test_graphql.py index 2ba9048dc..71e6a079e 100644 --- a/test/test_graphql.py +++ b/test/test_graphql.py @@ -261,12 +261,15 @@ async def test_participant_phenotypes(self): @run_as_sync async def test_family_participants(self): + """Test inserting + querying family participants from different directions""" family_layer = FamilyLayer(self.connection) + family_eid = 'family1' + rows = [ - ["family1", "individual1", "paternal1", "maternal1", "m", "1", "note1"], - ["family1", "paternal1", None, None, "m", "0", "note2"], - ["family1", "maternal1", None, None, "f", "1", "note3"], + [family_eid, 'individual1', 'paternal1', 'maternal1', 'm', '1', 'note1'], + [family_eid, 'paternal1', None, None, 'm', '0', 'note2'], + [family_eid, 'maternal1', None, None, 'f', '1', 'note3'], ] await family_layer.import_pedigree(None, rows, create_missing_participants=True) @@ -285,7 +288,7 @@ async def test_family_participants(self): } } families { - id + externalId familyParticipants { affected notes @@ -299,25 +302,50 @@ async def test_family_participants(self): """ resp = await self.run_graphql_query_async(q, {'project': self.project_name}) + assert resp is not None - { - "project": { - "participants": [ - {"externalId": "individual1", "familyParticipants": []}, - {"externalId": "maternal1", "familyParticipants": []}, - {"externalId": "paternal1", "familyParticipants": []}, - ], - "families": [ - { - "id": 1, - "familyParticipants": [ - {"affected": 0, "notes": "note2", "participant": {"id": 1}}, - {"affected": 1, "notes": "note3", "participant": {"id": 2}}, - {"affected": 1, "notes": "note1", "participant": {"id": 3}}, - ], - } + family_simple_obj = {'family': {'externalId': family_eid}} + + participants = resp['project']['participants'] + families = resp['project']['families'] + + participants_by_eid = {p['externalId']: p for p in participants} + self.assertEqual(3, len(participants)) + + self.assertDictEqual( + { + 'externalId': 'individual1', + 'familyParticipants': [ + {'affected': 1, 'notes': 'note1', **family_simple_obj} ], - } - } + }, + participants_by_eid['individual1'], + ) - print(resp) + self.assertEqual(1, len(families)) + self.assertEqual(family_eid, families[0]['externalId']) + + sorted_fps = sorted( + families[0]['familyParticipants'], + key=lambda x: x['participant']['externalId'], + ) + self.assertListEqual( + sorted_fps, + [ + { + 'affected': 1, + 'notes': 'note1', + 'participant': {'externalId': 'individual1'}, + }, + { + 'affected': 1, + 'notes': 'note3', + 'participant': {'externalId': 'maternal1'}, + }, + { + 'affected': 0, + 'notes': 'note2', + 'participant': {'externalId': 'paternal1'}, + }, + ], + ) diff --git a/test/test_search.py b/test/test_search.py index ca6718ec5..d5498b098 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -204,10 +204,11 @@ async def test_search_mixed(self): [ PedRowInternal( family_id=f_id, - participant_id=p.id, + individual_id=p.id, paternal_id=None, maternal_id=None, affected=0, + sex=0, notes=None, ) ] diff --git a/test/test_update_participant_family.py b/test/test_update_participant_family.py index e8517dd91..204295644 100644 --- a/test/test_update_participant_family.py +++ b/test/test_update_participant_family.py @@ -20,13 +20,21 @@ async def setUp(self) -> None: self.fid_2 = await fl.create_family(external_id='FAM02') pl = ParticipantLayer(self.connection) - self.pid = (await pl.upsert_participant(ParticipantUpsertInternal(external_id='EX01', reported_sex=2))).id - self.pat_pid = (await pl.upsert_participant( - ParticipantUpsertInternal(external_id='EX01_pat', reported_sex=1) - )).id - self.mat_pid = (await pl.upsert_participant( - ParticipantUpsertInternal(external_id='EX01_mat', reported_sex=2) - )).id + self.pid = ( + await pl.upsert_participant( + ParticipantUpsertInternal(external_id='EX01', reported_sex=2) + ) + ).id + self.pat_pid = ( + await pl.upsert_participant( + ParticipantUpsertInternal(external_id='EX01_pat', reported_sex=1) + ) + ).id + self.mat_pid = ( + await pl.upsert_participant( + ParticipantUpsertInternal(external_id='EX01_mat', reported_sex=2) + ) + ).id await pl.add_participant_to_family( family_id=self.fid_1, @@ -54,8 +62,9 @@ async def test_get_remove_add_family_participant_data(self): 'maternal_id': self.mat_pid, 'sex': 2, 'affected': 2, + 'notes': None, } - self.assertDictEqual(fp_row, expected_fp_row) + self.assertDictEqual(expected_fp_row, fp_row.to_dict()) await pl.remove_participant_from_family( family_id=self.fid_1, participant_id=self.pid @@ -64,9 +73,9 @@ async def test_get_remove_add_family_participant_data(self): await pl.add_participant_to_family( family_id=self.fid_2, participant_id=self.pid, - paternal_id=fp_row['paternal_id'], - maternal_id=fp_row['maternal_id'], - affected=fp_row['affected'], + paternal_id=fp_row.paternal_id, + maternal_id=fp_row.maternal_id, + affected=fp_row.affected, ) updated_fp_row = await pl.get_family_participant_data( @@ -80,8 +89,9 @@ async def test_get_remove_add_family_participant_data(self): 'maternal_id': self.mat_pid, 'sex': 2, 'affected': 2, + 'notes': None, } - self.assertDictEqual(updated_fp_row, expected_updated_fp_row) + self.assertDictEqual(expected_updated_fp_row, updated_fp_row.to_dict()) await pl.remove_participant_from_family( family_id=self.fid_2, participant_id=self.pid @@ -106,8 +116,9 @@ async def test_update_participant_family(self): 'maternal_id': self.mat_pid, 'sex': 2, 'affected': 2, + 'notes': None, } - self.assertDictEqual(updated_fp_row, expected_updated_fp_row) + self.assertDictEqual(expected_updated_fp_row, updated_fp_row.to_dict()) await pl.remove_participant_from_family( family_id=self.fid_2, participant_id=self.pid @@ -128,8 +139,9 @@ async def test_update_participant_to_nonexistent_family(self): 'maternal_id': self.mat_pid, 'sex': 2, 'affected': 2, + 'notes': None, } - self.assertDictEqual(fp_row, expected_fp_row) + self.assertDictEqual(expected_fp_row, fp_row.to_dict()) with self.assertRaises(IntegrityError): await pl.update_participant_family( @@ -141,4 +153,4 @@ async def test_update_participant_to_nonexistent_family(self): ) # Update transaction should rollback, so no change expected - self.assertDictEqual(rollback_fp_row, expected_fp_row) + self.assertDictEqual(expected_fp_row, rollback_fp_row.to_dict()) From 7b0bdc90e0e666e6f8fee4a99ab1b7e361bf153a Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 23 Apr 2024 22:21:53 +1000 Subject: [PATCH 06/12] Add codecov test value --- .github/workflows/test.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1c81caa00..abf2e6a29 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -76,6 +76,8 @@ jobs: uses: codecov/codecov-action@v3 with: files: ./coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + - name: "build web front-end" run: | From e5c4b236e15c0e5c2e0588df96b90a1772ea764d Mon Sep 17 00:00:00 2001 From: Michael Franklin <22381693+illusional@users.noreply.github.com> Date: Tue, 23 Apr 2024 22:31:40 +1000 Subject: [PATCH 07/12] Apply suggestions from code review --- api/graphql/schema.py | 2 ++ db/python/layers/family.py | 6 +++--- db/python/tables/family_participant.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/api/graphql/schema.py b/api/graphql/schema.py index ecfb1c1a4..37d3f97b1 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -161,6 +161,8 @@ async def families( id: GraphQLFilter[int] | None = None, external_id: GraphQLFilter[str] | None = None, ) -> list['GraphQLFamily']: + # don't need a data loader here as we're presuming we're not often running + # the "families" method for many projects at once. If so, we might need to fix that connection = info.context['connection'] families = await FamilyLayer(connection).query( FamilyFilter( diff --git a/db/python/layers/family.py b/db/python/layers/family.py index 344032d16..63015fa1a 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -200,9 +200,9 @@ async def get_pedigree( mapped_rows.append( { 'family_id': fmap.get(r.family_id, str(r.family_id)), - 'individual_id': pmap.get(r.individual_id, empty_participant_value), - 'paternal_id': pmap.get(r.paternal_id, empty_participant_value), - 'maternal_id': pmap.get(r.maternal_id, empty_participant_value), + 'individual_id': pmap.get(r.individual_id, r.individual_id) or empty_participant_value, + 'paternal_id': pmap.get(r.paternal_id, r.paternal_id) or empty_participant_value), + 'maternal_id': pmap.get(r.maternal_id, r.maternal_id) or empty_participant_value), 'sex': r.sex, 'affected': r.affected, 'notes': r.notes, diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py index 1808a2348..53b774d0c 100644 --- a/db/python/tables/family_participant.py +++ b/db/python/tables/family_participant.py @@ -199,7 +199,7 @@ async def delete_family_participant_row(self, family_id: int, participant_id: in _update_before_delete = """ UPDATE family_participant - set audit_log_id = :audit_log_id + SET audit_log_id = :audit_log_id WHERE family_id = :family_id AND participant_id = :participant_id """ From ace51071f6057ab7fa492ca319e72d40acfe2e48 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 23 Apr 2024 22:32:02 +1000 Subject: [PATCH 08/12] Fix minor formatting oddity --- api/graphql/filters.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/graphql/filters.py b/api/graphql/filters.py index bfb3e0ecc..2e41653b8 100644 --- a/api/graphql/filters.py +++ b/api/graphql/filters.py @@ -42,9 +42,7 @@ def all_values(self): return v - def to_internal_filter( - self, - ) -> GenericFilter[T]: + def to_internal_filter(self) -> GenericFilter[T]: """Convert from GraphQL to internal filter model""" return GenericFilter( From 82b3a6590f6ac220e661129841265ea7ef1b1c49 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Tue, 23 Apr 2024 22:34:30 +1000 Subject: [PATCH 09/12] Linting --- api/graphql/schema.py | 2 +- db/python/layers/family.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 37d3f97b1..98f8c6838 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -161,7 +161,7 @@ async def families( id: GraphQLFilter[int] | None = None, external_id: GraphQLFilter[str] | None = None, ) -> list['GraphQLFamily']: - # don't need a data loader here as we're presuming we're not often running + # don't need a data loader here as we're presuming we're not often running # the "families" method for many projects at once. If so, we might need to fix that connection = info.context['connection'] families = await FamilyLayer(connection).query( diff --git a/db/python/layers/family.py b/db/python/layers/family.py index 63015fa1a..0f7613488 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -200,9 +200,12 @@ async def get_pedigree( mapped_rows.append( { 'family_id': fmap.get(r.family_id, str(r.family_id)), - 'individual_id': pmap.get(r.individual_id, r.individual_id) or empty_participant_value, - 'paternal_id': pmap.get(r.paternal_id, r.paternal_id) or empty_participant_value), - 'maternal_id': pmap.get(r.maternal_id, r.maternal_id) or empty_participant_value), + 'individual_id': pmap.get(r.individual_id, r.individual_id) + or empty_participant_value, + 'paternal_id': pmap.get(r.paternal_id, r.paternal_id) + or empty_participant_value, + 'maternal_id': pmap.get(r.maternal_id, r.maternal_id) + or empty_participant_value, 'sex': r.sex, 'affected': r.affected, 'notes': r.notes, From 1ea48a3d55169950eccc129949cea808dc166e68 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Wed, 24 Apr 2024 17:00:32 +1000 Subject: [PATCH 10/12] Add a GraphQLMetaFilter test + improve filter hashing --- api/graphql/loaders.py | 29 +++------------------------- db/python/utils.py | 41 +++++++++++++++++++++++++++++++++++---- test/test_graphql.py | 44 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 30 deletions(-) diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index a3a512ecc..a819631ed 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -25,7 +25,7 @@ from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter -from db.python.utils import GenericFilter, NotFoundError +from db.python.utils import GenericFilter, NotFoundError, get_hashable_value from models.models import ( AnalysisInternal, AssayInternal, @@ -96,31 +96,8 @@ async def wrapped(*args, **kwargs): return connected_data_loader_caller -def _prepare_partial_value_for_hashing(value): - if value is None: - return None - if isinstance(value, (int, str, float, bool)): - return value - if isinstance(value, enum.Enum): - return value.value - if isinstance(value, list): - # let's see if later we need to prepare the values in the list - return tuple(value) - if isinstance(value, dict): - return tuple( - sorted( - ((k, _prepare_partial_value_for_hashing(v)) for k, v in value.items()), - key=lambda x: x[0], - ) - ) - - return hash(value) - - -def _get_connected_data_loader_partial_key(kwargs): - return _prepare_partial_value_for_hashing( - {k: v for k, v in kwargs.items() if k != 'id'} - ) +def _get_connected_data_loader_partial_key(kwargs) -> tuple: + return get_hashable_value({k: v for k, v in kwargs.items() if k != 'id'}) # type: ignore def connected_data_loader_with_params( diff --git a/db/python/utils.py b/db/python/utils.py index 1e42d016a..e5507cdca 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -118,10 +118,11 @@ def __repr__(self): ) return f'{self.__class__.__name__}({inner_values})' - def __hash__(self): - """Override to ensure we can hash this object""" - return hash( + def get_hashable_value(self): + """Get value that we could run hash on""" + return get_hashable_value( ( + self.__class__.__name__, self.eq, tuple(self.in_) if self.in_ is not None else None, tuple(self.nin) if self.nin is not None else None, @@ -132,6 +133,10 @@ def __hash__(self): ) ) + def __hash__(self): + """Override to ensure we can hash this object""" + return hash(self.get_hashable_value()) + @staticmethod def generate_field_name(name): """ @@ -219,6 +224,30 @@ def _sql_value_prep(value): return value +def get_hashable_value(value): + """Prepare a value that can be hashed, for use in a dict or set""" + if value is None: + return None + if isinstance(value, (int, str, float, bool)): + return value + if isinstance(value, Enum): + return value.value + if isinstance(value, (tuple, list)): + # let's see if later we need to prepare the values in the list + return tuple(get_hashable_value(v) for v in value) + if isinstance(value, dict): + return tuple( + sorted( + ((k, get_hashable_value(v)) for k, v in value.items()), + key=lambda x: x[0], + ) + ) + if hasattr(value, 'get_hashable_value'): + return value.get_hashable_value() + + return hash(value) + + # pylint: disable=missing-class-docstring GenericMetaFilter = dict[str, GenericFilter[Any]] @@ -231,7 +260,11 @@ class GenericFilterModel: def __hash__(self): """Hash the GenericFilterModel, this doesn't override well""" - return hash(dataclasses.astuple(self)) + return hash(self.get_hashable_value()) + + def get_hashable_value(self): + """Get value that we could run hash on""" + return get_hashable_value((self.__class__.__name__, *dataclasses.astuple(self))) def __post_init__(self): for field in dataclasses.fields(self): diff --git a/test/test_graphql.py b/test/test_graphql.py index 71e6a079e..5dff32b88 100644 --- a/test/test_graphql.py +++ b/test/test_graphql.py @@ -191,6 +191,45 @@ async def test_basic_graphql_query(self): p.samples[0].sequencing_groups[0].assays[0].id, assays[0]['id'] ) + @run_as_sync + async def test_query_sample_by_meta(self): + """Test querying a participant""" + await self.player.upsert_participant( + ParticipantUpsertInternal( + meta={}, + external_id='Demeter', + samples=[ + SampleUpsertInternal( + external_id='sample_id001', + meta={'thisKey': 'value'}, + ) + ], + ) + ) + q = """ + query MyQuery($project: String!, $meta: JSON!) { + project(name: $project) { + participants { + samples(meta: $meta) { + id + } + } + } + }""" + values = await self.run_graphql_query_async( + q, {'project': self.project_name, 'meta': {'thisKey': 'value'}} + ) + assert values + + self.assertEqual(1, len(values['project']['participants'][0]['samples'])) + + values2 = await self.run_graphql_query_async( + q, {'project': self.project_name, 'meta': {'thisKeyDoesNotExistEver': '-1'}} + ) + assert values2 + + self.assertEqual(0, len(values2['project']['participants'][0]['samples'])) + @run_as_sync async def test_sg_analyses_query(self): """Example graphql query of analyses from sequencing-group""" @@ -286,6 +325,9 @@ async def test_family_participants(self): externalId } } + families { + externalId + } } families { externalId @@ -315,12 +357,14 @@ async def test_family_participants(self): self.assertDictEqual( { 'externalId': 'individual1', + 'families': [{'externalId': family_eid}], 'familyParticipants': [ {'affected': 1, 'notes': 'note1', **family_simple_obj} ], }, participants_by_eid['individual1'], ) + self.assertEqual(1, len(participants_by_eid['individual1']['families'])) self.assertEqual(1, len(families)) self.assertEqual(family_eid, families[0]['externalId']) From 8363ea1259adba90b430ffda3c5e9cf4e165c8c3 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Wed, 24 Apr 2024 17:04:08 +1000 Subject: [PATCH 11/12] Apply review feedback --- db/python/layers/family.py | 6 ++-- db/python/layers/participant.py | 50 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/db/python/layers/family.py b/db/python/layers/family.py index 0f7613488..d861f9d32 100644 --- a/db/python/layers/family.py +++ b/db/python/layers/family.py @@ -57,7 +57,6 @@ async def get_family_by_external_id( self, external_id: str, project: ProjectId | None = None ): """Get family by external ID, requires project scope""" - # return await self.ftable.get_family_by_external_id(external_id, project=project) families = await self.ftable.query( FamilyFilter( external_id=GenericFilter(eq=external_id), @@ -304,9 +303,8 @@ async def import_pedigree( reported_sex=row.sex, ) ) - external_participant_ids_map[row.individual_id] = ( - upserted_participant.id - ) + pid = upserted_participant.id + external_participant_ids_map[row.individual_id] = pid for external_family_id in missing_external_family_ids: internal_family_id = await self.ftable.create_family( diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index dacc464ec..a945fd62a 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer @@ -188,7 +188,7 @@ def parse_age_of_onset(age_of_onset: str): ) @staticmethod - def parse_hpo_terms(hpo_terms: str) -> List[str]: + def parse_hpo_terms(hpo_terms: str) -> list[str]: """ Validate that comma-separated HPO terms must start with 'HP:' @@ -272,8 +272,8 @@ async def get_participants_by_ids( async def get_participants( self, project: int, - external_participant_ids: List[str] = None, - internal_participant_ids: List[int] = None, + external_participant_ids: list[str] | None = None, + internal_participant_ids: list[int] | None = None, ) -> list[ParticipantInternal]: """ Get participants for a project @@ -360,8 +360,8 @@ async def insert_participant_phenotypes( async def generic_individual_metadata_importer( self, - headers: List[str], - rows: List[List[str]], + headers: list[str], + rows: list[list[str]], extra_participants_method: ExtraParticipantImporterHandler = ExtraParticipantImporterHandler.FAIL, ): """ @@ -457,8 +457,8 @@ async def generic_individual_metadata_importer( } missing_family_ids = external_family_ids - set(fmap_by_external.keys()) - family_persons_to_insert: List[Tuple[str, int]] = [] - incompatible_familes: List[str] = [] + family_persons_to_insert: list[tuple[str, int]] = [] + incompatible_familes: list[str] = [] for pid, external_family_id in provided_pid_to_external_family.items(): if pid in pid_to_internal_family: # we know the family @@ -539,10 +539,10 @@ async def get_participants_by_families( async def get_id_map_by_external_ids( self, - external_ids: List[str], - project: Optional[ProjectId], + external_ids: list[str], + project: ProjectId | None, allow_missing=False, - ) -> Dict[str, int]: + ) -> dict[str, int]: """Get participant ID map by external_ids""" id_map = await self.pttable.get_id_map_by_external_ids( external_ids, project=project @@ -563,7 +563,7 @@ async def get_id_map_by_external_ids( async def get_external_participant_id_to_internal_sequencing_group_id_map( self, project: int, sequencing_type: str = None - ) -> List[Tuple[str, int]]: + ) -> list[tuple[str, int]]: """ Get a map of {external_participant_id} -> {internal_sequencing_group_id} useful to matching joint-called samples in the matrix table to the participant @@ -656,7 +656,7 @@ async def upsert_participants( return participants async def update_many_participant_external_ids( - self, internal_to_external_id: Dict[int, str], check_project_ids=True + self, internal_to_external_id: dict[int, str], check_project_ids=True ): """Update many participant external ids""" if check_project_ids: @@ -688,13 +688,13 @@ async def get_seqr_individual_template( self, project: int, *, - internal_participant_ids: Optional[list[int]] = None, - external_participant_ids: Optional[List[str]] = None, + internal_participant_ids: list[int] | None = None, + external_participant_ids: list[str] | None = None, # pylint: disable=invalid-name replace_with_participant_external_ids=True, replace_with_family_external_ids=True, ) -> dict[str, Any]: - """Get seqr individual level metadata template as List[List[str]]""" + """Get seqr individual level metadata template as list[list[str]]""" # avoid circular imports # pylint: disable=import-outside-toplevel,cyclic-import,too-many-locals @@ -747,7 +747,7 @@ async def get_seqr_individual_template( ] json_header_map = dict(zip(json_headers, headers)) lheader_to_json = dict(zip(lheaders, json_headers)) - rows: List[Dict[str, str]] = [] + rows: list[dict[str, str]] = [] for pid, d in pid_to_features.items(): d[SeqrMetadataKeys.INDIVIDUAL_ID.value] = internal_to_external_pid_map.get( pid, str(pid) @@ -878,22 +878,22 @@ def _validate_individual_metadata_participant_ids(rows, participant_id_field_ind @staticmethod def _prepare_individual_metadata_insertable_rows( - storeable_keys: List[str], - lheaders_to_idx_map: Dict[str, int], + storeable_keys: list[str], + lheaders_to_idx_map: dict[str, int], participant_id_field_idx: int, - pid_map: Dict[str, int], - rows: List[List[str]], + pid_map: dict[str, int], + rows: list[list[str]], ): # do all the matching in lowercase space, but store in regular case space # pylint: disable=invalid-name - storeable_header_col_number_tuples: List[Tuple[str, int]] = [ + storeable_header_col_number_tuples: list[tuple[str, int]] = [ (k, lheaders_to_idx_map[k.lower()]) for k in storeable_keys if k.lower() in lheaders_to_idx_map ] - # List of (PersonId, Key, value) to insert into the participant_phenotype table - insertable_rows: List[Tuple[int, str, Any]] = [] + # list of (PersonId, Key, value) to insert into the participant_phenotype table + insertable_rows: list[tuple[int, str, Any]] = [] parsers = {k.value: v for k, v in SeqrMetadataKeys.get_key_parsers().items()} hpo_col_indices = [ @@ -937,7 +937,7 @@ def _prepare_individual_metadata_insertable_rows( # endregion PHENOTYPES / SEQR async def check_project_access_for_participants_families( - self, participant_ids: List[int], family_ids: List[int] + self, participant_ids: list[int], family_ids: list[int] ): """Checks user access for the projects associated with participant IDs and family IDs""" pprojects = await self.pttable.get_project_ids_for_participant_ids( From 83eb5625b902c0bb59d5e18416cdcf854e1de866 Mon Sep 17 00:00:00 2001 From: Michael Franklin Date: Wed, 1 May 2024 13:43:20 +1000 Subject: [PATCH 12/12] Linting --- api/graphql/schema.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 737acab28..79d01e491 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -88,8 +88,6 @@ async def m(info: Info) -> list[str]: GraphQLAnalysisStatus = strawberry.enum(AnalysisStatus) -GraphQLAnalysisStatus = strawberry.enum(AnalysisStatus) - # Create cohort GraphQL model @strawberry.type @@ -377,11 +375,13 @@ async def cohorts( connection.project = root.id c_filter = CohortFilter( - id=id.to_internal_filter(cohort_id_transform_to_raw) if id else None, + id=id.to_internal_filter_mapped(cohort_id_transform_to_raw) if id else None, name=name.to_internal_filter() if name else None, author=author.to_internal_filter() if author else None, template_id=( - template_id.to_internal_filter(cohort_template_id_transform_to_raw) + template_id.to_internal_filter_mapped( + cohort_template_id_transform_to_raw + ) if template_id else None ), @@ -947,13 +947,13 @@ async def cohort_templates( user=connection.author, project_names=project_names, readonly=True ) project_name_map = {p.name: p.id for p in projects} - project_filter = project.to_internal_filter( + project_filter = project.to_internal_filter_mapped( lambda pname: project_name_map[pname] ) filter_ = CohortTemplateFilter( id=( - id.to_internal_filter(cohort_template_id_transform_to_raw) + id.to_internal_filter_mapped(cohort_template_id_transform_to_raw) if id else None ), @@ -988,17 +988,19 @@ async def cohorts( user=connection.author, project_names=project_names, readonly=True ) project_name_map = {p.name: p.id for p in projects} - project_filter = project.to_internal_filter( + project_filter = project.to_internal_filter_mapped( lambda pname: project_name_map[pname] ) filter_ = CohortFilter( - id=id.to_internal_filter(cohort_id_transform_to_raw) if id else None, + id=id.to_internal_filter_mapped(cohort_id_transform_to_raw) if id else None, name=name.to_internal_filter() if name else None, project=project_filter, author=author.to_internal_filter() if author else None, template_id=( - template_id.to_internal_filter(cohort_template_id_transform_to_raw) + template_id.to_internal_filter_mapped( + cohort_template_id_transform_to_raw + ) if template_id else None ),