Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fixup synapse.replication to pass mypy checks (#6667)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored Jan 14, 2020
1 parent 1177d3f commit e8b68a4
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 86 deletions.
1 change: 1 addition & 0 deletions changelog.d/6667.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixup `synapse.replication` to pass mypy checks.
10 changes: 5 additions & 5 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import abc
import logging
import re
from typing import Dict, List, Tuple

from six import raise_from
from six.moves import urllib
Expand Down Expand Up @@ -78,9 +79,8 @@ class ReplicationEndpoint(object):

__metaclass__ = abc.ABCMeta

NAME = abc.abstractproperty()
PATH_ARGS = abc.abstractproperty()

NAME = abc.abstractproperty() # type: str # type: ignore
PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
Expand Down Expand Up @@ -171,7 +171,7 @@ def send_request(**kwargs):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
headers = {}
headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = yield request_func(uri, data, headers=headers)
Expand Down Expand Up @@ -207,7 +207,7 @@ def register(self, http_server):
method = self.METHOD

if self.CACHE:
handler = self._cached_handler
handler = self._cached_handler # type: ignore
url_args.append("txn_id")

args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
Expand Down
7 changes: 4 additions & 3 deletions synapse/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import Dict
from typing import Dict, Optional

import six

Expand All @@ -41,7 +41,7 @@ def __init__(self, database: Database, db_conn, hs):
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
)
) # type: Optional[SlavedIdTracker]
else:
self._cache_id_gen = None

Expand All @@ -62,7 +62,8 @@ def stream_positions(self) -> Dict[str, int]:

def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
self._cache_id_gen.advance(token)
if self._cache_id_gen:
self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0]
Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/slave/storage/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, database: Database, db_conn, hs):

self._presence_on_startup = self._get_active_presence(db_conn)

self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)

Expand Down
12 changes: 7 additions & 5 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

import logging
from typing import Dict
from typing import Dict, List, Optional

from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
Expand All @@ -28,6 +28,7 @@
)

from .commands import (
Command,
FederationAckCommand,
InvalidateCacheCommand,
RemovePusherCommand,
Expand Down Expand Up @@ -89,15 +90,15 @@ def __init__(self, store: BaseSlavedStore):

# Any pending commands to be sent once a new connection has been
# established
self.pending_commands = []
self.pending_commands = [] # type: List[Command]

# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
self.awaiting_syncs = {}
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]

# The factory used to create connections.
self.factory = None
self.factory = None # type: Optional[ReplicationClientFactory]

def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
Expand Down Expand Up @@ -235,4 +236,5 @@ def finished_connecting(self):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
self.factory.resetDelay()
if self.factory:
self.factory.resetDelay()
42 changes: 21 additions & 21 deletions synapse/replication/tcp/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@

import logging
import platform
from typing import Tuple, Type

if platform.python_implementation() == "PyPy":
import json

_json_encoder = json.JSONEncoder()
else:
import simplejson as json
import simplejson as json # type: ignore[no-redef] # noqa: F821

_json_encoder = json.JSONEncoder(namedtuple_as_object=False)
_json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821

logger = logging.getLogger(__name__)

Expand All @@ -44,7 +45,7 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""

NAME = None
NAME = None # type: str

def __init__(self, data):
self.data = data
Expand Down Expand Up @@ -386,25 +387,24 @@ def to_line(self):
)


_COMMANDS = (
ServerCommand,
RdataCommand,
PositionCommand,
ErrorCommand,
PingCommand,
NameCommand,
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
SyncCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
) # type: Tuple[Type[Command], ...]

# Map of command name to command type.
COMMAND_MAP = {
cmd.NAME: cmd
for cmd in (
ServerCommand,
RdataCommand,
PositionCommand,
ErrorCommand,
PingCommand,
NameCommand,
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
SyncCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
)
}
COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}

# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (
Expand Down
36 changes: 21 additions & 15 deletions synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import logging
import struct
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple

from six import iteritems, iterkeys

Expand All @@ -65,13 +66,11 @@
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.stringutils import random_string

from .commands import (
from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
ErrorCommand,
NameCommand,
PingCommand,
Expand All @@ -82,6 +81,10 @@
SyncCommand,
UserSyncCommand,
)
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string

from .streams import STREAMS_MAP

connection_close_counter = Counter(
Expand Down Expand Up @@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):

delimiter = b"\n"

VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
# Valid commands we expect to receive
VALID_INBOUND_COMMANDS = [] # type: Collection[str]

# Valid commands we can send
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]

max_line_buffer = 10000

Expand All @@ -144,13 +150,13 @@ def __init__(self, clock):
self.conn_id = random_string(5) # To dedupe in case of name clashes.

# List of pending commands to send once we've established the connection
self.pending_commands = []
self.pending_commands = [] # type: List[Command]

# The LoopingCall for sending pings.
self._send_ping_loop = None

self.inbound_commands_counter = defaultdict(int)
self.outbound_commands_counter = defaultdict(int)
self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]

def connectionMade(self):
logger.info("[%s] Connection established", self.id())
Expand Down Expand Up @@ -409,14 +415,14 @@ def __init__(self, server_name, clock, streamer):
self.streamer = streamer

# The streams the client has subscribed to and is up to date with
self.replication_streams = set()
self.replication_streams = set() # type: Set[str]

# The streams the client is currently subscribing to.
self.connecting_streams = set()
self.connecting_streams = set() # type: Set[str]

# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {}
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]

def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
Expand Down Expand Up @@ -642,11 +648,11 @@ def __init__(
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set()
self.streams_connecting = set() # type: Set[str]

# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {}
self.pending_batches = {} # type: Dict[str, Any]

def connectionMade(self):
self.send_command(NameCommand(self.client_name))
Expand Down Expand Up @@ -766,7 +772,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0

Expand Down
3 changes: 2 additions & 1 deletion synapse/replication/tcp/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import random
from typing import List

from six import itervalues

Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self, hs):
self._replication_torture_level = hs.config.replication_torture_level

# Current connections.
self.connections = []
self.connections = [] # type: List[ServerReplicationStreamProtocol]

LaterGauge(
"synapse_replication_tcp_resource_total_connections",
Expand Down
Loading

0 comments on commit e8b68a4

Please sign in to comment.