Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the headers and co on the normal response #29

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/main/java/org/pac4j/jax/rs/filters/AbstractFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.ext.Providers;

import org.pac4j.core.authorization.authorizer.Authorizer;
import org.pac4j.core.config.Config;
import org.pac4j.core.http.HttpActionAdapter;
import org.pac4j.jax.rs.features.JaxRsContextFactoryProvider.JaxRsContextFactory;
Expand All @@ -18,7 +21,7 @@
* @since 1.0.0
*
*/
public abstract class AbstractFilter implements ContainerRequestFilter {
public abstract class AbstractFilter implements ContainerRequestFilter, ContainerResponseFilter {

protected Boolean skipResponse;

Expand All @@ -27,7 +30,7 @@ public abstract class AbstractFilter implements ContainerRequestFilter {
public AbstractFilter(Providers providers) {
this.providers = providers;
}

protected Config getConfig() {
return ProvidersHelper.getContext(providers, Config.class);
}
Expand All @@ -43,6 +46,22 @@ public void filter(ContainerRequestContext requestContext) throws IOException {
filter(context);
}

@Override
public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext)
throws IOException {
// in case the filter aborts the request, we never arrive here, but if it is not aborted
// there is case when pac4j sets things on the response, this is the role of this method.
// unfortunately, if skipResponse is used, we can't do that because pac4j considers
// its abort response in the same way as the normal response
if (skipResponse == null || !skipResponse) {
JaxRsContext context = ProvidersHelper.getContext(providers, JaxRsContextFactory.class)
.provides(requestContext);
assert context != null;

context.getResponseHolder().populateResponse(responseContext);
}
}

/**
* Prefer to set a specific {@link HttpActionAdapter} on the {@link Config} instead of overriding this method.
*
Expand Down Expand Up @@ -70,6 +89,9 @@ public Boolean isSkipResponse() {
}

/**
* Note that if this is set to <code>true</code>, this will also disable the effects of {@link Authorizer} and such
* that set things on the HTTP response! Use with caution!
*
* @param skipResponse
* If set to <code>true</code>, the pac4j response, such as redirect, will be skipped (the annotated
* method will be executed instead).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.pac4j.jax.rs.filters;

import javax.ws.rs.core.Response;

import org.pac4j.core.http.HttpActionAdapter;
import org.pac4j.jax.rs.pac4j.JaxRsContext;

Expand All @@ -15,7 +17,9 @@ public class JaxRsHttpActionAdapter implements HttpActionAdapter<Object, JaxRsCo

@Override
public Object adapt(int code, JaxRsContext context) {
context.getRequestContext().abortWith(context.getAbortBuilder().build());
Response response = context.getAbortBuilder().build();
assert response.getStatus() == code;
context.getRequestContext().abortWith(response);
return null;
}

Expand Down
93 changes: 90 additions & 3 deletions src/main/java/org/pac4j/jax/rs/pac4j/JaxRsContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.core.Form;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedHashMap;
import javax.ws.rs.core.MultivaluedMap;
Expand All @@ -38,6 +44,8 @@
*/
public class JaxRsContext implements WebContext {

public static final String RESPONSE_HOLDER = JaxRsContext.class + ".ResponseHolder";

private final ContainerRequestContext requestContext;

private SessionStore sessionStore;
Expand Down Expand Up @@ -80,35 +88,114 @@ public ResponseBuilder getAbortBuilder() {
return abortResponse;
}

public ResponseHolder getResponseHolder() {
ResponseHolder prop = (ResponseHolder) requestContext.getProperty(RESPONSE_HOLDER);
if (prop == null) {
prop = new ResponseHolder();
requestContext.setProperty(RESPONSE_HOLDER, prop);
}
return prop;
}

public static class ResponseHolder {

private boolean hasResponseContent = false;

private String responseContent = null;

private boolean hasResponseStatus = false;

private int responseStatus = 0;

private boolean hasResponseContentType = false;

private MediaType responseContentType = null;

private final Map<String, String> responseHeaders = new HashMap<>();

private final Set<NewCookie> responseCookies = new HashSet<>();

public void writeResponseContent(String content) {
responseContent = content;
hasResponseContent = true;
}

public void setResponseStatus(int code) {
responseStatus = code;
hasResponseStatus = true;
}

public void setResponseHeader(String name, String value) {
responseHeaders.put(name, value);
}

public void addResponseCookie(NewCookie cookie) {
responseCookies.add(cookie);
}

public void setResponseContentType(MediaType type) {
responseContentType = type;
hasResponseContentType = true;
}

public void populateResponse(ContainerResponseContext responseContext) {
if (hasResponseContent) {
responseContext.setEntity(responseContent);
}
if (hasResponseContentType) {
responseContext.getHeaders().putSingle(HttpHeaders.CONTENT_TYPE, responseContentType);
}
if (hasResponseStatus) {
responseContext.setStatus(responseStatus);
}
for (Entry<String, String> headers : responseHeaders.entrySet()) {
responseContext.getHeaders().putSingle(headers.getKey(), headers.getValue());
}
for (NewCookie cookie : responseCookies) {
responseContext.getHeaders().add(HttpHeaders.SET_COOKIE, cookie);
}
}
}

@Override
public void writeResponseContent(String content) {
getAbortBuilder().entity(content);
getResponseHolder().writeResponseContent(content);
}


@Override
public void setResponseStatus(int code) {
getAbortBuilder().status(code);
getResponseHolder().setResponseStatus(code);
}

@Override
public void setResponseHeader(String name, String value) {
CommonHelper.assertNotNull("name", name);
// header() adds headers, so we must remove the previous value first
getAbortBuilder().header(name, null);
getAbortBuilder().header(name, value);
getResponseHolder().setResponseHeader(name, value);
}

@Override
public void setResponseContentType(String content) {
getAbortBuilder().type(content);
MediaType type = content == null ? null : MediaType.valueOf(content);
getAbortBuilder().type(type);
getResponseHolder().setResponseContentType(type);
}

@Override
public void addResponseCookie(Cookie cookie) {
CommonHelper.assertNotNull("cookie", cookie);
// Note: expiry is not in servlet and is meant to be superseeded by
// max-age, so we simply make it null
getAbortBuilder().cookie(new NewCookie(cookie.getName(), cookie.getValue(), cookie.getPath(),
NewCookie c = new NewCookie(cookie.getName(), cookie.getValue(), cookie.getPath(),
cookie.getDomain(), cookie.getVersion(), cookie.getComment(), cookie.getMaxAge(), null,
cookie.isSecure(), cookie.isHttpOnly()));
cookie.isSecure(), cookie.isHttpOnly());
getAbortBuilder().cookie(c);
getResponseHolder().addResponseCookie(c);
}

/**
Expand Down
13 changes: 13 additions & 0 deletions src/test/java/org/pac4j/jax/rs/AbstractTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import javax.ws.rs.core.Form;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;

import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -185,4 +186,16 @@ public void directInjectSkipFail() {
.post(Entity.entity(form, MediaType.APPLICATION_FORM_URLENCODED_TYPE), String.class);
assertThat(ok).isEqualTo("fail");
}

@Test
public void directResponseHeadersSet() {
Form form = new Form();
form.param("username", "foo");
form.param("password", "foo");
final Response ok = container.getTarget("/directResponseHeadersSet").request()
.post(Entity.entity(form, MediaType.APPLICATION_FORM_URLENCODED_TYPE));
assertThat(ok.getStatus()).isEqualTo(Status.OK.getStatusCode());
assertThat(ok.readEntity(String.class)).isEqualTo("ok");
assertThat(ok.getHeaderString("X-Content-Type-Options")).isEqualTo("nosniff");
}
}
8 changes: 8 additions & 0 deletions src/test/java/org/pac4j/jax/rs/resources/TestResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,12 @@ public String directInjectSkip(@Pac4JProfile(readFromSession = false) Optional<C
return "fail";
}
}

@POST
@Path("directResponseHeadersSet")
@Pac4JSecurity(clients = "DirectFormClient", authorizers = { DefaultAuthorizers.IS_AUTHENTICATED,
DefaultAuthorizers.NOSNIFF })
public String directResponseHeadersSet() {
return "ok";
}
}