Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix SSO-login-via-a-worker
Browse files Browse the repository at this point in the history
Expose the SSO login endpoints on workers, like the documentation says.
  • Loading branch information
richvdh committed Feb 1, 2021
1 parent 3fb8c08 commit 64f5b6c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
11 changes: 7 additions & 4 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing_extensions import ContextManager

from twisted.internet import address
from twisted.web.resource import IResource

import synapse
import synapse.events
Expand Down Expand Up @@ -90,9 +91,8 @@
ToDeviceStream,
)
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events, room
from synapse.rest.client.v1 import events, login, room
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.rest.client.v1.login import LoginRestServlet
from synapse.rest.client.v1.profile import (
ProfileAvatarURLRestServlet,
ProfileDisplaynameRestServlet,
Expand Down Expand Up @@ -127,6 +127,7 @@
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer, cache_in_self
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
Expand Down Expand Up @@ -507,7 +508,7 @@ def _listen_http(self, listener_config: ListenerConfig):
site_tag = port

# We always include a health resource.
resources = {"/health": HealthResource()}
resources = {"/health": HealthResource()} # type: Dict[str, IResource]

for res in listener_config.http_options.resources:
for name in res.names:
Expand All @@ -517,7 +518,7 @@ def _listen_http(self, listener_config: ListenerConfig):
resource = JsonResource(self, canonical_json=False)

RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
Expand Down Expand Up @@ -557,6 +558,8 @@ def _listen_http(self, listener_config: ListenerConfig):
groups.register_servlets(self, resource)

resources.update({CLIENT_API_PREFIX: resource})

resources.update(build_synapse_client_resource_tree(self))
elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media":
Expand Down
40 changes: 20 additions & 20 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,26 @@ def f(txn):

return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)

async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> None:
"""Record a mapping from an external user id to a mxid
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
desc="record_user_external_id",
)

async def get_user_by_external_id(
self, auth_provider: str, external_id: str
) -> Optional[str]:
Expand Down Expand Up @@ -1371,26 +1391,6 @@ def _register_user(

self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))

async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
) -> None:
"""Record a mapping from an external user id to a mxid
Args:
auth_provider: identifier for the remote auth provider
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
"external_id": external_id,
"user_id": user_id,
},
desc="record_user_external_id",
)

async def user_set_password_hash(
self, user_id: str, password_hash: Optional[str]
) -> None:
Expand Down

0 comments on commit 64f5b6c

Please sign in to comment.