Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Customize claim for local part of JWT logins #11361

Merged
merged 12 commits into from
Nov 22, 2021
1 change: 1 addition & 0 deletions changelog.d/11361.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update the JWT login type to support custom a `sub` claim.
5 changes: 3 additions & 2 deletions docs/jwt.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ will be removed in a future version of Synapse.

The `token` field should include the JSON web token with the following claims:

* The `sub` (subject) claim is required and should encode the local part of the
user ID.
* A claim that encodes the local part of the user ID is required. By default,
the `sub` (subject) claim is used, or a custom claim can be set in the
configuration file.
* The expiration time (`exp`), not before time (`nbf`), and issued at (`iat`)
claims are optional, but validated if present.
* The issuer (`iss`) claim is optional, but required and validated if configured.
Expand Down
6 changes: 6 additions & 0 deletions docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2039,6 +2039,12 @@ sso:
#
#algorithm: "provided-by-your-issuer"

# Name of the claim containing a unique identifier for the user.
#
# Optional, defaults to `sub`.
#
#subject_claim: "sub"

# The issuer to validate the "iss" claim against.
#
# Optional, if provided the "iss" claim will be required and
Expand Down
9 changes: 9 additions & 0 deletions synapse/config/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def read_config(self, config, **kwargs):
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]

self.jwt_subject_claim = jwt_config.get("subject_claim", "sub")

# The issuer and audiences are optional, if provided, it is asserted
# that the claims exist on the JWT.
self.jwt_issuer = jwt_config.get("issuer")
Expand All @@ -46,6 +48,7 @@ def read_config(self, config, **kwargs):
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
self.jwt_subject_claim = None
self.jwt_issuer = None
self.jwt_audiences = None

Expand Down Expand Up @@ -88,6 +91,12 @@ def generate_config_section(self, **kwargs):
#
#algorithm: "provided-by-your-issuer"

# Name of the claim containing a unique identifier for the user.
#
# Optional, defaults to `sub`.
#
#subject_claim: "sub"

# The issuer to validate the "iss" claim against.
#
# Optional, if provided the "iss" claim will be required and
Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, hs: "HomeServer"):
# JWT configuration variables.
self.jwt_enabled = hs.config.jwt.jwt_enabled
self.jwt_secret = hs.config.jwt.jwt_secret
self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
self.jwt_algorithm = hs.config.jwt.jwt_algorithm
self.jwt_issuer = hs.config.jwt.jwt_issuer
self.jwt_audiences = hs.config.jwt.jwt_audiences
Expand Down Expand Up @@ -413,7 +414,7 @@ async def _do_jwt_login(
errcode=Codes.FORBIDDEN,
)

user = payload.get("sub", None)
user = payload.get(self.jwt_subject_claim, None)
if user is None:
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)

Expand Down
68 changes: 36 additions & 32 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):

jwt_secret = "secret"
jwt_algorithm = "HS256"
base_config = {
"enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
}

def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt.jwt_enabled = True
self.hs.config.jwt.jwt_secret = self.jwt_secret
self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
return self.hs
def default_config(self):
config = super().default_config()

# If jwt_config has been defined (eg via @override_config), don't replace it.
if config.get("jwt_config") is None:
config["jwt_config"] = self.base_config

return config

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
Expand Down Expand Up @@ -879,16 +886,7 @@ def test_login_no_sub(self):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")

@override_config(
{
"jwt_config": {
"jwt_enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
"issuer": "test-issuer",
}
}
)
@override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
def test_login_iss(self):
"""Test validating the issuer claim."""
# A valid issuer.
Expand Down Expand Up @@ -919,16 +917,7 @@ def test_login_iss_no_config(self):
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

@override_config(
{
"jwt_config": {
"jwt_enabled": True,
"secret": jwt_secret,
"algorithm": jwt_algorithm,
"audiences": ["test-audience"],
}
}
)
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self):
"""Test validating the audience claim."""
# A valid audience.
Expand Down Expand Up @@ -962,6 +951,19 @@ def test_login_aud_no_config(self):
channel.json_body["error"], "JWT validation failed: Invalid audience"
)

def test_login_default_sub(self):
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self):
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")

def test_login_no_token(self):
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
Expand Down Expand Up @@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
]
)

def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt.jwt_enabled = True
self.hs.config.jwt.jwt_secret = self.jwt_pubkey
self.hs.config.jwt.jwt_algorithm = "RS256"
return self.hs
def default_config(self):
config = super().default_config()
config["jwt_config"] = {
"enabled": True,
"secret": self.jwt_pubkey,
"algorithm": "RS256",
}
return config

def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
Expand Down