Skip to content

Commit

Permalink
Generalise path templates when generating ZIO Http RoutePattern-s (#4051
Browse files Browse the repository at this point in the history
)
  • Loading branch information
adamw authored Sep 18, 2024
1 parent c783120 commit 1cc1b38
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,25 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
basicRequest.get(uri"$baseUri/p1/x/p3").send(backend).map(_.body shouldBe Right("2: x")) >>
basicRequest.get(uri"$baseUri/p1/y/p3").send(backend).map(_.body shouldBe Right("2: y")) >>
basicRequest.get(uri"$baseUri/p1/p2/p4").send(backend).map(_.code shouldBe StatusCode.NotFound)
},
// #4050
testServer(
"two endpoints with fixed path & path capture as the middle component, different methods",
NonEmptyList.of(
route(
List[ServerEndpoint[Any, F]](
endpoint.get.in("p1" / "p2").out(stringBody).serverLogic(_ => pureResult("1".asRight[Unit])),
endpoint.delete.in("p1" / path[String]("p")).out(stringBody).serverLogic((v: String) => pureResult(s"2: $v".asRight[Unit]))
)
)
)
) { (backend, baseUri) =>
basicRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe Right("1")) >>
basicRequest.get(uri"$baseUri/p1/x").send(backend).map(_.code shouldBe StatusCode.MethodNotAllowed) >>
basicRequest.delete(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe Right("2: p2")) >>
basicRequest.delete(uri"$baseUri/p1/p3").send(backend).map(_.body shouldBe Right("2: p3")) >>
basicRequest.get(uri"$baseUri/p1/p2/p3").send(backend).map(_.code shouldBe StatusCode.NotFound) >>
basicRequest.delete(uri"$baseUri/p1/p2/p3").send(backend).map(_.code shouldBe StatusCode.NotFound)
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import sttp.tapir.ztapir._
import zio._
import zio.http.codec.PathCodec
import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _}
import scala.util.chaining._

