From f50688245edc63a99607382aa7d6e7eb26576fd6 Mon Sep 17 00:00:00 2001 From: Dave Hallam Date: Sat, 10 Jun 2023 22:25:31 +0100 Subject: [PATCH] 515 RFC7523 apply headers while signing --- authlib/oauth2/rfc7523/auth.py | 3 +- tests/core/test_oauth2/test_rfc7523.py | 410 +++++++++++++++++++++++++ 2 files changed, 412 insertions(+), 1 deletion(-) create mode 100644 tests/core/test_oauth2/test_rfc7523.py diff --git a/authlib/oauth2/rfc7523/auth.py b/authlib/oauth2/rfc7523/auth.py index 2cb60aa0..bd537552 100644 --- a/authlib/oauth2/rfc7523/auth.py +++ b/authlib/oauth2/rfc7523/auth.py @@ -41,7 +41,7 @@ def sign(self, auth, token_endpoint): client_id=auth.client_id, token_endpoint=token_endpoint, claims=self.claims, - headers=self.headers, + header=self.headers, alg=self.alg, ) @@ -89,5 +89,6 @@ def sign(self, auth, token_endpoint): client_id=auth.client_id, token_endpoint=token_endpoint, claims=self.claims, + header=self.headers, alg=self.alg, ) diff --git a/tests/core/test_oauth2/test_rfc7523.py b/tests/core/test_oauth2/test_rfc7523.py new file mode 100644 index 00000000..9bf0d5c3 --- /dev/null +++ b/tests/core/test_oauth2/test_rfc7523.py @@ -0,0 +1,410 @@ +import time +from unittest import TestCase, mock + +from authlib.jose import jwt +from authlib.oauth2.rfc7523 import ClientSecretJWT, PrivateKeyJWT +from tests.util import read_file_path + + +class ClientSecretJWTTest(TestCase): + def test_nothing_set(self): + jwt_signer = ClientSecretJWT() + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "HS256") + + def test_endpoint_set(self): + jwt_signer = ClientSecretJWT(token_endpoint="https://example.com/oauth/access_token") + + self.assertEqual(jwt_signer.token_endpoint, "https://example.com/oauth/access_token") + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "HS256") + + def test_alg_set(self): + jwt_signer = ClientSecretJWT(alg="HS512") + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "HS512") + + def test_claims_set(self): + jwt_signer = ClientSecretJWT(claims={"foo1": "bar1"}) + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, {"foo1": "bar1"}) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "HS256") + + def test_headers_set(self): + jwt_signer = ClientSecretJWT(headers={"foo1": "bar1"}) + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, {"foo1": "bar1"}) + self.assertEqual(jwt_signer.alg, "HS256") + + def test_all_set(self): + jwt_signer = ClientSecretJWT( + token_endpoint="https://example.com/oauth/access_token", claims={"foo1a": "bar1a"}, + headers={"foo1b": "bar1b"}, alg="HS512" + ) + + self.assertEqual(jwt_signer.token_endpoint, "https://example.com/oauth/access_token") + self.assertEqual(jwt_signer.claims, {"foo1a": "bar1a"}) + self.assertEqual(jwt_signer.headers, {"foo1b": "bar1b"}) + self.assertEqual(jwt_signer.alg, "HS512") + + @staticmethod + def sign_and_decode(jwt_signer, client_id, client_secret, token_endpoint): + auth = mock.MagicMock() + auth.client_id = client_id + auth.client_secret = client_secret + + pre_sign_time = int(time.time()) + + data = jwt_signer.sign(auth, token_endpoint).decode("utf-8") + decoded = jwt.decode(data, client_secret) # , claims_cls=None, claims_options=None, claims_params=None): + + iat = decoded.pop("iat") + exp = decoded.pop("exp") + jti = decoded.pop("jti") + + return decoded, pre_sign_time, iat, exp, jti + + def test_sign_nothing_set(self): + jwt_signer = ClientSecretJWT() + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", "client_secret_1", "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", }, + decoded + ) + + self.assertEqual( + {"alg": "HS256", "typ": "JWT"}, + decoded.header + ) + + def test_sign_custom_jti(self): + jwt_signer = ClientSecretJWT(claims={"jti": "custom_jti"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", "client_secret_1", "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertEqual("custom_jti", jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", } + ) + + self.assertEqual( + {"alg": "HS256", "typ": "JWT"}, + decoded.header + ) + + def test_sign_with_additional_header(self): + jwt_signer = ClientSecretJWT(headers={"kid": "custom_kid"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", "client_secret_1", "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", } + ) + + self.assertEqual( + {"alg": "HS256", "typ": "JWT", "kid": "custom_kid"}, + decoded.header + ) + + def test_sign_with_additional_headers(self): + jwt_signer = ClientSecretJWT(headers={"kid": "custom_kid", "jku": "https://example.com/oauth/jwks"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", "client_secret_1", "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", } + ) + + self.assertEqual( + {"alg": "HS256", "typ": "JWT", "kid": "custom_kid", "jku": "https://example.com/oauth/jwks"}, + decoded.header + ) + + def test_sign_with_additional_claim(self): + jwt_signer = ClientSecretJWT(claims={"name": "Foo"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", "client_secret_1", "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", + "name": "Foo"} + ) + + self.assertEqual( + {"alg": "HS256", "typ": "JWT"}, + decoded.header + ) + + def test_sign_with_additional_claims(self): + jwt_signer = ClientSecretJWT(claims={"name": "Foo", "role": "bar"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", "client_secret_1", "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", + "name": "Foo", "role": "bar"} + ) + + self.assertEqual( + {"alg": "HS256", "typ": "JWT"}, + decoded.header + ) + + +class PrivateKeyJWTTest(TestCase): + + @classmethod + def setUpClass(cls): + cls.public_key = read_file_path("rsa_public.pem") + cls.private_key = read_file_path("rsa_private.pem") + + def test_nothing_set(self): + jwt_signer = PrivateKeyJWT() + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "RS256") + + def test_endpoint_set(self): + jwt_signer = PrivateKeyJWT(token_endpoint="https://example.com/oauth/access_token") + + self.assertEqual(jwt_signer.token_endpoint, "https://example.com/oauth/access_token") + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "RS256") + + def test_alg_set(self): + jwt_signer = PrivateKeyJWT(alg="RS512") + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "RS512") + + def test_claims_set(self): + jwt_signer = PrivateKeyJWT(claims={"foo1": "bar1"}) + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, {"foo1": "bar1"}) + self.assertEqual(jwt_signer.headers, None) + self.assertEqual(jwt_signer.alg, "RS256") + + def test_headers_set(self): + jwt_signer = PrivateKeyJWT(headers={"foo1": "bar1"}) + + self.assertEqual(jwt_signer.token_endpoint, None) + self.assertEqual(jwt_signer.claims, None) + self.assertEqual(jwt_signer.headers, {"foo1": "bar1"}) + self.assertEqual(jwt_signer.alg, "RS256") + + def test_all_set(self): + jwt_signer = PrivateKeyJWT( + token_endpoint="https://example.com/oauth/access_token", claims={"foo1a": "bar1a"}, + headers={"foo1b": "bar1b"}, alg="RS512" + ) + + self.assertEqual(jwt_signer.token_endpoint, "https://example.com/oauth/access_token") + self.assertEqual(jwt_signer.claims, {"foo1a": "bar1a"}) + self.assertEqual(jwt_signer.headers, {"foo1b": "bar1b"}) + self.assertEqual(jwt_signer.alg, "RS512") + + @staticmethod + def sign_and_decode(jwt_signer, client_id, public_key, private_key, token_endpoint): + auth = mock.MagicMock() + auth.client_id = client_id + auth.client_secret = private_key + + pre_sign_time = int(time.time()) + + data = jwt_signer.sign(auth, token_endpoint).decode("utf-8") + decoded = jwt.decode(data, public_key) # , claims_cls=None, claims_options=None, claims_params=None): + + iat = decoded.pop("iat") + exp = decoded.pop("exp") + jti = decoded.pop("jti") + + return decoded, pre_sign_time, iat, exp, jti + + def test_sign_nothing_set(self): + jwt_signer = PrivateKeyJWT() + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", self.public_key, self.private_key, "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", }, + decoded + ) + + self.assertEqual( + {"alg": "RS256", "typ": "JWT"}, + decoded.header + ) + + def test_sign_custom_jti(self): + jwt_signer = PrivateKeyJWT(claims={"jti": "custom_jti"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", self.public_key, self.private_key, "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertEqual("custom_jti", jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", } + ) + + self.assertEqual( + {"alg": "RS256", "typ": "JWT"}, + decoded.header + ) + + def test_sign_with_additional_header(self): + jwt_signer = PrivateKeyJWT(headers={"kid": "custom_kid"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", self.public_key, self.private_key, "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", } + ) + + self.assertEqual( + {"alg": "RS256", "typ": "JWT", "kid": "custom_kid"}, + decoded.header + ) + + def test_sign_with_additional_headers(self): + jwt_signer = PrivateKeyJWT(headers={"kid": "custom_kid", "jku": "https://example.com/oauth/jwks"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", self.public_key, self.private_key, "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", } + ) + + self.assertEqual( + {"alg": "RS256", "typ": "JWT", "kid": "custom_kid", "jku": "https://example.com/oauth/jwks"}, + decoded.header + ) + + def test_sign_with_additional_claim(self): + jwt_signer = PrivateKeyJWT(claims={"name": "Foo"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", self.public_key, self.private_key, "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", + "name": "Foo"} + ) + + self.assertEqual( + {"alg": "RS256", "typ": "JWT"}, + decoded.header + ) + + def test_sign_with_additional_claims(self): + jwt_signer = PrivateKeyJWT(claims={"name": "Foo", "role": "bar"}) + + decoded, pre_sign_time, iat, exp, jti = self.sign_and_decode( + jwt_signer, "client_id_1", self.public_key, self.private_key, "https://example.com/oauth/access_token" + ) + + self.assertGreaterEqual(iat, pre_sign_time) + self.assertGreaterEqual(exp, iat + 3600) + self.assertLessEqual(exp, iat + 3600 + 2) + self.assertIsNotNone(jti) + + self.assertEqual( + decoded, {"iss": "client_id_1", "aud": "https://example.com/oauth/access_token", "sub": "client_id_1", + "name": "Foo", "role": "bar"} + ) + + self.assertEqual( + {"alg": "RS256", "typ": "JWT"}, + decoded.header + )