Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert DeviceLastConnectionInfo to attrs. (#16507)
Browse files Browse the repository at this point in the history
To improve type safety & memory usage.
  • Loading branch information
clokep authored Oct 17, 2023
1 parent 77dfc1f commit 6ad1f9e
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 103 deletions.
1 change: 1 addition & 0 deletions changelog.d/16507.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
23 changes: 7 additions & 16 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple

from synapse.api import errors
from synapse.api.constants import EduTypes, EventTypes
Expand All @@ -41,6 +31,7 @@
run_as_background_process,
wrap_as_background_process,
)
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
from synapse.types import (
JsonDict,
JsonMapping,
Expand Down Expand Up @@ -1008,14 +999,14 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:


def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
ip = client_ips.get((device["user_id"], device["device_id"]))
device.update(
{
"last_seen_user_agent": ip.get("user_agent"),
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
"last_seen_user_agent": ip.user_agent if ip else None,
"last_seen_ts": ip.last_seen if ip else None,
"last_seen_ip": ip.ip if ip else None,
}
)

Expand Down
46 changes: 26 additions & 20 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast

import attr
from typing_extensions import TypedDict

from synapse.metrics.background_process_metrics import wrap_as_background_process
Expand Down Expand Up @@ -42,7 +43,8 @@
LAST_SEEN_GRANULARITY = 120 * 1000


class DeviceLastConnectionInfo(TypedDict):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLastConnectionInfo:
"""Metadata for the last connection seen for a user and device combination"""

# These types must match the columns in the `devices` table
Expand Down Expand Up @@ -499,24 +501,29 @@ async def _get_last_client_ip_by_device_from_database(
device_id: If None fetches all devices for the user
Returns:
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""

keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id

res = cast(
List[DeviceLastConnectionInfo],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)

return {(d["user_id"], d["device_id"]): d for d in res}
return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
user_id=d["user_id"],
device_id=d["device_id"],
ip=d["ip"],
user_agent=d["user_agent"],
last_seen=d["last_seen"],
)
for d in res
}

async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
Expand Down Expand Up @@ -683,8 +690,7 @@ async def get_last_client_ip_by_device(
device_id: If None fetches all devices for the user
Returns:
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
A dictionary mapping a tuple of (user_id, device_id) to DeviceLastConnectionInfo.
"""
ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)

Expand All @@ -705,13 +711,13 @@ async def get_last_client_ip_by_device(
continue

if not device_id or did == device_id:
ret[(user_id, did)] = {
"user_id": user_id,
"ip": ip,
"user_agent": user_agent,
"device_id": did,
"last_seen": last_seen,
}
ret[(user_id, did)] = DeviceLastConnectionInfo(
user_id=user_id,
ip=ip,
user_agent=user_agent,
device_id=did,
last_seen=last_seen,
)
return ret

async def get_user_ip_and_agents(
Expand Down
137 changes: 70 additions & 67 deletions tests/storage/test_client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.storage.databases.main.client_ips import (
LAST_SEEN_GRANULARITY,
DeviceLastConnectionInfo,
)
from synapse.types import UserID
from synapse.util import Clock

Expand Down Expand Up @@ -65,15 +68,15 @@ def test_insert_new_client_ip(self) -> None:
)

r = result[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=12345678000,
),
r,
)

def test_insert_new_client_ip_none_device_id(self) -> None:
Expand Down Expand Up @@ -201,13 +204,13 @@ def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
self.assertEqual(
result,
{
(user_id, device_id): {
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
},
(user_id, device_id): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=12345678000,
),
},
)

Expand Down Expand Up @@ -292,20 +295,20 @@ def test_get_last_client_ip_by_device_combined_data(self) -> None:
self.assertEqual(
result,
{
(user_id, device_id_1): {
"user_id": user_id,
"device_id": device_id_1,
"ip": "ip_1",
"user_agent": "user_agent_1",
"last_seen": 12345678000,
},
(user_id, device_id_2): {
"user_id": user_id,
"device_id": device_id_2,
"ip": "ip_2",
"user_agent": "user_agent_3",
"last_seen": 12345688000 + LAST_SEEN_GRANULARITY,
},
(user_id, device_id_1): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id_1,
ip="ip_1",
user_agent="user_agent_1",
last_seen=12345678000,
),
(user_id, device_id_2): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id_2,
ip="ip_2",
user_agent="user_agent_3",
last_seen=12345688000 + LAST_SEEN_GRANULARITY,
),
},
)

Expand Down Expand Up @@ -526,15 +529,15 @@ def test_devices_last_seen_bg_update(self) -> None:
)

r = result[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": None,
"user_agent": None,
"last_seen": None,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip=None,
user_agent=None,
last_seen=None,
),
r,
)

# Register the background update to run again.
Expand All @@ -561,15 +564,15 @@ def test_devices_last_seen_bg_update(self) -> None:
)

r = result[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=0,
),
r,
)

def test_old_user_ips_pruned(self) -> None:
Expand Down Expand Up @@ -640,15 +643,15 @@ def test_old_user_ips_pruned(self) -> None:
)

r = result2[(user_id, device_id)]
self.assertLessEqual(
{
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip="ip",
user_agent="user_agent",
last_seen=0,
),
r,
)

def test_invalid_user_agents_are_ignored(self) -> None:
Expand Down Expand Up @@ -777,13 +780,13 @@ def _runtest(
self.store.get_last_client_ip_by_device(self.user_id, device_id)
)
r = result[(self.user_id, device_id)]
self.assertLessEqual(
{
"user_id": self.user_id,
"device_id": device_id,
"ip": expected_ip,
"user_agent": "Mozzila pizza",
"last_seen": 123456100,
}.items(),
r.items(),
self.assertEqual(
DeviceLastConnectionInfo(
user_id=self.user_id,
device_id=device_id,
ip=expected_ip,
user_agent="Mozzila pizza",
last_seen=123456100,
),
r,
)

0 comments on commit 6ad1f9e

Please sign in to comment.