Skip to content

Commit

Permalink
Merge pull request #16446 from sberyozkin/select_matching_auth_mechanism
Browse files Browse the repository at this point in the history
Do the correct challenge when more than one auth mechanism is used
  • Loading branch information
sberyozkin authored Apr 13, 2021
2 parents eb081bf + 9a0ecff commit 2ceef3d
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.oidc.runtime;

import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.quarkus.oidc.AccessTokenCredential;
import io.quarkus.oidc.OidcTenantConfig;
Expand All @@ -15,7 +16,7 @@
public class BearerAuthenticationMechanism extends AbstractOidcAuthenticationMechanism {

protected static final ChallengeData UNAUTHORIZED_CHALLENGE = new ChallengeData(HttpResponseStatus.UNAUTHORIZED.code(),
null, null);
HttpHeaderNames.WWW_AUTHENTICATE, OidcConstants.BEARER_SCHEME);

public Uni<SecurityIdentity> authenticate(RoutingContext context,
IdentityProviderManager identityProviderManager) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkus.jwt.test;

import static org.hamcrest.Matchers.equalTo;

import java.io.StringReader;
import java.net.HttpURLConnection;
import java.util.HashMap;
Expand Down Expand Up @@ -48,11 +50,11 @@ public void generateToken() throws Exception {
authTimeClaim = timeClaims.get(Claims.auth_time.name());
}

