Skip to content

Commit

Permalink
Fix DispatchTrie validation and reporting of invalid routes (#70)
Browse files Browse the repository at this point in the history
This regressed in #52, resulting in both false positives (where a `GET` and a `POST` shared the same route, giving an unnecessary error) and false negatives (where multiple `GET`s sharing the same route failed to create an error). The basic problem was that since combining the various HTTP methods into a single routing trie, the old logic comparing uniqueness/duplication/etc. was no longer correct in the new combined trie.

This PR fixes it by doing a `groupBy` to split up the entries in the combined trie by HTTP method, before running essentially the same validation.

We augment the test suite, tightening up cask/test/src/test/cask/DispatchTrieTests.scala to make it stricter, checking exact error messages to ensure we get not just any failure but the *correct* failure when the validation code triggers. This should hopefully catch this sort of regression in future.
  • Loading branch information
lihaoyi authored May 6, 2022
1 parent d6ef66a commit e8184c9
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 27 deletions.
76 changes: 60 additions & 16 deletions cask/src/cask/internal/DispatchTrie.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package cask.internal
import collection.mutable
object DispatchTrie{
def construct[T](index: Int,
inputs: collection.Seq[(collection.IndexedSeq[String], T, Boolean)]): DispatchTrie[T] = {
def construct[T, V](index: Int,
inputs: collection.Seq[(collection.IndexedSeq[String], T, Boolean)])
(validationGroups: T => Seq[V]): DispatchTrie[T] = {
val continuations = mutable.Map.empty[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean)]]

val terminals = mutable.Buffer.empty[(collection.IndexedSeq[String], T, Boolean)]
Expand All @@ -17,29 +18,67 @@ object DispatchTrie{
}
}

for(group <- inputs.flatMap(t => validationGroups(t._2)).distinct) {
val groupTerminals = terminals.flatMap{case (path, v, allowSubpath) =>
validationGroups(v)
.filter(_ == group)
.map{group => (path, v, allowSubpath, group)}
}

val groupContinuations = continuations
.map { case (k, vs) =>
k -> vs.flatMap { case (path, v, allowSubpath) =>
validationGroups(v)
.filter(_ == group)
.map { group => (path, v, allowSubpath, group) }
}
}
.filter(_._2.nonEmpty)

validateGroup(groupTerminals, groupContinuations)
}

DispatchTrie[T](
current = terminals.headOption.map(x => x._2 -> x._3),
children = continuations
.map{ case (k, vs) => (k, construct(index + 1, vs)(validationGroups))}
.toMap
)
}

