This repository has been archived by the owner on Apr 26, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow modules to set a display name on registration (#12009)
Co-authored-by: Patrick Cloke <[email protected]>
- Loading branch information
1 parent
da0e9f8
commit 707049c
Showing
6 changed files
with
195 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Enable modules to set a custom display name when registering a user. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,7 +84,7 @@ def parse_config(self): | |
|
||
def __init__(self, config, api: ModuleApi): | ||
api.register_password_auth_provider_callbacks( | ||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, | ||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth} | ||
) | ||
|
||
def check_auth(self, *args): | ||
|
@@ -122,7 +122,7 @@ def __init__(self, config, api: ModuleApi): | |
auth_checkers={ | ||
("test.login_type", ("test_field",)): self.check_auth, | ||
("m.login.password", ("password",)): self.check_auth, | ||
}, | ||
} | ||
) | ||
pass | ||
|
||
|
@@ -163,6 +163,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |
account.register_servlets, | ||
] | ||
|
||
CALLBACK_USERNAME = "get_username_for_registration" | ||
CALLBACK_DISPLAYNAME = "get_displayname_for_registration" | ||
|
||
def setUp(self): | ||
# we use a global mock device, so make sure we are starting with a clean slate | ||
mock_password_provider.reset_mock() | ||
|
@@ -754,7 +757,9 @@ def test_username(self): | |
"""Tests that the get_username_for_registration callback can define the username | ||
of a user when registering. | ||
""" | ||
self._setup_get_username_for_registration() | ||
self._setup_get_name_for_registration( | ||
callback_name=self.CALLBACK_USERNAME, | ||
) | ||
|
||
username = "rin" | ||
channel = self.make_request( | ||
|
@@ -777,30 +782,14 @@ def test_username_uia(self): | |
"""Tests that the get_username_for_registration callback is only called at the | ||
end of the UIA flow. | ||
""" | ||
m = self._setup_get_username_for_registration() | ||
|
||
# Initiate the UIA flow. | ||
username = "rin" | ||
channel = self.make_request( | ||
"POST", | ||
"register", | ||
{"username": username, "type": "m.login.password", "password": "bar"}, | ||
m = self._setup_get_name_for_registration( | ||
callback_name=self.CALLBACK_USERNAME, | ||
) | ||
self.assertEqual(channel.code, 401) | ||
self.assertIn("session", channel.json_body) | ||
|
||
# Check that the callback hasn't been called yet. | ||
m.assert_not_called() | ||
username = "rin" | ||
res = self._do_uia_assert_mock_not_called(username, m) | ||
|
||
# Finish the UIA flow. | ||
session = channel.json_body["session"] | ||
channel = self.make_request( | ||
"POST", | ||
"register", | ||
{"auth": {"session": session, "type": LoginType.DUMMY}}, | ||
) | ||
self.assertEqual(channel.code, 200, channel.json_body) | ||
mxid = channel.json_body["user_id"] | ||
mxid = res["user_id"] | ||
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") | ||
|
||
# Check that the callback has been called. | ||
|
@@ -817,6 +806,56 @@ def test_3pid_allowed(self): | |
self._test_3pid_allowed("rin", False) | ||
self._test_3pid_allowed("kitay", True) | ||
|
||
def test_displayname(self): | ||
"""Tests that the get_displayname_for_registration callback can define the | ||
display name of a user when registering. | ||
""" | ||
self._setup_get_name_for_registration( | ||
callback_name=self.CALLBACK_DISPLAYNAME, | ||
) | ||
|
||
username = "rin" | ||
channel = self.make_request( | ||
"POST", | ||
"/register", | ||
{ | ||
"username": username, | ||
"password": "bar", | ||
"auth": {"type": LoginType.DUMMY}, | ||
}, | ||
) | ||
self.assertEqual(channel.code, 200) | ||
|
||
# Our callback takes the username and appends "-foo" to it, check that's what we | ||
# have. | ||
user_id = UserID.from_string(channel.json_body["user_id"]) | ||
display_name = self.get_success( | ||
self.hs.get_profile_handler().get_displayname(user_id) | ||
) | ||
|
||
self.assertEqual(display_name, username + "-foo") | ||
|
||
def test_displayname_uia(self): | ||
"""Tests that the get_displayname_for_registration callback is only called at the | ||
end of the UIA flow. | ||
""" | ||
m = self._setup_get_name_for_registration( | ||
callback_name=self.CALLBACK_DISPLAYNAME, | ||
) | ||
|
||
username = "rin" | ||
res = self._do_uia_assert_mock_not_called(username, m) | ||
|
||
user_id = UserID.from_string(res["user_id"]) | ||
display_name = self.get_success( | ||
self.hs.get_profile_handler().get_displayname(user_id) | ||
) | ||
|
||
self.assertEqual(display_name, username + "-foo") | ||
|
||
# Check that the callback has been called. | ||
m.assert_called_once() | ||
|
||
def _test_3pid_allowed(self, username: str, registration: bool): | ||
"""Tests that the "is_3pid_allowed" module callback is called correctly, using | ||
either /register or /account URLs depending on the arguments. | ||
|
@@ -877,23 +916,47 @@ def _test_3pid_allowed(self, username: str, registration: bool): | |
|
||
m.assert_called_once_with("email", "[email protected]", registration) | ||
|
||
def _setup_get_username_for_registration(self) -> Mock: | ||
"""Registers a get_username_for_registration callback that appends "-foo" to the | ||
username the client is trying to register. | ||
def _setup_get_name_for_registration(self, callback_name: str) -> Mock: | ||
"""Registers either a get_username_for_registration callback or a | ||
get_displayname_for_registration callback that appends "-foo" to the username the | ||
client is trying to register. | ||
""" | ||
|
||
async def get_username_for_registration(uia_results, params): | ||
async def callback(uia_results, params): | ||
self.assertIn(LoginType.DUMMY, uia_results) | ||
username = params["username"] | ||
return username + "-foo" | ||
|
||
m = Mock(side_effect=get_username_for_registration) | ||
m = Mock(side_effect=callback) | ||
|
||
password_auth_provider = self.hs.get_password_auth_provider() | ||
password_auth_provider.get_username_for_registration_callbacks.append(m) | ||
getattr(password_auth_provider, callback_name + "_callbacks").append(m) | ||
|
||
return m | ||
|
||
def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: | ||
# Initiate the UIA flow. | ||
channel = self.make_request( | ||
"POST", | ||
"register", | ||
{"username": username, "type": "m.login.password", "password": "bar"}, | ||
) | ||
self.assertEqual(channel.code, 401) | ||
self.assertIn("session", channel.json_body) | ||
|
||
# Check that the callback hasn't been called yet. | ||
m.assert_not_called() | ||
|
||
# Finish the UIA flow. | ||
session = channel.json_body["session"] | ||
channel = self.make_request( | ||
"POST", | ||
"register", | ||
{"auth": {"session": session, "type": LoginType.DUMMY}}, | ||
) | ||
self.assertEqual(channel.code, 200, channel.json_body) | ||
return channel.json_body | ||
|
||
def _get_login_flows(self) -> JsonDict: | ||
channel = self.make_request("GET", "/_matrix/client/r0/login") | ||
self.assertEqual(channel.code, 200, channel.result) | ||
|