Skip to content

Commit

Permalink
feat: add Token Refresh e2e tests (#1160)
Browse files Browse the repository at this point in the history
* feat: add E2E tests for token renewal

* parallelize tests
  • Loading branch information
paullatzelsperger authored Mar 26, 2024
1 parent a67822e commit 5c8a549
Show file tree
Hide file tree
Showing 22 changed files with 575 additions and 197 deletions.
9 changes: 0 additions & 9 deletions .github/workflows/verify.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ jobs:
unit-tests:
runs-on: ubuntu-latest
needs: [ verify-formatting, verify-license-headers ]
steps:
- uses: actions/checkout@v4

Expand All @@ -90,7 +89,6 @@ jobs:

integration-tests:
runs-on: ubuntu-latest
needs: [ verify-formatting, verify-license-headers ]
steps:
- uses: actions/checkout@v4

Expand All @@ -101,7 +99,6 @@ jobs:

api-tests:
runs-on: ubuntu-latest
needs: [ verify-formatting, verify-license-headers ]
steps:
- uses: actions/checkout@v4

Expand All @@ -112,7 +109,6 @@ jobs:

end-to-end-tests:
runs-on: ubuntu-latest
needs: [ verify-formatting, verify-license-headers ]
strategy:
fail-fast: false
matrix:
Expand All @@ -134,9 +130,6 @@ jobs:
postgres-tests:
runs-on: ubuntu-latest

needs: [ verify-formatting, verify-license-headers ]

services:
postgres:
image: postgres:14.2
Expand All @@ -154,7 +147,6 @@ jobs:

dataplane-tests:
runs-on: ubuntu-latest
needs: [ verify-formatting, verify-license-headers ]

steps:
- uses: actions/checkout@v4
Expand All @@ -165,7 +157,6 @@ jobs:

miw-integration-tests:
runs-on: ubuntu-latest
needs: [ verify-formatting, verify-license-headers ]
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup-java
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ plugins {

dependencies {
api(project(":spi:tokenrefresh-spi"))
api(project(":spi:core-spi"))
implementation(libs.edc.spi.core)
implementation(libs.edc.spi.token)
implementation(libs.edc.spi.keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@
public class DataPlaneTokenRefreshServiceExtension implements ServiceExtension {
public static final String NAME = "DataPlane Token Refresh Service extension";
public static final int DEFAULT_TOKEN_EXPIRY_TOLERANCE_SECONDS = 5;
public static final long DEFAULT_TOKEN_EXPIRY_SECONDS = 300L;
@Setting(value = "Token expiry tolerance period in seconds to allow for clock skew", defaultValue = "" + DEFAULT_TOKEN_EXPIRY_TOLERANCE_SECONDS)
public static final String TOKEN_EXPIRY_TOLERANCE_SECONDS_PROPERTY = "edc.dataplane.token.expiry.tolerance";

@Setting(value = "The HTTP endpoint where clients can request a renewal of their access token for the public dataplane API")
public static final String REFRESH_ENDPOINT_PROPERTY = "edc.dataplane.token.refresh.endpoint";
@Setting(value = "Alias of private key used for signing tokens, retrieved from private key resolver")
public static final String TOKEN_SIGNER_PRIVATE_KEY_ALIAS = "edc.transfer.proxy.token.signer.privatekey.alias";

@Setting(value = "Alias of public key used for verifying the tokens, retrieved from the vault")
public static final String TOKEN_VERIFIER_PUBLIC_KEY_ALIAS = "edc.transfer.proxy.token.verifier.publickey.alias";

@Setting(value = "Expiry time of access token in seconds", defaultValue = DEFAULT_TOKEN_EXPIRY_SECONDS + "")
public static final String TOKEN_EXPIRY_SECONDS_PROPERTY = "edc.dataplane.token.expiry";
@Inject
private TokenValidationService tokenValidationService;
@Inject
Expand Down Expand Up @@ -105,15 +105,20 @@ private DataPlaneTokenRefreshServiceImpl getTokenRefreshService(ServiceExtension
var monitor = context.getMonitor().withPrefix("DataPlane Token Refresh");
var expiryTolerance = getExpiryToleranceConfig(context);
var refreshEndpoint = getRefreshEndpointConfig(context, monitor);
var tokenExpiry = getExpiryConfig(context);
monitor.debug("Token refresh endpoint: %s".formatted(refreshEndpoint));
monitor.debug("Token refresh time tolerance: %d s".formatted(expiryTolerance));
tokenRefreshService = new DataPlaneTokenRefreshServiceImpl(clock, tokenValidationService, didPkResolver, accessTokenDataStore, new JwtGenerationService(),
getPrivateKeySupplier(context), context.getMonitor(), refreshEndpoint, expiryTolerance,
getPrivateKeySupplier(context), context.getMonitor(), refreshEndpoint, expiryTolerance, tokenExpiry,
() -> context.getConfig().getString(TOKEN_VERIFIER_PUBLIC_KEY_ALIAS), vault, typeManager.getMapper());
}
return tokenRefreshService;
}

private Long getExpiryConfig(ServiceExtensionContext context) {
return context.getConfig().getLong(TOKEN_EXPIRY_SECONDS_PROPERTY, DEFAULT_TOKEN_EXPIRY_SECONDS);
}

private String getRefreshEndpointConfig(ServiceExtensionContext context, Monitor monitor) {
var refreshEndpoint = context.getConfig().getString(REFRESH_ENDPOINT_PROPERTY, null);
if (refreshEndpoint == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.jwt.JWTClaimNames;
import org.eclipse.edc.connector.dataplane.spi.AccessTokenData;
import org.eclipse.edc.connector.dataplane.spi.iam.DataPlaneAccessTokenService;
import org.eclipse.edc.connector.dataplane.spi.store.AccessTokenDataStore;
Expand All @@ -41,7 +42,6 @@
import org.eclipse.edc.token.spi.TokenGenerationService;
import org.eclipse.edc.token.spi.TokenValidationRule;
import org.eclipse.edc.token.spi.TokenValidationService;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.AuthTokenAudienceRule;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.ClaimIsPresentRule;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.IssuerEqualsSubjectRule;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.RefreshTokenValidationRule;
Expand All @@ -62,17 +62,19 @@

import static org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames.AUDIENCE;
import static org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames.EXPIRATION_TIME;
import static org.eclipse.tractusx.edc.edr.spi.CoreConstants.EDR_PROPERTY_EXPIRES_IN;
import static org.eclipse.tractusx.edc.edr.spi.CoreConstants.EDR_PROPERTY_REFRESH_ENDPOINT;
import static org.eclipse.tractusx.edc.edr.spi.CoreConstants.EDR_PROPERTY_REFRESH_TOKEN;

/**
* This implementation of the {@link DataPlaneTokenRefreshService} validates an incoming authentication token.
*/
public class DataPlaneTokenRefreshServiceImpl implements DataPlaneTokenRefreshService, DataPlaneAccessTokenService {
public static final String ACCESS_TOKEN_CLAIM = "access_token";
public static final String ACCESS_TOKEN_CLAIM = "token";
public static final String TOKEN_ID_CLAIM = "jti";
public static final String REFRESH_TOKEN_PROPERTY = "refreshToken";
private static final Long DEFAULT_EXPIRY_IN_SECONDS = 60 * 5L;
private final long tokenExpirySeconds;
private final List<TokenValidationRule> authenticationTokenValidationRules;
private final List<TokenValidationRule> accessTokenRules;
private final List<TokenValidationRule> accessTokenAuthorizationRules;
private final TokenValidationService tokenValidationService;
private final DidPublicKeyResolver publicKeyResolver;
private final AccessTokenDataStore accessTokenDataStore;
Expand All @@ -94,6 +96,7 @@ public DataPlaneTokenRefreshServiceImpl(Clock clock,
Monitor monitor,
String refreshEndpoint,
int tokenExpiryToleranceSeconds,
long tokenExpirySeconds,
Supplier<String> publicKeyIdSupplier,
Vault vault,
ObjectMapper objectMapper) {
Expand All @@ -108,11 +111,13 @@ public DataPlaneTokenRefreshServiceImpl(Clock clock,
this.publicKeyIdSupplier = publicKeyIdSupplier;
this.vault = vault;
this.objectMapper = objectMapper;
this.tokenExpirySeconds = tokenExpirySeconds;
authenticationTokenValidationRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(AUDIENCE), // we don't check the contents, only it is present
new ClaimIsPresentRule(ACCESS_TOKEN_CLAIM),
new ClaimIsPresentRule(TOKEN_ID_CLAIM));
accessTokenRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(TOKEN_ID_CLAIM)
/*new AuthTokenAudienceRule(accessTokenDataStore)*/);
accessTokenAuthorizationRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(AUDIENCE),
new ClaimIsPresentRule(TOKEN_ID_CLAIM),
new ExpirationIssuedAtValidationRule(clock, tokenExpiryToleranceSeconds));
Expand All @@ -137,15 +142,22 @@ public DataPlaneTokenRefreshServiceImpl(Clock clock,
@Override
public Result<TokenResponse> refreshToken(String refreshToken, String authenticationToken) {

var allRules = new ArrayList<>(authenticationTokenValidationRules);
allRules.add(new RefreshTokenValidationRule(vault, refreshToken, objectMapper));
allRules.add(new AuthTokenAudienceRule(accessTokenDataStore));

authenticationToken = authenticationToken.replace("Bearer", "").trim();

var accessTokenDataResult = resolveToken(authenticationToken, allRules);
// 1. validate authentication token
monitor.warning(" TOKEN REFRESH :: TEMPORARILY DISABLED RULE AuthTokenAudienceRule UNTIL THE 'audience' PROPERTY IS FORWARDED TO DATAPLANES%n");
var authTokenRes = tokenValidationService.validate(authenticationToken, publicKeyResolver, authenticationTokenValidationRules);
if (authTokenRes.failed()) {
return Result.failure("Authentication token validation failed: %s".formatted(authTokenRes.getFailureDetail()));
}

// 2. extract access token and validate it
var accessToken = authTokenRes.getContent().getStringClaim("token");
var accessTokenDataResult = tokenValidationService.validate(accessToken, publicKeyResolver, new RefreshTokenValidationRule(vault, refreshToken, objectMapper))
.map(accessTokenClaims -> accessTokenDataStore.getById(accessTokenClaims.getStringClaim(JwtRegisteredClaimNames.JWT_ID)));

if (accessTokenDataResult.failed()) {
return accessTokenDataResult.mapTo();
return Result.failure("Access token validation failed: %s".formatted(accessTokenDataResult.getFailureDetail()));
}

var existingAccessTokenData = accessTokenDataResult.getContent();
Expand All @@ -162,7 +174,7 @@ public Result<TokenResponse> refreshToken(String refreshToken, String authentica
return Result.failure("Failed to regenerate access/refresh token pair: %s".formatted(errors));
}

storeRefreshToken(existingAccessTokenData.id(), new RefreshToken(newRefreshToken.getContent(), DEFAULT_EXPIRY_IN_SECONDS, refreshEndpoint));
storeRefreshToken(existingAccessTokenData.id(), new RefreshToken(newRefreshToken.getContent(), tokenExpirySeconds, refreshEndpoint));

// the ClaimToken is created based solely on the TokenParameters. The additional information (refresh token...) is persisted separately
var claimToken = ClaimToken.Builder.newInstance().claims(newTokenParams.getClaims()).build();
Expand All @@ -171,7 +183,7 @@ public Result<TokenResponse> refreshToken(String refreshToken, String authentica
var storeResult = accessTokenDataStore.update(accessTokenData);
return storeResult.succeeded() ?
Result.success(new TokenResponse(newAccessToken.getContent(),
newRefreshToken.getContent(), DEFAULT_EXPIRY_IN_SECONDS, "bearer")) :
newRefreshToken.getContent(), tokenExpirySeconds, "bearer")) :
Result.failure(storeResult.getFailureMessages());
}

Expand Down Expand Up @@ -203,18 +215,18 @@ public Result<TokenRepresentation> obtainToken(TokenParameters tokenParameters,
var storeResult = accessTokenDataStore.store(accessTokenData);

storeRefreshToken(accessTokenResult.getContent().id(), new RefreshToken(refreshTokenResult.getContent().tokenRepresentation().getToken(),
DEFAULT_EXPIRY_IN_SECONDS, refreshEndpoint));
tokenExpirySeconds, refreshEndpoint));

// the refresh token information must be returned in the EDR
var edrAdditionalData = new HashMap<>(additionalTokenData);
edrAdditionalData.put("refreshToken", refreshTokenResult.getContent().tokenRepresentation().getToken());
edrAdditionalData.put("expiresIn", String.valueOf(DEFAULT_EXPIRY_IN_SECONDS));
edrAdditionalData.put("refreshEndpoint", refreshEndpoint);
edrAdditionalData.put(EDR_PROPERTY_REFRESH_TOKEN, refreshTokenResult.getContent().tokenRepresentation().getToken());
edrAdditionalData.put(EDR_PROPERTY_EXPIRES_IN, String.valueOf(tokenExpirySeconds));
edrAdditionalData.put(EDR_PROPERTY_REFRESH_ENDPOINT, refreshEndpoint);

var edrTokenRepresentation = TokenRepresentation.Builder.newInstance()
.token(accessTokenResult.getContent().tokenRepresentation().getToken()) // the access token
.additional(edrAdditionalData) //contains additional properties and the refresh token
.expiresIn(DEFAULT_EXPIRY_IN_SECONDS) //todo: needed?
.expiresIn(tokenExpirySeconds) //todo: needed?
.build();


Expand All @@ -223,7 +235,12 @@ public Result<TokenRepresentation> obtainToken(TokenParameters tokenParameters,

@Override
public Result<AccessTokenData> resolve(String token) {
return resolveToken(token, accessTokenRules);
return tokenValidationService.validate(token, publicKeyResolver, accessTokenAuthorizationRules)
.compose(claimToken -> {
var id = claimToken.getStringClaim(JWTClaimNames.JWT_ID);
var tokenData = accessTokenDataStore.getById(id);
return tokenData != null ? Result.success(tokenData) : Result.failure("AccessTokenData with ID '%s' does not exist.".formatted(id));
});
}

@Override
Expand Down Expand Up @@ -270,33 +287,15 @@ private Result<TokenRepresentationWithId> createToken(TokenParameters tokenParam
}
//if there is not "exp" header on the token params, we'll configure one
if (!tokenParameters.getClaims().containsKey(JwtRegisteredClaimNames.EXPIRATION_TIME)) {
monitor.info("No '%s' claim found on TokenParameters. Will use the default of %d seconds".formatted(EXPIRATION_TIME, DEFAULT_EXPIRY_IN_SECONDS));
var exp = clock.instant().plusSeconds(DEFAULT_EXPIRY_IN_SECONDS).getEpochSecond();
monitor.info("No '%s' claim found on TokenParameters. Will use the configured default of %d seconds".formatted(EXPIRATION_TIME, tokenExpirySeconds));
var exp = clock.instant().plusSeconds(tokenExpirySeconds).getEpochSecond();
allDecorators.add(tp -> tp.claims(JwtRegisteredClaimNames.EXPIRATION_TIME, exp));
}

return tokenGenerationService.generate(privateKeySupplier, allDecorators.toArray(new TokenDecorator[0]))
.map(tr -> new TokenRepresentationWithId(tokenId.get(), tr));
}

/**
* Parses the given token, and validates it against the given rules. For that, the publicKeyResolver is used.
* Once the token is deemed valid, the "jti" claim (which is mandatory) is extracted and used as key for a lookup in the
* {@link AccessTokenDataStore}. The result of that is then returned.
*/
private Result<AccessTokenData> resolveToken(String token, List<TokenValidationRule> rules) {
var validationResult = tokenValidationService.validate(token, publicKeyResolver, rules);
if (validationResult.failed()) {
return validationResult.mapTo();
}
var tokenId = validationResult.getContent().getStringClaim(TOKEN_ID_CLAIM);
var existingAccessToken = accessTokenDataStore.getById(tokenId);

return existingAccessToken == null ?
Result.failure("AccessTokenData with ID '%s' does not exist.".formatted(tokenId)) :
Result.success(existingAccessToken);
}

private Result<Void> storeRefreshToken(String id, RefreshToken refreshToken) {
try {
return vault.storeSecret(id, objectMapper.writeValueAsString(refreshToken));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024 Bayerische Motoren Werke Aktiengesellschaft
*
* See the NOTICE file(s) distributed with this work for additional
* information regarding copyright ownership.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/

package org.eclipse.tractusx.edc.dataplane.tokenrefresh.core;

import com.nimbusds.jwt.SignedJWT;
import org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames;
import org.jetbrains.annotations.Nullable;

import java.text.ParseException;

public class TokenFunctions {

/**
* Returns the "jti" claim of a JWT in serialized format. Will throw a {@link RuntimeException} if the token is not in valid
* serialized JWT format.
*/
public static @Nullable String getTokenId(String serializedJwt) {
try {
return SignedJWT.parse(serializedJwt).getJWTClaimsSet().getStringClaim(JwtRegisteredClaimNames.JWT_ID);
} catch (ParseException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

import java.util.Map;

import static org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.TokenFunctions.getTokenId;


/**
* Validates that the {@code iss} claim of a token is equal to the {@code audience} property found on the {@link org.eclipse.edc.connector.dataplane.spi.AccessTokenData}
* that is associated with that token (using the {@code jti} claim).
Expand All @@ -42,9 +45,13 @@ public AuthTokenAudienceRule(AccessTokenDataStore store) {
}

@Override
public Result<Void> checkRule(@NotNull ClaimToken claimToken, @Nullable Map<String, Object> map) {
var issuer = claimToken.getStringClaim(JWTClaimNames.ISSUER);
var tokenId = claimToken.getStringClaim(JWTClaimNames.JWT_ID);
public Result<Void> checkRule(@NotNull ClaimToken authenticationToken, @Nullable Map<String, Object> map) {
var issuer = authenticationToken.getStringClaim(JWTClaimNames.ISSUER);
var accessToken = authenticationToken.getStringClaim("token");
if (accessToken == null) {
return Result.failure("Authentication token must contain a 'token' claim");
}
var tokenId = getTokenId(accessToken);

var accessTokenData = store.getById(tokenId);
var expectedAudience = accessTokenData.additionalProperties().getOrDefault(AUDIENCE_PROPERTY, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ public RefreshTokenValidationRule(Vault vault, String incomingRefreshToken, Obje
}

@Override
public Result<Void> checkRule(@NotNull ClaimToken toVerify, @Nullable Map<String, Object> additional) {
var tokenId = toVerify.getStringClaim(JWTClaimNames.JWT_ID);
public Result<Void> checkRule(@NotNull ClaimToken accessToken, @Nullable Map<String, Object> additional) {

var tokenId = accessToken.getStringClaim(JWTClaimNames.JWT_ID);
var storedRefreshTokenJson = vault.resolveSecret(tokenId);
if (storedRefreshTokenJson == null) {
return failure("No refresh token with the ID '%s' was found in the vault.".formatted(tokenId));
Expand Down
Loading

0 comments on commit 5c8a549

Please sign in to comment.