From 7603dcd8c87c05b5eece25ab6aa7de993fe5acbf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 27 Jun 2022 14:28:17 +0100 Subject: [PATCH] Fix lint --- .../databases/main/event_push_actions.py | 3 +- synapse/storage/databases/main/stream.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index d0916e7874b9..9bdfa2d091b9 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -25,7 +25,6 @@ LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.util import json_encoder @@ -219,7 +218,7 @@ def _get_unread_counts_by_receipt_txn( retcol="event_id", ) - stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined] + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 8e88784d3ce3..20a294f10891 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -42,10 +42,12 @@ Collection, Dict, List, + Literal, Optional, Set, Tuple, cast, + overload, ) import attr @@ -795,6 +797,32 @@ async def get_current_room_stream_token_for_room_id( ) return RoomStreamToken(topo, stream_ordering) + @overload + def get_stream_id_for_event_txn( + self, + txn: LoggingTransaction, + event_id: str, + allow_none: Literal[True], + ) -> int: + ... + + @overload + def get_stream_id_for_event_txn( + self, + txn: LoggingTransaction, + event_id: str, + ) -> int: + ... + + @overload + def get_stream_id_for_event_txn( + self, + txn: LoggingTransaction, + event_id: str, + allow_none: bool = False, + ) -> Optional[int]: + ... + def get_stream_id_for_event_txn( self, txn: LoggingTransaction,