Skip to content

Commit

Permalink
Scalafmt: fix usages to pass the filename properly
Browse files Browse the repository at this point in the history
Create a new, package-private method .formatCode with the implementation
based on the existing .format.

Make sure not to modify the existing interfaces of .format overloads, as
they are externally expected (via reflection).

Fixes scalameta#1628.
  • Loading branch information
kitbellew committed Jan 28, 2020
1 parent acd6f44 commit e67c95a
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ abstract class MacroBenchmark(parallel: Boolean, maxFiles: Int)
@Benchmark
def scalafmt(): Unit = {
files.foreach { file =>
Try(Scalafmt.format(file))
Try(Scalafmt.formatCode(file))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.scalafmt.rewrite.{RedundantBraces, SortImports}
trait FormatBenchmark {
def formatRewrite(code: String): String = {
Scalafmt
.format(
.formatCode(
code,
style = ScalafmtConfig.default.copy(
rewrite = RewriteSettings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ abstract class MicroBenchmark(path: String*) extends FormatBenchmark {

@Benchmark
def scalafmt(): String = {
Scalafmt.format(code).get
Scalafmt.formatCode(code).get
}

@Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Scalafmt210 {
if (filename.endsWith(".sbt")) SRunner.sbt
else SRunner.default
val style = scalafmtStyle.copy(runner = runner)
Scalafmt.format(code, style) match {
Scalafmt.formatCode(code, style, filename = filename) match {
case Formatted.Success(formattedCode) => formattedCode
case error =>
error match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object ScalafmtCoreRunner extends ScalafmtRunner {
val scalafmtConfig =
if (inputMethod.isSbt || inputMethod.isSc) config.forSbt
else config
val formatResult = Scalafmt.format(
val formatResult = Scalafmt.formatCode(
input,
scalafmtConfig,
options.range,
Expand Down
13 changes: 12 additions & 1 deletion scalafmt-core/shared/src/main/scala/org/scalafmt/Scalafmt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ object Scalafmt {
private val WindowsLineEnding = "\r\n"
private val UnixLineEnding = "\n"

// XXX: don't modify signature, scalafmt-dynamic expects it via reflection
/**
* Format Scala code using scalafmt.
*
Expand All @@ -43,6 +44,15 @@ object Scalafmt {
style: ScalafmtConfig,
range: Set[Range],
filename: String
): Formatted = {
formatCode(code, style, range, filename)
}

private[scalafmt] def formatCode(
code: String,
style: ScalafmtConfig = ScalafmtConfig.default,
range: Set[Range] = Set.empty,
filename: String = "<input>"
): Formatted = {
try {
val runner = style.runner
Expand Down Expand Up @@ -82,12 +92,13 @@ object Scalafmt {
}
}

// XXX: don't modify signature, scalafmt-dynamic expects it via reflection
def format(
code: String,
style: ScalafmtConfig = ScalafmtConfig.default,
range: Set[Range] = Set.empty[Range]
): Formatted = {
format(code, style, range, "<input>")
formatCode(code, style, range)
}

def parseHoconConfig(configString: String): Configured[ScalafmtConfig] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ScalafmtTest extends scalatest.funsuite.AnyFunSuite {
config: ScalafmtConfig = ScalafmtConfig.default
): Unit = {
test(logger.revealWhitespace(original).take(30)) {
val obtained = Scalafmt.format(original, config).get
val obtained = Scalafmt.formatCode(original, config).get
if (obtained != expected) logger.elem(obtained)
assert(obtained == expected)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class FidelityTest extends AnyFunSuite with FormatAssertions {
examples.foreach { example =>
test(example.filename) {
val formatted =
Scalafmt.format(example.code, ScalafmtConfig.default).get
Scalafmt.formatCode(example.code, filename = example.filename).get
assertFormatPreservesAst(example.code, formatted)(
scala.meta.parsers.Parse.parseSource,
Scala211
Expand Down
38 changes: 23 additions & 15 deletions scalafmt-tests/src/test/scala/org/scalafmt/FormatTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,22 @@ class FormatTests

def run(t: DiffTest, parse: Parse[_ <: Tree]): Unit = {
val runner = scalafmtRunner(t.style.runner).copy(parser = parse)
val obtained =
Scalafmt.format(t.original, t.style.copy(runner = runner)) match {
case Formatted.Failure(e)
if t.style.onTestFailure.nonEmpty && e.getMessage.contains(
e.getMessage
) =>
t.expected
case Formatted.Failure(e: Incomplete) => e.formattedCode
case Formatted.Failure(e: SearchStateExploded) =>
logger.elem(e)
e.partialOutput
case x => x.get
}
val obtained = Scalafmt.formatCode(
t.original,
t.style.copy(runner = runner),
filename = t.filename
) match {
case Formatted.Failure(e)
if t.style.onTestFailure.nonEmpty && e.getMessage.contains(
e.getMessage
) =>
t.expected
case Formatted.Failure(e: Incomplete) => e.formattedCode
case Formatted.Failure(e: SearchStateExploded) =>
logger.elem(e)
e.partialOutput
case x => x.get
}
debugResults += saveResult(t, obtained, onlyOne)
if (t.style.rewrite.rules.isEmpty &&
!t.style.assumeStandardLibraryStripMargin &&
Expand All @@ -74,8 +77,13 @@ class FormatTests
t.style.runner.dialect
)
}
val formattedAgain =
Scalafmt.format(obtained, t.style.copy(runner = runner)).get
val formattedAgain = Scalafmt
.format(
obtained,
t.style.copy(runner = runner),
filename = t.filename
)
.get
// getFormatOutput(t.style, true) // uncomment to debug
assertNoDiff(formattedAgain, obtained, "Idempotency violated")
if (!onlyManual) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,13 @@ trait HasTests extends AnyFunSuiteLike with FormatAssertions {

def defaultRun(t: DiffTest, parse: Parse[_ <: Tree]): Unit = {
val runner = scalafmtRunner(t.style.runner).copy(parser = parse)
val obtained =
Scalafmt.format(t.original, t.style.copy(runner = runner)).get
val obtained = Scalafmt
.formatCode(
t.original,
t.style.copy(runner = runner),
filename = t.filename
)
.get
if (t.style.rewrite.rules.isEmpty) {
assertFormatPreservesAst(t.original, obtained)(
parse,
Expand Down

0 comments on commit e67c95a

Please sign in to comment.