Skip to content

Commit

Permalink
fix: Regression of AgentSummary GQL resolver (#3045) (#3172)
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine authored Nov 28, 2024
1 parent 64582ee commit 101f0c5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/3045.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix regression of the `AgentSummary` resolver caused by an incorrect `batch_load_func` assignment.
21 changes: 16 additions & 5 deletions src/ai/backend/manager/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
import enum
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, TypeAlias, cast, override
from typing import (
TYPE_CHECKING,
Any,
Dict,
Mapping,
Optional,
Self,
Sequence,
TypeAlias,
cast,
override,
)

import graphene
import sqlalchemy as sa
Expand Down Expand Up @@ -539,7 +550,7 @@ def from_row(
cls,
ctx: GraphQueryContext,
row: Row,
) -> Agent:
) -> Self:
return cls(
id=row["id"],
status=row["status"].name,
Expand Down Expand Up @@ -572,11 +583,11 @@ async def batch_load(
graph_ctx: GraphQueryContext,
agent_ids: Sequence[AgentId],
*,
access_key: AccessKey,
domain_name: str | None,
raw_status: Optional[str] = None,
scaling_group: Optional[str] = None,
access_key: str,
) -> Sequence[Agent | None]:
) -> Sequence[Optional[Self]]:
query = (
sa.select([agents])
.select_from(agents)
Expand Down Expand Up @@ -638,7 +649,7 @@ async def load_slice(
raw_status: Optional[str] = None,
filter: Optional[str] = None,
order: Optional[str] = None,
) -> Sequence[Agent]:
) -> Sequence[Self]:
query = sa.select([agents]).select_from(agents).limit(limit).offset(offset)
query = await _append_sgroup_from_clause(
graph_ctx, query, access_key, domain_name, scaling_group
Expand Down
9 changes: 8 additions & 1 deletion src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TYPE_CHECKING,
Any,
ClassVar,
Concatenate,
Final,
Generic,
NamedTuple,
Expand Down Expand Up @@ -752,7 +753,12 @@ def _get_func_key(
def get_loader_by_func(
self,
context: ContextT,
batch_load_func: Callable[[ContextT, Sequence[LoaderKeyT]], Awaitable[LoaderResultT]],
batch_load_func: Callable[
Concatenate[ContextT, Sequence[LoaderKeyT], ...], Awaitable[LoaderResultT]
],
# Using kwargs-only to prevent argument position confusion
# when DataLoader calls `batch_load_func(keys)` which is `partial(batch_load_func, **kwargs)(keys)`.
**kwargs,
) -> DataLoader:
key = self._get_func_key(batch_load_func)
loader = self.cache.get(key)
Expand All @@ -761,6 +767,7 @@ def get_loader_by_func(
functools.partial(
batch_load_func,
context,
**kwargs,
),
max_batch_size=128,
)
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,9 +892,9 @@ async def resolve_agent_summary(
if ctx.local_config["manager"]["hide-agents"]:
raise ObjectNotFound(object_name="agent")

loader = ctx.dataloader_manager.get_loader(
loader = ctx.dataloader_manager.get_loader_by_func(
ctx,
"Agent",
AgentSummary.batch_load,
raw_status=None,
scaling_group=scaling_group,
domain_name=domain_name,
Expand Down

0 comments on commit 101f0c5

Please sign in to comment.