From 375e0e6827216ceb38e9bcf0b97b46bb79d79de6 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Thu, 18 Jan 2024 15:32:01 +0100 Subject: [PATCH] Handle Content-Length in ShallowEtagHeaderFilter more robustly This commit ensures that setting the Content-Length through setHeader("Content-Length", x") has the same effect as calling setContentLength in the ShallowEtagHeaderFilter. It also filters out Content-Type headers similarly to Content-Length. Closes gh-32039 --- .../util/ContentCachingResponseWrapper.java | 139 +++++++++++++++++- .../filter/ShallowEtagHeaderFilterTests.java | 9 +- 2 files changed, 142 insertions(+), 6 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index dd9a46cab6fc..4223a673976c 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,10 @@ import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; import jakarta.servlet.ServletOutputStream; import jakarta.servlet.WriteListener; @@ -55,6 +59,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @Nullable private Integer contentLength; + @Nullable + private String contentType; + /** * Create a new ContentCachingResponseWrapper for the given servlet response. @@ -139,6 +146,122 @@ public void setContentLengthLong(long len) { this.contentLength = lenInt; } + @Override + public void setContentType(String type) { + this.contentType = type; + } + + @Override + @Nullable + public String getContentType() { + return this.contentType; + } + + @Override + public boolean containsHeader(String name) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + return this.contentLength != null; + } + else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + return this.contentType != null; + } + else { + return super.containsHeader(name); + } + } + + @Override + public void setHeader(String name, String value) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + this.contentLength = Integer.valueOf(value); + } + else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + this.contentType = value; + } + else { + super.setHeader(name, value); + } + } + + @Override + public void addHeader(String name, String value) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + this.contentLength = Integer.valueOf(value); + } + else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + this.contentType = value; + } + else { + super.addHeader(name, value); + } + } + + @Override + public void setIntHeader(String name, int value) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + this.contentLength = Integer.valueOf(value); + } + else { + super.setIntHeader(name, value); + } + } + + @Override + public void addIntHeader(String name, int value) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + this.contentLength = Integer.valueOf(value); + } + else { + super.addIntHeader(name, value); + } + } + + @Override + @Nullable + public String getHeader(String name) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + return (this.contentLength != null) ? this.contentLength.toString() : null; + } + else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + return this.contentType; + } + else { + return super.getHeader(name); + } + } + + @Override + public Collection getHeaders(String name) { + if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { + return this.contentLength != null ? Collections.singleton(this.contentLength.toString()) : + Collections.emptySet(); + } + else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { + return this.contentType != null ? Collections.singleton(this.contentType) : Collections.emptySet(); + } + else { + return super.getHeaders(name); + } + } + + @Override + public Collection getHeaderNames() { + Collection headerNames = super.getHeaderNames(); + if (this.contentLength != null || this.contentType != null) { + List result = new ArrayList<>(headerNames); + if (this.contentLength != null) { + result.add(HttpHeaders.CONTENT_LENGTH); + } + if (this.contentType != null) { + result.add(HttpHeaders.CONTENT_TYPE); + } + return result; + } + else { + return headerNames; + } + } + @Override public void setBufferSize(int size) { if (size > this.content.size()) { @@ -197,11 +320,17 @@ public void copyBodyToResponse() throws IOException { protected void copyBodyToResponse(boolean complete) throws IOException { if (this.content.size() > 0) { HttpServletResponse rawResponse = (HttpServletResponse) getResponse(); - if ((complete || this.contentLength != null) && !rawResponse.isCommitted()) { - if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) { - rawResponse.setContentLength(complete ? this.content.size() : this.contentLength); + if (!rawResponse.isCommitted()) { + if (complete || this.contentLength != null) { + if (rawResponse.getHeader(HttpHeaders.TRANSFER_ENCODING) == null) { + rawResponse.setContentLength(complete ? this.content.size() : this.contentLength); + } + this.contentLength = null; + } + if (complete || this.contentType != null) { + rawResponse.setContentType(this.contentType); + this.contentType = null; } - this.contentLength = null; } this.content.writeTo(rawResponse.getOutputStream()); this.content.reset(); diff --git a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java index f688e6cbb381..0de0ab476a1a 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java @@ -23,6 +23,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; +import org.springframework.http.MediaType; import org.springframework.util.FileCopyUtils; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; import org.springframework.web.testfixture.servlet.MockHttpServletResponse; @@ -68,6 +69,7 @@ void filterNoMatch() throws Exception { FilterChain filterChain = (filterRequest, filterResponse) -> { assertThat(filterRequest).as("Invalid request passed").isEqualTo(request); ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE); FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); }; filter.doFilter(request, response, filterChain); @@ -75,6 +77,7 @@ void filterNoMatch() throws Exception { assertThat(response.getStatus()).as("Invalid status").isEqualTo(200); assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\""); assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0); + assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE); assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody); } @@ -88,6 +91,7 @@ void filterNoMatchWeakETag() throws Exception { FilterChain filterChain = (filterRequest, filterResponse) -> { assertThat(filterRequest).as("Invalid request passed").isEqualTo(request); ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE); FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); }; filter.doFilter(request, response, filterChain); @@ -95,6 +99,7 @@ void filterNoMatchWeakETag() throws Exception { assertThat(response.getStatus()).as("Invalid status").isEqualTo(200); assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("W/\"0b10a8db164e0754105b7a99be72e3fe5\""); assertThat(response.getContentLength()).as("Invalid Content-Length header").isGreaterThan(0); + assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(MediaType.TEXT_PLAIN_VALUE); assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(responseBody); } @@ -108,14 +113,16 @@ void filterMatch() throws Exception { FilterChain filterChain = (filterRequest, filterResponse) -> { assertThat(filterRequest).as("Invalid request passed").isEqualTo(request); byte[] responseBody = "Hello World".getBytes(StandardCharsets.UTF_8); - FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); filterResponse.setContentLength(responseBody.length); + filterResponse.setContentType(MediaType.TEXT_PLAIN_VALUE); + FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); }; filter.doFilter(request, response, filterChain); assertThat(response.getStatus()).as("Invalid status").isEqualTo(304); assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\""); assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse(); + assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse(); byte[] expecteds = new byte[0]; assertThat(response.getContentAsByteArray()).as("Invalid content").isEqualTo(expecteds); }