Skip to content

Commit

Permalink
Merge pull request #39328 from sberyozkin/oidc_client_name_in_filter
Browse files Browse the repository at this point in the history
Pass the client and tenant id to OIDC request filters
  • Loading branch information
sberyozkin authored Mar 11, 2024
2 parents 04358cb + 96e162c commit fd062d2
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
public class OidcClientImpl implements OidcClient {

private static final Logger LOG = Logger.getLogger(OidcClientImpl.class);

private static final String CLIENT_ID_ATTRIBUTE = "client-id";
private static final String DEFAULT_OIDC_CLIENT_ID = "Default";
private static final String AUTHORIZATION_HEADER = String.valueOf(HttpHeaders.AUTHORIZATION);

private final WebClient client;
Expand Down Expand Up @@ -279,7 +280,8 @@ private void checkClosed() {

private HttpRequest<Buffer> filter(OidcEndpoint.Type endpointType, HttpRequest<Buffer> request, Buffer body) {
if (!filters.isEmpty()) {
OidcRequestContextProperties props = new OidcRequestContextProperties();
OidcRequestContextProperties props = new OidcRequestContextProperties(
Map.of(CLIENT_ID_ATTRIBUTE, oidcConfig.getId().orElse(DEFAULT_OIDC_CLIENT_ID)));
for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) {
filter.filter(request, body, props);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.quarkus.oidc.client.OidcClients;
import io.quarkus.oidc.client.Tokens;
import io.quarkus.oidc.common.OidcEndpoint;
import io.quarkus.oidc.common.OidcRequestContextProperties;
import io.quarkus.oidc.common.OidcRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
import io.quarkus.oidc.common.runtime.OidcConstants;
Expand All @@ -35,6 +36,7 @@
public class OidcClientRecorder {

private static final Logger LOG = Logger.getLogger(OidcClientRecorder.class);
private static final String CLIENT_ID_ATTRIBUTE = "client-id";
private static final String DEFAULT_OIDC_CLIENT_ID = "Default";

public OidcClients setup(OidcClientsConfig oidcClientsConfig, TlsConfig tlsConfig, Supplier<Vertx> vertx) {
Expand Down Expand Up @@ -224,8 +226,10 @@ private static Uni<OidcConfigurationMetadata> discoverTokenUris(WebClient client
Map<OidcEndpoint.Type, List<OidcRequestFilter>> oidcRequestFilters,
String authServerUrl, OidcClientConfig oidcConfig, io.vertx.mutiny.core.Vertx vertx) {
final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig);
OidcRequestContextProperties contextProps = new OidcRequestContextProperties(
Map.of(CLIENT_ID_ATTRIBUTE, oidcConfig.getId().orElse(DEFAULT_OIDC_CLIENT_ID)));
return OidcCommonUtils
.discoverMetadata(client, oidcRequestFilters, authServerUrl, connectionDelayInMillisecs, vertx,
.discoverMetadata(client, oidcRequestFilters, contextProps, authServerUrl, connectionDelayInMillisecs, vertx,
oidcConfig.useBlockingDnsLookup)
.onItem().transform(json -> new OidcConfigurationMetadata(json.getString("token_endpoint"),
json.getString("revocation_endpoint")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,15 @@ public static Predicate<? super Throwable> oidcEndpointNotAvailable() {
}

public static Uni<JsonObject> discoverMetadata(WebClient client, Map<OidcEndpoint.Type, List<OidcRequestFilter>> filters,
String authServerUrl, long connectionDelayInMillisecs, Vertx vertx, boolean blockingDnsLookup) {
OidcRequestContextProperties contextProperties, String authServerUrl,
long connectionDelayInMillisecs, Vertx vertx, boolean blockingDnsLookup) {
final String discoveryUrl = getDiscoveryUri(authServerUrl);
HttpRequest<Buffer> request = client.getAbs(discoveryUrl);
if (!filters.isEmpty()) {
OidcRequestContextProperties requestProps = new OidcRequestContextProperties(
Map.of(OidcRequestContextProperties.DISCOVERY_ENDPOINT, discoveryUrl));
Map<String, Object> newProperties = contextProperties == null ? new HashMap<>()
: new HashMap<>(contextProperties.getAll());
newProperties.put(OidcRequestContextProperties.DISCOVERY_ENDPOINT, discoveryUrl);
OidcRequestContextProperties requestProps = new OidcRequestContextProperties(newProperties);
for (OidcRequestFilter filter : getMatchingOidcRequestFilters(filters, OidcEndpoint.Type.DISCOVERY)) {
filter.filter(request, null, requestProps);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
public class OidcProviderClient implements Closeable {
private static final Logger LOG = Logger.getLogger(OidcProviderClient.class);

private static final String TENANT_ID_ATTRIBUTE = "oidc-tenant-id";
private static final String AUTHORIZATION_HEADER = String.valueOf(HttpHeaders.AUTHORIZATION);
private static final String CONTENT_TYPE_HEADER = String.valueOf(HttpHeaders.CONTENT_TYPE);
private static final String ACCEPT_HEADER = String.valueOf(HttpHeaders.ACCEPT);
Expand Down Expand Up @@ -265,6 +266,7 @@ private HttpRequest<Buffer> filter(OidcEndpoint.Type endpointType, HttpRequest<B
if (!filters.isEmpty()) {
Map<String, Object> newProperties = contextProperties == null ? new HashMap<>()
: new HashMap<>(contextProperties.getAll());
newProperties.put(OidcUtils.TENANT_ID_ATTRIBUTE, oidcConfig.getTenantId().orElse(OidcUtils.DEFAULT_TENANT_ID));
newProperties.put(OidcConfigurationMetadata.class.getName(), metadata);
OidcRequestContextProperties newContextProperties = new OidcRequestContextProperties(newProperties);
for (OidcRequestFilter filter : OidcCommonUtils.getMatchingOidcRequestFilters(filters, endpointType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.quarkus.oidc.TenantConfigResolver;
import io.quarkus.oidc.TenantIdentityProvider;
import io.quarkus.oidc.common.OidcEndpoint;
import io.quarkus.oidc.common.OidcRequestContextProperties;
import io.quarkus.oidc.common.OidcRequestFilter;
import io.quarkus.oidc.common.runtime.OidcCommonConfig;
import io.quarkus.oidc.common.runtime.OidcCommonUtils;
Expand Down Expand Up @@ -487,8 +488,11 @@ protected static Uni<OidcProviderClient> createOidcClientUni(OidcTenantConfig oi
metadataUni = Uni.createFrom().item(createLocalMetadata(oidcConfig, authServerUriString));
} else {
final long connectionDelayInMillisecs = OidcCommonUtils.getConnectionDelayInMillis(oidcConfig);
OidcRequestContextProperties contextProps = new OidcRequestContextProperties(
Map.of(OidcUtils.TENANT_ID_ATTRIBUTE, oidcConfig.getTenantId().orElse(OidcUtils.DEFAULT_TENANT_ID)));
metadataUni = OidcCommonUtils
.discoverMetadata(client, oidcRequestFilters, authServerUriString, connectionDelayInMillisecs, mutinyVertx,
.discoverMetadata(client, oidcRequestFilters, contextProps, authServerUriString, connectionDelayInMillisecs,
mutinyVertx,
oidcConfig.useBlockingDnsLookup)
.onItem()
.transform(new Function<JsonObject, OidcConfigurationMetadata>() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class OidcRequestCustomizer implements OidcRequestFilter {
public void filter(HttpRequest<Buffer> request, Buffer buffer, OidcRequestContextProperties contextProps) {
String uri = request.uri();
if (uri.endsWith("/non-standard-tokens")) {
request.putHeader("client-id", contextProps.getString("client-id"));
request.putHeader("GrantType", getGrantType(buffer.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.it.keycloak;

import static com.github.tomakehurst.wiremock.client.WireMock.containing;
import static com.github.tomakehurst.wiremock.client.WireMock.matching;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;

Expand Down Expand Up @@ -54,6 +55,7 @@ public Map<String, String> start() {
server.stubFor(WireMock.post("/non-standard-tokens")
.withHeader("X-Custom", matching("XCustomHeaderValue"))
.withHeader("GrantType", matching("password"))
.withHeader("client-id", containing("non-standard-response"))
.withRequestBody(matching("grant_type=password&username=alice&password=alice&extra_param=extra_param_value"))
.willReturn(WireMock
.aResponse()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.quarkus.oidc.common.OidcEndpoint.Type;
import io.quarkus.oidc.common.OidcRequestContextProperties;
import io.quarkus.oidc.common.OidcRequestFilter;
import io.quarkus.oidc.runtime.OidcUtils;
import io.vertx.mutiny.core.buffer.Buffer;
import io.vertx.mutiny.ext.web.client.HttpRequest;

Expand All @@ -23,6 +24,7 @@ public void filter(HttpRequest<Buffer> request, Buffer buffer, OidcRequestContex
throw new OIDCException("Filter is applied to the wrong endpoint: " + request.uri());
}
request.putHeader("Filter", "OK");
request.putHeader(OidcUtils.TENANT_ID_ATTRIBUTE, contextProps.getString(OidcUtils.TENANT_ID_ATTRIBUTE));
}

private boolean isJwksRequest(HttpRequest<Buffer> request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public void testAccessResourceAzure() throws Exception {
String azureJwk = readFile("jwks.json");
wireMockServer.stubFor(WireMock.get("/auth/azure/jwk")
.withHeader("Authorization", matching("Access token: " + azureToken))
.withHeader("Filter", matching("OK"))
.withHeader("tenant-id", matching("bearer-azure"))
.willReturn(WireMock.aResponse().withBody(azureJwk)));
RestAssured.given().auth().oauth2(azureToken)
.when().get("/api/admin/bearer-azure")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package io.quarkus.it.keycloak;

import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.absent;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.not;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;

Expand All @@ -26,6 +28,7 @@ public void start() {
server.stubFor(
get(urlEqualTo("/auth/realms/quarkus2/.well-known/openid-configuration"))
.withHeader("Filter", equalTo("OK"))
.withHeader("tenant-id", not(absent()))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("{\n" +
Expand All @@ -36,6 +39,7 @@ public void start() {
server.stubFor(
get(urlEqualTo("/auth/realms/quarkus2/protocol/openid-connect/certs"))
.withHeader("Filter", equalTo("OK"))
.withHeader("tenant-id", not(absent()))
.willReturn(aResponse()
.withHeader("Content-Type", "application/json")
.withBody("{\n" +
Expand Down

0 comments on commit fd062d2

Please sign in to comment.