Skip to content

Commit

Permalink
Add Ktfmt support (#3620)
Browse files Browse the repository at this point in the history
Fixes #3612.

This PR is based on the logic done in #3531, so it looks quite similar.

Co-authored-by: 0xnm <[email protected]>
  • Loading branch information
0xnm and 0xnm authored Sep 28, 2024
1 parent e068f58 commit 55f02d9
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 0 deletions.
25 changes: 25 additions & 0 deletions example/kotlinlib/linting/3-ktfmt/build.mill
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package build

import mill._
import mill.util.Jvm
import mill.api.Loose
import kotlinlib.KotlinModule
import kotlinlib.contrib.ktfmt.KtfmtModule

object `package` extends RootModule with KotlinModule with KtfmtModule {

def kotlinVersion = "1.9.24"

}

/** Usage

> ./mill ktfmt --format=false # run ktfmt to produce a list of files which should be formatter
...src/example/FooWrong.kt...
> ./mill ktfmt # running without arguments will format all files
Done formatting ...src/example/FooWrong.kt
> ./mill ktfmt # after fixing the violations, ktfmt no longer prints any file

> ./mill mill.kotlinlib.contrib.ktfmt.KtfmtModule/ __.sources # alternatively, use external module to check/format

*/
11 changes: 11 additions & 0 deletions example/kotlinlib/linting/3-ktfmt/src/example/FooRight.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package example

class FooRight {
fun someFun(one: String, two: String) = Unit

companion object {
const val LINE =
"veryyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy" +
"yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyylongline"
}
}
12 changes: 12 additions & 0 deletions example/kotlinlib/linting/3-ktfmt/src/example/FooWrong.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package example

class FooWrong {

fun someFun(one: String, two: String) = Unit

companion object {
const val LINE =
"veryyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy" +
"yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyylongline"
}
}
22 changes: 22 additions & 0 deletions kotlinlib/src/mill/kotlinlib/contrib/ktfmt/KtfmtArgs.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package mill.kotlinlib.contrib.ktfmt

import mainargs.{ParserForClass, main}

/**
* Arguments for [[KtfmtModule.ktfmt]].
*
* @param style formatting style to apply, can be either "kotlin", "meta" or "google". Default is "kotlin".
* @param format if auto-formatting should be done. Default is "true"
* @param removeUnusedImports flag to remove unused imports if auto-formatting is applied. Default is "true".
*/
@main(doc = "arguments for KtfmtModule.ktfmt")
case class KtfmtArgs(
style: String = "kotlin",
format: Boolean = true,
removeUnusedImports: Boolean = true
)

object KtfmtArgs {

implicit val PFC: ParserForClass[KtfmtArgs] = ParserForClass[KtfmtArgs]
}
139 changes: 139 additions & 0 deletions kotlinlib/src/mill/kotlinlib/contrib/ktfmt/KtfmtModule.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package mill
package kotlinlib.contrib.ktfmt

import mainargs.Leftover
import mill.api.{Loose, PathRef}
import mill.define.{Discover, ExternalModule}
import mill.kotlinlib.{DepSyntax, KotlinModule}
import mill.main.Tasks
import mill.util.Jvm

trait KtfmtBaseModule extends KotlinModule {

/**
* Classpath for running Ktfmt.
*/
def ktfmtClasspath: T[Loose.Agg[PathRef]] = T {
defaultResolver().resolveDeps(
Agg(ivy"com.facebook:ktfmt:${ktfmtVersion()}")
)
}

/**
* Ktfmt version.
*/
def ktfmtVersion: T[String] = T {
"0.52"
}

/**
* Additional arguments for Ktfmt. Check
* [[https://github.com/facebook/ktfmt/blob/main/core/src/main/java/com/facebook/ktfmt/cli/ParsedArgs.kt#L51 available options]].
*/
def ktfmtOptions: T[Seq[String]] = T {
Seq.empty[String]
}
}

/**
* Performs formatting checks on Kotlin source files using [[https://github.com/facebook/ktfmt Ktfmt]].
*/
trait KtfmtModule extends KtfmtBaseModule {

/**
* Runs [[https://github.com/facebook/ktfmt Ktfmt]]
*
* @param ktfmtArgs arguments for the [[https://github.com/facebook/ktfmt Ktfmt]].
* @param sources list of sources to run the tool on. If not provided, default module sources will be taken.
*/
def ktfmt(
@mainargs.arg ktfmtArgs: KtfmtArgs,
@mainargs.arg(positional = true) sources: Leftover[String]
): Command[Unit] = Task.Command {
val _sources = if (sources.value.isEmpty) {
this.sources()
} else {
sources.value.iterator.map(rel => PathRef(millSourcePath / os.RelPath(rel)))
}
KtfmtModule.ktfmtAction(
ktfmtArgs.style,
ktfmtArgs.format,
ktfmtArgs.removeUnusedImports,
_sources,
ktfmtClasspath(),
ktfmtOptions()
)
}
}

object KtfmtModule extends ExternalModule with KtfmtBaseModule with TaskModule {

def kotlinVersion = "1.9.24"

lazy val millDiscover: Discover = Discover[this.type]

override def defaultCommandName(): String = "formatAll"

/**
* Runs [[https://github.com/facebook/ktfmt Ktfmt]].
*
* @param ktfmtArgs formatting arguments
* @param sources list of [[KotlinModule]] to process
*/
def formatAll(
@mainargs.arg ktfmtArgs: KtfmtArgs,
@mainargs.arg(positional = true) sources: Tasks[Seq[PathRef]]
): Command[Unit] = Task.Command {
val _sources = T.sequence(sources.value)().iterator.flatten
ktfmtAction(
ktfmtArgs.style,
ktfmtArgs.format,
ktfmtArgs.removeUnusedImports,
_sources,
ktfmtClasspath(),
ktfmtOptions()
)
}

private def ktfmtAction(
style: String,
format: Boolean,
removeUnusedImports: Boolean,
sources: IterableOnce[PathRef],
classPath: Loose.Agg[PathRef],
options: Seq[String]
)(implicit ctx: api.Ctx): Unit = {

ctx.log.info("running ktfmt ...")

val args = Seq.newBuilder[String]
args ++= options
args += (style match {
case "kotlin" => "--kotlinlang-style"
case "google" => "--google-style"
case "meta" => "--meta-style"
case _ => throw new IllegalArgumentException(s"Unknown style ktfmt style: $style")
})
if (!format) {
args += "--dry-run"
}
if (!removeUnusedImports) {
args += "--do-not-remove-unused-imports"
}
args ++= sources.iterator.map(_.path.toString())

val exitCode = Jvm.callSubprocess(
mainClass = "com.facebook.ktfmt.cli.Main",
classPath = classPath.map(_.path),
mainArgs = args.result(),
workingDir = millSourcePath, // allow passing relative paths for sources like src/a/b
streamOut = true,
check = false
).exitCode

if (exitCode == 0) {} // do nothing
else {
throw new RuntimeException(s"ktfmt exited abnormally with exit code = $exitCode")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import java.lang.RuntimeException

class Example {
fun sample(arg: String) {
println(arg)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class Example {
fun sample(arg: String) {
println(arg)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class Example {
fun sample(arg: String) {
println(arg)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class Example {
fun sample(arg: String) {
println(arg)
}
}
5 changes: 5 additions & 0 deletions kotlinlib/test/resources/contrib/ktfmt/before/src/Example.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import java.lang.RuntimeException

class Example {
fun sample(arg: String) {println(arg)}
}
150 changes: 150 additions & 0 deletions kotlinlib/test/src/mill/kotlinlib/contrib/ktfmt/KtfmtModuleTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package mill.kotlinlib.contrib.ktfmt

import mainargs.Leftover
import mill.{T, api}
import mill.kotlinlib.KotlinModule
import mill.main.Tasks
import mill.testkit.{TestBaseModule, UnitTester}
import utest.{TestSuite, Tests, assert, test}

object KtfmtModuleTests extends TestSuite {
def tests: Tests = Tests {

val (before, after) = {
val root = os.Path(sys.env("MILL_TEST_RESOURCE_FOLDER")) / "contrib" / "ktfmt"
(root / "before", root / "after")
}

test("ktfmt - kotlin style") {
assert(
checkState(
afterFormat(before, style = "kotlin"),
after / "style" / "kotlin"
)
)
}

test("ktfmt - google style") {
assert(
checkState(
afterFormat(before, style = "google"),
after / "style" / "google"
)
)
}

test("ktfmt - meta style") {
assert(
checkState(
afterFormat(before, style = "meta"),
after / "style" / "meta"
)
)
}

test("ktfmt - dry-run") {
checkState(
afterFormat(before, format = true),
before
)
}

test("ktfmt - don't remove unused imports") {
checkState(
afterFormat(before, removeUnusedImports = false),
after / "imports"
)
}

test("ktfmt - explicit files") {
checkState(
afterFormat(before, sources = Seq("src/Example.kt")),
after / "style" / "kotlin"
)
}

test("formatAll") {

assert(
checkState(
afterFormatAll(before),
after / "style" / "kotlin"
)
)
}
}

def checkState(actualFiles: Seq[os.Path], expectedRoot: os.Path): Boolean = {
val expectedFiles = walkFiles(expectedRoot)
actualFiles.length == expectedFiles.length &&
actualFiles.iterator.zip(expectedFiles.iterator).forall {
case (actualFile, expectedFile) =>
val actual = os.read(actualFile)
val expected = os.read(expectedFile)
actual == expected
}
}

def afterFormat(
moduleRoot: os.Path,
style: String = "kotlin",
format: Boolean = true,
removeUnusedImports: Boolean = true,
sources: Seq[String] = Seq.empty
): Seq[os.Path] = {

object module extends TestBaseModule with KotlinModule with KtfmtModule {
override def kotlinVersion: T[String] = "1.9.24"
}

val eval = UnitTester(module, moduleRoot)

eval(module.ktfmt(
KtfmtArgs(
style = style,
format = format,
removeUnusedImports = removeUnusedImports
),
sources = Leftover(sources: _*)
)).fold(
{
case api.Result.Exception(cause, _) => throw cause
case failure => throw failure
},
{ _ =>
val Right(sources) = eval(module.sources)

sources.value.flatMap(ref => walkFiles(ref.path))
}
)
}

def afterFormatAll(modulesRoot: os.Path, format: Boolean = true): Seq[os.Path] = {

object module extends TestBaseModule with KotlinModule {
override def kotlinVersion: T[String] = "1.9.24"
}

val eval = UnitTester(module, modulesRoot)
eval(KtfmtModule.formatAll(
KtfmtArgs(
style = "kotlin",
format = format,
removeUnusedImports = true
),
sources = Tasks(Seq(module.sources))
)).fold(
{
case api.Result.Exception(cause, _) => throw cause
case failure => throw failure
},
{ _ =>
val Right(sources) = eval(module.sources)
sources.value.flatMap(ref => walkFiles(ref.path))
}
)
}

def walkFiles(root: os.Path): Seq[os.Path] =
os.walk(root).filter(os.isFile)
}

0 comments on commit 55f02d9

Please sign in to comment.