Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use inline type hints in handlers/ and rest/ #10382

Merged
merged 5 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10382.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.
8 changes: 4 additions & 4 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class BaseHandler:
"""

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
Expand All @@ -55,12 +55,12 @@ def __init__(self, hs: "HomeServer"):
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
) # type: Optional[Ratelimiter]
)
else:
self.admin_redaction_ratelimiter = None

Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
to_key = RoomStreamToken(None, stream_ordering)

# Events that we've processed in this room
written_events = set() # type: Set[str]
written_events: Set[str] = set()

# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
Expand All @@ -152,7 +152,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
# events "children".
unseen_to_child_events = {} # type: Dict[str, Set[str]]
unseen_to_child_events: Dict[str, Set[str]] = {}

# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def _notify_interested_services(self, max_token: RoomStreamToken):
self.current_max, limit
)

events_by_room = {} # type: Dict[str, List[EventBase]]
events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)

Expand Down Expand Up @@ -275,7 +275,7 @@ async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events = [] # type: List[JsonDict]
events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
Expand Down Expand Up @@ -375,7 +375,7 @@ async def get_3pe_protocols(
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
protocols: Dict[str, List[JsonDict]] = {}

# Collect up all the individual protocol responses out of the ASes
for s in services:
Expand Down
16 changes: 8 additions & 8 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
Expand Down Expand Up @@ -296,7 +296,7 @@ def __init__(self, hs: "HomeServer"):

# A mapping of user ID to extra attributes to include in the login
# response.
self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}

async def validate_user_via_ui_auth(
self,
Expand Down Expand Up @@ -500,7 +500,7 @@ async def check_ui_auth(
all the stages in any of the permitted flows.
"""

sid = None # type: Optional[str]
sid: Optional[str] = None
authdict = clientdict.pop("auth", {})
if "session" in authdict:
sid = authdict["session"]
Expand Down Expand Up @@ -588,9 +588,9 @@ async def check_ui_auth(
)

# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
errordict: Dict[str, Any] = {}
if "type" in authdict:
login_type = authdict["type"] # type: str
login_type: str = authdict["type"]
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
Expand Down Expand Up @@ -766,7 +766,7 @@ def _auth_dict_for_flows(
LoginType.TERMS: self._get_params_terms,
}

params = {} # type: Dict[str, Any]
params: Dict[str, Any] = {}

for f in public_flows:
for stage in f:
Expand Down Expand Up @@ -1530,9 +1530,9 @@ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> s
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))

user_id_to_verify = await self.get_session_data(
user_id_to_verify: str = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
) # type: str
)

idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:

# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes = {} # type: Dict[str, List[Optional[str]]]
attributes: Dict[str, List[Optional[str]]] = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
Expand Down
14 changes: 7 additions & 7 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ async def notify_device_update(
user_id
)

hosts = set() # type: Set[str]
hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
Expand Down Expand Up @@ -613,20 +613,20 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
self._remote_edu_linearizer = Linearizer(name="remote_device_list")

# user_id -> list of updates waiting to be handled.
self._pending_updates = (
{}
) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
self._pending_updates: Dict[
str, List[Tuple[str, str, Iterable[str], JsonDict]]
] = {}

# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
) # type: ExpiringCache[str, Set[str]]
)

# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
Expand Down Expand Up @@ -755,7 +755,7 @@ async def _need_to_do_resync(
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
seen_updates: Set[str] = self._seen_updates.get(user_id, set())

extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/devicemessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def send_device_message(
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ async def _delete_association(self, room_alias: RoomAlias) -> str:
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
result = await self.get_association_from_room_alias(
room_alias
) # type: Optional[RoomAliasMapping]
result: Optional[
RoomAliasMapping
] = await self.get_association_from_room_alias(room_alias)

if result:
room_id = result.room_id
Expand Down
40 changes: 19 additions & 21 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ async def query_devices(
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query = query_body.get(
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
)

# separate users by domain.
# make a map from domain to user_id to device_ids
Expand All @@ -136,7 +136,7 @@ async def query_devices(

# First get local devices.
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
failures: Dict[str, JsonDict] = {}
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
Expand All @@ -151,11 +151,9 @@ async def query_devices(

# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = (
{}
) # type: Dict[str, Dict[str, Iterable[str]]]
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
query_list = [] # type: List[Tuple[str, Optional[str]]]
query_list: List[Tuple[str, Optional[str]]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
Expand Down Expand Up @@ -362,9 +360,9 @@ async def query_local_devices(
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = [] # type: List[Tuple[str, Optional[str]]]
local_query: List[Tuple[str, Optional[str]]] = []

result_dict = {} # type: Dict[str, Dict[str, dict]]
result_dict: Dict[str, Dict[str, dict]] = {}
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
Expand Down Expand Up @@ -402,9 +400,9 @@ async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
"""Handle a device key query from a federated server"""
device_keys_query = query_body.get(
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
)
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}

Expand All @@ -421,8 +419,8 @@ async def on_federation_query_client_keys(
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
local_query = [] # type: List[Tuple[str, str, str]]
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}

for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
Expand All @@ -439,8 +437,8 @@ async def claim_one_time_keys(
results = await self.store.claim_e2e_one_time_keys(local_query)

# A map of user ID -> device ID -> key ID -> key.
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
failures = {} # type: Dict[str, JsonDict]
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
Expand Down Expand Up @@ -768,8 +766,8 @@ async def _process_self_signatures(
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures

Expand Down Expand Up @@ -930,8 +928,8 @@ async def _process_other_signatures(
Raises:
SynapseError: if the input is malformed
"""
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
signature_list: List["SignatureListItem"] = []
failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures

Expand Down Expand Up @@ -1300,7 +1298,7 @@ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")

# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}

async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
Expand Down Expand Up @@ -1349,7 +1347,7 @@ async def _handle_signing_key_updates(self, user_id: str) -> None:
# This can happen since we batch updates
return

device_ids = [] # type: List[str]
device_ids: List[str] = []

logger.info("pending updates: %r", pending_updates)

Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def get_stream(

# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = [] # type: List[JsonDict]
to_add: List[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
continue
Expand All @@ -103,9 +103,9 @@ async def get_stream(
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = await self.store.get_users_in_room(
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
) # type: Iterable[str]
)
else:
users = [event.state_key]

Expand Down
Loading