Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor get_user_by_id #16316

Merged
merged 4 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16316.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `get_user_by_id`.
2 changes: 1 addition & 1 deletion synapse/api/auth/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ async def get_user_by_access_token(
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
if not stored_user.is_guest:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/api/auth/msc3861_delegated.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ async def get_user_by_access_token(
user_id = UserID(username, self._hostname)

# First try to find a user from the username claim
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None:
# If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def _get_local_account_status(self, user_id: UserID) -> JsonDict:
"""
status = {"exists": False}

userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
userinfo = await self._main_store.get_user_by_id(user_id.to_string())

if userinfo is not None:
status = {
Expand Down
49 changes: 22 additions & 27 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,38 +57,30 @@ async def get_whois(self, user: UserID) -> JsonDict:

async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
if user_info is None:
return None

# Restrict returned information to a known set of fields. This prevents additional
# fields added to get_user_by_id from modifying Synapse's external API surface.
user_info_to_return = {
"name",
"admin",
"deactivated",
"locked",
"shadow_banned",
"creation_ts",
"appservice_id",
"consent_server_notice_sent",
"consent_version",
"consent_ts",
"user_type",
"is_guest",
"last_seen_ts",
user_info_dict = {
"name": user.to_string(),
"admin": user_info.is_admin,
"deactivated": user_info.is_deactivated,
"locked": user_info.locked,
"shadow_banned": user_info.is_shadow_banned,
"creation_ts": user_info.creation_ts,
"appservice_id": user_info.appservice_id,
"consent_server_notice_sent": user_info.consent_server_notice_sent,
"consent_version": user_info.consent_version,
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
}

if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")

# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
user_info_dict["approved"] = user_info.approved

# Add additional user metadata
profile = await self._store.get_profileinfo(user)
Expand All @@ -105,6 +97,9 @@ async def get_user(self, user: UserID) -> Optional[JsonDict]:
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())

last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts

return user_info_dict

async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,13 +828,13 @@ async def assert_accepted_privacy_policy(self, requester: Requester) -> None:

u = await self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
return
if u["appservice_id"] is not None:
if u.appservice_id is not None:
# users registered by an appservice are exempt
return
if u["consent_version"] == self.config.consent.user_consent_version:
if u.consent_version == self.config.consent.user_consent_version:
return

consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
Expand Down
4 changes: 2 additions & 2 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
Returns:
UserInfo object if a user was found, otherwise None
"""
return await self._store.get_userinfo_by_id(user_id)
return await self._store.get_user_by_id(user_id)

async def get_user_by_req(
self,
Expand Down Expand Up @@ -1878,7 +1878,7 @@ async def put_global(
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")

# Ensure the user exists, so we don't just write to users that aren't there.
if await self._store.get_userinfo_by_id(user_id) is None:
if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.")

await self._handler.add_account_data_for_user(user_id, data_type, new_data)
2 changes: 1 addition & 1 deletion synapse/rest/consent/consent_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def _async_render_GET(self, request: Request) -> None:
if u is None:
raise NotFoundError("Unknown user")

has_consented = u["consent_version"] == version
has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii")

try:
Expand Down
6 changes: 3 additions & 3 deletions synapse/server_notices/consent_server_notices.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
if u is None:
return

if u["is_guest"] and not self._send_to_guests:
if u.is_guest and not self._send_to_guests:
# don't send to guests
return

if u["consent_version"] == self._current_consent_version:
if u.consent_version == self._current_consent_version:
# user has already consented
return

if u["consent_server_notice_sent"] == self._current_consent_version:
if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user
return

Expand Down
11 changes: 11 additions & 0 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,14 @@ async def get_user_ip_and_agents(
}

return list(results.values())

async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
"""Get the last seen timestamp for a user, if we have it."""

return await self.db_pool.simple_select_one_onecol(
table="user_ips",
keyvalues={"user_id": user_id},
retcol="MAX(last_seen)",
allow_none=True,
desc="get_last_seen_for_user_id",
)
76 changes: 22 additions & 54 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import attr

Expand Down Expand Up @@ -192,8 +192,8 @@ def __init__(
)

@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists."""

def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
Expand All @@ -202,16 +202,12 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
txn.execute(
"""
SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts,
name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked, last_seen_ts
COALESCE(locked, FALSE) AS locked
FROM users
LEFT JOIN (
SELECT user_id, MAX(last_seen) AS last_seen_ts
FROM user_ips GROUP BY user_id
) ls ON users.name = ls.user_id
WHERE name = ?
""",
(user_id,),
Expand All @@ -228,51 +224,23 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
desc="get_user_by_id",
func=get_user_by_id_txn,
)

if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = [
"admin",
"deactivated",
"shadow_banned",
"approved",
"locked",
]
for column in boolean_columns:
row[column] = bool(row[column])

return row

async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.

Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
this method should be cached.

Args:
user_id: The user to fetch user info for.
Returns:
`UserInfo` object if user found, otherwise `None`.
"""
user_data = await self.get_user_by_id(user_id)
if not user_data:
if row is None:
return None

return UserInfo(
appservice_id=user_data["appservice_id"],
consent_server_notice_sent=user_data["consent_server_notice_sent"],
consent_version=user_data["consent_version"],
creation_ts=user_data["creation_ts"],
is_admin=bool(user_data["admin"]),
is_deactivated=bool(user_data["deactivated"]),
is_guest=bool(user_data["is_guest"]),
is_shadow_banned=bool(user_data["shadow_banned"]),
user_id=UserID.from_string(user_data["name"]),
user_type=user_data["user_type"],
last_seen_ts=user_data["last_seen_ts"],
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
)

async def is_trial_user(self, user_id: str) -> bool:
Expand All @@ -290,10 +258,10 @@ async def is_trial_user(self, user_id: str) -> bool:

now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get(
info["appservice_id"], self.config.server.mau_trial_days
info.appservice_id, self.config.server.mau_trial_days
)
trial_duration_ms = days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial

@cached()
Expand Down
10 changes: 7 additions & 3 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(

@attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo:
"""Holds information about a user. Result of get_userinfo_by_id.
"""Holds information about a user. Result of get_user_by_id.

Attributes:
user_id: ID of the user.
appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to.
consent_ts: Time the user consented
creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options).
last_seen_ts: Last activity timestamp of the user.
approved: If the user has been "approved" to register on the server.
locked: Whether the user's account has been locked
"""

user_id: UserID
appservice_id: Optional[int]
consent_server_notice_sent: Optional[str]
consent_version: Optional[str]
consent_ts: Optional[int]
user_type: Optional[str]
creation_ts: int
is_admin: bool
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
last_seen_ts: Optional[int]
approved: bool
locked: bool


class UserProfile(TypedDict):
Expand Down
12 changes: 9 additions & 3 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,11 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value.
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})

class FakeUserInfo:
is_guest = False

self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)

request = Mock(args={})
Expand Down Expand Up @@ -341,7 +344,10 @@ def test_get_user_from_macaroon(self) -> None:
)

def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
class FakeUserInfo:
is_guest = True

self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)

user_id = "@baldrick:matrix.org"
Expand Down
Loading