Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove remaining usage of cursor_to_dict. #16564

Merged
merged 12 commits into from
Oct 31, 2023
9 changes: 6 additions & 3 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import attr

from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
Expand Down Expand Up @@ -161,11 +163,12 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
)

# If support for MSC3866 is not enabled, don't show the approval flag.
filter = None
if not self._msc3866_enabled:
for user in users:
del user["approved"]
def _filter(a: attr.Attribute) -> bool:
return a.name != "approved"

ret = {"users": users, "total": total}
ret = {"users": [attr.asdict(u, filter=filter) for u in users], "total": total}
if (start + limit) < total:
ret["next_token"] = str(start + len(users))

Expand Down
52 changes: 42 additions & 10 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast

import attr

from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
from synapse.storage._base import make_in_list_sql_clause
Expand All @@ -28,7 +30,7 @@
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import get_domain_from_id

from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
Expand Down Expand Up @@ -82,6 +84,25 @@
logger = logging.getLogger(__name__)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserPaginateResponse:
"""This is very similar to UserInfo, but not quite the same."""

name: str
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
user_type: Optional[str]
is_guest: bool
admin: bool
deactivated: bool
shadow_banned: bool
displayname: Optional[str]
avatar_url: Optional[str]
creation_ts: int
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
approved: bool
erased: bool
last_seen_ts: int
locked: bool


class DataStore(
EventsBackgroundUpdatesStore,
ExperimentalFeaturesStore,
Expand Down Expand Up @@ -156,7 +177,7 @@ async def get_users_paginate(
approved: bool = True,
not_user_types: Optional[List[str]] = None,
locked: bool = False,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[UserPaginateResponse], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Expand All @@ -182,7 +203,7 @@ async def get_users_paginate(

def get_users_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[UserPaginateResponse], int]:
filters = []
args: list = []

Expand Down Expand Up @@ -282,13 +303,24 @@ def get_users_paginate_txn(
"""
args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)

# some of those boolean values are returned as integers when we're on SQLite
columns_to_boolify = ["erased"]
for user in users:
for column in columns_to_boolify:
user[column] = bool(user[column])
users = [
UserPaginateResponse(
name=row[0],
user_type=row[1],
is_guest=bool(row[2]),
admin=bool(row[3]),
deactivated=bool(row[4]),
shadow_banned=bool(row[5]),
displayname=row[6],
avatar_url=row[7],
creation_ts=row[8],
approved=bool(row[9]),
erased=bool(row[10]),
last_seen_ts=row[11],
locked=bool(row[12]),
)
for row in txn
]

return users, count

Expand Down
4 changes: 2 additions & 2 deletions tests/storage/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def test_get_users_paginate(self) -> None:
)

self.assertEqual(1, total)
self.assertEqual(self.displayname, users.pop()["displayname"])
self.assertEqual(self.displayname, users.pop().displayname)

users, total = self.get_success(
self.store.get_users_paginate(0, 10, name="BC", guests=False)
)

self.assertEqual(1, total)
self.assertEqual(self.displayname, users.pop()["displayname"])
self.assertEqual(self.displayname, users.pop().displayname)