From 2aedf5899b2e99191ea2643f9439dfd0e4ec844e Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Thu, 7 Jul 2022 16:36:35 -0500 Subject: [PATCH] LazyCsrfTokenRepository#loadToken Supports Deferring Delegation Previously LazyCsrfTokenRepository supported lazily saving the CsrfToken which allowed for lazily saving the CsrfToken. However, it did not support lazily reading the CsrfToken. This meant every request required reading the CsrfToken (often the HttpSession). This commit allows for lazily reading the CsrfToken and thus prevents unnecessary reads to the HttpSession. Closes gh-11700 --- .../web/csrf/LazyCsrfTokenRepository.java | 63 +++++++++++++++++++ .../csrf/LazyCsrfTokenRepositoryTests.java | 12 ++++ 2 files changed, 75 insertions(+) diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index d5c0c211904..082cbc12bf6 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -38,6 +38,8 @@ public final class LazyCsrfTokenRepository implements CsrfTokenRepository { private final CsrfTokenRepository delegate; + private boolean deferLoadToken; + /** * Creates a new instance * @param delegate the {@link CsrfTokenRepository} to use. Cannot be null @@ -48,6 +50,15 @@ public LazyCsrfTokenRepository(CsrfTokenRepository delegate) { this.delegate = delegate; } + /** + * Determines if {@link #loadToken(HttpServletRequest)} should be lazily loaded. + * @param deferLoadToken true if should lazily load + * {@link #loadToken(HttpServletRequest)}. Default false. + */ + public void setDeferLoadToken(boolean deferLoadToken) { + this.deferLoadToken = deferLoadToken; + } + /** * Generates a new token * @param request the {@link HttpServletRequest} to use. The @@ -77,6 +88,9 @@ public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletRe */ @Override public CsrfToken loadToken(HttpServletRequest request) { + if (this.deferLoadToken) { + return new LazyLoadCsrfToken(request, this.delegate); + } return this.delegate.loadToken(request); } @@ -92,6 +106,55 @@ private HttpServletResponse getResponse(HttpServletRequest request) { return response; } + private final class LazyLoadCsrfToken implements CsrfToken { + + private final HttpServletRequest request; + + private final CsrfTokenRepository tokenRepository; + + private CsrfToken token; + + private LazyLoadCsrfToken(HttpServletRequest request, CsrfTokenRepository tokenRepository) { + this.request = request; + this.tokenRepository = tokenRepository; + } + + private CsrfToken getDelegate() { + if (this.token != null) { + return this.token; + } + // load from the delegate repository + this.token = LazyCsrfTokenRepository.this.delegate.loadToken(this.request); + if (this.token == null) { + // return a generated token that is lazily saved since + // LazyCsrfTokenRepository#loadToken always returns a value + this.token = generateToken(this.request); + } + return this.token; + } + + @Override + public String getHeaderName() { + return getDelegate().getHeaderName(); + } + + @Override + public String getParameterName() { + return getDelegate().getParameterName(); + } + + @Override + public String getToken() { + return getDelegate().getToken(); + } + + @Override + public String toString() { + return "LazyLoadCsrfToken{" + "token=" + this.token + '}'; + } + + } + private static final class SaveOnAccessCsrfToken implements CsrfToken { private transient CsrfTokenRepository tokenRepository; diff --git a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java index e41ed591314..9be9d96518d 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java @@ -30,6 +30,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyZeroInteractions; /** @@ -97,4 +98,15 @@ public void loadTokenDelegates() { verify(this.delegate).loadToken(this.request); } + @Test + public void loadTokenWhenDeferLoadToken() { + given(this.delegate.loadToken(this.request)).willReturn(this.token); + this.repository.setDeferLoadToken(true); + CsrfToken loadToken = this.repository.loadToken(this.request); + verifyNoInteractions(this.delegate); + assertThat(loadToken.getToken()).isEqualTo(this.token.getToken()); + assertThat(loadToken.getHeaderName()).isEqualTo(this.token.getHeaderName()); + assertThat(loadToken.getParameterName()).isEqualTo(this.token.getParameterName()); + } + }