Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JdbcOidcSessionRegistry implementation #14511

Open
jzheaux opened this issue Jan 30, 2024 · 3 comments
Open

Add JdbcOidcSessionRegistry implementation #14511

jzheaux opened this issue Jan 30, 2024 · 3 comments
Labels
in: oauth2 An issue in OAuth2 modules (oauth2-core, oauth2-client, oauth2-resource-server, oauth2-jose) type: enhancement A general enhancement

Comments

@jzheaux
Copy link
Contributor

jzheaux commented Jan 30, 2024

An InMemoryOidcSessionRegistry is limited to storing things only on a single instance. A JDBC-based implementation will make so that OIDC Backchannel Logout will work in a clustered environment.

@jsantana3c
Copy link

I was trying to do it on Redis too, but I need the Mixing for OidcSessionInformation

@aelillie
Copy link

aelillie commented Jun 4, 2024

Sample from my code to implement this:

/**
 * OIDC Session registry for a clustered server setup with multiple nodes,
 * which saves user session information in a central database.
 * This follows the suggestion in the Spring Security docs:
 * <a href="https://docs.spring.io/spring-security/reference/servlet/oauth2/login/logout.html#_customizing_the_oidc_provider_session_strategy">Customizing the OIDC Provider Session Strategy</a>
 * Implementation logic follows the implementation for the default OIDC session registry, {@code InMemoryOidcSessionRegistry}.
 * @see org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry
 */
@Slf4j
@Component
public class ClusteredOidcSessionRegistry implements OidcSessionRegistry {
    private final OidcUserSessionRepository oidcUserSessionRepository;

    public ClusteredOidcSessionRegistry(OidcUserSessionRepository oidcUserSessionRepository) {
        this.oidcUserSessionRepository = oidcUserSessionRepository;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void saveSessionInformation(OidcSessionInformation info) {
        var oidcUserSession = new OidcUserSession();
        oidcUserSession.setSessionId(info.getSessionId());
        oidcUserSession.setSessionInformation(info);
        oidcUserSessionRepository.save(oidcUserSession);
    }

    /**
     * {@inheritDoc}
     */
    @Transactional
    @Override
    public OidcSessionInformation removeSessionInformation(String clientSessionId) {
        Optional<OidcUserSession> oidcUserSession = oidcUserSessionRepository.findBySessionId(clientSessionId);
        oidcUserSession.ifPresent(oidcUserSessionRepository::delete);
        return oidcUserSession.map(OidcUserSession::getSessionInformation).orElse(null);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Iterable<OidcSessionInformation> removeSessionInformation(OidcLogoutToken token) {
        List<String> audience = token.getAudience();
        String issuer = token.getIssuer().toString();
        String subject = token.getSubject();
        String providerSessionId = token.getSessionId();
        Predicate<OidcSessionInformation> matcher = (providerSessionId != null)
                ? sessionIdMatcher(audience, issuer, providerSessionId)
                : subjectMatcher(audience, issuer, subject);
        var allSavedSessions = oidcUserSessionRepository.findAll();
        var deletedOidcSessions = deleteAndGetMatchedSessions(allSavedSessions, matcher);
        if (deletedOidcSessions.isEmpty()) {
            log.debug("Failed to remove any sessions since none matched");
        } else {
            log.trace("Found and removed {} session(s) from mapping of {} session(s)", deletedOidcSessions.size(), allSavedSessions.size());
        }
        return deletedOidcSessions;
    }

    private Set<OidcSessionInformation> deleteAndGetMatchedSessions(List<OidcUserSession> oidcUserSessions,
                                                                    Predicate<OidcSessionInformation> matcher) {
        Set<OidcSessionInformation> infos = new HashSet<>();
        oidcUserSessions.forEach(oidcUserSession -> {
            var sessionInfo = oidcUserSession.getSessionInformation();
            if (matcher.test(sessionInfo)) {
                oidcUserSessionRepository.delete(oidcUserSession);
                infos.add(sessionInfo);
            }
        });
        return infos;
    }

    private static Predicate<OidcSessionInformation> sessionIdMatcher(List<String> audience, String issuer,
                                                                      String sessionId) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SID, sessionId);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSessionId = session.getPrincipal().getClaimAsString(LogoutTokenClaimNames.SID);
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && sessionId.equals(thatSessionId);
        };
    }

    private static Predicate<OidcSessionInformation> subjectMatcher(List<String> audience, String issuer,
                                                                    String subject) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SUB, subject);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSubject = session.getPrincipal().getSubject();
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && subject.equals(thatSubject);
        };
    }
}

