Skip to content

Commit

Permalink
feat: Add dependee/dependent/graph ComputeSessionNode connection queries
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol committed Sep 18, 2024
1 parent 38d0b9a commit 90d3091
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ class Meta:
kernel_nodes = PaginatedConnectionField(
KernelConnection,
)
dependents = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
)
dependees = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
)
graph = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
)

@classmethod
def from_row(
Expand Down Expand Up @@ -264,6 +273,74 @@ async def resolve_kernel_nodes(
total_count=len(kernels),
)

async def resolve_dependees(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependee_id")
sessions = await loader.load(self.row_id)
return ConnectionResolverResult(
sessions,
None,
None,
None,
total_count=len(sessions),
)

async def resolve_dependents(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader(ctx, "ComputeSessionNode.by_dependent_id")
sessions = await loader.load(self.row_id)
return ConnectionResolverResult(
sessions,
None,
None,
None,
total_count=len(sessions),
)

async def resolve_graph(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[ComputeSessionNode]:
from ..session import SessionDependencyRow, SessionRow

ctx: GraphQueryContext = info.context

async with ctx.db.begin_readonly_session() as db_sess:
dependency_cte = (
sa.select(SessionRow.id)
.filter(SessionRow.id == self.row_id)
.cte(name="dependency_cte", recursive=True)
)
dependee = sa.select(SessionDependencyRow.depends_on).join(
dependency_cte, SessionDependencyRow.session_id == dependency_cte.c.id
)
dependent = sa.select(SessionDependencyRow.session_id).join(
dependency_cte, SessionDependencyRow.depends_on == dependency_cte.c.id
)
dependency_cte = dependency_cte.union_all(dependee).union_all(dependent)
# Get the session IDs in the graph
query = sa.select(dependency_cte.c.id)
session_ids = (await db_sess.execute(query)).scalars().all()
# Get the session rows in the graph
query = sa.select(SessionRow).where(SessionRow.id.in_(session_ids))
session_rows = (await db_sess.execute(query)).scalars().all()

# Convert into GraphQL node objects
sessions = [ComputeSessionNode.from_row(ctx, r) for r in session_rows]
return ConnectionResolverResult(
sessions,
None,
None,
None,
total_count=len(sessions),
)

@classmethod
async def batch_load_idle_checks(
cls, ctx: GraphQueryContext, session_ids: Sequence[SessionId]
Expand Down

0 comments on commit 90d3091

Please sign in to comment.