Skip to content

Commit

Permalink
Add interoperability with email account validity (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
babolivier authored Apr 26, 2022
1 parent 7fd8d61 commit f7ed5cb
Show file tree
Hide file tree
Showing 5 changed files with 400 additions and 20 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ across a closed federation. When users update their global `im.vector.hide_profi
account data with `{"hide_profile": True}`, they are removed from this discovery room,
and added to a local database table to filter them out from local results.

This module can also interact with the [synapse-email-account-validity](https://github.com/matrix-org/synapse-email-account-validity)
module. If this compatibility feature is enabled, the module will automatically scan for
expired and renewed users every hour. It will then add expired users to the red list and
remove renewed users from it (without updating the users' account data).

## Installation

From the virtual environment that you use for Synapse, install this module with:
Expand All @@ -24,6 +29,9 @@ modules:
# ID of the room used for user discovery.
# Optional, defaults to no room.
discovery_room: "!someroom:example.com"
# Whether to enable compatibility with the synapse-email-account-validity module.
# Optional, defaults to false.
use_email_account_validity: false
```
Expand Down
185 changes: 175 additions & 10 deletions tchap_red_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional, Tuple
import time
from typing import Any, Dict, List, Optional, Tuple

import attr
from synapse.module_api import (
Expand All @@ -24,6 +25,7 @@
cached,
run_in_background,
)
from synapse.module_api.errors import ConfigError, SynapseError

logger = logging.getLogger(__name__)

Expand All @@ -33,6 +35,7 @@
@attr.s(auto_attribs=True, frozen=True)
class RedListManagerConfig:
discovery_room: Optional[str] = None
use_email_account_validity: bool = False


class RedListManager:
Expand All @@ -59,6 +62,12 @@ def __init__(
# the table to be accessed before it's fully created.
run_in_background(self._setup_db)

if self._config.use_email_account_validity:
self._api.looping_background_call(self._add_expired_users, 60 * 60 * 1000)
self._api.looping_background_call(
self._remove_renewed_users, 60 * 60 * 1000
)

@staticmethod
def parse_config(config: Dict[str, Any]) -> RedListManagerConfig:
return RedListManagerConfig(**config)
Expand All @@ -79,16 +88,23 @@ async def update_red_list_status(
# Compare what status (in the list, not in the list) the user wants to have with
# what it already has. If they're the same, don't do anything more.
desired_status = bool(content.get("hide_profile"))
current_status, _ = await self._get_user_status(user_id)
current_status, because_expired = await self._get_user_status(user_id)

if current_status == desired_status:
return

# Add or remove the user depending on whether they want their profile hidden.
if desired_status is True:
await self._add_to_red_list(user_id)
if because_expired is True:
# There can be a delay between the user renewing their account (from an
# account validity perspective) and the module actually picking up the
# renewal, during which the user might decide to add their profile to the
# red list.
# In this case, we want to clear the because_expired flag so the user
# isn't removed from the red list next time we check account validity
# data.
await self._make_addition_permanent(user_id)
else:
await self._remove_from_red_list(user_id)
if desired_status is True:
await self._add_to_red_list(user_id)
else:
await self._remove_from_red_list(user_id)

async def _maybe_change_membership_in_discovery_room(
self, user_id: str, membership: str
Expand Down Expand Up @@ -119,8 +135,125 @@ async def check_user_in_red_list(self, user_profile: UserProfile) -> bool:
user_in_red_list, _ = await self._get_user_status(user_profile["user_id"])
return user_in_red_list

async def _add_expired_users(self) -> None:
"""Retrieve all expired users and adds them to the red list."""

def add_expired_users_txn(txn: LoggingTransaction) -> List[str]:
# Retrieve all the expired users.
sql = """
SELECT user_id FROM email_account_validity WHERE expiration_ts_ms <= ?
"""

now_ms = int(time.time() * 1000)
txn.execute(sql, (now_ms,))
expired_users_rows = txn.fetchall()

expired_users = [row[0] for row in expired_users_rows]

# Figure out which users are in the red list.
# We could also inspect the cache on self._get_user_status and only query the
# status of the users that aren't cached, but
# 1) it's probably digging too much into Synapse's internals (i.e. it could
# easily break without warning)
# 2) it's not clear that there would be such a huge perf gain from doing
# things this way.
red_list_users_rows = DatabasePool.simple_select_many_txn(
txn=txn,
table="tchap_red_list",
column="user_id",
iterable=expired_users,
keyvalues={},
retcols=["user_id"],
)

# Figure out which users we need to add to the red list by looking up whether
# they're already in it.
users_in_red_list = [row["user_id"] for row in red_list_users_rows]
users_to_add = [
user for user in expired_users if user not in users_in_red_list
]

# Add all the expired users not in the red list.
sql = """
INSERT INTO tchap_red_list(user_id, because_expired) VALUES(?, ?)
"""
for user in users_to_add:
txn.execute(sql, (user, True))
self._get_user_status.invalidate((user,))

return users_to_add

users_added = await self._api.run_db_interaction(
"tchap_red_list_hide_expired_users",
add_expired_users_txn,
)

# Make the expired users leave the discovery room if there's one.
for user in users_added:
await self._maybe_change_membership_in_discovery_room(user, "leave")

async def _remove_renewed_users(self) -> None:
"""Remove users from the red list if they have been added by _add_expired_users
and have since then renewed their account.
"""

def remove_renewed_users_txn(txn: LoggingTransaction) -> List[str]:
# Retrieve the list of users we have previously added because their account
# expired.
rows = DatabasePool.simple_select_list_txn(
txn=txn,
table="tchap_red_list",
keyvalues={"because_expired": True},
retcols=["user_id"],
)

previously_expired_users = [row["user_id"] for row in rows]

# Among these users, figure out which ones are still expired.
rows = DatabasePool.simple_select_many_txn(
txn=txn,
table="email_account_validity",
column="user_id",
iterable=previously_expired_users,
keyvalues={},
retcols=["user_id", "expiration_ts_ms"],
)

renewed_users: List[str] = []
now_ms = int(time.time() * 1000)
for row in rows:
if row["expiration_ts_ms"] > now_ms:
renewed_users.append(row["user_id"])

# Remove the users who aren't expired anymore.
DatabasePool.simple_delete_many_txn(
txn=txn,
table="tchap_red_list",
column="user_id",
values=renewed_users,
keyvalues={},
)

for user in renewed_users:
self._get_user_status.invalidate((user,))

return renewed_users

users_removed = await self._api.run_db_interaction(
"tchap_red_list_remove_renewed_users",
remove_renewed_users_txn,
)

# Make the renewed users re-join the discovery room if there's one.
for user in users_removed:
await self._maybe_change_membership_in_discovery_room(user, "join")

async def _setup_db(self) -> None:
"""Create the table needed to store the red list data."""
"""Create the table needed to store the red list data.
If the module is configured to interact with the email account validity module,
also check that the table exists.
"""

def setup_db_txn(txn: LoggingTransaction) -> None:
sql = """
Expand All @@ -131,6 +264,15 @@ def setup_db_txn(txn: LoggingTransaction) -> None:
"""
txn.execute(sql, ())

if self._config.use_email_account_validity:
try:
txn.execute("SELECT * FROM email_account_validity LIMIT 0", ())
except SynapseError:
raise ConfigError(
"use_email_account_validity is set but no email account validity"
" database table found."
)

await self._api.run_db_interaction(
"tchap_red_list_setup_db",
setup_db_txn,
Expand Down Expand Up @@ -165,6 +307,29 @@ def _add_to_red_list_txn(txn: LoggingTransaction) -> None:
# If there is a room used for user discovery, make them leave it.
await self._maybe_change_membership_in_discovery_room(user_id, "leave")

async def _make_addition_permanent(self, user_id: str) -> None:
"""Update a user's addition to the red list to make it permanent so it's not
removed automatically when the user renews their account.
Args:
user_id: the user to update.
"""

def make_addition_permanent(txn: LoggingTransaction) -> None:
DatabasePool.simple_update_one_txn(
txn=txn,
table="tchap_red_list",
keyvalues={"user_id": user_id},
updatevalues={"because_expired": False},
)

self._get_user_status.invalidate((user_id,))

await self._api.run_db_interaction(
"tchap_red_list_make_addition_permanent",
make_addition_permanent,
)

async def _remove_from_red_list(self, user_id: str) -> None:
"""Remove the given user from the red list.
Expand Down Expand Up @@ -216,7 +381,7 @@ def _get_user_status_txn(txn: LoggingTransaction) -> Tuple[bool, bool]:
if row is None:
return False, False

return True, row["because_expired"]
return True, bool(row["because_expired"])

return await self._api.run_db_interaction(
"tchap_red_list_get_status",
Expand Down
46 changes: 41 additions & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self) -> None:
async def run_db_interaction(
self, desc: str, f: Callable[..., RV], *args: Any, **kwargs: Any
) -> RV:
cur = self.conn.cursor()
cur = CursorWrapper(self.conn.cursor())
try:
res = f(cur, *args, **kwargs)
self.conn.commit()
Expand All @@ -46,6 +46,41 @@ async def run_db_interaction(
raise


class MockEngine:
supports_using_any_list = False


class CursorWrapper:
"""Wrapper around a SQLite cursor."""

def __init__(self, cursor: sqlite3.Cursor) -> None:
self.cur = cursor
self.database_engine = MockEngine()

def execute(self, sql: str, args: Any) -> None:
self.cur.execute(sql, args)

@property
def description(self) -> Any:
return self.cur.description

@property
def rowcount(self) -> Any:
return self.cur.rowcount

def fetchone(self) -> Any:
return self.cur.fetchone()

def fetchall(self) -> Any:
return self.cur.fetchall()

def __iter__(self) -> Any:
return self.cur.__iter__()

def __next__(self) -> Any:
return self.cur.__next__()


def make_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
Expand All @@ -59,15 +94,16 @@ def make_awaitable(result: TV) -> Awaitable[TV]:

async def create_module(
config: Optional[JsonDict] = None,
) -> Tuple[RedListManager, Mock]:
) -> Tuple[RedListManager, Mock, SQLiteStore]:
"""Create an instance of the module.
Args:
config: the config to give the module, if any.
Returns:
The instance of the module and the mock for the module API so the tests can check
its calls.
The instance of the module, the mock for the module API so the tests can check
its calls, and the store used by the module so the test can e.g. maintain a dummy
account validity table.
"""
store = SQLiteStore()

Expand All @@ -91,4 +127,4 @@ async def create_module(
# call history.
module_api.run_db_interaction.reset_mock()

return module, module_api
return module, module_api, store
Loading

0 comments on commit f7ed5cb

Please sign in to comment.