diff --git a/cask/src/cask/internal/DispatchTrie.scala b/cask/src/cask/internal/DispatchTrie.scala index 9f1a9c3394..cc9f7e7252 100644 --- a/cask/src/cask/internal/DispatchTrie.scala +++ b/cask/src/cask/internal/DispatchTrie.scala @@ -1,8 +1,9 @@ package cask.internal import collection.mutable object DispatchTrie{ - def construct[T](index: Int, - inputs: collection.Seq[(collection.IndexedSeq[String], T, Boolean)]): DispatchTrie[T] = { + def construct[T, V](index: Int, + inputs: collection.Seq[(collection.IndexedSeq[String], T, Boolean)]) + (validationGroups: T => Seq[V]): DispatchTrie[T] = { val continuations = mutable.Map.empty[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean)]] val terminals = mutable.Buffer.empty[(collection.IndexedSeq[String], T, Boolean)] @@ -17,29 +18,67 @@ object DispatchTrie{ } } + for(group <- inputs.flatMap(t => validationGroups(t._2)).distinct) { + val groupTerminals = terminals.flatMap{case (path, v, allowSubpath) => + validationGroups(v) + .filter(_ == group) + .map{group => (path, v, allowSubpath, group)} + } + + val groupContinuations = continuations + .map { case (k, vs) => + k -> vs.flatMap { case (path, v, allowSubpath) => + validationGroups(v) + .filter(_ == group) + .map { group => (path, v, allowSubpath, group) } + } + } + .filter(_._2.nonEmpty) + + validateGroup(groupTerminals, groupContinuations) + } + + DispatchTrie[T]( + current = terminals.headOption.map(x => x._2 -> x._3), + children = continuations + .map{ case (k, vs) => (k, construct(index + 1, vs)(validationGroups))} + .toMap + ) + } + + def validateGroup[T, V](terminals: collection.Seq[(collection.Seq[String], T, Boolean, V)], + continuations: mutable.Map[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean, V)]]) = { val wildcards = continuations.filter(_._1(0) == ':') - if (terminals.length > 1){ + + def renderTerminals = terminals + .map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"} + .mkString(", ") + + def renderContinuations = continuations.toSeq + .flatMap(_._2) + .map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"} + .mkString(", ") + + if (terminals.length > 1) { throw new Exception( - "More than one endpoint has the same path: " + - terminals.map(_._1.map(_.mkString("/"))).mkString(", ") + s"More than one endpoint has the same path: $renderTerminals" ) - } else if(wildcards.size >= 1 && continuations.size > 1) { + } + + if (wildcards.size >= 1 && continuations.size > 1) { throw new Exception( - "Routes overlap with wildcards: " + - (wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/")) + s"Routes overlap with wildcards: $renderContinuations" ) - }else if (terminals.headOption.exists(_._3) && continuations.size == 1){ + } + + if (terminals.headOption.exists(_._3) && continuations.size == 1) { throw new Exception( - "Routes overlap with subpath capture: " + - (wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/")) - ) - }else{ - DispatchTrie[T]( - current = terminals.headOption.map(x => x._2 -> x._3), - children = continuations.map{ case (k, vs) => (k, construct(index + 1, vs))}.toMap + s"Routes overlap with subpath capture: $renderTerminals, $renderContinuations" ) } } + + def renderPath(p: collection.Seq[String]) = " /" + p.mkString("/") } /** @@ -72,4 +111,9 @@ case class DispatchTrie[T](current: Option[(T, Boolean)], } } + + def map[V](f: T => V): DispatchTrie[V] = DispatchTrie( + current.map{case (t, v) => (f(t), v)}, + children.map { case (k, v) => (k, v.map(f))} + ) } diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index 6019196262..597a73993c 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -154,12 +154,12 @@ object Main{ } val dispatchInputs = flattenedRoutes.groupBy(_._1).map { case (segments, values) => - val methodMap = values.map(_._2).flatten.toMap + val methodMap = values.map(_._2).flatten val hasSubpath = values.map(_._3).contains(true) (segments, methodMap, hasSubpath) }.toSeq - DispatchTrie.construct(0, dispatchInputs) + DispatchTrie.construct(0, dispatchInputs)(_.map(_._1)).map(_.toMap) } def writeResponse(exchange: HttpServerExchange, response: Response.Raw) = { diff --git a/cask/test/src/test/cask/DispatchTrieTests.scala b/cask/test/src/test/cask/DispatchTrieTests.scala index 77e0369607..1da0df3286 100644 --- a/cask/test/src/test/cask/DispatchTrieTests.scala +++ b/cask/test/src/test/cask/DispatchTrieTests.scala @@ -8,7 +8,7 @@ object DispatchTrieTests extends TestSuite { "hello" - { val x = DispatchTrie.construct(0, Seq((Vector("hello"), 1, false)) - ) + )(Seq(_)) assert( x.lookup(List("hello"), Map()) == Some((1, Map(), Nil)), @@ -22,7 +22,7 @@ object DispatchTrieTests extends TestSuite { (Vector("hello", "world"), 1, false), (Vector("hello", "cow"), 2, false) ) - ) + )(Seq(_)) assert( x.lookup(List("hello", "world"), Map()) == Some((1, Map(), Nil)), x.lookup(List("hello", "cow"), Map()) == Some((2, Map(), Nil)), @@ -34,7 +34,7 @@ object DispatchTrieTests extends TestSuite { "bindings" - { val x = DispatchTrie.construct(0, Seq((Vector(":hello", ":world"), 1, false)) - ) + )(Seq(_)) assert( x.lookup(List("hello", "world"), Map()) == Some((1, Map("hello" -> "hello", "world" -> "world"), Nil)), x.lookup(List("world", "hello"), Map()) == Some((1, Map("hello" -> "world", "world" -> "hello"), Nil)), @@ -47,7 +47,7 @@ object DispatchTrieTests extends TestSuite { "path" - { val x = DispatchTrie.construct(0, Seq((Vector("hello"), 1, true)) - ) + )(Seq(_)) assert( x.lookup(List("hello", "world"), Map()) == Some((1,Map(), Seq("world"))), @@ -58,44 +58,113 @@ object DispatchTrieTests extends TestSuite { } "errors" - { - intercept[Exception]{ + test - { DispatchTrie.construct(0, Seq( (Vector("hello", ":world"), 1, false), (Vector("hello", "world"), 2, false) ) + )(Seq(_)) + + val ex = intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + (Vector("hello", ":world"), 1, false), + (Vector("hello", "world"), 1, false) + ) + )(Seq(_)) + } + + assert( + ex.getMessage == + "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world" ) } - intercept[Exception]{ + test - { DispatchTrie.construct(0, Seq( (Vector("hello", ":world"), 1, false), (Vector("hello", "world", "omg"), 2, false) ) + )(Seq(_)) + + val ex = intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + (Vector("hello", ":world"), 1, false), + (Vector("hello", "world", "omg"), 1, false) + ) + )(Seq(_)) + } + + assert( + ex.getMessage == + "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world/omg" ) } - intercept[Exception]{ + test - { DispatchTrie.construct(0, Seq( (Vector("hello"), 1, true), (Vector("hello", "cow", "omg"), 2, false) ) + )(Seq(_)) + + val ex = intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + (Vector("hello"), 1, true), + (Vector("hello", "cow", "omg"), 1, false) + ) + )(Seq(_)) + } + + assert( + ex.getMessage == + "Routes overlap with subpath capture: 1 /hello, 1 /hello/cow/omg" ) } - intercept[Exception]{ + test - { DispatchTrie.construct(0, Seq( (Vector("hello", ":world"), 1, false), (Vector("hello", ":cow"), 2, false) ) + )(Seq(_)) + + val ex = intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + (Vector("hello", ":world"), 1, false), + (Vector("hello", ":cow"), 1, false) + ) + )(Seq(_)) + } + + assert( + ex.getMessage == + "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/:cow" ) } - intercept[Exception]{ + test - { DispatchTrie.construct(0, Seq( (Vector("hello", "world"), 1, false), (Vector("hello", "world"), 2, false) ) + )(Seq(_)) + + val ex = intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + (Vector("hello", "world"), 1, false), + (Vector("hello", "world"), 1, false) + ) + )(Seq(_)) + } + assert( + ex.getMessage == + "More than one endpoint has the same path: 1 /hello/world, 1 /hello/world" ) } } diff --git a/example/variableRoutes/app/src/VariableRoutes.scala b/example/variableRoutes/app/src/VariableRoutes.scala index a1c8a4fef5..c9e55bf08d 100644 --- a/example/variableRoutes/app/src/VariableRoutes.scala +++ b/example/variableRoutes/app/src/VariableRoutes.scala @@ -15,5 +15,10 @@ object VariableRoutes extends cask.MainRoutes{ s"Subpath ${request.remainingPathSegments}" } + @cask.post("/path", subpath = true) + def postShowSubpath(request: cask.Request) = { + s"POST Subpath ${request.remainingPathSegments}" + } + initialize() } diff --git a/example/variableRoutes/app/test/src/ExampleTests.scala b/example/variableRoutes/app/test/src/ExampleTests.scala index 78bf991668..c818ee8126 100644 --- a/example/variableRoutes/app/test/src/ExampleTests.scala +++ b/example/variableRoutes/app/test/src/ExampleTests.scala @@ -46,6 +46,9 @@ object ExampleTests extends TestSuite{ requests.get(s"$host/path/one/two/three").text() ==> "Subpath List(one, two, three)" + + requests.post(s"$host/path/one/two/three").text() ==> + "POST Subpath List(one, two, three)" } }