Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Allow modules to set a display name on registration #12009

Merged
merged 7 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12009.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable modules to set a custom display name when registering a user.
35 changes: 31 additions & 4 deletions docs/modules/password_auth_provider_callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback return `None`,
any of the subsequent implementations of this callback. If every callback returns `None`,
the authentication is denied.

### `on_logged_out`
Expand Down Expand Up @@ -162,10 +162,38 @@ return `None`.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback return `None`,
any of the subsequent implementations of this callback. If every callback returns `None`,
the username provided by the user is used, if any (otherwise one is automatically
generated).

### `get_displayname_for_registration`

_First introduced in Synapse v1.54.0_

```python
async def get_displayname_for_registration(
uia_results: Dict[str, Any],
params: Dict[str, Any],
) -> Optional[str]
```

Called when registering a new user. The module can return a display name to set for the
user being registered by returning it as a string, or `None` if it doesn't wish to force a
display name for this user.

This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api)
has been completed by the user. It is not called when registering a user via SSO. It is
passed two dictionaries, which include the information that the user has provided during
the registration process. These dictionaries are identical to the ones passed to
[`get_username_for_registration`](#get_username_for_registration), so refer to the
documentation of this callback for more information about them.

If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If every callback returns `None`,
the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`).

## `is_3pid_allowed`

_First introduced in Synapse v1.53.0_
Expand Down Expand Up @@ -194,8 +222,7 @@ The example module below implements authentication checkers for two different lo
- Is checked by the method: `self.check_my_login`
- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
- Expects a `password` field to be sent to `/login`
- Is checked by the method: `self.check_pass`

- Is checked by the method: `self.check_pass`

```python
from typing import Awaitable, Callable, Optional, Tuple
Expand Down
58 changes: 58 additions & 0 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,6 +2064,10 @@ def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]