trait ZioHttpInterpreter[R] {
def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default
Expand Down Expand Up @@ -61,8 +62,8 @@ trait ZioHttpInterpreter[R] {
// here we'll keep the endpoint together with the meta-data needed to create the zio-http routing information
case class ServerEndpointWithPattern(
index: Int,
pathTemplate: String,
routePattern: RoutePattern[_],
pathTemplate: Vector[String],
routePattern: RoutePattern[Any], // the Any here is a way to work around the type checker
endpoint: ZServerEndpoint[R & R2, ZioStreams with WebSockets]
)

Expand All @@ -72,13 +73,13 @@ trait ZioHttpInterpreter[R] {

// Creating the path template - no-trailing-slash inputs are treated as wildcard inputs, as they are usually
// accompanied by endpoints which handle wildcard path inputs, when the `/` is present (to serve files). They
// need to end up in the same group (see below), so that they are disambiguated by tapir's logic.
val pathTemplate = inputs.foldLeft("") { case (p, component) =>
// need to end up in the same group (see below), so that they are disambiguated by Tapir's logic.
val pathTemplate = inputs.foldLeft(Vector.empty[String]) { case (p, component) =>
component match {
case _: EndpointInput.PathCapture[_] => p + "/?"
case _: EndpointInput.PathsCapture[_] => p + "/..."
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => p + "/..."
case i: EndpointInput.FixedPath[_] => p + "/" + i.s
case _: EndpointInput.PathCapture[_] => p :+ "?"
case _: EndpointInput.PathsCapture[_] => p :+ "..."
case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => p :+ "..."
case i: EndpointInput.FixedPath[_] => p :+ s"{${i.s}}"
case _ => p
}
}
Expand All @@ -94,7 +95,7 @@ trait ZioHttpInterpreter[R] {
case _ => false
}

val routePattern = if (hasPath) {
val routePattern: RoutePattern[Any] = if (hasPath) {
val initialPattern = RoutePattern(Method.ANY, PathCodec.empty).asInstanceOf[RoutePattern[Any]]
// The second tuple parameter specifies if PathCodec.trailing should be added to the route's pattern. It can
// be added either because of a PathsCapture, or because of an noTrailingSlash input.
Expand All @@ -109,7 +110,7 @@ trait ZioHttpInterpreter[R] {
}
}

if (addTrailing) p / PathCodec.trailing else p
if (addTrailing) (p / PathCodec.trailing).asInstanceOf[RoutePattern[Any]] else p
} else {
// if there are no path inputs, we return a catch-all
RoutePattern(Method.ANY, PathCodec.trailing).asInstanceOf[RoutePattern[Any]]
Expand All @@ -118,21 +119,59 @@ trait ZioHttpInterpreter[R] {
ServerEndpointWithPattern(index, pathTemplate, routePattern, se)
}

// Grouping the endpoints by path template. This way, if there are multiple endpoints with/without trailing slash or
// with path wildcards, they will end up in the same group, and they will be disambiguated by the tapir logic.
// That's because there's not way currently to create a zio-http route pattern which would match on
// no-trailing-slashes. A group also includes multiple endpoints with different methods, but same path.
val widenedSesGroupedByPathPrefixTemplate = widenedSes.zipWithIndex
.map { case (se, index) => toPattern(se, index) }
.groupBy(_.pathTemplate)
.toList
.map(_._2)
// we try to maintain the order of endpoints as passed by the user; this order might be changed if there are
// endpoints with/without trailing slashes, or with different methods, which are not passed as subsequent
// values in the original `ses` list
.sortBy(_.map(_.index).min)

val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathPrefixTemplate.map { sesWithPattern =>
/** `t1` and `t2` are both path templates as created by `toPattern` above. Each path template is a vector of: ? | ... | {string}. This
* method checks if `t1` is at least as general as `t2`, that is if each request that matches `t2` also matches `t1`
*/
def isAtLeastAsGeneralAs(t1: Vector[String], t2: Vector[String]): Boolean = (t1, t2) match {
case ("..." +: _, _) => true
case (_, "..." +: _) => false
case ("?" +: tail1, "?" +: tail2) => isAtLeastAsGeneralAs(tail1, tail2)
case ("?" +: tail1, _ +: tail2) => isAtLeastAsGeneralAs(tail1, tail2)
case (_ +: _, "?" +: _) => false
case (p1 +: tail1, p2 +: tail2) => (p1 == p2) && isAtLeastAsGeneralAs(tail1, tail2)
case (Vector(), Vector()) => true
case _ => false
}

/** For each server endpoint, find the most general template among all the templates in the list, and use it for the endpoint, along
* with the `RoutePattern` corresponding to that template.
*/
def generaliseTemplates(endpoints: List[ServerEndpointWithPattern]): List[ServerEndpointWithPattern] = {
// de-duplicating the path templates
val allTemplates: List[(Vector[String], RoutePattern[Any])] = endpoints.map(se => (se.pathTemplate, se.routePattern)).toMap.toList
endpoints.map { se =>
val mostGeneral: (Vector[String], RoutePattern[Any]) =
allTemplates.foldLeft((se.pathTemplate, se.routePattern)) {
case ((mostGeneralTemplate, mostGeneralPattern), (template, pattern)) =>
if (template != mostGeneralTemplate && isAtLeastAsGeneralAs(template, mostGeneralTemplate)) {
(template, pattern)
} else {
(mostGeneralTemplate, mostGeneralPattern)
}
}
se.copy(pathTemplate = mostGeneral._1, routePattern = mostGeneral._2)
}
}

// Generating a path tempalte for each endpoint, and then finding the most general template among all of the
// endpoints. Once this is done, grouping the endpoints by path template. This way, if there are multiple endpoints
// with/without trailing slash or with path wildcards, they will end up in the same group, and they will be
// disambiguated by the Tapir logic. That's because there's no way currently to create a zio-http route pattern
// which would match on no-trailing-slashes. A group also includes multiple endpoints with different methods, but
// same path.
val widenedSesGroupedByPathTemplate =
widenedSes.zipWithIndex
.map { case (se, index) => toPattern(se, index) }
.pipe(generaliseTemplates)
.groupBy(_.pathTemplate)
.toList
.map(_._2)
// we try to maintain the order of endpoints as passed by the user; this order might be changed if there are
// endpoints with/without trailing slashes, or with different methods, which are not passed as subsequent
// values in the original `ses` list
.sortBy(_.map(_.index).min)

val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathTemplate.map { sesWithPattern =>
val pattern = sesWithPattern.head.routePattern
val endpoints = sesWithPattern.sortBy(_.index).map(_.endpoint)
// The pattern that we generate should be the same for all endpoints in a group
Expand Down

0 comments on commit 1cc1b38

Please sign in to comment.