diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index 017c0f635628..1487809d942d 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -43,6 +43,7 @@ *

Used e.g. by {@link org.springframework.web.filter.ShallowEtagHeaderFilter}. * * @author Juergen Hoeller + * @author Sam Brannen * @since 4.1.3 * @see ContentCachingRequestWrapper */ @@ -157,16 +158,19 @@ public void setContentType(@Nullable String type) { @Override @Nullable public String getContentType() { - return this.contentType; + if (this.contentType != null) { + return this.contentType; + } + return super.getContentType(); } @Override public boolean containsHeader(String name) { - if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { - return this.contentLength != null; + if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + return true; } - else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { - return this.contentType != null; + else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + return true; } else { return super.containsHeader(name); @@ -222,10 +226,10 @@ public void addIntHeader(String name, int value) { @Override @Nullable public String getHeader(String name) { - if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { - return (this.contentLength != null) ? this.contentLength.toString() : null; + if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + return this.contentLength.toString(); } - else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { return this.contentType; } else { @@ -235,12 +239,11 @@ else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { @Override public Collection getHeaders(String name) { - if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { - return this.contentLength != null ? Collections.singleton(this.contentLength.toString()) : - Collections.emptySet(); + if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + return Collections.singleton(this.contentLength.toString()); } - else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { - return this.contentType != null ? Collections.singleton(this.contentType) : Collections.emptySet(); + else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + return Collections.singleton(this.contentType); } else { return super.getHeaders(name); @@ -330,7 +333,7 @@ protected void copyBodyToResponse(boolean complete) throws IOException { } this.contentLength = null; } - if (complete || this.contentType != null) { + if (this.contentType != null) { rawResponse.setContentType(this.contentType); this.contentType = null; } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java index 23912e9d7758..e63f206f9568 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java @@ -19,6 +19,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; +import org.springframework.http.MediaType; import org.springframework.util.FileCopyUtils; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; import org.springframework.web.util.ContentCachingResponseWrapper; @@ -26,12 +27,14 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.http.HttpHeaders.CONTENT_LENGTH; +import static org.springframework.http.HttpHeaders.CONTENT_TYPE; import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING; /** * Tests for {@link ContentCachingResponseWrapper}. * * @author Rossen Stoyanchev + * @author Sam Brannen */ class ContentCachingResponseWrapperTests { @@ -50,6 +53,79 @@ void copyBodyToResponse() throws Exception { assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); } + @Test + void copyBodyToResponseWithPresetHeaders() throws Exception { + String PUZZLE = "puzzle"; + String ENIGMA = "enigma"; + String NUMBER = "number"; + String MAGIC = "42"; + + byte[] responseBody = "Hello World".getBytes(UTF_8); + String responseLength = Integer.toString(responseBody.length); + String contentType = MediaType.APPLICATION_JSON_VALUE; + + MockHttpServletResponse response = new MockHttpServletResponse(); + response.setContentType(contentType); + response.setContentLength(999); + response.setHeader(PUZZLE, ENIGMA); + response.setIntHeader(NUMBER, 42); + + ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); + responseWrapper.setStatus(HttpServletResponse.SC_OK); + + assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + assertThat(responseWrapper.getContentSize()).isZero(); + assertThat(responseWrapper.getHeaderNames()) + .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH); + + assertThat(responseWrapper.containsHeader(PUZZLE)).as(PUZZLE).isTrue(); + assertThat(responseWrapper.getHeader(PUZZLE)).as(PUZZLE).isEqualTo(ENIGMA); + assertThat(responseWrapper.getHeaders(PUZZLE)).as(PUZZLE).containsExactly(ENIGMA); + + assertThat(responseWrapper.containsHeader(NUMBER)).as(NUMBER).isTrue(); + assertThat(responseWrapper.getHeader(NUMBER)).as(NUMBER).isEqualTo(MAGIC); + assertThat(responseWrapper.getHeaders(NUMBER)).as(NUMBER).containsExactly(MAGIC); + + assertThat(responseWrapper.containsHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isTrue(); + assertThat(responseWrapper.getHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isEqualTo(contentType); + assertThat(responseWrapper.getHeaders(CONTENT_TYPE)).as(CONTENT_TYPE).containsExactly(contentType); + assertThat(responseWrapper.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType); + + assertThat(responseWrapper.containsHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isTrue(); + assertThat(responseWrapper.getHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isEqualTo("999"); + assertThat(responseWrapper.getHeaders(CONTENT_LENGTH)).as(CONTENT_LENGTH).containsExactly("999"); + + FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); + responseWrapper.copyBodyToResponse(); + + assertThat(responseWrapper.getHeaderNames()) + .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH); + + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + assertThat(response.getContentType()).isEqualTo(contentType); + assertThat(response.getContentLength()).isEqualTo(responseBody.length); + assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); + assertThat(response.getHeaderNames()) + .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH); + + assertThat(response.containsHeader(PUZZLE)).as(PUZZLE).isTrue(); + assertThat(response.getHeader(PUZZLE)).as(PUZZLE).isEqualTo(ENIGMA); + assertThat(response.getHeaders(PUZZLE)).as(PUZZLE).containsExactly(ENIGMA); + + assertThat(response.containsHeader(NUMBER)).as(NUMBER).isTrue(); + assertThat(response.getHeader(NUMBER)).as(NUMBER).isEqualTo(MAGIC); + assertThat(response.getHeaders(NUMBER)).as(NUMBER).containsExactly(MAGIC); + + assertThat(response.containsHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isTrue(); + assertThat(response.getHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isEqualTo(contentType); + assertThat(response.getHeaders(CONTENT_TYPE)).as(CONTENT_TYPE).containsExactly(contentType); + assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType); + + assertThat(response.containsHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isTrue(); + assertThat(response.getHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isEqualTo(responseLength); + assertThat(response.getHeaders(CONTENT_LENGTH)).as(CONTENT_LENGTH).containsExactly(responseLength); + } + @Test void copyBodyToResponseWithTransferEncoding() throws Exception { byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(UTF_8);