diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java index 3a01b31a59633..f6e7965d9630b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.core.security.authc.support; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -33,9 +32,10 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { private final List invalidatedTokens; private final List previouslyInvalidatedTokens; private final List errors; + private final int attemptCount; public TokensInvalidationResult(List invalidatedTokens, List previouslyInvalidatedTokens, - @Nullable List errors) { + @Nullable List errors, int attemptCount) { Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided"); this.invalidatedTokens = invalidatedTokens; Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided"); @@ -45,19 +45,18 @@ public TokensInvalidationResult(List invalidatedTokens, List pre } else { this.errors = Collections.emptyList(); } + this.attemptCount = attemptCount; } public TokensInvalidationResult(StreamInput in) throws IOException { this.invalidatedTokens = in.readStringList(); this.previouslyInvalidatedTokens = in.readStringList(); this.errors = in.readList(StreamInput::readException); - if (in.getVersion().before(Version.V_8_0_0)) { - in.readVInt(); - } + this.attemptCount = in.readVInt(); } public static TokensInvalidationResult emptyResult() { - return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0); } @@ -73,6 +72,10 @@ public List getErrors() { return errors; } + public int getAttemptCount() { + return attemptCount; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject() @@ -97,8 +100,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(invalidatedTokens); out.writeStringCollection(previouslyInvalidatedTokens); out.writeCollection(errors, StreamOutput::writeException); - if (out.getVersion().before(Version.V_8_0_0)) { - out.writeVInt(5); - } + out.writeVInt(attemptCount); } } diff --git a/x-pack/plugin/core/src/main/resources/security-index-template.json b/x-pack/plugin/core/src/main/resources/security-index-template.json index e938464ac6f50..183ffff4ea534 100644 --- a/x-pack/plugin/core/src/main/resources/security-index-template.json +++ b/x-pack/plugin/core/src/main/resources/security-index-template.json @@ -199,13 +199,6 @@ "refreshed" : { "type" : "boolean" }, - "refresh_time": { - "type": "date", - "format": "epoch_millis" - }, - "superseded_by": { - "type": "keyword" - }, "invalidated" : { "type" : "boolean" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java index bbfba920e385a..c9c2e4706440e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java @@ -29,7 +29,8 @@ public void testSerialization() throws IOException { TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), - new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), + randomIntBetween(0, 5)); InvalidateTokenResponse response = new InvalidateTokenResponse(result); try (BytesStreamOutput output = new BytesStreamOutput()) { response.writeTo(output); @@ -46,7 +47,8 @@ public void testSerialization() throws IOException { } result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), - Arrays.asList(generateRandomStringArray(20, 15, false)), Collections.emptyList()); + Arrays.asList(generateRandomStringArray(20, 15, false)), + Collections.emptyList(), randomIntBetween(0, 5)); response = new InvalidateTokenResponse(result); try (BytesStreamOutput output = new BytesStreamOutput()) { response.writeTo(output); @@ -66,7 +68,8 @@ public void testToXContent() throws IOException { List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false)); TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens, Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), - new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), + randomIntBetween(0, 5)); InvalidateTokenResponse response = new InvalidateTokenResponse(result); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java index 0e5acf5394f40..dee12f4a6bd7f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java @@ -63,7 +63,7 @@ protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListe final Map tokenMeta = (Map) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA); tokenService.createUserToken(authentication, originatingAuthentication, ActionListener.wrap(tuple -> { - final String tokenString = tokenService.getAccessTokenAsString(tuple.v1()); + final String tokenString = tokenService.getUserTokenString(tuple.v1()); final TimeValue expiresIn = tokenService.getExpirationDelay(); listener.onResponse( new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn)); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java index 75c3ee9df42af..5d5442803e3af 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java @@ -89,7 +89,7 @@ private void createToken(CreateTokenRequest request, Authentication authenticati boolean includeRefreshToken, ActionListener listener) { try { tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); + final String tokenStr = tokenService.getUserTokenString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); final CreateTokenResponse response = diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java index 71aeb64bc4276..0eac8d71fb20f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java @@ -31,7 +31,7 @@ public TransportRefreshTokenAction(TransportService transportService, ActionFilt @Override protected void doExecute(Task task, CreateTokenRequest request, ActionListener listener) { tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); + final String tokenStr = tokenService.getUserTokenString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); final CreateTokenResponse response = diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index a9994fcf2d1fe..36144899d2842 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -19,7 +19,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest.OpType; import org.elasticsearch.action.DocWriteResponse; -import org.elasticsearch.action.bulk.BackoffPolicy; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -29,7 +28,6 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.update.UpdateRequest; @@ -69,7 +67,6 @@ import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.seqno.SequenceNumbers; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.XPackField; @@ -116,12 +113,12 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Predicate; @@ -132,7 +129,6 @@ import static org.elasticsearch.search.SearchService.DEFAULT_KEEPALIVE_SETTING; import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.threadpool.ThreadPool.Names.GENERIC; /** * Service responsible for the creation, validation, and other management of {@link UserToken} @@ -159,7 +155,6 @@ public final class TokenService { private static final String MALFORMED_TOKEN_WWW_AUTH_VALUE = "Bearer realm=\"" + XPackField.SECURITY + "\", error=\"invalid_token\", error_description=\"The access token is malformed\""; private static final String TYPE = "doc"; - private static final BackoffPolicy DEFAULT_BACKOFF = BackoffPolicy.exponentialBackoff(); public static final String THREAD_POOL_NAME = XPackField.SECURITY + "-token-key"; public static final Setting TOKEN_EXPIRATION = Setting.timeSetting("xpack.security.authc.token.timeout", @@ -172,7 +167,8 @@ public final class TokenService { private static final String TOKEN_DOC_TYPE = "token"; private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_"; static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; - static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); + private static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); + private static final int MAX_RETRY_ATTEMPTS = 5; private static final Logger logger = LogManager.getLogger(TokenService.class); private final SecureRandom secureRandom = new SecureRandom(); @@ -225,22 +221,12 @@ public static Boolean isTokenServiceEnabled(Settings settings) { } /** - * Creates a token based on the provided authentication and metadata with an auto-generated token id. + * Create a token based on the provided authentication and metadata. * The created token will be stored in the security index. */ public void createUserToken(Authentication authentication, Authentication originatingClientAuth, ActionListener> listener, Map metadata, boolean includeRefreshToken) throws IOException { - createUserToken(UUIDs.randomBase64UUID(), authentication, originatingClientAuth, listener, metadata, includeRefreshToken); - } - - /** - * Create a token based on the provided authentication and metadata with the given token id. - * The created token will be stored in the security index. - */ - private void createUserToken(String userTokenId, Authentication authentication, Authentication originatingClientAuth, - ActionListener> listener, Map metadata, - boolean includeRefreshToken) throws IOException { ensureEnabled(); if (authentication == null) { listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided"))); @@ -253,7 +239,7 @@ private void createUserToken(String userTokenId, Authentication authentication, final Version version = clusterService.state().nodes().getMinNodeVersion(); final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), authentication.getLookedUpBy(), version, AuthenticationType.TOKEN, authentication.getMetadata()); - final UserToken userToken = new UserToken(userTokenId, version, tokenAuth, expiration, metadata); + final UserToken userToken = new UserToken(version, tokenAuth, expiration, metadata); final String refreshToken = includeRefreshToken ? UUIDs.randomBase64UUID() : null; try (XContentBuilder builder = XContentFactory.jsonBuilder()) { @@ -294,33 +280,9 @@ private void createUserToken(String userTokenId, Authentication authentication, } } - /** - * Reconstructs the {@link UserToken} from the existing {@code userTokenSource} and call the listener with the {@link UserToken} and the - * refresh token string - */ - private void reIssueTokens(Map userTokenSource, - String refreshToken, ActionListener> listener) { - final String authString = (String) userTokenSource.get("authentication"); - final Integer version = (Integer) userTokenSource.get("version"); - final Map metadata = (Map) userTokenSource.get("metadata"); - final String id = (String) userTokenSource.get("id"); - final Long expiration = (Long) userTokenSource.get("expiration_time"); - - Version authVersion = Version.fromId(version); - try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { - in.setVersion(authVersion); - Authentication authentication = new Authentication(in); - UserToken userToken = new UserToken(id, authVersion, authentication, Instant.ofEpochMilli(expiration), metadata); - listener.onResponse(new Tuple<>(userToken, refreshToken)); - } catch (IOException e) { - logger.error("Unable to decode existing user token", e); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } - } - /** * Looks in the context to see if the request provided a header with a user token and if so the - * token is validated, which might include authenticated decryption and verification that the token + * token is validated, which includes authenticated decryption and verification that the token * has not been revoked or is expired. */ void getAndValidateToken(ThreadContext ctx, ActionListener listener) { @@ -367,78 +329,23 @@ public void getAuthenticationAndMetaData(String token, ActionListener listener) { - if (securityIndex.isAvailable() == false) { - logger.warn("failed to get token [{}] since index is not available", userTokenId); - listener.onResponse(null); - } else { - securityIndex.checkIndexVersionThenExecute( - ex -> listener.onFailure(traceLog("prepare security index", userTokenId, ex)), - () -> { - final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, - getTokenDocumentId(userTokenId)).request(); - Consumer onFailure = ex -> listener.onFailure(traceLog("decode token", userTokenId, ex)); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, - ActionListener.wrap(response -> { - if (response.isExists()) { - Map accessTokenSource = - (Map) response.getSource().get("access_token"); - if (accessTokenSource == null) { - onFailure.accept(new IllegalStateException( - "token document is missing the access_token field")); - } else if (accessTokenSource.containsKey("user_token") == false) { - onFailure.accept(new IllegalStateException( - "token document is missing the user_token field")); - } else { - Map userTokenSource = - (Map) accessTokenSource.get("user_token"); - listener.onResponse(UserToken.fromSourceMap(userTokenSource)); - } - } else { - onFailure.accept( - new IllegalStateException("token document is missing and must be present")); - } - }, e -> { - // if the index or the shard is not there / available we assume that - // the token is not valid - if (isShardNotAvailableException(e)) { - logger.warn("failed to get token [{}] since index is not available", userTokenId); - listener.onResponse(null); - } else { - logger.error(new ParameterizedMessage("failed to get token [{}]", userTokenId), e); - listener.onFailure(e); - } - }), client::get); - }); - } - } - - /* - * If needed, for tokens that were created in a pre 7.1.0 cluster, it asynchronously decodes the token to get the token document Id. - * The process for this is asynchronous as we may need to compute a key, which can be computationally expensive - * so this should not block the current thread, which is typically a network thread. A second reason for being asynchronous is that - * we can restrain the amount of resources consumed by the key computation to a single thread. - * For tokens created in an after 7.1.0 cluster, the token is just the token document Id so this is used directly without decryption + * Asynchronously decodes the string representation of a {@link UserToken}. The process for + * this is asynchronous as we may need to compute a key, which can be computationally expensive + * so this should not block the current thread, which is typically a network thread. A second + * reason for being asynchronous is that we can restrain the amount of resources consumed by + * the key computation to a single thread. */ void decodeToken(String token, ActionListener listener) throws IOException { // We intentionally do not use try-with resources since we need to keep the stream open if we need to compute a key! byte[] bytes = token.getBytes(StandardCharsets.UTF_8); StreamInput in = new InputStreamStreamInput(Base64.getDecoder().wrap(new ByteArrayInputStream(bytes)), bytes.length); - final Version version = Version.readVersion(in); - if (version.onOrAfter(Version.V_8_0_0)) { - // The token was created in a > 7.1.0 cluster so it contains the tokenId as a String - String usedTokenId = in.readString(); - getUserTokenFromId(usedTokenId, listener); + if (in.available() < MINIMUM_BASE64_BYTES) { + logger.debug("invalid token"); + listener.onResponse(null); } else { - // The token was created in a < 7.1.0 cluster so we need to decrypt it to get the tokenId + // the token exists and the value is at least as long as we'd expect + final Version version = Version.readVersion(in); in.setVersion(version); - if (in.available() < MINIMUM_BASE64_BYTES) { - logger.debug("invalid token, smaller than [{}] bytes", MINIMUM_BASE64_BYTES); - listener.onResponse(null); - return; - } final BytesKey decodedSalt = new BytesKey(in.readByteArray()); final BytesKey passphraseHash = new BytesKey(in.readByteArray()); KeyAndCache keyAndCache = keyCache.get(passphraseHash); @@ -447,8 +354,51 @@ void decodeToken(String token, ActionListener listener) throws IOExce try { final byte[] iv = in.readByteArray(); final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt); - decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> getUserTokenFromId(tokenId, listener), - listener::onFailure)); + decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> { + if (securityIndex.isAvailable() == false) { + logger.warn("failed to get token [{}] since index is not available", tokenId); + listener.onResponse(null); + } else { + securityIndex.checkIndexVersionThenExecute( + ex -> listener.onFailure(traceLog("prepare security index", tokenId, ex)), + () -> { + final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, + getTokenDocumentId(tokenId)).request(); + Consumer onFailure = ex -> listener.onFailure(traceLog("decode token", tokenId, ex)); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, + ActionListener.wrap(response -> { + if (response.isExists()) { + Map accessTokenSource = + (Map) response.getSource().get("access_token"); + if (accessTokenSource == null) { + onFailure.accept(new IllegalStateException( + "token document is missing the access_token field")); + } else if (accessTokenSource.containsKey("user_token") == false) { + onFailure.accept(new IllegalStateException( + "token document is missing the user_token field")); + } else { + Map userTokenSource = + (Map) accessTokenSource.get("user_token"); + listener.onResponse(UserToken.fromSourceMap(userTokenSource)); + } + } else { + onFailure.accept( + new IllegalStateException("token document is missing and must be present")); + } + }, e -> { + // if the index or the shard is not there / available we assume that + // the token is not valid + if (isShardNotAvailableException(e)) { + logger.warn("failed to get token [{}] since index is not available", tokenId); + listener.onResponse(null); + } else { + logger.error(new ParameterizedMessage("failed to get token [{}]", tokenId), e); + listener.onFailure(e); + } + }), client::get); + }); + } + }, listener::onFailure)); } catch (GeneralSecurityException e) { // could happen with a token that is not ours logger.warn("invalid token", e); @@ -492,8 +442,8 @@ private static void decryptTokenId(StreamInput in, Cipher cipher, Version versio /** * This method performs the steps necessary to invalidate a token so that it may no longer be - * used. The process of invalidation involves performing an update to the token document and setting - * the invalidated field to true + * used. The process of invalidation involves performing an update to + * the token document and setting the invalidated field to true */ public void invalidateAccessToken(String tokenString, ActionListener listener) { ensureEnabled(); @@ -502,13 +452,12 @@ public void invalidateAccessToken(String tokenString, ActionListener backoff = DEFAULT_BACKOFF.iterator(); try { decodeToken(tokenString, ActionListener.wrap(userToken -> { if (userToken == null) { listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException())); } else { - indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff, + indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), "access_token", null); } }, listener::onFailure)); @@ -531,14 +480,12 @@ public void invalidateAccessToken(UserToken userToken, ActionListener backoff = DEFAULT_BACKOFF.iterator(); - indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff, "access_token", null); + indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), "access_token", null); } } /** - * This method onvalidates a refresh token so that it may no longer be used. Iinvalidation involves performing an update to the token - * document and setting the refresh_token.invalidated field to true + * This method performs the steps necessary to invalidate a refresh token so that it may no longer be used. * * @param refreshToken The string representation of the refresh token * @param listener the listener to notify upon completion @@ -550,17 +497,16 @@ public void invalidateRefreshToken(String refreshToken, ActionListener backoff = DEFAULT_BACKOFF.iterator(); findTokenFromRefreshToken(refreshToken, - ActionListener.wrap(searchResponse -> { - final String docId = getTokenIdFromDocumentId(searchResponse.getHits().getAt(0).getId()); - indexInvalidation(Collections.singletonList(docId), listener, backoff, "refresh_token", null); - }, listener::onFailure), backoff); + ActionListener.wrap(tuple -> { + final String docId = getTokenIdFromDocumentId(tuple.v1().getHits().getAt(0).getId()); + indexInvalidation(Collections.singletonList(docId), listener, tuple.v2(), "refresh_token", null); + }, listener::onFailure), new AtomicInteger(0)); } } /** - * Invalidates all access tokens and all refresh tokens of a given {@code realmName} and/or of a given + * Invalidate all access tokens and all refresh tokens of a given {@code realmName} and/or of a given * {@code username} so that they may no longer be used * * @param realmName the realm of which the tokens should be invalidated @@ -611,30 +557,32 @@ private void invalidateAllTokens(Collection accessTokenIds, ActionListen maybeStartTokenRemover(); // Invalidate the refresh tokens first so that they cannot be used to get new // access tokens while we invalidate the access tokens we currently know about - final Iterator backoff = DEFAULT_BACKOFF.iterator(); indexInvalidation(accessTokenIds, ActionListener.wrap(result -> - indexInvalidation(accessTokenIds, listener, backoff, "access_token", result), - listener::onFailure), backoff, "refresh_token", null); + indexInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()), + "access_token", result), + listener::onFailure), new AtomicInteger(0), "refresh_token", null); } /** - * Performs the actual invalidation of a collection of tokens. In case of recoverable errors ( see - * {@link TransportActions#isShardNotAvailableException} ) the UpdateRequests to mark the tokens as invalidated are retried using - * an exponential backoff policy. + * Performs the actual invalidation of a collection of tokens * * @param tokenIds the tokens to invalidate * @param listener the listener to notify upon completion - * @param backoff the amount of time to delay between attempts + * @param attemptCount the number of attempts to invalidate that have already been tried * @param srcPrefix the prefix to use when constructing the doc to update, either refresh_token or access_token depending on * what type of tokens should be invalidated * @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating * tokens up to the point of the retry. This result is added to the result of the current attempt */ private void indexInvalidation(Collection tokenIds, ActionListener listener, - Iterator backoff, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { + AtomicInteger attemptCount, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { if (tokenIds.isEmpty()) { logger.warn("No [{}] tokens provided for invalidation", srcPrefix); listener.onFailure(invalidGrantException("No tokens provided for invalidation")); + } else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { + logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(), + attemptCount.get()); + listener.onFailure(invalidGrantException("failed to invalidate tokens")); } else { BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); for (String tokenId : tokenIds) { @@ -679,30 +627,20 @@ private void indexInvalidation(Collection tokenIds, ActionListener indexInvalidation(retryTokenDocIds, listener, backoff, srcPrefix, incompleteResult), - backoff.next(), GENERIC); - } else { - logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries", - retryTokenDocIds.size(), tokenIds.size()); - } - } else { - TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, - failedRequestResponses); - listener.onResponse(result); + TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses, attemptCount.get()); + attemptCount.incrementAndGet(); + indexInvalidation(retryTokenDocIds, listener, attemptCount, srcPrefix, incompleteResult); } + TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses, attemptCount.get()); + listener.onResponse(result); }, e -> { Throwable cause = ExceptionsHelper.unwrapCause(e); traceLog("invalidate tokens", cause); - if (isShardNotAvailableException(cause) && backoff.hasNext()) { - logger.debug("failed to invalidate tokens, retrying "); - client.threadPool().schedule( - () -> indexInvalidation(tokenIds, listener, backoff, srcPrefix, previousResult), backoff.next(), GENERIC); + if (isShardNotAvailableException(cause)) { + attemptCount.incrementAndGet(); + indexInvalidation(tokenIds, listener, attemptCount, srcPrefix, previousResult); } else { listener.onFailure(e); } @@ -711,272 +649,142 @@ private void indexInvalidation(Collection tokenIds, ActionListener> listener) { ensureEnabled(); - final Instant refreshRequested = clock.instant(); - final Iterator backoff = DEFAULT_BACKOFF.iterator(); findTokenFromRefreshToken(refreshToken, - ActionListener.wrap(searchResponse -> { + ActionListener.wrap(tuple -> { final Authentication clientAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); - final SearchHit tokenDocHit = searchResponse.getHits().getHits()[0]; - final String tokenDocId = tokenDocHit.getId(); - innerRefresh(tokenDocId, tokenDocHit.getSourceAsMap(), tokenDocHit.getSeqNo(), tokenDocHit.getPrimaryTerm(), clientAuth, - listener, backoff, refreshRequested); + final String tokenDocId = tuple.v1().getHits().getHits()[0].getId(); + innerRefresh(tokenDocId, clientAuth, listener, tuple.v2()); }, listener::onFailure), - backoff); + new AtomicInteger(0)); } - /** - * Performs an asynchronous search request for the token document that contains the {@code refreshToken} and calls the listener with the - * {@link SearchResponse}. In case of recoverable errors the SearchRequest is retried using an exponential backoff policy. - */ - private void findTokenFromRefreshToken(String refreshToken, ActionListener listener, - Iterator backoff) { - SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) - .setQuery(QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE)) - .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) - .seqNoAndPrimaryTerm(true) - .request(); - - final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); - if (frozenSecurityIndex.indexExists() == false) { - logger.warn("security index does not exist therefore refresh token [{}] cannot be validated", refreshToken); + private void findTokenFromRefreshToken(String refreshToken, ActionListener> listener, + AtomicInteger attemptCount) { + if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { + logger.warn("Failed to find token for refresh token [{}] after [{}] attempts", refreshToken, attemptCount.get()); listener.onFailure(invalidGrantException("could not refresh the requested token")); - } else if (frozenSecurityIndex.isAvailable() == false) { - logger.debug("security index is not available to find token from refresh token, retrying"); - client.threadPool().scheduleWithFixedDelay( - () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); } else { - Consumer onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex)); - securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, - ActionListener.wrap(searchResponse -> { - if (searchResponse.isTimedOut()) { - if (backoff.hasNext()) { - client.threadPool().scheduleWithFixedDelay( - () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); - } else { - logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken); + SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) + .setQuery(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE)) + .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) + .setVersion(true) + .request(); + + final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); + if (frozenSecurityIndex.indexExists() == false) { + logger.warn("security index does not exist therefore refresh token [{}] cannot be validated", refreshToken); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } else if (frozenSecurityIndex.isAvailable() == false) { + logger.debug("security index is not available to find token from refresh token, retrying"); + attemptCount.incrementAndGet(); + findTokenFromRefreshToken(refreshToken, listener, attemptCount); + } else { + Consumer onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex)); + securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, + ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + attemptCount.incrementAndGet(); + findTokenFromRefreshToken(refreshToken, listener, attemptCount); + } else if (searchResponse.getHits().getHits().length < 1) { + logger.info("could not find token document with refresh_token [{}]", refreshToken); onFailure.accept(invalidGrantException("could not refresh the requested token")); + } else if (searchResponse.getHits().getHits().length > 1) { + onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token")); + } else { + listener.onResponse(new Tuple<>(searchResponse, attemptCount)); } - } else if (searchResponse.getHits().getHits().length < 1) { - logger.warn("could not find token document with refresh_token [{}]", refreshToken); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } else if (searchResponse.getHits().getHits().length > 1) { - onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token")); - } else { - listener.onResponse(searchResponse); - } - }, e -> { - if (isShardNotAvailableException(e)) { - if (backoff.hasNext()) { - logger.debug("failed to find token for refresh token [{}], retrying", refreshToken); - client.threadPool().scheduleWithFixedDelay( - () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); + }, e -> { + if (isShardNotAvailableException(e)) { + logger.debug("failed to search for token document, retrying", e); + attemptCount.incrementAndGet(); + findTokenFromRefreshToken(refreshToken, listener, attemptCount); } else { - logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken); - onFailure.accept(invalidGrantException("could not refresh the requested token")); + onFailure.accept(e); } - } else { - onFailure.accept(e); - } - }), - client::search)); + }), + client::search)); + } } } /** - * Performs the actual refresh of the token with retries in case of certain exceptions that may be recoverable. The - * refresh involves two steps: - * First, we check if the token document is still valid for refresh ({@link TokenService#checkTokenDocForRefresh(Map, Authentication)} - * Then, in the case that the token has been refreshed within the previous 30 seconds (see - * {@link TokenService#checkLenientlyIfTokenAlreadyRefreshed(Map, Authentication)}), we do not create a new token document - * but instead retrieve the one that was created by the original refresh and return a user token and - * refresh token based on that ( see {@link TokenService#reIssueTokens(Map, String, ActionListener)} ). - * Otherwise this token document gets its refresh_token marked as refreshed, while also storing the Instant when it was - * refreshed along with a pointer to the new token document that holds the refresh_token that supersedes this one. The new - * document that contains the new access token and refresh token is created and finally the new access token and refresh token are - * returned to the listener. + * Performs the actual refresh of the token with retries in case of certain exceptions that + * may be recoverable. The refresh involves retrieval of the token document and then + * updating the token document to indicate that the document has been refreshed. */ - private void innerRefresh(String tokenDocId, Map source, long seqNo, long primaryTerm, Authentication clientAuth, - ActionListener> listener, Iterator backoff, Instant refreshRequested) { - logger.debug("Attempting to refresh token [{}]", tokenDocId); - Consumer onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); - final Optional invalidSource = checkTokenDocForRefresh(source, clientAuth); - if (invalidSource.isPresent()) { - onFailure.accept(invalidSource.get()); + private void innerRefresh(String tokenDocId, Authentication clientAuth, ActionListener> listener, + AtomicInteger attemptCount) { + if (attemptCount.getAndIncrement() > MAX_RETRY_ATTEMPTS) { + logger.warn("Failed to refresh token for doc [{}] after [{}] attempts", tokenDocId, attemptCount.get()); + listener.onFailure(invalidGrantException("could not refresh the requested token")); } else { - if (eligibleForMultiRefresh(source, refreshRequested)) { - final Map refreshTokenSrc = (Map) source.get("refresh_token"); - final String supersedingTokenDocId = (String) refreshTokenSrc.get("superseded_by"); - logger.debug("Token document [{}] was recently refreshed, attempting to reuse [{}] for returning an " + - "access token and refresh token", tokenDocId, supersedingTokenDocId); - final ActionListener getSupersedingListener = new ActionListener() { - @Override - public void onResponse(GetResponse response) { - if (response.isExists()) { - logger.debug("Found superseding token document [{}] ", supersedingTokenDocId); - final Map supersedingTokenSource = response.getSource(); - final Map supersedingUserTokenSource = (Map) - ((Map) supersedingTokenSource.get("access_token")).get("user_token"); - final Map supersedingRefreshTokenSrc = - (Map) supersedingTokenSource.get("refresh_token"); - final String supersedingRefreshTokenValue = (String) supersedingRefreshTokenSrc.get("token"); - reIssueTokens(supersedingUserTokenSource, supersedingRefreshTokenValue, listener); - } else if (backoff.hasNext()) { - // We retry this since the creation of the superseding token document might already be in flight but not - // yet completed, triggered by a refresh request that came a few milliseconds ago - logger.info("could not find superseding token document [{}] for token document [{}], retrying", - supersedingTokenDocId, tokenDocId); - client.threadPool().schedule(() -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC); - } else { - logger.warn("could not find superseding token document [{}] for token document [{}] after all retries", - supersedingTokenDocId, tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } - } - - @Override - public void onFailure(Exception e) { - if (isShardNotAvailableException(e)) { - if (backoff.hasNext()) { - logger.info("could not find superseding token document [{}] for refresh, retrying", supersedingTokenDocId); - client.threadPool().schedule( - () -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC); - } else { - logger.warn("could not find token document [{}] for refresh after all retries", supersedingTokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } + Consumer onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); + GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, + ActionListener.wrap(response -> { + if (response.isExists()) { + final Map source = response.getSource(); + final Optional invalidSource = checkTokenDocForRefresh(source, clientAuth); + + if (invalidSource.isPresent()) { + onFailure.accept(invalidSource.get()); } else { - logger.warn("could not find superseding token document [{}] for refresh", supersedingTokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } - } - }; - getTokenDocAsync(supersedingTokenDocId, getSupersedingListener); - } else { - final Map userTokenSource = (Map) - ((Map) source.get("access_token")).get("user_token"); - final String authString = (String) userTokenSource.get("authentication"); - final Integer version = (Integer) userTokenSource.get("version"); - final Map metadata = (Map) userTokenSource.get("metadata"); - Version authVersion = Version.fromId(version); - Authentication authentication; - try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { - in.setVersion(authVersion); - authentication = new Authentication(in); - } catch (IOException e) { - logger.error("failed to decode the authentication stored with token document [{}]", tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - return; - } - final String newUserTokenId = UUIDs.randomBase64UUID(); - final Instant refreshTime = clock.instant(); - Map updateMap = new HashMap<>(); - updateMap.put("refreshed", true); - updateMap.put("refresh_time", refreshTime.toEpochMilli()); - updateMap.put("superseded_by", getTokenDocumentId(newUserTokenId)); - UpdateRequestBuilder updateRequest = - client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) - .setDoc("refresh_token", updateMap) - .setFetchSource(true) - .setRefreshPolicy(RefreshPolicy.IMMEDIATE); - assert seqNo != SequenceNumbers.UNASSIGNED_SEQ_NO : "expected an assigned sequence number"; - updateRequest.setIfSeqNo(seqNo); - assert primaryTerm != SequenceNumbers.UNASSIGNED_PRIMARY_TERM : "expected an assigned primary term"; - updateRequest.setIfPrimaryTerm(primaryTerm); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(), - ActionListener.wrap( - updateResponse -> { - if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { - logger.debug("updated the original token document to {}", updateResponse.getGetResult().sourceAsMap()); - createUserToken(newUserTokenId, authentication, clientAuth, listener, metadata, true); - } else if (backoff.hasNext()) { - logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying", - tokenDocId, updateResponse.getResult()); - client.threadPool().schedule( - () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff, - refreshRequested), - backoff.next(), GENERIC); - } else { - logger.info("failed to update the original token document [{}] after all retries, " + - "the update result was [{}]. ", tokenDocId, updateResponse.getResult()); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } - }, e -> { - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (cause instanceof VersionConflictEngineException) { - //The document has been updated by another thread, get it again. - if (backoff.hasNext()) { - logger.debug("version conflict while updating document [{}], attempting to get it again", - tokenDocId); - final ActionListener getListener = new ActionListener() { - @Override - public void onResponse(GetResponse response) { - if (response.isExists()) { - innerRefresh(tokenDocId, response.getSource(), response.getSeqNo(), - response.getPrimaryTerm(), clientAuth, listener, backoff, refreshRequested); - } else { - logger.warn("could not find token document [{}] for refresh", tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } - } - - @Override - public void onFailure(Exception e) { - if (isShardNotAvailableException(e)) { - if (backoff.hasNext()) { - logger.info("could not get token document [{}] for refresh, " + - "retrying", tokenDocId); - client.threadPool().schedule( - () -> getTokenDocAsync(tokenDocId, this), backoff.next(), GENERIC); - } else { - logger.warn("could not get token document [{}] for refresh after all retries", - tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final String authString = (String) userTokenSource.get("authentication"); + final Integer version = (Integer) userTokenSource.get("version"); + final Map metadata = (Map) userTokenSource.get("metadata"); + + Version authVersion = Version.fromId(version); + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { + in.setVersion(authVersion); + Authentication authentication = new Authentication(in); + UpdateRequestBuilder updateRequest = + client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) + .setDoc("refresh_token", Collections.singletonMap("refreshed", true)) + .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); + updateRequest.setIfSeqNo(response.getSeqNo()); + updateRequest.setIfPrimaryTerm(response.getPrimaryTerm()); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(), + ActionListener.wrap( + updateResponse -> createUserToken(authentication, clientAuth, listener, metadata, true), + e -> { + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof VersionConflictEngineException || + isShardNotAvailableException(e)) { + innerRefresh(tokenDocId, clientAuth, + listener, attemptCount); } else { onFailure.accept(e); } - } - }; - getTokenDocAsync(tokenDocId, getListener); - } else { - logger.warn("version conflict while updating document [{}], no retries left", tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } - } else if (isShardNotAvailableException(e)) { - if (backoff.hasNext()) { - logger.debug("failed to update the original token document [{}], retrying", tokenDocId); - client.threadPool().schedule( - () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff, - refreshRequested), - backoff.next(), GENERIC); - } else { - logger.warn("failed to update the original token document [{}], after all retries", tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } - } else { - onFailure.accept(e); + }), + client::update); } - }), - client::update); - } + } + } else { + logger.info("could not find token document [{}] for refresh", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + }, e -> { + if (isShardNotAvailableException(e)) { + innerRefresh(tokenDocId, clientAuth, listener, attemptCount); + } else { + listener.onFailure(e); + } + }), client::get); } } - private void getTokenDocAsync(String tokenDocId, ActionListener listener) { - GetRequest getRequest = - client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, listener, client::get); - } - /** * Performs checks on the retrieved source and returns an {@link Optional} with the exception - * if there is an issue that makes the retrieved token unsuitable to be refreshed + * if there is an issue */ private Optional checkTokenDocForRefresh(Map source, Authentication clientAuth) { final Map refreshTokenSrc = (Map) source.get("refresh_token"); @@ -997,6 +805,8 @@ private Optional checkTokenDocForRefresh(Map checkTokenDocForRefresh(Map checkClient(Map if (clientInfo == null) { return Optional.of(invalidGrantException("token is missing client information")); } else if (clientAuth.getUser().principal().equals(clientInfo.get("user")) == false) { - logger.warn("Token was originally created by [{}] but [{}] attempted to refresh it", clientInfo.get("user"), - clientAuth.getUser().principal()); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); } else if (clientAuth.getAuthenticatedBy().getName().equals(clientInfo.get("realm")) == false) { - logger.warn("[{}] created the refresh token while authenticated by [{}] but is now authenticated by [{}]", - clientInfo.get("user"), clientInfo.get("realm"), clientAuth.getAuthenticatedBy().getName()); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); } else { return Optional.empty(); } } - /** - * Checks if the retrieved refresh token is already refreshed taking into consideration that we allow refresh tokens - * to be refreshed multiple times for a very small time window in order to gracefully handle multiple concurrent requests - * from clients - */ - @SuppressWarnings("unchecked") - private Optional checkLenientlyIfTokenAlreadyRefreshed(Map source, - Authentication userAuth) { - final Map refreshTokenSrc = (Map) source.get("refresh_token"); - final Map userTokenSource = (Map) - ((Map) source.get("access_token")).get("user_token"); - final Integer version = (Integer) userTokenSource.get("version"); - Version authVersion = Version.fromId(version); - final Boolean refreshed = (Boolean) refreshTokenSrc.get("refreshed"); - if (refreshed) { - if (authVersion.onOrAfter(Version.V_8_0_0)) { - final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time"); - final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli); - final String supersededBy = (String) refreshTokenSrc.get("superseded_by"); - if (supersededBy == null) { - return Optional.of(invalidGrantException("token document is missing superseded by value")); - } else if (refreshTime == null) { - return Optional.of(invalidGrantException("token document is missing refresh time value")); - } else if (clock.instant().isAfter(refreshTime.plus(30L, ChronoUnit.SECONDS))) { - return Optional.of(invalidGrantException("token has already been refreshed more than 30 seconds in the past")); - } - } else { - return Optional.of(invalidGrantException("token has already been refreshed")); - } - } - return checkClient(refreshTokenSrc, userAuth); - } - - /** - * Checks if a refreshed token is eligible to be refreshed again. This is only allowed for versions after 7.1.0 and - * when the refresh_token contains the refresh_time and superseded_by fields and it has been refreshed in a specific - * time period of 60 seconds. The period is defined as 30 seconds before the token was refreshed until 30 seconds after. The - * time window needs to handle instants before the request time as we capture an instant early on in - * {@link TokenService#refreshToken(String, ActionListener)} and in the case of multiple concurrent requests, - * the {@code refreshRequested} when dealing with one of the subsequent requests might well be before the instant when - * the first of the requests refreshed the token. - * - * @param source The source of the token document that contains the originally refreshed token - * @param refreshRequested The instant when the this refresh request was acknowledged by the TokenService - */ - private boolean eligibleForMultiRefresh(Map source, Instant refreshRequested) { - final Map refreshTokenSrc = (Map) source.get("refresh_token"); - final Map userTokenSource = (Map) - ((Map) source.get("access_token")).get("user_token"); - final Integer version = (Integer) userTokenSource.get("version"); - Version authVersion = Version.fromId(version); - final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time"); - final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli); - final String supersededBy = (String) refreshTokenSrc.get("superseded_by"); - return authVersion.onOrAfter(Version.V_8_0_0) - && supersededBy != null - && refreshTime != null - && refreshRequested.isBefore(refreshTime.plus(30L, ChronoUnit.SECONDS)) - && refreshRequested.isAfter(refreshTime.minus(30L, ChronoUnit.SECONDS)); - } - /** * Find stored refresh and access tokens that have not been invalidated or expired, and were issued against - * the specified realm. + * the specified realm. * * @param realmName The name of the realm for which to get the tokens - * @param listener The listener to notify upon completion - * @param filter an optional Predicate to test the source of the found documents against + * @param listener The listener to notify upon completion + * @param filter an optional Predicate to test the source of the found documents against */ public void findActiveTokensForRealm(String realmName, ActionListener>> listener, @Nullable Predicate> filter) { @@ -1148,6 +893,7 @@ public void findActiveTokensForRealm(String realmName, ActionListener>> listener) { ensureEnabled(); + final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); if (Strings.isNullOrEmpty(username)) { listener.onFailure(new IllegalArgumentException("username is required")); @@ -1212,15 +958,17 @@ private Tuple filterAndParseHit(SearchHit hit, @Nullable Pred } /** + * * Parses a token document into a Tuple of a {@link UserToken} and a String representing the corresponding refresh_token * * @param source The token document source as retrieved * @param filter an optional Predicate to test the source of the UserToken against * @return A {@link Tuple} of access-token and refresh-token-id or null if a Predicate is defined and the userToken source doesn't - * satisfy it + * satisfy it */ private Tuple parseTokensFromDocument(Map source, @Nullable Predicate> filter) throws IOException { + final String refreshToken = (String) ((Map) source.get("refresh_token")).get("token"); final Map userTokenSource = (Map) ((Map) source.get("access_token")).get("user_token"); @@ -1238,7 +986,7 @@ private Tuple parseTokensFromDocument(Map sou in.setVersion(authVersion); Authentication authentication = new Authentication(in); return new Tuple<>(new UserToken(id, Version.fromId(version), authentication, Instant.ofEpochMilli(expiration), metadata), - refreshToken); + refreshToken); } } @@ -1315,6 +1063,7 @@ private void checkIfTokenIsValid(UserToken userToken, ActionListener } } + public TimeValue getExpirationDelay() { return expirationDelay; } @@ -1339,47 +1088,34 @@ private void maybeStartTokenRemover() { private String getFromHeader(ThreadContext threadContext) { String header = threadContext.getHeader("Authorization"); if (Strings.hasText(header) && header.regionMatches(true, 0, "Bearer ", 0, "Bearer ".length()) - && header.length() > "Bearer ".length()) { + && header.length() > "Bearer ".length()) { return header.substring("Bearer ".length()); } return null; } /** - * Serializes a token to a String containing the version of the node that created the token and - * either an encrypted representation of the token id for versions earlier to 7.0.0 or the token ie - * itself for versions after 7.0.0 + * Serializes a token to a String containing an encrypted representation of the token */ - public String getAccessTokenAsString(UserToken userToken) throws IOException, GeneralSecurityException { - if (clusterService.state().nodes().getMinNodeVersion().onOrAfter(Version.V_8_0_0)) { - try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); - OutputStream base64 = Base64.getEncoder().wrap(os); - StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(userToken.getVersion()); - Version.writeVersion(userToken.getVersion(), out); - out.writeString(userToken.getId()); - return new String(os.toByteArray(), StandardCharsets.UTF_8); - } - } else { - // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly - try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); - OutputStream base64 = Base64.getEncoder().wrap(os); - StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(userToken.getVersion()); - KeyAndCache keyAndCache = keyCache.activeKeyCache; - Version.writeVersion(userToken.getVersion(), out); - out.writeByteArray(keyAndCache.getSalt().bytes); - out.writeByteArray(keyAndCache.getKeyHash().bytes); - final byte[] initializationVector = getNewInitializationVector(); - out.writeByteArray(initializationVector); - try (CipherOutputStream encryptedOutput = + public String getUserTokenString(UserToken userToken) throws IOException, GeneralSecurityException { + // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly + try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); + OutputStream base64 = Base64.getEncoder().wrap(os); + StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(userToken.getVersion()); + KeyAndCache keyAndCache = keyCache.activeKeyCache; + Version.writeVersion(userToken.getVersion(), out); + out.writeByteArray(keyAndCache.getSalt().bytes); + out.writeByteArray(keyAndCache.getKeyHash().bytes); + final byte[] initializationVector = getNewInitializationVector(); + out.writeByteArray(initializationVector); + try (CipherOutputStream encryptedOutput = new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion())); - StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { - encryptedStreamOutput.setVersion(userToken.getVersion()); - encryptedStreamOutput.writeString(userToken.getId()); - encryptedStreamOutput.close(); - return new String(os.toByteArray(), StandardCharsets.UTF_8); - } + StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { + encryptedStreamOutput.setVersion(userToken.getVersion()); + encryptedStreamOutput.writeString(userToken.getId()); + encryptedStreamOutput.close(); + return new String(os.toByteArray(), StandardCharsets.UTF_8); } } } @@ -1389,8 +1125,7 @@ private void ensureEncryptionCiphersSupported() throws NoSuchPaddingException, N SecretKeyFactory.getInstance(KDF_ALGORITHM); } - // Package private for testing - Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { + private Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER); BytesKey salt = keyAndCache.getSalt(); try { @@ -1412,8 +1147,7 @@ private Cipher getDecryptionCipher(byte[] iv, SecretKey key, Version version, return cipher; } - // Package private for testing - byte[] getNewInitializationVector() { + private byte[] getNewInitializationVector() { final byte[] initializationVector = new byte[IV_BYTES]; secureRandom.nextBytes(initializationVector); return initializationVector; @@ -1424,7 +1158,7 @@ byte[] getNewInitializationVector() { * This method is computationally expensive. */ static SecretKey computeSecretKey(char[] rawPassword, byte[] salt) - throws NoSuchAlgorithmException, InvalidKeySpecException { + throws NoSuchAlgorithmException, InvalidKeySpecException { SecretKeyFactory secretKeyFactory = SecretKeyFactory.getInstance(KDF_ALGORITHM); PBEKeySpec keySpec = new PBEKeySpec(rawPassword, salt, ITERATIONS, 128); SecretKey tmp = secretKeyFactory.generateSecret(keySpec); @@ -1438,7 +1172,7 @@ static SecretKey computeSecretKey(char[] rawPassword, byte[] salt) */ private static ElasticsearchSecurityException expiredTokenException() { ElasticsearchSecurityException e = - new ElasticsearchSecurityException("token expired", RestStatus.UNAUTHORIZED); + new ElasticsearchSecurityException("token expired", RestStatus.UNAUTHORIZED); e.addHeader("WWW-Authenticate", EXPIRED_TOKEN_WWW_AUTH_VALUE); return e; } @@ -1535,8 +1269,8 @@ protected void doRun() { listener.onResponse(computedKey); } catch (ExecutionException e) { if (e.getCause() != null && - (e.getCause() instanceof GeneralSecurityException || e.getCause() instanceof IOException - || e.getCause() instanceof IllegalArgumentException)) { + (e.getCause() instanceof GeneralSecurityException || e.getCause() instanceof IOException + || e.getCause() instanceof IllegalArgumentException)) { // this could happen if another realm supports the Bearer token so we should // see if another realm can use this token! logger.debug("unable to decode bearer token", e); @@ -1571,7 +1305,7 @@ synchronized TokenMetaData generateSpareKey() { continue; // collision -- generate a new key } return newTokenMetaData(keyCache.currentTokenKeyHash, Iterables.concat(keyCache.cache.values(), - Collections.singletonList(keyAndCache))); + Collections.singletonList(keyAndCache))); } } return newTokenMetaData(keyCache.currentTokenKeyHash, keyCache.cache.values()); @@ -1601,10 +1335,10 @@ synchronized TokenMetaData pruneKeys(int numKeysToKeep) { KeyAndCache currentKey = keyCache.get(keyCache.currentTokenKeyHash); ArrayList entries = new ArrayList<>(keyCache.cache.values()); Collections.sort(entries, - (left, right) -> Long.compare(right.keyAndTimestamp.getTimestamp(), left.keyAndTimestamp.getTimestamp())); + (left, right) -> Long.compare(right.keyAndTimestamp.getTimestamp(), left.keyAndTimestamp.getTimestamp())); for (KeyAndCache value : entries) { if (map.size() < numKeysToKeep || value.keyAndTimestamp.getTimestamp() >= currentKey - .keyAndTimestamp.getTimestamp()) { + .keyAndTimestamp.getTimestamp()) { logger.debug("keeping key {} ", value.getKeyHash()); map.put(value.getKeyHash(), value); } else { @@ -1683,16 +1417,16 @@ void rotateKeysOnMaster(ActionListener listener) { logger.info("rotate keys on master"); TokenMetaData tokenMetaData = generateSpareKey(); clusterService.submitStateUpdateTask("publish next key to prepare key rotation", - new TokenMetadataPublishAction( - ActionListener.wrap((res) -> { - if (res.isAcknowledged()) { - TokenMetaData metaData = rotateToSpareKey(); - clusterService.submitStateUpdateTask("publish next key to prepare key rotation", - new TokenMetadataPublishAction(listener, metaData)); - } else { - listener.onFailure(new IllegalStateException("not acked")); - } - }, listener::onFailure), tokenMetaData)); + new TokenMetadataPublishAction( + ActionListener.wrap((res) -> { + if (res.isAcknowledged()) { + TokenMetaData metaData = rotateToSpareKey(); + clusterService.submitStateUpdateTask("publish next key to prepare key rotation", + new TokenMetadataPublishAction(listener, metaData)); + } else { + listener.onFailure(new IllegalStateException("not acked")); + } + }, listener::onFailure), tokenMetaData)); } private final class TokenMetadataPublishAction extends AckedClusterStateUpdateTask { @@ -1794,19 +1528,12 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS } /** - * Package private for testing + * For testing */ void clearActiveKeyCache() { this.keyCache.activeKeyCache.keyCache.invalidateAll(); } - /** - * Package private for testing - */ - KeyAndCache getActiveKeyCache() { - return this.keyCache.activeKeyCache; - } - static final class KeyAndCache implements Closeable { private final KeyAndTimestamp keyAndTimestamp; private final Cache keyCache; @@ -1816,9 +1543,9 @@ static final class KeyAndCache implements Closeable { private KeyAndCache(KeyAndTimestamp keyAndTimestamp, BytesKey salt) { this.keyAndTimestamp = keyAndTimestamp; keyCache = CacheBuilder.builder() - .setExpireAfterAccess(TimeValue.timeValueMinutes(60L)) - .setMaximumWeight(500L) - .build(); + .setExpireAfterAccess(TimeValue.timeValueMinutes(60L)) + .setMaximumWeight(500L) + .build(); try { SecretKey secretKey = computeSecretKey(keyAndTimestamp.getKey().getChars(), salt.bytes); keyCache.put(salt, secretKey); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java index 795cc9fb225d4..085df140f3ecb 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java @@ -242,7 +242,7 @@ public void testLogoutInvalidatesToken() throws Exception { tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true); final UserToken userToken = future.actionGet().v1(); mockGetTokenFromId(userToken, false, client); - final String tokenString = tokenService.getAccessTokenAsString(userToken); + final String tokenString = tokenService.getUserTokenString(userToken); final SamlLogoutRequest request = new SamlLogoutRequest(); request.setToken(tokenString); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index cda0586886c1f..5eee33711e26f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -1109,7 +1109,7 @@ public void testAuthenticateWithToken() throws Exception { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); } - String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); + String token = tokenService.getUserTokenString(tokenFuture.get().v1()); when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); mockGetTokenFromId(tokenFuture.get().v1(), false, client); when(securityIndex.isAvailable()).thenReturn(true); @@ -1192,7 +1192,7 @@ public void testExpiredToken() throws Exception { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); } - String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); + String token = tokenService.getUserTokenString(tokenFuture.get().v1()); mockGetTokenFromId(tokenFuture.get().v1(), true, client); doAnswer(invocationOnMock -> { ((Runnable) invocationOnMock.getArguments()[1]).run(); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java index 7499d8be7d18b..61ea4ef967224 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java @@ -5,14 +5,11 @@ */ package org.elasticsearch.xpack.security.authc; -import org.apache.directory.api.util.Strings; import org.elasticsearch.ElasticsearchSecurityException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse; import org.elasticsearch.common.settings.SecureString; @@ -26,7 +23,6 @@ import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xpack.core.XPackSettings; -import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse; @@ -42,13 +38,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -340,7 +330,7 @@ public void testRefreshingInvalidatedToken() { assertEquals("token has been invalidated", e.getHeader("error_description").get(0)); } - public void testRefreshingMultipleTimesFails() throws Exception { + public void testRefreshingMultipleTimes() { Client client = client().filterWithHeader(Collections.singletonMap("Authorization", UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); @@ -353,101 +343,12 @@ public void testRefreshingMultipleTimesFails() throws Exception { assertNotNull(createTokenResponse.getRefreshToken()); CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); assertNotNull(refreshResponse); - // We now have two documents, the original(now refreshed) token doc and the new one with the new access doc - AtomicReference docId = new AtomicReference<>(); - assertBusy(() -> { - SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) - .setSource(SearchSourceBuilder.searchSource() - .query(QueryBuilders.boolQuery() - .must(QueryBuilders.termQuery("doc_type", "token")) - .must(QueryBuilders.termQuery("refresh_token.refreshed", "true")))) - .setSize(1) - .setTerminateAfter(1) - .get(); - assertThat(searchResponse.getHits().getTotalHits().value, equalTo(1L)); - docId.set(searchResponse.getHits().getAt(0).getId()); - }); - // hack doc to modify the refresh time to 50 seconds ago so that we don't hit the lenient refresh case - Instant refreshed = Instant.now(); - Instant aWhileAgo = refreshed.minus(50L, ChronoUnit.SECONDS); - assertTrue(Instant.now().isAfter(aWhileAgo)); - UpdateResponse updateResponse = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, "doc", docId.get()) - .setDoc("refresh_token", Collections.singletonMap("refresh_time", aWhileAgo.toEpochMilli())) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setFetchSource("refresh_token", Strings.EMPTY_STRING) - .get(); - assertNotNull(updateResponse); - Map refreshTokenMap = (Map) updateResponse.getGetResult().sourceAsMap().get("refresh_token"); - assertTrue( - Instant.ofEpochMilli((long) refreshTokenMap.get("refresh_time")).isBefore(Instant.now().minus(30L, ChronoUnit.SECONDS))); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); assertEquals("invalid_grant", e.getMessage()); assertEquals(RestStatus.BAD_REQUEST, e.status()); - assertEquals("token has already been refreshed more than 30 seconds in the past", e.getHeader("error_description").get(0)); - } - - public void testRefreshingMultipleTimesWithinWindowSucceeds() throws Exception { - Client client = client().filterWithHeader(Collections.singletonMap("Authorization", - UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, - SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); - SecurityClient securityClient = new SecurityClient(client); - Set refreshTokens = new HashSet<>(); - Set accessTokens = new HashSet<>(); - CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() - .setGrantType("password") - .setUsername(SecuritySettingsSource.TEST_USER_NAME) - .setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray())) - .get(); - assertNotNull(createTokenResponse.getRefreshToken()); - final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); - final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3); - List threads = new ArrayList<>(numberOfThreads); - final CountDownLatch readyLatch = new CountDownLatch(numberOfThreads + 1); - final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads); - AtomicBoolean failed = new AtomicBoolean(); - for (int i = 0; i < numberOfThreads; i++) { - threads.add(new Thread(() -> { - // Each thread gets its own client so that more than one nodes will be hit - Client threadClient = client().filterWithHeader(Collections.singletonMap("Authorization", - UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, - SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); - SecurityClient threadSecurityClient = new SecurityClient(threadClient); - CreateTokenRequest refreshRequest = - threadSecurityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).request(); - readyLatch.countDown(); - try { - readyLatch.await(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - completedLatch.countDown(); - return; - } - threadSecurityClient.refreshToken(refreshRequest, ActionListener.wrap(result -> { - accessTokens.add(result.getTokenString()); - refreshTokens.add(result.getRefreshToken()); - logger.info("received access token [{}] and refresh token [{}]", result.getTokenString(), result.getRefreshToken()); - completedLatch.countDown(); - }, e -> { - failed.set(true); - completedLatch.countDown(); - logger.error("caught exception", e); - })); - })); - } - for (Thread thread : threads) { - thread.start(); - } - readyLatch.countDown(); - readyLatch.await(); - for (Thread thread : threads) { - thread.join(); - } - completedLatch.await(); - assertThat(failed.get(), equalTo(false)); - assertThat(accessTokens.size(), equalTo(1)); - assertThat(refreshTokens.size(), equalTo(1)); + assertEquals("token has already been refreshed", e.getHeader("error_description").get(0)); } public void testRefreshAsDifferentUser() { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 7efb4b51632d4..8caf82e8648cb 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.security.authc; import org.elasticsearch.ElasticsearchSecurityException; -import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.get.GetAction; @@ -24,8 +23,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; -import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -54,11 +51,7 @@ import org.junit.Before; import org.junit.BeforeClass; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; -import java.nio.charset.StandardCharsets; -import java.security.GeneralSecurityException; import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -68,7 +61,6 @@ import java.util.Map; import java.util.function.Consumer; -import javax.crypto.CipherOutputStream; import javax.crypto.SecretKey; import static java.time.Clock.systemUTC; @@ -159,7 +151,7 @@ public void testAttachAndGetToken() throws Exception { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getAccessTokenAsString(token)); + requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getUserTokenString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -206,7 +198,7 @@ public void testRotateKey() throws Exception { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -227,10 +219,10 @@ public void testRotateKey() throws Exception { tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); - assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); + assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); mockGetTokenFromId(newToken, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { @@ -255,7 +247,7 @@ public void testKeyExchange() throws Exception { rotateKeys(tokenService); } TokenService otherTokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, - clusterService); + clusterService); otherTokenService.refreshMetaData(tokenService.getTokenMetaData()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); PlainActionFuture> tokenFuture = new PlainActionFuture<>(); @@ -266,7 +258,7 @@ public void testKeyExchange() throws Exception { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); otherTokenService.getAndValidateToken(requestContext, future); @@ -297,7 +289,7 @@ public void testPruneKeys() throws Exception { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -324,7 +316,7 @@ public void testPruneKeys() throws Exception { tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); - assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); + assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); metaData = tokenService.pruneKeys(1); tokenService.refreshMetaData(metaData); @@ -337,7 +329,7 @@ public void testPruneKeys() throws Exception { } requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); mockGetTokenFromId(newToken, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -359,7 +351,7 @@ public void testPassphraseWorks() throws Exception { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -370,8 +362,8 @@ public void testPassphraseWorks() throws Exception { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { // verify a second separate token service with its own passphrase cannot verify - TokenService anotherService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, - clusterService); + TokenService anotherService = new TokenService(Settings.EMPTY, systemUTC(), client, securityIndex, + clusterService); PlainActionFuture future = new PlainActionFuture<>(); anotherService.getAndValidateToken(requestContext, future); assertNull(future.get()); @@ -385,10 +377,10 @@ public void testGetTokenWhenKeyCacheHasExpired() throws Exception { PlainActionFuture> tokenFuture = new PlainActionFuture<>(); tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true); UserToken token = tokenFuture.get().v1(); - assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); + assertThat(tokenService.getUserTokenString(token), notNullValue()); tokenService.clearActiveKeyCache(); - assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); + assertThat(tokenService.getUserTokenString(token), notNullValue()); } public void testInvalidatedToken() throws Exception { @@ -403,7 +395,7 @@ public void testInvalidatedToken() throws Exception { mockGetTokenFromId(token, true); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -457,7 +449,7 @@ public void testTokenExpiry() throws Exception { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { // the clock is still frozen, so the cookie should be valid @@ -567,7 +559,7 @@ public void testIndexNotAvailable() throws Exception { //mockGetTokenFromId(token, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); doAnswer(invocationOnMock -> { ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -606,7 +598,7 @@ public void testGetAuthenticationWorksWithExpiredUserToken() throws Exception { Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS)); mockGetTokenFromId(expired, false); - String userTokenString = tokenService.getAccessTokenAsString(expired); + String userTokenString = tokenService.getUserTokenString(expired); PlainActionFuture>> authFuture = new PlainActionFuture<>(); tokenService.getAuthenticationAndMetaData(userTokenString, authFuture); Authentication retrievedAuth = authFuture.actionGet().v1(); @@ -647,28 +639,4 @@ public static void assertAuthentication(Authentication result, Authentication ex assertEquals(expected.getMetadata(), result.getMetadata()); assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType()); } - - protected String getDeprecatedAccessTokenString(TokenService tokenService, UserToken userToken) throws IOException, - GeneralSecurityException { - try (ByteArrayOutputStream os = new ByteArrayOutputStream(TokenService.MINIMUM_BASE64_BYTES); - OutputStream base64 = Base64.getEncoder().wrap(os); - StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(Version.V_7_0_0); - TokenService.KeyAndCache keyAndCache = tokenService.getActiveKeyCache(); - Version.writeVersion(Version.V_7_0_0, out); - out.writeByteArray(keyAndCache.getSalt().bytes); - out.writeByteArray(keyAndCache.getKeyHash().bytes); - final byte[] initializationVector = tokenService.getNewInitializationVector(); - out.writeByteArray(initializationVector); - try (CipherOutputStream encryptedOutput = - new CipherOutputStream(out, tokenService.getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0)); - StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { - encryptedStreamOutput.setVersion(Version.V_7_0_0); - encryptedStreamOutput.writeString(userToken.getId()); - encryptedStreamOutput.close(); - return new String(os.toByteArray(), StandardCharsets.UTF_8); - } - } - } - } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java index 55ae297ae4e01..f180e356b767c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java @@ -25,7 +25,8 @@ public void testToXcontent() throws Exception{ TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Arrays.asList("token3", "token4"), Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")), - new ElasticsearchException("boo", new IllegalStateException("far")))); + new ElasticsearchException("boo", new IllegalStateException("far"))), + randomIntBetween(0, 5)); try (XContentBuilder builder = JsonXContent.contentBuilder()) { result.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -55,8 +56,9 @@ public void testToXcontent() throws Exception{ } public void testToXcontentWithNoErrors() throws Exception{ - TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Collections.emptyList(), - Collections.emptyList()); + TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), + Collections.emptyList(), + Collections.emptyList(), randomIntBetween(0, 5)); try (XContentBuilder builder = JsonXContent.contentBuilder()) { result.toXContent(builder, ToXContent.EMPTY_PARAMS); assertThat(Strings.toString(builder),