def validateGroup[T, V](terminals: collection.Seq[(collection.Seq[String], T, Boolean, V)],
continuations: mutable.Map[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean, V)]]) = {
val wildcards = continuations.filter(_._1(0) == ':')
if (terminals.length > 1){

def renderTerminals = terminals
.map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"}
.mkString(", ")

def renderContinuations = continuations.toSeq
.flatMap(_._2)
.map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"}
.mkString(", ")

if (terminals.length > 1) {
throw new Exception(
"More than one endpoint has the same path: " +
terminals.map(_._1.map(_.mkString("/"))).mkString(", ")
s"More than one endpoint has the same path: $renderTerminals"
)
} else if(wildcards.size >= 1 && continuations.size > 1) {
}

if (wildcards.size >= 1 && continuations.size > 1) {
throw new Exception(
"Routes overlap with wildcards: " +
(wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/"))
s"Routes overlap with wildcards: $renderContinuations"
)
}else if (terminals.headOption.exists(_._3) && continuations.size == 1){
}

if (terminals.headOption.exists(_._3) && continuations.size == 1) {
throw new Exception(
"Routes overlap with subpath capture: " +
(wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/"))
)
}else{
DispatchTrie[T](
current = terminals.headOption.map(x => x._2 -> x._3),
children = continuations.map{ case (k, vs) => (k, construct(index + 1, vs))}.toMap
s"Routes overlap with subpath capture: $renderTerminals, $renderContinuations"
)
}
}

def renderPath(p: collection.Seq[String]) = " /" + p.mkString("/")
}

/**
Expand Down Expand Up @@ -72,4 +111,9 @@ case class DispatchTrie[T](current: Option[(T, Boolean)],

}
}

def map[V](f: T => V): DispatchTrie[V] = DispatchTrie(
current.map{case (t, v) => (f(t), v)},
children.map { case (k, v) => (k, v.map(f))}
)
}
4 changes: 2 additions & 2 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ object Main{
}

val dispatchInputs = flattenedRoutes.groupBy(_._1).map { case (segments, values) =>
val methodMap = values.map(_._2).flatten.toMap
val methodMap = values.map(_._2).flatten
val hasSubpath = values.map(_._3).contains(true)
(segments, methodMap, hasSubpath)
}.toSeq

DispatchTrie.construct(0, dispatchInputs)
DispatchTrie.construct(0, dispatchInputs)(_.map(_._1)).map(_.toMap)
}

def writeResponse(exchange: HttpServerExchange, response: Response.Raw) = {
Expand Down
87 changes: 78 additions & 9 deletions cask/test/src/test/cask/DispatchTrieTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ object DispatchTrieTests extends TestSuite {
"hello" - {
val x = DispatchTrie.construct(0,
Seq((Vector("hello"), 1, false))
)
)(Seq(_))

assert(
x.lookup(List("hello"), Map()) == Some((1, Map(), Nil)),
Expand All @@ -22,7 +22,7 @@ object DispatchTrieTests extends TestSuite {
(Vector("hello", "world"), 1, false),
(Vector("hello", "cow"), 2, false)
)
)
)(Seq(_))
assert(
x.lookup(List("hello", "world"), Map()) == Some((1, Map(), Nil)),
x.lookup(List("hello", "cow"), Map()) == Some((2, Map(), Nil)),
Expand All @@ -34,7 +34,7 @@ object DispatchTrieTests extends TestSuite {
"bindings" - {
val x = DispatchTrie.construct(0,
Seq((Vector(":hello", ":world"), 1, false))
)
)(Seq(_))
assert(
x.lookup(List("hello", "world"), Map()) == Some((1, Map("hello" -> "hello", "world" -> "world"), Nil)),
x.lookup(List("world", "hello"), Map()) == Some((1, Map("hello" -> "world", "world" -> "hello"), Nil)),
Expand All @@ -47,7 +47,7 @@ object DispatchTrieTests extends TestSuite {
"path" - {
val x = DispatchTrie.construct(0,
Seq((Vector("hello"), 1, true))
)
)(Seq(_))

assert(
x.lookup(List("hello", "world"), Map()) == Some((1,Map(), Seq("world"))),
Expand All @@ -58,44 +58,113 @@ object DispatchTrieTests extends TestSuite {
}

"errors" - {
intercept[Exception]{
test - {
DispatchTrie.construct(0,
Seq(
(Vector("hello", ":world"), 1, false),
(Vector("hello", "world"), 2, false)
)
)(Seq(_))

val ex = intercept[Exception]{
DispatchTrie.construct(0,
Seq(
(Vector("hello", ":world"), 1, false),
(Vector("hello", "world"), 1, false)
)
)(Seq(_))
}

assert(
ex.getMessage ==
"Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world"
)
}
intercept[Exception]{
test - {
DispatchTrie.construct(0,
Seq(
(Vector("hello", ":world"), 1, false),
(Vector("hello", "world", "omg"), 2, false)
)
)(Seq(_))

val ex = intercept[Exception]{
DispatchTrie.construct(0,
Seq(
(Vector("hello", ":world"), 1, false),
(Vector("hello", "world", "omg"), 1, false)
)
)(Seq(_))
}

assert(
ex.getMessage ==
"Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world/omg"
)
}
intercept[Exception]{
test - {
DispatchTrie.construct(0,
Seq(
(Vector("hello"), 1, true),
(Vector("hello", "cow", "omg"), 2, false)
)
)(Seq(_))

val ex = intercept[Exception]{
DispatchTrie.construct(0,
Seq(
(Vector("hello"), 1, true),
(Vector("hello", "cow", "omg"), 1, false)
)
)(Seq(_))
}

assert(
ex.getMessage ==
"Routes overlap with subpath capture: 1 /hello, 1 /hello/cow/omg"
)
}
intercept[Exception]{
test - {
DispatchTrie.construct(0,
Seq(
(Vector("hello", ":world"), 1, false),
(Vector("hello", ":cow"), 2, false)
)
)(Seq(_))

val ex = intercept[Exception]{
DispatchTrie.construct(0,
Seq(
(Vector("hello", ":world"), 1, false),
(Vector("hello", ":cow"), 1, false)
)
)(Seq(_))
}

assert(
ex.getMessage ==
"Routes overlap with wildcards: 1 /hello/:world, 1 /hello/:cow"
)
}
intercept[Exception]{
test - {
DispatchTrie.construct(0,
Seq(
(Vector("hello", "world"), 1, false),
(Vector("hello", "world"), 2, false)
)
)(Seq(_))

val ex = intercept[Exception]{
DispatchTrie.construct(0,
Seq(
(Vector("hello", "world"), 1, false),
(Vector("hello", "world"), 1, false)
)
)(Seq(_))
}
assert(
ex.getMessage ==
"More than one endpoint has the same path: 1 /hello/world, 1 /hello/world"
)
}
}
Expand Down
5 changes: 5 additions & 0 deletions example/variableRoutes/app/src/VariableRoutes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,10 @@ object VariableRoutes extends cask.MainRoutes{
s"Subpath ${request.remainingPathSegments}"
}

@cask.post("/path", subpath = true)
def postShowSubpath(request: cask.Request) = {
s"POST Subpath ${request.remainingPathSegments}"
}

initialize()
}
3 changes: 3 additions & 0 deletions example/variableRoutes/app/test/src/ExampleTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ object ExampleTests extends TestSuite{

requests.get(s"$host/path/one/two/three").text() ==>
"Subpath List(one, two, three)"

requests.post(s"$host/path/one/two/three").text() ==>
"POST Subpath List(one, two, three)"
}

}
Expand Down

0 comments on commit e8184c9

Please sign in to comment.