From e1f51cbce768870aa2f1d8eaa72628f4797c457c Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 13 Jul 2021 16:31:58 +0100 Subject: [PATCH] Check both https and wss in forwarded header checks Closes gh-27097 --- .../web/filter/ForwardedHeaderFilter.java | 4 ++-- .../web/util/UriComponentsBuilder.java | 2 +- .../web/filter/ForwardedHeaderFilterTests.java | 13 ++++++++----- .../web/util/UriComponentsBuilderTests.java | 11 +++++++---- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index 4a4938b49748..4b0b2db90283 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -239,7 +239,7 @@ private static class ForwardedHeaderExtractingRequest extends ForwardedHeaderRem int port = uriComponents.getPort(); this.scheme = uriComponents.getScheme(); - this.secure = "https".equals(this.scheme); + this.secure = "https".equals(this.scheme) || "wss".equals(this.scheme); this.host = uriComponents.getHost(); this.port = (port == -1 ? (this.secure ? 443 : 80) : port); diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java index f814955fe2b3..6787571c4c97 100644 --- a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java @@ -882,7 +882,7 @@ else if (isForwardedSslOn(headers)) { } if (this.scheme != null && ((this.scheme.equals("http") && "80".equals(this.port)) || - (this.scheme.equals("https") && "443".equals(this.port)))) { + ((this.scheme.equals("https") || this.scheme.equals("wss")) && "443".equals(this.port)))) { port(null); } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index f224591e4e0b..35fa6df4c4f7 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -30,6 +30,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -102,10 +104,11 @@ public void shouldNotFilter() { assertThat(this.filter.shouldNotFilter(new MockHttpServletRequest())).isTrue(); } - @Test - public void forwardedRequest() throws Exception { + @ParameterizedTest + @ValueSource(strings = {"https", "wss"}) + public void forwardedRequest(String protocol) throws Exception { this.request.setRequestURI("/mvc-showcase"); - this.request.addHeader(X_FORWARDED_PROTO, "https"); + this.request.addHeader(X_FORWARDED_PROTO, protocol); this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199"); this.request.addHeader(X_FORWARDED_PORT, "443"); this.request.addHeader("foo", "bar"); @@ -115,8 +118,8 @@ public void forwardedRequest() throws Exception { HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest(); assertThat(actual).isNotNull(); - assertThat(actual.getRequestURL().toString()).isEqualTo("https://84.198.58.199/mvc-showcase"); - assertThat(actual.getScheme()).isEqualTo("https"); + assertThat(actual.getRequestURL().toString()).isEqualTo(protocol + "://84.198.58.199/mvc-showcase"); + assertThat(actual.getScheme()).isEqualTo(protocol); assertThat(actual.getServerName()).isEqualTo("84.198.58.199"); assertThat(actual.getServerPort()).isEqualTo(443); assertThat(actual.isSecure()).isTrue(); diff --git a/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java index 3b8a5168cfb2..52623dacc9af 100644 --- a/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java +++ b/spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java @@ -28,6 +28,8 @@ import java.util.function.BiConsumer; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpRequest; @@ -374,10 +376,11 @@ void fromHttpRequest() { assertThat(result.getQuery()).isEqualTo("a=1"); } - @Test // SPR-12771 - void fromHttpRequestResetsPortBeforeSettingIt() { + @ParameterizedTest // gh-17368, gh-27097 + @ValueSource(strings = {"https", "wss"}) + void fromHttpRequestResetsPortBeforeSettingIt(String protocol) { MockHttpServletRequest request = new MockHttpServletRequest(); - request.addHeader("X-Forwarded-Proto", "https"); + request.addHeader("X-Forwarded-Proto", protocol); request.addHeader("X-Forwarded-Host", "84.198.58.199"); request.addHeader("X-Forwarded-Port", 443); request.setScheme("http"); @@ -388,7 +391,7 @@ void fromHttpRequestResetsPortBeforeSettingIt() { HttpRequest httpRequest = new ServletServerHttpRequest(request); UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build(); - assertThat(result.getScheme()).isEqualTo("https"); + assertThat(result.getScheme()).isEqualTo(protocol); assertThat(result.getHost()).isEqualTo("84.198.58.199"); assertThat(result.getPort()).isEqualTo(-1); assertThat(result.getPath()).isEqualTo("/rest/mobile/users/1");