Expand All @@ -2080,6 +2084,9 @@ def __init__(self) -> None:
self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
self.get_displayname_for_registration_callbacks: List[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []

# Mapping from login type to login parameters
Expand All @@ -2099,6 +2106,9 @@ def register_password_auth_provider_callbacks(
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
Expand Down Expand Up @@ -2148,6 +2158,11 @@ def register_password_auth_provider_callbacks(
get_username_for_registration,
)

if get_displayname_for_registration is not None:
self.get_displayname_for_registration_callbacks.append(
get_displayname_for_registration,
)

if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)

Expand Down Expand Up @@ -2350,6 +2365,49 @@ async def get_username_for_registration(

return None

async def get_displayname_for_registration(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge deal, but it might make sense to make a helper function which does most of this, these could then become lightweight wrappers of something like:

async def get_displayname_for_registration(self, uia_results, params):
    return await self._try_callbacks(self.get_displayname_for_registration_callbacks, uia_results, params)

async def _try_callbacks(self, callbacks, *args, **kwargs):
    ....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This could definitely be a follow-up too, just noting that it seems we have a lot of boilerplate code here.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'd definitely like to have some standardised wrapping like this as a follow-up. (see also the aside mentioned in the description of #11031)

self,
uia_results: JsonDict,
params: JsonDict,
) -> Optional[str]:
"""Defines the display name to use when registering the user, using the
credentials and parameters provided during the UIA flow.

Stops at the first callback that returns a tuple containing at least one string.

Args:
uia_results: The credentials provided during the UIA flow.
params: The parameters provided by the registration request.

Returns:
A tuple which first element is the display name, and the second is an MXC URL
to the user's avatar.
"""
for callback in self.get_displayname_for_registration_callbacks:
try:
res = await callback(uia_results, params)

if isinstance(res, str):
return res
elif res is not None:
# mypy complains that this line is unreachable because it assumes the
# data returned by the module fits the expected type. We just want
# to make sure this is the case.
logger.warning( # type: ignore[unreachable]
"Ignoring non-string value returned by"
" get_displayname_for_registration callback %s: %s",
callback,
res,
)
except Exception as e:
logger.error(
"Module raised an exception in get_displayname_for_registration: %s",
e,
)
raise SynapseError(code=500, msg="Internal Server Error")

return None

async def is_3pid_allowed(
self,
medium: str,
Expand Down
5 changes: 5 additions & 0 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
Expand Down Expand Up @@ -317,6 +318,9 @@ def register_password_auth_provider_callbacks(
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None:
"""Registers callbacks for password auth provider capabilities.
Expand All @@ -328,6 +332,7 @@ def register_password_auth_provider_callbacks(
is_3pid_allowed=is_3pid_allowed,
auth_checkers=auth_checkers,
get_username_for_registration=get_username_for_registration,
get_displayname_for_registration=get_displayname_for_registration,
)

def register_background_update_controller_callbacks(
Expand Down
7 changes: 7 additions & 0 deletions synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,18 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
session_id
)

display_name = await (
self.password_auth_provider.get_displayname_for_registration(
auth_result, params
)
)

registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
default_display_name=display_name,
address=client_addr,
user_agent_ips=entries,
)
Expand Down
123 changes: 93 additions & 30 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def parse_config(self):

def __init__(self, config, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)

def check_auth(self, *args):
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(self, config, api: ModuleApi):
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
},
}
)
pass

Expand Down Expand Up @@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
account.register_servlets,
]

CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"

def setUp(self):
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
Expand Down Expand Up @@ -754,7 +757,9 @@ def test_username(self):
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
self._setup_get_username_for_registration()
self._setup_get_name_for_registration(
callback_name=self.CALLBACK_USERNAME,
)

username = "rin"
channel = self.make_request(
Expand All @@ -777,30 +782,14 @@ def test_username_uia(self):
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
m = self._setup_get_username_for_registration()

# Initiate the UIA flow.
username = "rin"
channel = self.make_request(
"POST",
"register",
{"username": username, "type": "m.login.password", "password": "bar"},
m = self._setup_get_name_for_registration(
callback_name=self.CALLBACK_USERNAME,
)
self.assertEqual(channel.code, 401)
self.assertIn("session", channel.json_body)

# Check that the callback hasn't been called yet.
m.assert_not_called()
username = "rin"
res = self._do_uia_assert_mock_not_called(username, m)

# Finish the UIA flow.
session = channel.json_body["session"]
channel = self.make_request(
"POST",
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
mxid = channel.json_body["user_id"]
mxid = res["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")

# Check that the callback has been called.
Expand All @@ -817,6 +806,56 @@ def test_3pid_allowed(self):
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)

def test_displayname(self):
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
"""
self._setup_get_name_for_registration(
callback_name=self.CALLBACK_DISPLAYNAME,
)
babolivier marked this conversation as resolved.
Show resolved Hide resolved

username = "rin"
channel = self.make_request(
"POST",
"/register",
{
"username": username,
"password": "bar",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(channel.code, 200)

# Our callback takes the username and appends "-foo" to it, check that's what we
# have.
user_id = UserID.from_string(channel.json_body["user_id"])
display_name = self.get_success(
self.hs.get_profile_handler().get_displayname(user_id)
)

self.assertEqual(display_name, username + "-foo")

def test_displayname_uia(self):
"""Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow.
"""
m = self._setup_get_name_for_registration(
callback_name=self.CALLBACK_DISPLAYNAME,
)

username = "rin"
res = self._do_uia_assert_mock_not_called(username, m)

user_id = UserID.from_string(res["user_id"])
display_name = self.get_success(
self.hs.get_profile_handler().get_displayname(user_id)
)

self.assertEqual(display_name, username + "-foo")

# Check that the callback has been called.
m.assert_called_once()

def _test_3pid_allowed(self, username: str, registration: bool):
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
Expand Down Expand Up @@ -877,23 +916,47 @@ def _test_3pid_allowed(self, username: str, registration: bool):

m.assert_called_once_with("email", "[email protected]", registration)

def _setup_get_username_for_registration(self) -> Mock:
"""Registers a get_username_for_registration callback that appends "-foo" to the
username the client is trying to register.
def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
"""Registers either a get_username_for_registration callback or a
get_displayname_for_registration callback that appends "-foo" to the username the
client is trying to register.
"""

async def get_username_for_registration(uia_results, params):
async def callback(uia_results, params):
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"

m = Mock(side_effect=get_username_for_registration)
m = Mock(side_effect=callback)

password_auth_provider = self.hs.get_password_auth_provider()
password_auth_provider.get_username_for_registration_callbacks.append(m)
getattr(password_auth_provider, callback_name + "_callbacks").append(m)

return m

def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
# Initiate the UIA flow.
channel = self.make_request(
"POST",
"register",
{"username": username, "type": "m.login.password", "password": "bar"},
)
self.assertEqual(channel.code, 401)
self.assertIn("session", channel.json_body)

# Check that the callback hasn't been called yet.
m.assert_not_called()

# Finish the UIA flow.
session = channel.json_body["session"]
channel = self.make_request(
"POST",
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
return channel.json_body

def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
Expand Down