Skip to content

Commit

Permalink
add functionality to get/set suspended status
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Shay committed Apr 4, 2024
1 parent 27e0c2a commit ebff147
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
54 changes: 53 additions & 1 deletion synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
COALESCE(locked, FALSE) AS locked
COALESCE(locked, FALSE) AS locked,
COALESCE(suspended, FALSE) AS suspended
FROM users
WHERE name = ?
""",
Expand All @@ -261,6 +262,7 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
shadow_banned,
approved,
locked,
suspended,
) = row

return UserInfo(
Expand All @@ -277,6 +279,7 @@ def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
user_type=user_type,
approved=bool(approved),
locked=bool(locked),
suspended=bool(suspended),
)

return await self.db_pool.runInteraction(
Expand Down Expand Up @@ -1180,6 +1183,26 @@ async def get_user_locked_status(self, user_id: str) -> bool:
# Convert the potential integer into a boolean.
return bool(res)

@cached()
async def get_user_suspended_status(self, user_id: str) -> bool:
"""
Determine whether the user's account is suspended.
Args:
user_id: The user ID of the user in question
Returns:
True if the user's account is suspended, false if not.
"""

res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="suspended",
allow_none=True,
desc="get_user_suspended",
)

return bool(res)

async def get_threepid_validation_session(
self,
medium: Optional[str],
Expand Down Expand Up @@ -2206,6 +2229,35 @@ def set_user_deactivated_status_txn(
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))

async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None:
"""
Set whether the user's account is suspended in the `users` table.
Args:
user_id: The user ID of the user in question
suspended: True if the user is suspended, false if not
"""
await self.db_pool.runInteraction(
"set_user_suspended_status",
self.set_user_suspended_status_txn,
user_id,
suspended,
)

def set_user_suspended_status_txn(
self, txn: LoggingTransaction, user_id: str, suspended: bool
) -> None:
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"suspended": 1 if suspended else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_suspended_status, (user_id,)
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))

async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
"""Set the `locked` property for the provided user to the provided value.
Expand Down
2 changes: 1 addition & 1 deletion tests/storage/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def test_register(self) -> None:

self.assertEqual(
UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name'

This comment has been minimized.

Copy link
@H-Shay

H-Shay Apr 4, 2024

Author Contributor

This is just random cleanup while I was in the area and noticed it. Pretty sure this comment from 2014 is outdated.

user_id=UserID.from_string(self.user_id),
is_admin=False,
is_guest=False,
Expand All @@ -57,6 +56,7 @@ def test_register(self) -> None:
locked=False,
is_shadow_banned=False,
approved=True,
suspended=False,
),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
Expand Down

0 comments on commit ebff147

Please sign in to comment.