diff --git a/flask_appbuilder/security/manager.py b/flask_appbuilder/security/manager.py index 869735b8db..20b35d2644 100644 --- a/flask_appbuilder/security/manager.py +++ b/flask_appbuilder/security/manager.py @@ -627,7 +627,9 @@ def get_oauth_user_info( log.debug("User info from Azure: %s", me) # https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims return { - "email": me["email"], + # To keep backward compatibility with previous versions + # of FAB, we use upn if available, otherwise we use email + "email": me["upn"] if "upn" in me else me["email"], "first_name": me.get("given_name", ""), "last_name": me.get("family_name", ""), "username": me["oid"], diff --git a/tests/security/test_auth_oauth.py b/tests/security/test_auth_oauth.py index 5c411f3120..acf4b9e833 100644 --- a/tests/security/test_auth_oauth.py +++ b/tests/security/test_auth_oauth.py @@ -480,6 +480,42 @@ def test_oauth_user_info_unknown_provider(self): with self.assertRaises(OAuthProviderUnknown): self.appbuilder.sm.oauth_user_info("unknown", {}) + def test_oauth_user_info_azure_email_upn(self): + self.appbuilder = AppBuilder(self.app, self.db.session) + claims = { + "aud": "test-aud", + "iss": "https://sts.windows.net/test/", + "iat": 7282182129, + "nbf": 7282182129, + "exp": 1000000000, + "amr": ["pwd"], + "email": "test@gmail.com", + "upn": "test@upn.com", + "family_name": "user", + "given_name": "test", + "idp": "live.com", + "name": "Test user", + "oid": "b1a54a40-8dfa-4a6d-a2b8-f90b84d4b1df", + "unique_name": "live.com#test@gmail.com", + "ver": "1.0", + } + + # Create an unsigned JWT + unsigned_jwt = jwt.encode(claims, key=None, algorithm="none") + user_info = self.appbuilder.sm.get_oauth_user_info( + "azure", {"access_token": "", "id_token": unsigned_jwt} + ) + self.assertEqual( + user_info, + { + "email": "test@upn.com", + "first_name": "test", + "last_name": "user", + "role_keys": [], + "username": "b1a54a40-8dfa-4a6d-a2b8-f90b84d4b1df", + }, + ) + def test_oauth_user_info_azure(self): self.appbuilder = AppBuilder(self.app, self.db.session) claims = { diff --git a/tests/test_api.py b/tests/test_api.py index af676f5bbd..d97a8103b4 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -961,7 +961,7 @@ def test_get_item_not_found(self): client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) with model1_data(self.appbuilder.session, 1): - model_id = MODEL1_DATA_SIZE + 1 + model_id = self.appbuilder.session.query(func.max(Model1.id)).scalar() + 1 rv = self.auth_client_get(client, token, f"api/v1/model1api/{model_id}") self.assertEqual(rv.status_code, 404)