From 4ed4d86557378c19136830d16ee889df680ff406 Mon Sep 17 00:00:00 2001 From: adamw Date: Wed, 18 Sep 2024 16:44:05 +0200 Subject: [PATCH] Generalise path templates when generating ZIO Http RoutePattern-s --- .../tapir/server/tests/ServerBasicTests.scala | 19 ++++ .../server/ziohttp/ZioHttpInterpreter.scala | 89 +++++++++++++------ 2 files changed, 83 insertions(+), 25 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 0c36b93340..544c8d4773 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -706,6 +706,25 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( basicRequest.get(uri"$baseUri/p1/x/p3").send(backend).map(_.body shouldBe Right("2: x")) >> basicRequest.get(uri"$baseUri/p1/y/p3").send(backend).map(_.body shouldBe Right("2: y")) >> basicRequest.get(uri"$baseUri/p1/p2/p4").send(backend).map(_.code shouldBe StatusCode.NotFound) + }, + // #4050 + testServer( + "two endpoints with fixed path & path capture as the middle component, different methods", + NonEmptyList.of( + route( + List[ServerEndpoint[Any, F]]( + endpoint.get.in("p1" / "p2").out(stringBody).serverLogic(_ => pureResult("1".asRight[Unit])), + endpoint.delete.in("p1" / path[String]("p")).out(stringBody).serverLogic((v: String) => pureResult(s"2: $v".asRight[Unit])) + ) + ) + ) + ) { (backend, baseUri) => + basicRequest.get(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe Right("1")) >> + basicRequest.get(uri"$baseUri/p1/x").send(backend).map(_.code shouldBe StatusCode.MethodNotAllowed) >> + basicRequest.delete(uri"$baseUri/p1/p2").send(backend).map(_.body shouldBe Right("2: p2")) >> + basicRequest.delete(uri"$baseUri/p1/p3").send(backend).map(_.body shouldBe Right("2: p3")) >> + basicRequest.get(uri"$baseUri/p1/p2/p3").send(backend).map(_.code shouldBe StatusCode.NotFound) >> + basicRequest.delete(uri"$baseUri/p1/p2/p3").send(backend).map(_.code shouldBe StatusCode.NotFound) } ) diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index 9bc8e45f87..e03ecf7c28 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -14,6 +14,7 @@ import sttp.tapir.ztapir._ import zio._ import zio.http.codec.PathCodec import zio.http.{Header => ZioHttpHeader, Headers => ZioHttpHeaders, _} +import scala.util.chaining._ trait ZioHttpInterpreter[R] { def zioHttpServerOptions: ZioHttpServerOptions[R] = ZioHttpServerOptions.default @@ -61,8 +62,8 @@ trait ZioHttpInterpreter[R] { // here we'll keep the endpoint together with the meta-data needed to create the zio-http routing information case class ServerEndpointWithPattern( index: Int, - pathTemplate: String, - routePattern: RoutePattern[_], + pathTemplate: Vector[String], + routePattern: RoutePattern[Any], // the Any here is a way to work around the type checker endpoint: ZServerEndpoint[R & R2, ZioStreams with WebSockets] ) @@ -72,13 +73,13 @@ trait ZioHttpInterpreter[R] { // Creating the path template - no-trailing-slash inputs are treated as wildcard inputs, as they are usually // accompanied by endpoints which handle wildcard path inputs, when the `/` is present (to serve files). They - // need to end up in the same group (see below), so that they are disambiguated by tapir's logic. - val pathTemplate = inputs.foldLeft("") { case (p, component) => + // need to end up in the same group (see below), so that they are disambiguated by Tapir's logic. + val pathTemplate = inputs.foldLeft(Vector.empty[String]) { case (p, component) => component match { - case _: EndpointInput.PathCapture[_] => p + "/?" - case _: EndpointInput.PathsCapture[_] => p + "/..." - case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => p + "/..." - case i: EndpointInput.FixedPath[_] => p + "/" + i.s + case _: EndpointInput.PathCapture[_] => p :+ "?" + case _: EndpointInput.PathsCapture[_] => p :+ "..." + case i: EndpointInput.ExtractFromRequest[_] if i.attribute(NoTrailingSlash.Attribute).getOrElse(false) => p :+ "..." + case i: EndpointInput.FixedPath[_] => p :+ s"{${i.s}}" case _ => p } } @@ -94,7 +95,7 @@ trait ZioHttpInterpreter[R] { case _ => false } - val routePattern = if (hasPath) { + val routePattern: RoutePattern[Any] = if (hasPath) { val initialPattern = RoutePattern(Method.ANY, PathCodec.empty).asInstanceOf[RoutePattern[Any]] // The second tuple parameter specifies if PathCodec.trailing should be added to the route's pattern. It can // be added either because of a PathsCapture, or because of an noTrailingSlash input. @@ -109,7 +110,7 @@ trait ZioHttpInterpreter[R] { } } - if (addTrailing) p / PathCodec.trailing else p + if (addTrailing) (p / PathCodec.trailing).asInstanceOf[RoutePattern[Any]] else p } else { // if there are no path inputs, we return a catch-all RoutePattern(Method.ANY, PathCodec.trailing).asInstanceOf[RoutePattern[Any]] @@ -118,21 +119,59 @@ trait ZioHttpInterpreter[R] { ServerEndpointWithPattern(index, pathTemplate, routePattern, se) } - // Grouping the endpoints by path template. This way, if there are multiple endpoints with/without trailing slash or - // with path wildcards, they will end up in the same group, and they will be disambiguated by the tapir logic. - // That's because there's not way currently to create a zio-http route pattern which would match on - // no-trailing-slashes. A group also includes multiple endpoints with different methods, but same path. - val widenedSesGroupedByPathPrefixTemplate = widenedSes.zipWithIndex - .map { case (se, index) => toPattern(se, index) } - .groupBy(_.pathTemplate) - .toList - .map(_._2) - // we try to maintain the order of endpoints as passed by the user; this order might be changed if there are - // endpoints with/without trailing slashes, or with different methods, which are not passed as subsequent - // values in the original `ses` list - .sortBy(_.map(_.index).min) - - val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathPrefixTemplate.map { sesWithPattern => + /** `t1` and `t2` are both path templates as created by `toPattern` above. Each path template is a vector of: ? | ... | {string}. This + * method checks if `t1` is at least as general as `t2`, that is if each request that matches `t2` also matches `t1` + */ + def isAtLeastAsGeneralAs(t1: Vector[String], t2: Vector[String]): Boolean = (t1, t2) match { + case ("..." +: _, _) => true + case (_, "..." +: _) => false + case ("?" +: tail1, "?" +: tail2) => isAtLeastAsGeneralAs(tail1, tail2) + case ("?" +: tail1, _ +: tail2) => isAtLeastAsGeneralAs(tail1, tail2) + case (_ +: _, "?" +: _) => false + case (p1 +: tail1, p2 +: tail2) => (p1 == p2) && isAtLeastAsGeneralAs(tail1, tail2) + case (Vector(), Vector()) => true + case _ => false + } + + /** For each server endpoint, find the most general template among all the templates in the list, and use it for the endpoint, along + * with the `RoutePattern` corresponding to that template. + */ + def generaliseTemplates(endpoints: List[ServerEndpointWithPattern]): List[ServerEndpointWithPattern] = { + // de-duplicating the path templates + val allTemplates: List[(Vector[String], RoutePattern[Any])] = endpoints.map(se => (se.pathTemplate, se.routePattern)).toMap.toList + endpoints.map { se => + val mostGeneral: (Vector[String], RoutePattern[Any]) = + allTemplates.foldLeft((se.pathTemplate, se.routePattern)) { + case ((mostGeneralTemplate, mostGeneralPattern), (template, pattern)) => + if (template != mostGeneralTemplate && isAtLeastAsGeneralAs(template, mostGeneralTemplate)) { + (template, pattern) + } else { + (mostGeneralTemplate, mostGeneralPattern) + } + } + se.copy(pathTemplate = mostGeneral._1, routePattern = mostGeneral._2) + } + } + + // Generating a path tempalte for each endpoint, and then finding the most general template among all of the + // endpoints. Once this is done, grouping the endpoints by path template. This way, if there are multiple endpoints + // with/without trailing slash or with path wildcards, they will end up in the same group, and they will be + // disambiguated by the Tapir logic. That's because there's no way currently to create a zio-http route pattern + // which would match on no-trailing-slashes. A group also includes multiple endpoints with different methods, but + // same path. + val widenedSesGroupedByPathTemplate = + widenedSes.zipWithIndex + .map { case (se, index) => toPattern(se, index) } + .pipe(generaliseTemplates) + .groupBy(_.pathTemplate) + .toList + .map(_._2) + // we try to maintain the order of endpoints as passed by the user; this order might be changed if there are + // endpoints with/without trailing slashes, or with different methods, which are not passed as subsequent + // values in the original `ses` list + .sortBy(_.map(_.index).min) + + val handlers: List[Route[R & R2, Response]] = widenedSesGroupedByPathTemplate.map { sesWithPattern => val pattern = sesWithPattern.head.routePattern val endpoints = sesWithPattern.sortBy(_.index).map(_.endpoint) // The pattern that we generate should be the same for all endpoints in a group