// Basic @ServletSecurity tests
@Test()
public void testSecureAccessFailure() {
RestAssured.when().get("/endp/verifyInjectedIssuer").then()
.statusCode(401);
.statusCode(401)
.header("WWW-Authenticate", equalTo("Bearer"));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public Uni<ChallengeData> getChallenge(RoutingContext context) {
ChallengeData result = new ChallengeData(
HttpResponseStatus.UNAUTHORIZED.code(),
HttpHeaderNames.WWW_AUTHENTICATE,
"Bearer {token}");
"Bearer");
return Uni.createFrom().item(result);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,14 @@ public void testBasicAuthFailure() {
CookieFilter cookies = new CookieFilter();
RestAssured
.given()
.auth().basic("admin", "wrongpassword")
.auth().preemptive().basic("admin", "wrongpassword")
.filter(cookies)
.redirects().follow(false)
.when()
.get("/admin")
.then()
.assertThat()
.statusCode(302)
.header("location", containsString("/login"))
.cookie("quarkus-redirect-location", containsString("/admin"));
.statusCode(401)
.header("WWW-Authenticate", equalTo("basic realm=\"Quarkus\""));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ public class BasicAuthenticationMechanism implements HttpAuthenticationMechanism
*/
public static final String USER_AGENT_CHARSETS = "user-agent-charsets";

private final String name;
private final String challenge;

private static final String BASIC = "basic";
Expand All @@ -87,32 +86,40 @@ public class BasicAuthenticationMechanism implements HttpAuthenticationMechanism
private final Map<Pattern, Charset> userAgentCharsets;

public BasicAuthenticationMechanism(final String realmName) {
this(realmName, "BASIC");
this(realmName, false);
}

public BasicAuthenticationMechanism(final String realmName, final boolean silent) {
this(realmName, silent, StandardCharsets.UTF_8, Collections.emptyMap());
}

public BasicAuthenticationMechanism(final String realmName, final boolean silent,
Charset charset, Map<Pattern, Charset> userAgentCharsets) {
this.challenge = BASIC_PREFIX + "realm=\"" + realmName + "\"";
this.silent = silent;
this.charset = charset;
this.userAgentCharsets = Collections.unmodifiableMap(new LinkedHashMap<>(userAgentCharsets));
}

@Deprecated
public BasicAuthenticationMechanism(final String realmName, final String mechanismName) {
this(realmName, mechanismName, false);
}

@Deprecated
public BasicAuthenticationMechanism(final String realmName, final String mechanismName, final boolean silent) {
this(realmName, mechanismName, silent, StandardCharsets.UTF_8, Collections.emptyMap());
}

@Deprecated
public BasicAuthenticationMechanism(final String realmName, final String mechanismName, final boolean silent,
Charset charset, Map<Pattern, Charset> userAgentCharsets) {
this.challenge = BASIC_PREFIX + "realm=\"" + realmName + "\"";
this.name = mechanismName;
this.silent = silent;
this.charset = charset;
this.userAgentCharsets = Collections.unmodifiableMap(new LinkedHashMap<>(userAgentCharsets));
}

private static void clear(final char[] array) {
for (int i = 0; i < array.length; i++) {
array[i] = 0x00;
}
}

@Override
public Uni<SecurityIdentity> authenticate(RoutingContext context,
IdentityProviderManager identityProviderManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ static Uni<ChallengeData> getRedirect(final RoutingContext exchange, final Strin
public Uni<SecurityIdentity> authenticate(RoutingContext context,
IdentityProviderManager identityProviderManager) {

if (context.normalisedPath().endsWith(postLocation) && context.request().method().equals(HttpMethod.POST)) {
if (context.normalizedPath().endsWith(postLocation) && context.request().method().equals(HttpMethod.POST)) {
//we always re-auth if it is a post to the auth URL
return runFormAuth(context, identityProviderManager);
} else {
Expand All @@ -173,7 +173,7 @@ public void accept(SecurityIdentity securityIdentity) {

@Override
public Uni<ChallengeData> getChallenge(RoutingContext context) {
if (context.normalisedPath().endsWith(postLocation) && context.request().method().equals(HttpMethod.POST)) {
if (context.normalizedPath().endsWith(postLocation) && context.request().method().equals(HttpMethod.POST)) {
log.debugf("Serving form auth error page %s for %s", loginPage, context);
// This method would no longer be called if authentication had already occurred.
return getRedirect(context, errorPage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.security.identity.request.AnonymousAuthenticationRequest;
import io.quarkus.security.identity.request.AuthenticationRequest;
import io.quarkus.vertx.http.runtime.security.HttpCredentialTransport.Type;
import io.smallrye.mutiny.Uni;
import io.vertx.core.http.HttpHeaders;
import io.vertx.ext.web.RoutingContext;

/**
* Class that is responsible for running the HTTP based authentication
*/
@ApplicationScoped
public class HttpAuthenticator {

final HttpAuthenticationMechanism[] mechanisms;
@Inject
IdentityProviderManager identityProviderManager;
Expand Down Expand Up @@ -97,6 +98,11 @@ IdentityProviderManager getIdentityProviderManager() {
*/
public Uni<SecurityIdentity> attemptAuthentication(RoutingContext routingContext) {

HttpAuthenticationMechanism matchingMech = findMechanismWithAuthorizationScheme(routingContext);
if (matchingMech != null) {
return matchingMech.authenticate(routingContext, identityProviderManager);
}

Uni<SecurityIdentity> result = mechanisms[0].authenticate(routingContext, identityProviderManager);
for (int i = 1; i < mechanisms.length; ++i) {
HttpAuthenticationMechanism mech = mechanisms[i];
Expand All @@ -118,18 +124,26 @@ public Uni<SecurityIdentity> apply(SecurityIdentity data) {
* @return
*/
public Uni<Boolean> sendChallenge(RoutingContext routingContext) {
Uni<Boolean> result = mechanisms[0].sendChallenge(routingContext);
for (int i = 1; i < mechanisms.length; ++i) {
HttpAuthenticationMechanism mech = mechanisms[i];
result = result.onItem().transformToUni(new Function<Boolean, Uni<? extends Boolean>>() {
@Override
public Uni<? extends Boolean> apply(Boolean authDone) {
if (authDone) {
return Uni.createFrom().item(authDone);
Uni<Boolean> result = null;

HttpAuthenticationMechanism matchingMech = findMechanismWithAuthorizationScheme(routingContext);
if (matchingMech != null) {
result = matchingMech.sendChallenge(routingContext);
}
if (result == null) {
result = mechanisms[0].sendChallenge(routingContext);
for (int i = 1; i < mechanisms.length; ++i) {
HttpAuthenticationMechanism mech = mechanisms[i];
result = result.onItem().transformToUni(new Function<Boolean, Uni<? extends Boolean>>() {
@Override
public Uni<? extends Boolean> apply(Boolean authDone) {
if (authDone) {
return Uni.createFrom().item(authDone);
}
return mech.sendChallenge(routingContext);
}
return mech.sendChallenge(routingContext);
}
});
});
}
}
return result.onItem().transformToUni(new Function<Boolean, Uni<? extends Boolean>>() {
@Override
Expand All @@ -144,6 +158,10 @@ public Uni<? extends Boolean> apply(Boolean authDone) {
}

public Uni<ChallengeData> getChallenge(RoutingContext routingContext) {
HttpAuthenticationMechanism matchingMech = findMechanismWithAuthorizationScheme(routingContext);
if (matchingMech != null) {
return matchingMech.getChallenge(routingContext);
}
Uni<ChallengeData> result = mechanisms[0].getChallenge(routingContext);
for (int i = 1; i < mechanisms.length; ++i) {
HttpAuthenticationMechanism mech = mechanisms[i];
Expand All @@ -161,6 +179,32 @@ public Uni<? extends ChallengeData> apply(ChallengeData data) {
return result;
}

private HttpAuthenticationMechanism findMechanismWithAuthorizationScheme(RoutingContext routingContext) {
String authScheme = getAuthorizationScheme(routingContext);
if (authScheme == null) {
return null;
}
for (int i = 0; i < mechanisms.length; ++i) {
HttpCredentialTransport credType = mechanisms[i].getCredentialTransport();
if (credType != null && credType.getTransportType() == Type.AUTHORIZATION
&& credType.getTypeTarget().toLowerCase().startsWith(authScheme.toLowerCase())) {
return mechanisms[i];
}
}
return null;
}

private static String getAuthorizationScheme(RoutingContext routingContext) {
String authorization = routingContext.request().getHeader(HttpHeaders.AUTHORIZATION);
if (authorization != null) {
int spaceIndex = authorization.indexOf(' ');
if (spaceIndex > 0) {
return authorization.substring(0, spaceIndex);
}
}
return null;
}

static class NoAuthenticationMechanism implements HttpAuthenticationMechanism {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ public boolean equals(Object o) {

if (transportType != that.transportType)
return false;
return typeTarget != null ? typeTarget.equals(that.typeTarget) : that.typeTarget == null;
return typeTarget.equals(that.typeTarget);
}

@Override
public int hashCode() {
int result = transportType != null ? transportType.hashCode() : 0;
result = 31 * result + (typeTarget != null ? typeTarget.hashCode() : 0);
int result = transportType.hashCode();
result = 31 * result + typeTarget.hashCode();
return result;
}

Expand All @@ -74,4 +74,12 @@ public String toString() {
", typeTarget='" + typeTarget + '\'' +
'}';
}

public Type getTransportType() {
return transportType;
}

public String getTypeTarget() {
return typeTarget;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ public Supplier<?> setupBasicAuth(HttpBuildTimeConfig buildTimeConfig) {
return new Supplier<BasicAuthenticationMechanism>() {
@Override
public BasicAuthenticationMechanism get() {
return new BasicAuthenticationMechanism(buildTimeConfig.auth.realm, "BASIC", buildTimeConfig.auth.form.enabled);
return new BasicAuthenticationMechanism(buildTimeConfig.auth.realm, buildTimeConfig.auth.form.enabled);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public void testDeniedAccessAdminResource() {
public void testDeniedNoBearerToken() {
RestAssured.given()
.when().get("/api/users/me").then()
.statusCode(401);
.statusCode(401)
.header("WWW-Authenticate", equalTo("Bearer"));
}

@Test
Expand All @@ -71,7 +72,8 @@ public void testExpiredBearerToken() {
RestAssured.given().auth().oauth2(token).when()
.get("/api/users/me")
.then()
.statusCode(401);
.statusCode(401)
.header("WWW-Authenticate", equalTo("Bearer"));
}

@Test
Expand All @@ -81,7 +83,8 @@ public void testBearerTokenWrongIssuer() {
RestAssured.given().auth().oauth2(token).when()
.get("/api/users/me")
.then()
.statusCode(401);
.statusCode(401)
.header("WWW-Authenticate", equalTo("Bearer"));
}

@Test
Expand All @@ -91,7 +94,8 @@ public void testBearerTokenWrongAudience() {
RestAssured.given().auth().oauth2(token).when()
.get("/api/users/me")
.then()
.statusCode(401);
.statusCode(401)
.header("WWW-Authenticate", equalTo("Bearer"));
}

private String getAccessToken(String userName, Set<String> groups) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,16 @@ public void testDeniedAccessAdminResource() {
public void testVerificationFailedNoBearerToken() {
RestAssured.given()
.when().get("/api/users/me").then()
.statusCode(401);
.statusCode(401)
.header("WWW-Authenticate", equalTo("Bearer"));
}

@Test
public void testVerificationFailedInvalidToken() {
RestAssured.given().auth().oauth2("123")
.when().get("/api/users/me").then()
.statusCode(401);
.statusCode(401)
.header("WWW-Authenticate", equalTo("Bearer"));
}

//see https://github.com/quarkusio/quarkus/issues/5809
Expand Down

0 comments on commit 2ceef3d

Please sign in to comment.