Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add missing type hints to synapse.replication. (#11938)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Feb 8, 2022
1 parent 8c94b3a commit d0e78af
Show file tree
Hide file tree
Showing 19 changed files with 209 additions and 147 deletions.
1 change: 1 addition & 0 deletions changelog.d/11938.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to replication code.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ disallow_untyped_defs = True
[mypy-synapse.push.*]
disallow_untyped_defs = True

[mypy-synapse.replication.*]
disallow_untyped_defs = True

[mypy-synapse.rest.*]
disallow_untyped_defs = True

Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/slave/storage/_slaved_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
for table, column in extra_tables:
self.advance(None, _load_current_id(db_conn, table, column))

def advance(self, instance_name: Optional[str], new_id: int):
def advance(self, instance_name: Optional[str], new_id: int) -> None:
self._current = (max if self.step > 0 else min)(self._current, new_id)

def get_current_token(self) -> int:
Expand Down
4 changes: 3 additions & 1 deletion synapse/replication/slave/storage/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
cache_name="client_ip_last_seen", max_size=50000
)

async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
async def insert_client_ip(
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
) -> None:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)

Expand Down
10 changes: 7 additions & 3 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
Expand Down Expand Up @@ -60,7 +60,9 @@ def __init__(
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
Expand All @@ -70,7 +72,9 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def _invalidate_caches_for_devices(self, token, rows):
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
Expand Down
8 changes: 5 additions & 3 deletions synapse/replication/slave/storage/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
Expand Down Expand Up @@ -44,10 +44,12 @@ def __init__(
self._group_updates_id_gen.get_current_token(),
)

def get_group_stream_token(self):
def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
7 changes: 5 additions & 2 deletions synapse/replication/slave/storage/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable

from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
Expand All @@ -20,10 +21,12 @@


class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
6 changes: 3 additions & 3 deletions synapse/replication/slave/storage/pushers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable

from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
Expand Down Expand Up @@ -41,8 +41,8 @@ def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()

def process_replication_rows(
self, stream_name: str, instance_name: str, token, rows
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) # type: ignore
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
45 changes: 26 additions & 19 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
"""A replication client for use by synapse workers.
"""
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple

from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure

from synapse.api.constants import EventTypes
from synapse.federation import send_queue
Expand Down Expand Up @@ -79,10 +81,10 @@ def __init__(

hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)

def startedConnecting(self, connector):
def startedConnecting(self, connector: IConnector) -> None:
logger.info("Connecting to replication: %r", connector.getDestination())

def buildProtocol(self, addr):
def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.hs,
Expand All @@ -92,11 +94,11 @@ def buildProtocol(self, addr):
self.command_handler,
)

def clientConnectionLost(self, connector, reason):
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.error("Lost replication conn: %r", reason)
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)

def clientConnectionFailed(self, connector, reason):
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)

Expand Down Expand Up @@ -131,7 +133,7 @@ def __init__(self, hs: "HomeServer"):

async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
Expand Down Expand Up @@ -252,14 +254,16 @@ async def on_rdata(
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]

async def on_position(self, stream_name: str, instance_name: str, token: int):
async def on_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
await self.on_rdata(stream_name, instance_name, token, [])

# We poke the generic "replication" notifier to wake anything up that
# may be streaming.
self.notifier.notify_replication()

def on_remote_server_up(self, server: str):
def on_remote_server_up(self, server: str) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""

# Let's wake up the transaction queue for the server in case we have
Expand All @@ -269,7 +273,7 @@ def on_remote_server_up(self, server: str):

async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
) -> None:
"""Wait until this instance has received updates up to and including
the given stream position.
"""
Expand Down Expand Up @@ -304,7 +308,7 @@ async def wait_for_stream_position(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)

def stop_pusher(self, user_id, app_id, pushkey):
def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers:
return

Expand All @@ -316,13 +320,13 @@ def stop_pusher(self, user_id, app_id, pushkey):
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()

async def start_pusher(self, user_id, app_id, pushkey):
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers:
return

key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)


class FederationSenderHandler:
Expand Down Expand Up @@ -353,10 +357,12 @@ def __init__(self, hs: "HomeServer"):

self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")

def wake_destination(self, server: str):
def wake_destination(self, server: str) -> None:
self.federation_sender.wake_destination(server)

async def process_replication_rows(self, stream_name, token, rows):
async def process_replication_rows(
self, stream_name: str, token: int, rows: list
) -> None:
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
if stream_name == "federation":
Expand Down Expand Up @@ -384,11 +390,12 @@ async def process_replication_rows(self, stream_name, token, rows):
for host in hosts:
self.federation_sender.send_device_messages(host)

async def _on_new_receipts(self, rows):
async def _on_new_receipts(
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
) -> None:
"""
Args:
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
new receipts to be processed
rows: new receipts to be processed
"""
for receipt in rows:
# we only want to send on receipts for our own users
Expand All @@ -408,7 +415,7 @@ async def _on_new_receipts(self, rows):
)
await self.federation_sender.send_read_receipt(receipt_info)

async def update_token(self, token):
async def update_token(self, token: int) -> None:
"""Update the record of where we have processed to in the federation stream.
Called after we have processed a an update received over replication. Sends
Expand All @@ -428,7 +435,7 @@ async def update_token(self, token):

run_as_background_process("_save_and_send_ack", self._save_and_send_ack)

async def _save_and_send_ack(self):
async def _save_and_send_ack(self) -> None:
"""Save the current federation position in the database and send an ACK
to master with where we're up to.
"""
Expand Down
Loading

0 comments on commit d0e78af

Please sign in to comment.