Skip to content

Commit

Permalink
Merge pull request playframework#912 from playframework/mergify/bp/ma…
Browse files Browse the repository at this point in the history
…in/pr-909

[main] StreamedResponse `headers` map is now case insensitive (backport playframework#909) by @mkurz
  • Loading branch information
mkurz authored Jul 8, 2024
2 parents 1f4f1d1 + b9799ae commit a98acd4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import scala.jdk.javaapi.StreamConverters;

import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Predicate;

import static java.util.stream.Collectors.toMap;
Expand Down Expand Up @@ -121,7 +123,16 @@ public URI getUri() {
}

private static java.util.Map<String, List<String>> asJava(scala.collection.Map<String, Seq<String>> scalaMap) {
return StreamConverters.asJavaSeqStream(scalaMap).collect(toMap(f -> f._1(), f -> CollectionConverters.asJava(f._2())));
return StreamConverters.asJavaSeqStream(scalaMap).collect(toMap(f -> f._1(), f -> CollectionConverters.asJava(f._2()),
(l, r) -> {
final List<String> merged = new ArrayList<>(l.size() + r.size());
merged.addAll(l);
merged.addAll(r);
return merged;
},
() -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER)
)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import play.api.libs.ws.StandaloneWSResponse
import play.api.libs.ws.WSCookie
import play.shaded.ahc.org.asynchttpclient.HttpResponseBodyPart

import scala.collection.immutable.TreeMap
import scala.collection.mutable

/**
* A streamed response containing a response header and a streamable body.
*
Expand Down Expand Up @@ -39,17 +42,49 @@ class StreamedResponse(
val status: Int,
val statusText: String,
val uri: java.net.URI,
val headers: Map[String, scala.collection.Seq[String]],
publisher: Publisher[HttpResponseBodyPart],
val useLaxCookieEncoder: Boolean
) extends StandaloneWSResponse
with CookieBuilder {

def this(
client: StandaloneAhcWSClient,
status: Int,
statusText: String,
uri: java.net.URI,
headers: Map[String, scala.collection.Seq[String]],
publisher: Publisher[HttpResponseBodyPart],
useLaxCookieEncoder: Boolean
) = {
this(
client,
status,
statusText,
uri,
publisher,
useLaxCookieEncoder
)
origHeaders = headers
}

private var origHeaders: Map[String, scala.collection.Seq[String]] = Map.empty

/**
* Get the underlying response object.
*/
override def underlying[T]: T = publisher.asInstanceOf[T]

override lazy val headers: Map[String, scala.collection.Seq[String]] = {
val mutableMap = mutable.TreeMap[String, scala.collection.Seq[String]]()(CaseInsensitiveOrdered)
origHeaders.keys.foreach { name =>
mutableMap.updateWith(name) {
case Some(value) => Some(value ++ origHeaders.getOrElse(name, Seq.empty))
case None => Some(origHeaders.getOrElse(name, Seq.empty))
}
}
TreeMap[String, scala.collection.Seq[String]]()(CaseInsensitiveOrdered) ++ mutableMap
}

/**
* Get all the cookies.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ class AhcWSResponseSpec extends Specification with DefaultBodyReadables with Def
headers.contains("Bar") must beTrue
}

"get headers map which retrieves headers case insensitively (for streamed responses)" in {
val srcHeaders = Map("Foo" -> Seq("a"), "foo" -> Seq("b"), "FOO" -> Seq("b"), "Bar" -> Seq("baz"))
val response = new StreamedResponse(null, 200, "", null, srcHeaders, null, true)
val headers = response.headers
headers("foo") must_== Seq("a", "b", "b")
headers("BAR") must_== Seq("baz")
}

"get a single header" in {
val ahcResponse: AHCResponse = mock[AHCResponse]
val ahcHeaders = new DefaultHttpHeaders(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class AhcWSResponseSpec extends Specification with DefaultBodyReadables with Def
headers.get("BAR").asScala must_== Seq("baz")
}

"get headers map which retrieves headers case insensitively (for streamed responses)" in {
val srcHeaders = Map("Foo" -> Seq("a"), "foo" -> Seq("b"), "FOO" -> Seq("b"), "Bar" -> Seq("baz"))
val response = new StreamedResponse(null, 200, "", null, srcHeaders, null, true)
val headers = response.getHeaders
headers.get("foo").asScala must_== Seq("a", "b", "b")
headers.get("BAR").asScala must_== Seq("baz")
}

"get a single header" in {
val srcResponse = mock[Response]
val srcHeaders = new DefaultHttpHeaders()
Expand Down

0 comments on commit a98acd4

Please sign in to comment.