Skip to content

Commit

Permalink
Handle Content-Length in ShallowEtagHeaderFilter more robustly
Browse files Browse the repository at this point in the history
This commit ensures that setting the Content-Length through
setHeader("Content-Length", x") has the same effect as calling
setContentLength in the ShallowEtagHeaderFilter. It also filters out
Content-Type headers similarly to Content-Length.

Closes gh-32039
  • Loading branch information
poutsma committed Jan 18, 2024
1 parent b8b31ff commit 375e0e6
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,10 @@
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.WriteListener;
Expand Down Expand Up @@ -55,6 +59,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@Nullable
private Integer contentLength;

@Nullable
private String contentType;


/**
* Create a new ContentCachingResponseWrapper for the given servlet response.
Expand Down Expand Up @@ -139,6 +146,122 @@ public void setContentLengthLong(long len) {
this.contentLength = lenInt;
}

@Override
public void setContentType(String type) {
this.contentType = type;
}

@Override
@Nullable
public String getContentType() {
return this.contentType;
}

@Override
public boolean containsHeader(String name) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return this.contentLength != null;
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType != null;
}
else {
return super.containsHeader(name);
}
}

@Override
public void setHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.setHeader(name, value);
}
}

@Override
public void addHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.addHeader(name, value);
}
}

@Override
public void setIntHeader(String name, int value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else {
super.setIntHeader(name, value);
}
}

@Override
public void addIntHeader(String name, int value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else {
super.addIntHeader(name, value);
}
}

@Override
@Nullable
public String getHeader(String name) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return (this.contentLength != null) ? this.contentLength.toString() : null;
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType;
}
else {
return super.getHeader(name);
}
}

@Override
public Collection<String> getHeaders(String name) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return this.contentLength != null ? Collections.singleton(this.contentLength.toString()) :
Collections.emptySet();
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType != null ? Collections.singleton(this.contentType) : Collections.emptySet();
}
else {
return super.getHeaders(name);
}
}

@Override
public Collection<String> getHeaderNames() {
Collection<String> headerNames = super.getHeaderNames();
if (this.contentLength != null || this.contentType != null) {
List<String> result = new ArrayList<>(headerNames);
if (this.contentLength != null) {
result.add(HttpHeaders.CONTENT_LENGTH);
}
if (this.contentType != null) {
result.add(HttpHeaders.CONTENT_TYPE);
}
return result;
}
else {
return headerNames;
}
}

@Override
public void setBufferSize(int size) {
if (size > this.content.size()) {
Expand Down Expand Up @@ -197,11 +320,17 @@ public void copyBodyToResponse() throws IOException {
protected void copyBodyToResponse(boolean complete) throws IOException {
if (this.content.size() > 0) {
HttpServletResponse rawResponse = (HttpServletResponse) getResponse();
if ((complete || this.contentLength != null) && !rawResponse.isCommitted()) {
if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) {
rawResponse.setContentLength(complete ? this.content.size() : this.contentLength);
if (!rawResponse.isCommitted()) {
if (complete || this.contentLength != null) {
if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) {
rawResponse.setContentLength(complete ? this.content.size() : this.contentLength);
}
this.contentLength = null;
}
if (complete || this.contentType != null) {
rawResponse.setContentType(this.contentType);
this.contentType = null;
}
this.contentLength = null;
}
this.content.writeTo(rawResponse.getOutputStream());
this.content.reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;

import org.springframework.http.MediaType;
import org.springframework.util.FileCopyUtils;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
Expand Down Expand Up @@ -68,13 +69,15 @@ void filterNoMatch() throws Exception {
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};
filter.doFilter(request, response, filterChain);

assertThat(response.getStatus()).as("Invalid status").isEqualTo(200);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0);
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE);
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody);
}

Expand All @@ -88,13 +91,15 @@ void filterNoMatchWeakETag() throws Exception {
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};
filter.doFilter(request, response, filterChain);

assertThat(response.getStatus()).as("Invalid status").isEqualTo(200);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("W/\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0);
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE);
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody);
}

Expand All @@ -108,14 +113,16 @@ void filterMatch() throws Exception {
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
byte[] responseBody = "Hello World".getBytes(StandardCharsets.UTF_8);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
filterResponse.setContentLength(responseBody.length);
filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};
filter.doFilter(request, response, filterChain);

assertThat(response.getStatus()).as("Invalid status").isEqualTo(304);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse();
assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse();
byte[] expecteds = new byte[0];
assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(expecteds);
}
Expand Down

0 comments on commit 375e0e6

Please sign in to comment.