@xiechangning20
Copy link

Sample from my code to implement this:

/**
 * OIDC Session registry for a clustered server setup with multiple nodes,
 * which saves user session information in a central database.
 * This follows the suggestion in the Spring Security docs:
 * <a href="https://docs.spring.io/spring-security/reference/servlet/oauth2/login/logout.html#_customizing_the_oidc_provider_session_strategy">Customizing the OIDC Provider Session Strategy</a>
 * Implementation logic follows the implementation for the default OIDC session registry, {@code InMemoryOidcSessionRegistry}.
 * @see org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry
 */
@Slf4j
@Component
public class ClusteredOidcSessionRegistry implements OidcSessionRegistry {
    private final OidcUserSessionRepository oidcUserSessionRepository;

    public ClusteredOidcSessionRegistry(OidcUserSessionRepository oidcUserSessionRepository) {
        this.oidcUserSessionRepository = oidcUserSessionRepository;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void saveSessionInformation(OidcSessionInformation info) {
        var oidcUserSession = new OidcUserSession();
        oidcUserSession.setSessionId(info.getSessionId());
        oidcUserSession.setSessionInformation(info);
        oidcUserSessionRepository.save(oidcUserSession);
    }

    /**
     * {@inheritDoc}
     */
    @Transactional
    @Override
    public OidcSessionInformation removeSessionInformation(String clientSessionId) {
        Optional<OidcUserSession> oidcUserSession = oidcUserSessionRepository.findBySessionId(clientSessionId);
        oidcUserSession.ifPresent(oidcUserSessionRepository::delete);
        return oidcUserSession.map(OidcUserSession::getSessionInformation).orElse(null);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public Iterable<OidcSessionInformation> removeSessionInformation(OidcLogoutToken token) {
        List<String> audience = token.getAudience();
        String issuer = token.getIssuer().toString();
        String subject = token.getSubject();
        String providerSessionId = token.getSessionId();
        Predicate<OidcSessionInformation> matcher = (providerSessionId != null)
                ? sessionIdMatcher(audience, issuer, providerSessionId)
                : subjectMatcher(audience, issuer, subject);
        var allSavedSessions = oidcUserSessionRepository.findAll();
        var deletedOidcSessions = deleteAndGetMatchedSessions(allSavedSessions, matcher);
        if (deletedOidcSessions.isEmpty()) {
            log.debug("Failed to remove any sessions since none matched");
        } else {
            log.trace("Found and removed {} session(s) from mapping of {} session(s)", deletedOidcSessions.size(), allSavedSessions.size());
        }
        return deletedOidcSessions;
    }

    private Set<OidcSessionInformation> deleteAndGetMatchedSessions(List<OidcUserSession> oidcUserSessions,
                                                                    Predicate<OidcSessionInformation> matcher) {
        Set<OidcSessionInformation> infos = new HashSet<>();
        oidcUserSessions.forEach(oidcUserSession -> {
            var sessionInfo = oidcUserSession.getSessionInformation();
            if (matcher.test(sessionInfo)) {
                oidcUserSessionRepository.delete(oidcUserSession);
                infos.add(sessionInfo);
            }
        });
        return infos;
    }

    private static Predicate<OidcSessionInformation> sessionIdMatcher(List<String> audience, String issuer,
                                                                      String sessionId) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SID, sessionId);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSessionId = session.getPrincipal().getClaimAsString(LogoutTokenClaimNames.SID);
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && sessionId.equals(thatSessionId);
        };
    }

    private static Predicate<OidcSessionInformation> subjectMatcher(List<String> audience, String issuer,
                                                                    String subject) {
        log.trace("Looking up sessions by issuer [{}] and {} [{}]", issuer, LogoutTokenClaimNames.SUB, subject);
        return session -> {
            List<String> thatAudience = session.getPrincipal().getAudience();
            String thatIssuer = session.getPrincipal().getIssuer().toString();
            String thatSubject = session.getPrincipal().getSubject();
            if (thatAudience == null) {
                return false;
            }
            return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer)
                    && subject.equals(thatSubject);
        };
    }
}

@aelillie Thanks for Sharing this. Mind if I ask you to also share your implementation on the OidcUserSessionRepository?
I'm having a real hard time implementing the logic to properly interact with JDBC. -- Thank you so much

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
in: oauth2 An issue in OAuth2 modules (oauth2-core, oauth2-client, oauth2-resource-server, oauth2-jose) type: enhancement A general enhancement
Projects
None yet
Development

No branches or pull requests

4 participants