Skip to content

Commit

Permalink
feat: add E2E tests for token renewal
Browse files Browse the repository at this point in the history
  • Loading branch information
paullatzelsperger committed Mar 25, 2024
1 parent a67822e commit 049876d
Show file tree
Hide file tree
Showing 21 changed files with 565 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ plugins {

dependencies {
api(project(":spi:tokenrefresh-spi"))
api(project(":spi:core-spi"))
implementation(libs.edc.spi.core)
implementation(libs.edc.spi.token)
implementation(libs.edc.spi.keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@
public class DataPlaneTokenRefreshServiceExtension implements ServiceExtension {
public static final String NAME = "DataPlane Token Refresh Service extension";
public static final int DEFAULT_TOKEN_EXPIRY_TOLERANCE_SECONDS = 5;
public static final long DEFAULT_TOKEN_EXPIRY_SECONDS = 300L;
@Setting(value = "Token expiry tolerance period in seconds to allow for clock skew", defaultValue = "" + DEFAULT_TOKEN_EXPIRY_TOLERANCE_SECONDS)
public static final String TOKEN_EXPIRY_TOLERANCE_SECONDS_PROPERTY = "edc.dataplane.token.expiry.tolerance";

@Setting(value = "The HTTP endpoint where clients can request a renewal of their access token for the public dataplane API")
public static final String REFRESH_ENDPOINT_PROPERTY = "edc.dataplane.token.refresh.endpoint";
@Setting(value = "Alias of private key used for signing tokens, retrieved from private key resolver")
public static final String TOKEN_SIGNER_PRIVATE_KEY_ALIAS = "edc.transfer.proxy.token.signer.privatekey.alias";

@Setting(value = "Alias of public key used for verifying the tokens, retrieved from the vault")
public static final String TOKEN_VERIFIER_PUBLIC_KEY_ALIAS = "edc.transfer.proxy.token.verifier.publickey.alias";

@Setting(value = "Expiry time of access token in seconds", defaultValue = DEFAULT_TOKEN_EXPIRY_SECONDS + "")
public static final String TOKEN_EXPIRY_SECONDS_PROPERTY = "edc.dataplane.token.expiry";
@Inject
private TokenValidationService tokenValidationService;
@Inject
Expand Down Expand Up @@ -105,15 +105,20 @@ private DataPlaneTokenRefreshServiceImpl getTokenRefreshService(ServiceExtension
var monitor = context.getMonitor().withPrefix("DataPlane Token Refresh");
var expiryTolerance = getExpiryToleranceConfig(context);
var refreshEndpoint = getRefreshEndpointConfig(context, monitor);
var tokenExpiry = getExpiryConfig(context);
monitor.debug("Token refresh endpoint: %s".formatted(refreshEndpoint));
monitor.debug("Token refresh time tolerance: %d s".formatted(expiryTolerance));
tokenRefreshService = new DataPlaneTokenRefreshServiceImpl(clock, tokenValidationService, didPkResolver, accessTokenDataStore, new JwtGenerationService(),
getPrivateKeySupplier(context), context.getMonitor(), refreshEndpoint, expiryTolerance,
getPrivateKeySupplier(context), context.getMonitor(), refreshEndpoint, expiryTolerance, tokenExpiry,
() -> context.getConfig().getString(TOKEN_VERIFIER_PUBLIC_KEY_ALIAS), vault, typeManager.getMapper());
}
return tokenRefreshService;
}

private Long getExpiryConfig(ServiceExtensionContext context) {
return context.getConfig().getLong(TOKEN_EXPIRY_SECONDS_PROPERTY, DEFAULT_TOKEN_EXPIRY_SECONDS);
}

private String getRefreshEndpointConfig(ServiceExtensionContext context, Monitor monitor) {
var refreshEndpoint = context.getConfig().getString(REFRESH_ENDPOINT_PROPERTY, null);
if (refreshEndpoint == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.jwt.JWTClaimNames;
import org.eclipse.edc.connector.dataplane.spi.AccessTokenData;
import org.eclipse.edc.connector.dataplane.spi.iam.DataPlaneAccessTokenService;
import org.eclipse.edc.connector.dataplane.spi.store.AccessTokenDataStore;
Expand All @@ -41,7 +42,6 @@
import org.eclipse.edc.token.spi.TokenGenerationService;
import org.eclipse.edc.token.spi.TokenValidationRule;
import org.eclipse.edc.token.spi.TokenValidationService;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.AuthTokenAudienceRule;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.ClaimIsPresentRule;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.IssuerEqualsSubjectRule;
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.rules.RefreshTokenValidationRule;
Expand All @@ -62,17 +62,19 @@

import static org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames.AUDIENCE;
import static org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames.EXPIRATION_TIME;
import static org.eclipse.tractusx.edc.edr.spi.CoreConstants.EDR_PROPERTY_EXPIRES_IN;
import static org.eclipse.tractusx.edc.edr.spi.CoreConstants.EDR_PROPERTY_REFRESH_ENDPOINT;
import static org.eclipse.tractusx.edc.edr.spi.CoreConstants.EDR_PROPERTY_REFRESH_TOKEN;

/**
* This implementation of the {@link DataPlaneTokenRefreshService} validates an incoming authentication token.
*/
public class DataPlaneTokenRefreshServiceImpl implements DataPlaneTokenRefreshService, DataPlaneAccessTokenService {
public static final String ACCESS_TOKEN_CLAIM = "access_token";
public static final String ACCESS_TOKEN_CLAIM = "token";
public static final String TOKEN_ID_CLAIM = "jti";
public static final String REFRESH_TOKEN_PROPERTY = "refreshToken";
private static final Long DEFAULT_EXPIRY_IN_SECONDS = 60 * 5L;
private final long tokenExpirySeconds;
private final List<TokenValidationRule> authenticationTokenValidationRules;
private final List<TokenValidationRule> accessTokenRules;
private final List<TokenValidationRule> accessTokenAuthorizationRules;
private final TokenValidationService tokenValidationService;
private final DidPublicKeyResolver publicKeyResolver;
private final AccessTokenDataStore accessTokenDataStore;
Expand All @@ -94,6 +96,7 @@ public DataPlaneTokenRefreshServiceImpl(Clock clock,
Monitor monitor,
String refreshEndpoint,
int tokenExpiryToleranceSeconds,
long tokenExpirySeconds,
Supplier<String> publicKeyIdSupplier,
Vault vault,
ObjectMapper objectMapper) {
Expand All @@ -108,11 +111,13 @@ public DataPlaneTokenRefreshServiceImpl(Clock clock,
this.publicKeyIdSupplier = publicKeyIdSupplier;
this.vault = vault;
this.objectMapper = objectMapper;
this.tokenExpirySeconds = tokenExpirySeconds;
authenticationTokenValidationRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(AUDIENCE), // we don't check the contents, only it is present
new ClaimIsPresentRule(ACCESS_TOKEN_CLAIM),
new ClaimIsPresentRule(TOKEN_ID_CLAIM));
accessTokenRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(TOKEN_ID_CLAIM)
/*new AuthTokenAudienceRule(accessTokenDataStore)*/);
accessTokenAuthorizationRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(AUDIENCE),
new ClaimIsPresentRule(TOKEN_ID_CLAIM),
new ExpirationIssuedAtValidationRule(clock, tokenExpiryToleranceSeconds));
Expand All @@ -137,15 +142,22 @@ public DataPlaneTokenRefreshServiceImpl(Clock clock,
@Override
public Result<TokenResponse> refreshToken(String refreshToken, String authenticationToken) {

var allRules = new ArrayList<>(authenticationTokenValidationRules);
allRules.add(new RefreshTokenValidationRule(vault, refreshToken, objectMapper));
allRules.add(new AuthTokenAudienceRule(accessTokenDataStore));

authenticationToken = authenticationToken.replace("Bearer", "").trim();

var accessTokenDataResult = resolveToken(authenticationToken, allRules);
// 1. validate authentication token
monitor.warning(" TOKEN REFRESH :: TEMPORARILY DISABLED RULE AuthTokenAudienceRule UNTIL THE 'audience' PROPERTY IS FORWARDED TO DATAPLANES%n");
var authTokenRes = tokenValidationService.validate(authenticationToken, publicKeyResolver, authenticationTokenValidationRules);
if (authTokenRes.failed()) {
return Result.failure("Authentication token validation failed: %s".formatted(authTokenRes.getFailureDetail()));
}

// 2. extract access token and validate it
var accessToken = authTokenRes.getContent().getStringClaim("token");
var accessTokenDataResult = tokenValidationService.validate(accessToken, publicKeyResolver, new RefreshTokenValidationRule(vault, refreshToken, objectMapper))
.map(accessTokenClaims -> accessTokenDataStore.getById(accessTokenClaims.getStringClaim(JwtRegisteredClaimNames.JWT_ID)));

if (accessTokenDataResult.failed()) {
return accessTokenDataResult.mapTo();
return Result.failure("Access token validation failed: %s".formatted(accessTokenDataResult.getFailureDetail()));
}

var existingAccessTokenData = accessTokenDataResult.getContent();
Expand All @@ -162,7 +174,7 @@ public Result<TokenResponse> refreshToken(String refreshToken, String authentica
return Result.failure("Failed to regenerate access/refresh token pair: %s".formatted(errors));
}

storeRefreshToken(existingAccessTokenData.id(), new RefreshToken(newRefreshToken.getContent(), DEFAULT_EXPIRY_IN_SECONDS, refreshEndpoint));
storeRefreshToken(existingAccessTokenData.id(), new RefreshToken(newRefreshToken.getContent(), tokenExpirySeconds, refreshEndpoint));

// the ClaimToken is created based solely on the TokenParameters. The additional information (refresh token...) is persisted separately
var claimToken = ClaimToken.Builder.newInstance().claims(newTokenParams.getClaims()).build();
Expand All @@ -171,7 +183,7 @@ public Result<TokenResponse> refreshToken(String refreshToken, String authentica
var storeResult = accessTokenDataStore.update(accessTokenData);
return storeResult.succeeded() ?
Result.success(new TokenResponse(newAccessToken.getContent(),
newRefreshToken.getContent(), DEFAULT_EXPIRY_IN_SECONDS, "bearer")) :
newRefreshToken.getContent(), tokenExpirySeconds, "bearer")) :
Result.failure(storeResult.getFailureMessages());
}

Expand Down Expand Up @@ -203,18 +215,18 @@ public Result<TokenRepresentation> obtainToken(TokenParameters tokenParameters,
var storeResult = accessTokenDataStore.store(accessTokenData);

storeRefreshToken(accessTokenResult.getContent().id(), new RefreshToken(refreshTokenResult.getContent().tokenRepresentation().getToken(),
DEFAULT_EXPIRY_IN_SECONDS, refreshEndpoint));
tokenExpirySeconds, refreshEndpoint));

// the refresh token information must be returned in the EDR
var edrAdditionalData = new HashMap<>(additionalTokenData);
edrAdditionalData.put("refreshToken", refreshTokenResult.getContent().tokenRepresentation().getToken());
edrAdditionalData.put("expiresIn", String.valueOf(DEFAULT_EXPIRY_IN_SECONDS));
edrAdditionalData.put("refreshEndpoint", refreshEndpoint);
edrAdditionalData.put(EDR_PROPERTY_REFRESH_TOKEN, refreshTokenResult.getContent().tokenRepresentation().getToken());
edrAdditionalData.put(EDR_PROPERTY_EXPIRES_IN, String.valueOf(tokenExpirySeconds));
edrAdditionalData.put(EDR_PROPERTY_REFRESH_ENDPOINT, refreshEndpoint);

var edrTokenRepresentation = TokenRepresentation.Builder.newInstance()
.token(accessTokenResult.getContent().tokenRepresentation().getToken()) // the access token
.additional(edrAdditionalData) //contains additional properties and the refresh token
.expiresIn(DEFAULT_EXPIRY_IN_SECONDS) //todo: needed?
.expiresIn(tokenExpirySeconds) //todo: needed?
.build();


Expand All @@ -223,7 +235,12 @@ public Result<TokenRepresentation> obtainToken(TokenParameters tokenParameters,

@Override
public Result<AccessTokenData> resolve(String token) {
return resolveToken(token, accessTokenRules);
return tokenValidationService.validate(token, publicKeyResolver, accessTokenAuthorizationRules)
.compose(claimToken -> {
var id = claimToken.getStringClaim(JWTClaimNames.JWT_ID);
var tokenData = accessTokenDataStore.getById(id);
return tokenData != null ? Result.success(tokenData) : Result.failure("AccessTokenData with ID '%s' does not exist.".formatted(id));
});
}

@Override
Expand Down Expand Up @@ -270,33 +287,15 @@ private Result<TokenRepresentationWithId> createToken(TokenParameters tokenParam
}
//if there is not "exp" header on the token params, we'll configure one
if (!tokenParameters.getClaims().containsKey(JwtRegisteredClaimNames.EXPIRATION_TIME)) {
monitor.info("No '%s' claim found on TokenParameters. Will use the default of %d seconds".formatted(EXPIRATION_TIME, DEFAULT_EXPIRY_IN_SECONDS));
var exp = clock.instant().plusSeconds(DEFAULT_EXPIRY_IN_SECONDS).getEpochSecond();
monitor.info("No '%s' claim found on TokenParameters. Will use the configured default of %d seconds".formatted(EXPIRATION_TIME, tokenExpirySeconds));
var exp = clock.instant().plusSeconds(tokenExpirySeconds).getEpochSecond();
allDecorators.add(tp -> tp.claims(JwtRegisteredClaimNames.EXPIRATION_TIME, exp));
}

return tokenGenerationService.generate(privateKeySupplier, allDecorators.toArray(new TokenDecorator[0]))
.map(tr -> new TokenRepresentationWithId(tokenId.get(), tr));
}

/**
* Parses the given token, and validates it against the given rules. For that, the publicKeyResolver is used.
* Once the token is deemed valid, the "jti" claim (which is mandatory) is extracted and used as key for a lookup in the
* {@link AccessTokenDataStore}. The result of that is then returned.
*/
private Result<AccessTokenData> resolveToken(String token, List<TokenValidationRule> rules) {
var validationResult = tokenValidationService.validate(token, publicKeyResolver, rules);
if (validationResult.failed()) {
return validationResult.mapTo();
}
var tokenId = validationResult.getContent().getStringClaim(TOKEN_ID_CLAIM);
var existingAccessToken = accessTokenDataStore.getById(tokenId);

return existingAccessToken == null ?
Result.failure("AccessTokenData with ID '%s' does not exist.".formatted(tokenId)) :
Result.success(existingAccessToken);
}

private Result<Void> storeRefreshToken(String id, RefreshToken refreshToken) {
try {
return vault.storeSecret(id, objectMapper.writeValueAsString(refreshToken));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024 Bayerische Motoren Werke Aktiengesellschaft
*
* See the NOTICE file(s) distributed with this work for additional
* information regarding copyright ownership.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/

package org.eclipse.tractusx.edc.dataplane.tokenrefresh.core;

import com.nimbusds.jwt.SignedJWT;
import org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames;
import org.jetbrains.annotations.Nullable;

import java.text.ParseException;

public class TokenFunctions {

/**
* Returns the "jti" claim of a JWT in serialized format. Will throw a {@link RuntimeException} if the token is not in valid
* serialized JWT format.
*/
public static @Nullable String getTokenId(String serializedJwt) {
try {
return SignedJWT.parse(serializedJwt).getJWTClaimsSet().getStringClaim(JwtRegisteredClaimNames.JWT_ID);
} catch (ParseException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

import java.util.Map;

import static org.eclipse.tractusx.edc.dataplane.tokenrefresh.core.TokenFunctions.getTokenId;


/**
* Validates that the {@code iss} claim of a token is equal to the {@code audience} property found on the {@link org.eclipse.edc.connector.dataplane.spi.AccessTokenData}
* that is associated with that token (using the {@code jti} claim).
Expand All @@ -42,9 +45,13 @@ public AuthTokenAudienceRule(AccessTokenDataStore store) {
}

@Override
public Result<Void> checkRule(@NotNull ClaimToken claimToken, @Nullable Map<String, Object> map) {
var issuer = claimToken.getStringClaim(JWTClaimNames.ISSUER);
var tokenId = claimToken.getStringClaim(JWTClaimNames.JWT_ID);
public Result<Void> checkRule(@NotNull ClaimToken authenticationToken, @Nullable Map<String, Object> map) {
var issuer = authenticationToken.getStringClaim(JWTClaimNames.ISSUER);
var accessToken = authenticationToken.getStringClaim("token");
if (accessToken == null) {
return Result.failure("Authentication token must contain a 'token' claim");
}
var tokenId = getTokenId(accessToken);

var accessTokenData = store.getById(tokenId);
var expectedAudience = accessTokenData.additionalProperties().getOrDefault(AUDIENCE_PROPERTY, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ public RefreshTokenValidationRule(Vault vault, String incomingRefreshToken, Obje
}

@Override
public Result<Void> checkRule(@NotNull ClaimToken toVerify, @Nullable Map<String, Object> additional) {
var tokenId = toVerify.getStringClaim(JWTClaimNames.JWT_ID);
public Result<Void> checkRule(@NotNull ClaimToken accessToken, @Nullable Map<String, Object> additional) {

var tokenId = accessToken.getStringClaim(JWTClaimNames.JWT_ID);
var storedRefreshTokenJson = vault.resolveSecret(tokenId);
if (storedRefreshTokenJson == null) {
return failure("No refresh token with the ID '%s' was found in the vault.".formatted(tokenId));
Expand Down
Loading

0 comments on commit 049876d

Please sign in to comment.