Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement CSRF feature as ServerRequestFilter with form read #29977

Merged
merged 1 commit into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,31 +1,15 @@
package io.quarkus.csrf.reactive;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;

import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.MethodInfo;
import org.jboss.resteasy.reactive.server.model.FixedHandlersChainCustomizer;
import org.jboss.resteasy.reactive.server.model.HandlerChainCustomizer;
import org.jboss.resteasy.reactive.server.processor.scanning.MethodScanner;

import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.csrf.reactive.runtime.CsrfHandler;
import io.quarkus.csrf.reactive.runtime.CsrfReactiveConfig;
import io.quarkus.csrf.reactive.runtime.CsrfRecorder;
import io.quarkus.csrf.reactive.runtime.CsrfResponseFilter;
import io.quarkus.csrf.reactive.runtime.CsrfRequestResponseReactiveFilter;
import io.quarkus.csrf.reactive.runtime.CsrfTokenParameterProvider;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.BuildSteps;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.AdditionalIndexedClassesBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.resteasy.reactive.server.spi.HandlerConfigurationProviderBuildItem;
import io.quarkus.resteasy.reactive.server.spi.MethodScannerBuildItem;

@BuildSteps(onlyIf = CsrfReactiveBuildStep.IsEnabled.class)
public class CsrfReactiveBuildStep {
Expand All @@ -34,33 +18,11 @@ public class CsrfReactiveBuildStep {
void registerProvider(BuildProducer<AdditionalBeanBuildItem> additionalBeans,
BuildProducer<ReflectiveClassBuildItem> reflectiveClass,
BuildProducer<AdditionalIndexedClassesBuildItem> additionalIndexedClassesBuildItem) {
additionalBeans.produce(AdditionalBeanBuildItem.unremovableOf(CsrfResponseFilter.class));
reflectiveClass.produce(new ReflectiveClassBuildItem(true, true, CsrfResponseFilter.class));
additionalIndexedClassesBuildItem
.produce(new AdditionalIndexedClassesBuildItem(CsrfResponseFilter.class.getName()));

additionalBeans.produce(AdditionalBeanBuildItem.unremovableOf(CsrfRequestResponseReactiveFilter.class));
reflectiveClass.produce(new ReflectiveClassBuildItem(true, true, CsrfRequestResponseReactiveFilter.class));
additionalBeans.produce(AdditionalBeanBuildItem.unremovableOf(CsrfTokenParameterProvider.class));
}

@BuildStep
public MethodScannerBuildItem configureHandler() {
return new MethodScannerBuildItem(new MethodScanner() {
@Override
public List<HandlerChainCustomizer> scan(MethodInfo method, ClassInfo actualEndpointClass,
Map<String, Object> methodContext) {
return Collections.singletonList(
new FixedHandlersChainCustomizer(
List.of(new CsrfHandler()),
HandlerChainCustomizer.Phase.BEFORE_METHOD_INVOKE));
}
});
}

@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
public HandlerConfigurationProviderBuildItem applyRuntimeConfig(CsrfRecorder recorder,
CsrfReactiveConfig csrfReactiveConfig) {
return new HandlerConfigurationProviderBuildItem(CsrfReactiveConfig.class, recorder.configure(csrfReactiveConfig));
additionalIndexedClassesBuildItem
.produce(new AdditionalIndexedClassesBuildItem(CsrfRequestResponseReactiveFilter.class.getName()));
}

public static class IsEnabled implements BooleanSupplier {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
package io.quarkus.csrf.reactive.runtime;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.security.SecureRandom;
import java.util.Base64;

import javax.enterprise.inject.Instance;
import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;

import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.server.ServerRequestFilter;
import org.jboss.resteasy.reactive.server.ServerResponseFilter;
import org.jboss.resteasy.reactive.server.WithFormRead;
import org.jboss.resteasy.reactive.server.core.ResteasyReactiveRequestContext;
import org.jboss.resteasy.reactive.server.spi.GenericRuntimeConfigurableServerRestHandler;
import org.jboss.resteasy.reactive.server.spi.ResteasyReactiveContainerRequestContext;

import io.vertx.core.http.Cookie;
import io.vertx.core.http.impl.CookieImpl;
import io.vertx.core.http.impl.ServerCookie;
import io.vertx.ext.web.RoutingContext;

public class CsrfHandler implements GenericRuntimeConfigurableServerRestHandler<CsrfReactiveConfig> {
private static final Logger LOG = Logger.getLogger(CsrfHandler.class);
public class CsrfRequestResponseReactiveFilter {
private static final Logger LOG = Logger.getLogger(CsrfRequestResponseReactiveFilter.class);

/**
* CSRF token key.
Expand All @@ -30,27 +36,12 @@ public class CsrfHandler implements GenericRuntimeConfigurableServerRestHandler<
*/
private static final String CSRF_TOKEN_VERIFIED = "csrf_token_verified";

// although technically the field does not need to be volatile (since the access mode is determined by the VarHandle use)
// it is a recommended practice by Doug Lea meant to catch cases where the field is accessed directly (by accident)
@SuppressWarnings("unused")
private volatile SecureRandom secureRandom;

// use a VarHandle to access the secureRandom as the value is written only by the main thread
// and all other threads simply read the value, and thus we can use the Release / Acquire access mode
private static final VarHandle SECURE_RANDOM_VH;

static {
try {
SECURE_RANDOM_VH = MethodHandles.lookup().findVarHandle(CsrfHandler.class, "secureRandom",
SecureRandom.class);
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new Error(e);
}
}
private final SecureRandom secureRandom = new SecureRandom();

private CsrfReactiveConfig config;
@Inject
Instance<CsrfReactiveConfig> configInstance;

public CsrfHandler() {
public CsrfRequestResponseReactiveFilter() {
}

/**
Expand All @@ -68,10 +59,10 @@ public CsrfHandler() {
* {@value #CSRF_TOKEN_KEY} and value that is equal to the one supplied in the cookie.</li>
* </ul>
*/
public void handle(ResteasyReactiveRequestContext reactiveRequestContext) {
final ContainerRequestContext requestContext = reactiveRequestContext.getContainerRequestContext();

final RoutingContext routing = reactiveRequestContext.serverRequest().unwrap(RoutingContext.class);
@ServerRequestFilter
@WithFormRead
public void filter(ResteasyReactiveContainerRequestContext requestContext, RoutingContext routing) {
final CsrfReactiveConfig config = this.configInstance.get();

String cookieToken = getCookieToken(routing, config);
if (cookieToken != null) {
Expand Down Expand Up @@ -99,7 +90,7 @@ public void handle(ResteasyReactiveRequestContext reactiveRequestContext) {
if (cookieToken == null && isCsrfTokenRequired(routing, config)) {
// Set the CSRF cookie with a randomly generated value
byte[] tokenBytes = new byte[config.tokenSize];
getSecureRandom().nextBytes(tokenBytes);
secureRandom.nextBytes(tokenBytes);
routing.put(CSRF_TOKEN_BYTES_KEY, tokenBytes);
routing.put(CSRF_TOKEN_KEY, Base64.getUrlEncoder().withoutPadding().encodeToString(tokenBytes));
}
Expand All @@ -115,7 +106,6 @@ public void handle(ResteasyReactiveRequestContext reactiveRequestContext) {
} else {
LOG.debugf("Request has the media type: %s, skipping the token verification",
requestContext.getMediaType().toString());
requestContext.abortWith(badClientRequest());
return;
}
}
Expand All @@ -132,8 +122,9 @@ public void handle(ResteasyReactiveRequestContext reactiveRequestContext) {
return;
}

String csrfToken = (String) reactiveRequestContext.getFormParameter(config.formFieldName, true, true);

ResteasyReactiveRequestContext rrContext = (ResteasyReactiveRequestContext) requestContext
.getServerRequestContext();
String csrfToken = (String) rrContext.getFormParameter(config.formFieldName, true, false);
if (csrfToken == null) {
LOG.debug("CSRF token is not found");
requestContext.abortWith(badClientRequest());
Expand All @@ -148,6 +139,7 @@ public void handle(ResteasyReactiveRequestContext reactiveRequestContext) {
return;
} else {
routing.put(CSRF_TOKEN_VERIFIED, true);
return;
}
}
} else if (cookieToken == null) {
Expand All @@ -156,10 +148,6 @@ public void handle(ResteasyReactiveRequestContext reactiveRequestContext) {
}
}

private SecureRandom getSecureRandom() {
return (SecureRandom) SECURE_RANDOM_VH.getAcquire(this);
}

private static boolean isMatchingMediaType(MediaType contentType, MediaType expectedType) {
return contentType.getType().equals(expectedType.getType())
&& contentType.getSubtype().equals(expectedType.getSubtype());
Expand All @@ -169,6 +157,47 @@ private static Response badClientRequest() {
return Response.status(400).build();
}

/**
* If the requirements below are true, sets a cookie by the name {@value #CSRF_TOKEN_KEY} that contains a CSRF token.
* <ul>
* <li>The request method is {@code GET}.</li>
* <li>The request does not contain a valid CSRF token cookie.</li>
* </ul>
*
* @throws IllegalStateException if the {@link RoutingContext} does not have a value for the key {@value #CSRF_TOKEN_KEY}
* and a cookie needs to be set.
*/
@ServerResponseFilter
public void filter(ContainerRequestContext requestContext,
ContainerResponseContext responseContext, RoutingContext routing) {
final CsrfReactiveConfig config = configInstance.get();
if (requestContext.getMethod().equals("GET") && isCsrfTokenRequired(routing, config)
&& getCookieToken(routing, config) == null) {

String cookieValue = null;
if (config.tokenSignatureKey.isPresent()) {
byte[] csrfTokenBytes = (byte[]) routing.get(CSRF_TOKEN_BYTES_KEY);

if (csrfTokenBytes == null) {
throw new IllegalStateException(
"CSRF Filter should have set the property " + CSRF_TOKEN_KEY + ", but it is null");
}
cookieValue = CsrfTokenUtils.signCsrfToken(csrfTokenBytes, config.tokenSignatureKey.get());
} else {
String csrfToken = (String) routing.get(CSRF_TOKEN_KEY);

if (csrfToken == null) {
throw new IllegalStateException(
"CSRF Filter should have set the property " + CSRF_TOKEN_KEY + ", but it is null");
}
cookieValue = csrfToken;
}

createCookie(cookieValue, routing, config);
}

}

/**
* Gets the CSRF token from the CSRF cookie from the current {@code RoutingContext}.
*
Expand All @@ -189,6 +218,19 @@ private boolean isCsrfTokenRequired(RoutingContext routing, CsrfReactiveConfig c
return config.createTokenPath.isPresent() ? config.createTokenPath.get().contains(routing.request().path()) : true;
}

private void createCookie(String csrfToken, RoutingContext routing, CsrfReactiveConfig config) {

ServerCookie cookie = new CookieImpl(config.cookieName, csrfToken);
cookie.setHttpOnly(true);
cookie.setSecure(config.cookieForceSecure || routing.request().isSSL());
cookie.setMaxAge(config.cookieMaxAge.toSeconds());
cookie.setPath(config.cookiePath);
if (config.cookieDomain.isPresent()) {
cookie.setDomain(config.cookieDomain.get());
}
routing.response().addCookie(cookie);
}

private static boolean requestMethodIsSafe(ContainerRequestContext context) {
switch (context.getMethod()) {
case "GET":
Expand All @@ -199,14 +241,4 @@ private static boolean requestMethodIsSafe(ContainerRequestContext context) {
return false;
}
}

public void configure(CsrfReactiveConfig configuration) {
this.config = configuration;
SECURE_RANDOM_VH.setRelease(this, new SecureRandom());
}

@Override
public Class<CsrfReactiveConfig> getConfigurationClass() {
return CsrfReactiveConfig.class;
}
}
Loading