Skip to content

Commit

Permalink
Fix default commands in nested cross modules (#2039)
Browse files Browse the repository at this point in the history
This fixes lookup of default commands in (nested) cross modules.

Before this change, default commands in cross modules (defined via
`TaskModule`) were ignored, which is an bug IMO.

This fixes issue #2027

I took the opportunity to split up some of the gnarly `Resolve` code
into multiple files. It's still hard to understand, yet a bit more
navigable.

Pull request: #2039
  • Loading branch information
lefou authored Sep 22, 2022
1 parent c60bc5c commit 1cb86db
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 278 deletions.
28 changes: 28 additions & 0 deletions main/src/mill/main/LevenshteinDistance.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package mill.main

/**
* Compute the Levenshtein Distance.
*/
// Using as trait to keep binary compatibility within Mill 0.10
// TODO: make it an object in Mill 0.11
trait LevenshteinDistance {
def minimum(i1: Int, i2: Int, i3: Int) = math.min(math.min(i1, i2), i3)

/**
* Short Levenshtein distance algorithm, based on
*
* https://rosettacode.org/wiki/Levenshtein_distance#Scala
*/
def editDistance(s1: String, s2: String) = {
val dist = Array.tabulate(s2.length + 1, s1.length + 1) { (j, i) =>
if (j == 0) i else if (i == 0) j else 0
}

for (j <- 1 to s2.length; i <- 1 to s1.length)
dist(j)(i) =
if (s2(j - 1) == s1(i - 1)) dist(j - 1)(i - 1)
else minimum(dist(j - 1)(i) + 1, dist(j)(i - 1) + 1, dist(j - 1)(i - 1) + 1)

dist(s2.length)(s1.length)
}
}
318 changes: 48 additions & 270 deletions main/src/mill/main/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,254 +2,13 @@ package mill.main

import mill.define._
import mill.define.TaskModule
import ammonite.util.Res
import mainargs.{MainData, TokenGrouping}
import mill.main.ResolveMetadata.singleModuleMeta

import scala.collection.immutable
import scala.reflect.ClassTag

object ResolveMetadata extends Resolve[String] {
def singleModuleMeta(obj: Module, discover: Discover[_], isRootModule: Boolean): Seq[String] = {
val modules = obj.millModuleDirectChildren.map(_.toString)
val targets =
obj
.millInternal
.reflectAll[NamedTask[_]]
.map(_.toString)
val commands =
for {
(cls, entryPoints) <- discover.value
if cls.isAssignableFrom(obj.getClass)
ep <- entryPoints
} yield
if (isRootModule) ep._2.name
else s"$obj.${ep._2.name}"

modules ++ targets ++ commands
}

def endResolveLabel(
obj: Module,
last: String,
discover: Discover[_],
rest: Seq[String]
): Either[String, Seq[String]] = {
def direct = singleModuleMeta(obj, discover, obj.millModuleSegments.value.isEmpty)
last match {
case "__" =>
Right(
// Filter out our own module in
obj.millInternal.modules.flatMap(m => singleModuleMeta(m, discover, m == obj))
)
case "_" => Right(direct)
case _ =>
direct.find(_.split('.').last == last) match {
case None =>
Resolve.errorMsgLabel(direct, Seq(Segment.Label(last)), obj.millModuleSegments.value)
case Some(s) => Right(Seq(s))
}
}
}

def endResolveCross(
obj: Module,
last: List[String],
discover: Discover[_],
rest: Seq[String]
): Either[String, List[String]] = {
obj match {
case c: Cross[Module] =>
last match {
case List("__") => Right(c.items.map(_._2.toString))
case items =>
c.items
.filter(_._1.length == items.length)
.filter(_._1.zip(last).forall { case (a, b) => b == "_" || a.toString == b })
.map(_._2.toString) match {
case Nil =>
Resolve.errorMsgCross(
c.items.map(_._1.map(_.toString)),
last,
obj.millModuleSegments.value
)
case res => Right(res)
}

}
case _ =>
Left(
Resolve.unableToResolve(Segment.Cross(last), obj.millModuleSegments.value) +
Resolve.hintListLabel(obj.millModuleSegments.value)
)
}
}
}

object ResolveSegments extends Resolve[Segments] {

override def endResolveCross(
obj: Module,
last: List[String],
discover: Discover[_],
rest: Seq[String]
): Either[String, Seq[Segments]] = {
obj match {
case c: Cross[Module] =>
last match {
case List("__") => Right(c.items.map(_._2.millModuleSegments))
case items =>
c.items
.filter(_._1.length == items.length)
.filter(_._1.zip(last).forall { case (a, b) => b == "_" || a.toString == b })
.map(_._2.millModuleSegments) match {
case Nil =>
Resolve.errorMsgCross(
c.items.map(_._1.map(_.toString)),
last,
obj.millModuleSegments.value
)
case res => Right(res)
}
}
case _ =>
Left(
Resolve.unableToResolve(Segment.Cross(last), obj.millModuleSegments.value) +
Resolve.hintListLabel(obj.millModuleSegments.value)
)
}
}

def endResolveLabel(
obj: Module,
last: String,
discover: Discover[_],
rest: Seq[String]
): Either[String, Seq[Segments]] = {
val target =
obj
.millInternal
.reflectSingle[Target[_]](last)
.map(t => Right(t.ctx.segments))

val command =
Resolve
.invokeCommand(obj, last, discover.asInstanceOf[Discover[Module]], rest)
.headOption
.map(_.map(_.ctx.segments))

val module =
obj.millInternal
.reflectNestedObjects[Module]
.find(_.millOuterCtx.segment == Segment.Label(last))
.map(m => Right(m.millModuleSegments))

command orElse target orElse module match {
case None =>
Resolve.errorMsgLabel(
singleModuleMeta(obj, discover, obj.millModuleSegments.value.isEmpty),
Seq(Segment.Label(last)),
obj.millModuleSegments.value
)

case Some(either) => either.right.map(Seq(_))
}
}
}

object ResolveTasks extends Resolve[NamedTask[Any]] {

def endResolveCross(
obj: Module,
last: List[String],
discover: Discover[_],
rest: Seq[String]
): Either[String, Seq[NamedTask[Any]]] = {
obj match {
case c: Cross[Module] =>
Resolve.runDefault(obj, Segment.Cross(last), discover, rest).flatten.headOption match {
case None =>
Left(
"Cannot find default task to evaluate for module " +
Segments((obj.millModuleSegments.value :+ Segment.Cross(last)): _*).render
)
case Some(v) => v.map(Seq(_))
}
case _ =>
Left(
Resolve.unableToResolve(Segment.Cross(last), obj.millModuleSegments.value) +
Resolve.hintListLabel(obj.millModuleSegments.value)
)
}
}

def endResolveLabel(
obj: Module,
last: String,
discover: Discover[_],
rest: Seq[String]
): Either[String, Seq[NamedTask[Any]]] = last match {
case "__" =>
Right(
obj.millInternal.modules
.filter(_ != obj)
.flatMap(m => m.millInternal.reflectAll[NamedTask[_]])
)
case "_" => Right(obj.millInternal.reflectAll[NamedTask[_]])

case _ =>
val target =
obj
.millInternal
.reflectSingle[NamedTask[_]](last)
.map(Right(_))

val command = Resolve.invokeCommand(
obj,
last,
discover.asInstanceOf[Discover[Module]],
rest
).headOption

command orElse target orElse Resolve.runDefault(
obj,
Segment.Label(last),
discover,
rest
).flatten.headOption match {
case None =>
Resolve.errorMsgLabel(
singleModuleMeta(obj, discover, obj.millModuleSegments.value.isEmpty),
Seq(Segment.Label(last)),
obj.millModuleSegments.value
)

// Contents of `either` *must* be a `Task`, because we only select
// methods returning `Task` in the discovery process
case Some(either) => either.map(Seq(_))
}
}
}

object Resolve {
def minimum(i1: Int, i2: Int, i3: Int) = math.min(math.min(i1, i2), i3)

/**
* Short Levenshtein distance algorithm, based on
*
* https://rosettacode.org/wiki/Levenshtein_distance#Scala
*/
def editDistance(s1: String, s2: String) = {
val dist = Array.tabulate(s2.length + 1, s1.length + 1) { (j, i) =>
if (j == 0) i else if (i == 0) j else 0
}

for (j <- 1 to s2.length; i <- 1 to s1.length)
dist(j)(i) =
if (s2(j - 1) == s1(i - 1)) dist(j - 1)(i - 1)
else minimum(dist(j - 1)(i) + 1, dist(j)(i - 1) + 1, dist(j - 1)(i - 1) + 1)

dist(s2.length)(s1.length)
}
object Resolve extends LevenshteinDistance {

def unableToResolve(last: Segment, revSelectorsSoFar: Seq[Segment]): String = {
unableToResolve(Segments((last +: revSelectorsSoFar).reverse: _*).render)
Expand Down Expand Up @@ -334,7 +93,12 @@ object Resolve {
)
}

def invokeCommand(target: Module, name: String, discover: Discover[Module], rest: Seq[String]) =
def invokeCommand(
target: Module,
name: String,
discover: Discover[Module],
rest: Seq[String]
): immutable.Iterable[Either[String, Command[_]]] =
for {
(cls, entryPoints) <- discover.value
if cls.isAssignableFrom(target.getClass)
Expand Down Expand Up @@ -370,24 +134,31 @@ object Resolve {
}
}

def runDefault(obj: Module, last: Segment, discover: Discover[_], rest: Seq[String]) = for {
child <- obj.millInternal.reflectNestedObjects[Module]
if child.millOuterCtx.segment == last
res <- child match {
case taskMod: TaskModule =>
Some(
invokeCommand(
child,
taskMod.defaultCommandName(),
discover.asInstanceOf[Discover[Module]],
rest
).headOption
)
case _ => None
}
} yield res

def runDefault(
obj: Module,
last: Segment,
discover: Discover[_],
rest: Seq[String]
): Array[Option[Either[String, Command[_]]]] = {
for {
child <- obj.millModuleDirectChildren
if child.millOuterCtx.segment == last
res <- child match {
case taskMod: TaskModule =>
Some(
invokeCommand(
child,
taskMod.defaultCommandName(),
discover.asInstanceOf[Discover[Module]],
rest
).headOption
)
case _ => None
}
} yield res
}.toArray
}

abstract class Resolve[R: ClassTag] {
def endResolveCross(
obj: Module,
Expand Down Expand Up @@ -417,9 +188,12 @@ abstract class Resolve[R: ClassTag] {
endResolveLabel(obj, last, discover, rest)

case head :: tail =>
def recurse(searchModules: Seq[Module], resolveFailureMsg: => Left[String, Nothing]) = {
def recurse(
searchModules: Seq[Module],
resolveFailureMsg: => Left[String, Nothing]
): Either[String, Seq[R]] = {
val matching = searchModules
.map(resolve(tail, _, discover, rest, remainingCrossSelectors))
.map(m => resolve(tail, m, discover, rest, remainingCrossSelectors))

matching match {
case Seq(Left(err)) => Left(err)
Expand Down Expand Up @@ -473,20 +247,24 @@ abstract class Resolve[R: ClassTag] {
case Segment.Cross(cross) =>
obj match {
case c: Cross[Module] =>
recurse(
val searchModules =
if (cross == Seq("__")) for ((_, v) <- c.items) yield v
else if (cross.contains("_")) {
for {
(k, v) <- c.items
if k.length == cross.length
if k.zip(cross).forall { case (l, r) => l == r || r == "_" }
} yield v
} else c.itemMap.get(cross.toList).toSeq,
Resolve.errorMsgCross(
c.items.map(_._1.map(_.toString)),
cross.map(_.toString),
obj.millModuleSegments.value
)
} else c.itemMap.get(cross.toList).toSeq

recurse(
searchModules = searchModules,
resolveFailureMsg =
Resolve.errorMsgCross(
c.items.map(_._1.map(_.toString)),
cross.map(_.toString),
obj.millModuleSegments.value
)
)
case _ =>
Left(
Expand Down
Loading

0 comments on commit 1cb86db

Please sign in to comment.