Skip to content

Commit

Permalink
refactor: Avoid using OrderedDict in the manager API and client SDK (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Sep 18, 2024
1 parent 6f9c9cb commit 9130c32
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
1 change: 1 addition & 0 deletions changes/2842.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid using `collections.OrderedDict` when not necessary in the manager API and client SDK
6 changes: 3 additions & 3 deletions src/ai/backend/client/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json as modjson
import logging
import sys
from collections import OrderedDict, namedtuple
from collections import namedtuple
from datetime import datetime
from decimal import Decimal
from pathlib import Path
Expand Down Expand Up @@ -412,7 +412,7 @@ async def text(self) -> str:
return await self._raw_response.text()

async def json(self, *, loads=modjson.loads) -> Any:
loads = functools.partial(loads, object_pairs_hook=OrderedDict)
loads = functools.partial(loads)
return await self._raw_response.json(loads=loads)

async def read(self, n: int = -1) -> bytes:
Expand All @@ -433,7 +433,7 @@ def text(self) -> str:
)

def json(self, *, loads=modjson.loads) -> Any:
loads = functools.partial(loads, object_pairs_hook=OrderedDict)
loads = functools.partial(loads)
sync_session = cast(SyncSession, self._session)
return sync_session.worker_thread.execute(
self._raw_response.json(loads=loads),
Expand Down
18 changes: 8 additions & 10 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import collections
import enum
import functools
import logging
Expand All @@ -20,7 +19,6 @@
Callable,
ClassVar,
Coroutine,
Dict,
Final,
Generic,
List,
Expand Down Expand Up @@ -847,8 +845,8 @@ async def batch_result(
"""
A batched query adaptor for (key -> item) resolving patterns.
"""
objs_per_key: Dict[_Key, Optional[_GenericSQLBasedGQLObject]]
objs_per_key = collections.OrderedDict()
objs_per_key: dict[_Key, Optional[_GenericSQLBasedGQLObject]]
objs_per_key = dict()
for key in key_list:
objs_per_key[key] = None
if isinstance(db_conn, SASession):
Expand All @@ -871,8 +869,8 @@ async def batch_multiresult(
"""
A batched query adaptor for (key -> [item]) resolving patterns.
"""
objs_per_key: Dict[_Key, List[_GenericSQLBasedGQLObject]]
objs_per_key = collections.OrderedDict()
objs_per_key: dict[_Key, list[_GenericSQLBasedGQLObject]]
objs_per_key = dict()
for key in key_list:
objs_per_key[key] = list()
if isinstance(db_conn, SASession):
Expand All @@ -898,8 +896,8 @@ async def batch_result_in_session(
A batched query adaptor for (key -> item) resolving patterns.
stream the result in async session.
"""
objs_per_key: Dict[_Key, Optional[_GenericSQLBasedGQLObject]]
objs_per_key = collections.OrderedDict()
objs_per_key: dict[_Key, Optional[_GenericSQLBasedGQLObject]]
objs_per_key = dict()
for key in key_list:
objs_per_key[key] = None
async for row in await db_sess.stream(query):
Expand All @@ -919,8 +917,8 @@ async def batch_multiresult_in_session(
A batched query adaptor for (key -> [item]) resolving patterns.
stream the result in async session.
"""
objs_per_key: Dict[_Key, List[_GenericSQLBasedGQLObject]]
objs_per_key = collections.OrderedDict()
objs_per_key: dict[_Key, list[_GenericSQLBasedGQLObject]]
objs_per_key = dict()
for key in key_list:
objs_per_key[key] = list()
async for row in await db_sess.stream(query):
Expand Down

0 comments on commit 9130c32

Please sign in to comment.