diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java
index 017c0f635628..1487809d942d 100644
--- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java
+++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java
@@ -43,6 +43,7 @@
*
Used e.g. by {@link org.springframework.web.filter.ShallowEtagHeaderFilter}.
*
* @author Juergen Hoeller
+ * @author Sam Brannen
* @since 4.1.3
* @see ContentCachingRequestWrapper
*/
@@ -157,16 +158,19 @@ public void setContentType(@Nullable String type) {
@Override
@Nullable
public String getContentType() {
- return this.contentType;
+ if (this.contentType != null) {
+ return this.contentType;
+ }
+ return super.getContentType();
}
@Override
public boolean containsHeader(String name) {
- if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
- return this.contentLength != null;
+ if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
+ return true;
}
- else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
- return this.contentType != null;
+ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
+ return true;
}
else {
return super.containsHeader(name);
@@ -222,10 +226,10 @@ public void addIntHeader(String name, int value) {
@Override
@Nullable
public String getHeader(String name) {
- if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
- return (this.contentLength != null) ? this.contentLength.toString() : null;
+ if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
+ return this.contentLength.toString();
}
- else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
+ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType;
}
else {
@@ -235,12 +239,11 @@ else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
@Override
public Collection getHeaders(String name) {
- if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
- return this.contentLength != null ? Collections.singleton(this.contentLength.toString()) :
- Collections.emptySet();
+ if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
+ return Collections.singleton(this.contentLength.toString());
}
- else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
- return this.contentType != null ? Collections.singleton(this.contentType) : Collections.emptySet();
+ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
+ return Collections.singleton(this.contentType);
}
else {
return super.getHeaders(name);
@@ -330,7 +333,7 @@ protected void copyBodyToResponse(boolean complete) throws IOException {
}
this.contentLength = null;
}
- if (complete || this.contentType != null) {
+ if (this.contentType != null) {
rawResponse.setContentType(this.contentType);
this.contentType = null;
}
diff --git a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java
index 23912e9d7758..e63f206f9568 100644
--- a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java
+++ b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java
@@ -19,6 +19,7 @@
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;
+import org.springframework.http.MediaType;
import org.springframework.util.FileCopyUtils;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import org.springframework.web.util.ContentCachingResponseWrapper;
@@ -26,12 +27,14 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
+import static org.springframework.http.HttpHeaders.CONTENT_TYPE;
import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING;
/**
* Tests for {@link ContentCachingResponseWrapper}.
*
* @author Rossen Stoyanchev
+ * @author Sam Brannen
*/
class ContentCachingResponseWrapperTests {
@@ -50,6 +53,79 @@ void copyBodyToResponse() throws Exception {
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
}
+ @Test
+ void copyBodyToResponseWithPresetHeaders() throws Exception {
+ String PUZZLE = "puzzle";
+ String ENIGMA = "enigma";
+ String NUMBER = "number";
+ String MAGIC = "42";
+
+ byte[] responseBody = "Hello World".getBytes(UTF_8);
+ String responseLength = Integer.toString(responseBody.length);
+ String contentType = MediaType.APPLICATION_JSON_VALUE;
+
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ response.setContentType(contentType);
+ response.setContentLength(999);
+ response.setHeader(PUZZLE, ENIGMA);
+ response.setIntHeader(NUMBER, 42);
+
+ ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);
+ responseWrapper.setStatus(HttpServletResponse.SC_OK);
+
+ assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
+ assertThat(responseWrapper.getContentSize()).isZero();
+ assertThat(responseWrapper.getHeaderNames())
+ .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertThat(responseWrapper.containsHeader(PUZZLE)).as(PUZZLE).isTrue();
+ assertThat(responseWrapper.getHeader(PUZZLE)).as(PUZZLE).isEqualTo(ENIGMA);
+ assertThat(responseWrapper.getHeaders(PUZZLE)).as(PUZZLE).containsExactly(ENIGMA);
+
+ assertThat(responseWrapper.containsHeader(NUMBER)).as(NUMBER).isTrue();
+ assertThat(responseWrapper.getHeader(NUMBER)).as(NUMBER).isEqualTo(MAGIC);
+ assertThat(responseWrapper.getHeaders(NUMBER)).as(NUMBER).containsExactly(MAGIC);
+
+ assertThat(responseWrapper.containsHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isTrue();
+ assertThat(responseWrapper.getHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isEqualTo(contentType);
+ assertThat(responseWrapper.getHeaders(CONTENT_TYPE)).as(CONTENT_TYPE).containsExactly(contentType);
+ assertThat(responseWrapper.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType);
+
+ assertThat(responseWrapper.containsHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isTrue();
+ assertThat(responseWrapper.getHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isEqualTo("999");
+ assertThat(responseWrapper.getHeaders(CONTENT_LENGTH)).as(CONTENT_LENGTH).containsExactly("999");
+
+ FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
+ responseWrapper.copyBodyToResponse();
+
+ assertThat(responseWrapper.getHeaderNames())
+ .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
+ assertThat(response.getContentType()).isEqualTo(contentType);
+ assertThat(response.getContentLength()).isEqualTo(responseBody.length);
+ assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
+ assertThat(response.getHeaderNames())
+ .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertThat(response.containsHeader(PUZZLE)).as(PUZZLE).isTrue();
+ assertThat(response.getHeader(PUZZLE)).as(PUZZLE).isEqualTo(ENIGMA);
+ assertThat(response.getHeaders(PUZZLE)).as(PUZZLE).containsExactly(ENIGMA);
+
+ assertThat(response.containsHeader(NUMBER)).as(NUMBER).isTrue();
+ assertThat(response.getHeader(NUMBER)).as(NUMBER).isEqualTo(MAGIC);
+ assertThat(response.getHeaders(NUMBER)).as(NUMBER).containsExactly(MAGIC);
+
+ assertThat(response.containsHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isTrue();
+ assertThat(response.getHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isEqualTo(contentType);
+ assertThat(response.getHeaders(CONTENT_TYPE)).as(CONTENT_TYPE).containsExactly(contentType);
+ assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType);
+
+ assertThat(response.containsHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isTrue();
+ assertThat(response.getHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isEqualTo(responseLength);
+ assertThat(response.getHeaders(CONTENT_LENGTH)).as(CONTENT_LENGTH).containsExactly(responseLength);
+ }
+
@Test
void copyBodyToResponseWithTransferEncoding() throws Exception {
byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(UTF_8);