diff --git a/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala b/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala index 411bc12..5bf4176 100644 --- a/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala +++ b/plugin/src/main/scala/org/scalafmt/sbt/ScalafmtPlugin.scala @@ -176,15 +176,19 @@ object ScalafmtPlugin extends AutoPlugin { @inline private def asRelative(file: File): File = file.relativeTo(baseDir).getOrElse(file) - private def filterFiles(sources: Seq[File]): Seq[File] = { - val filter = getFileFilter() - sources.distinct.filter { file => - val path = file.toPath.toAbsolutePath - scalafmtSession.matchesProjectFilters(path) && filter(path) + private def filterFiles(sources: Seq[File], dirs: Seq[File]): Seq[File] = { + val filter = getFileFilter(dirs) + sources.distinct.flatMap { file => + val canonFile = file.getCanonicalFile + val path = canonFile.toPath.toAbsolutePath + val ok = scalafmtSession.matchesProjectFilters(path) && filter(path) + if (ok) Some(canonFile) else None } } - private def getFileFilter(): Path => Boolean = { + private def getFileFilter(dirs: Seq[File]): Path => Boolean = { + // dirs don't have to be within baseDir but within the same git tree + def absDirs = dirs.map(x => AbsoluteFile(x.toPath)) def gitOps = GitOps.FactoryImpl(AbsoluteFile(baseDir.toPath)) def getFromFiles(getFiles: => Seq[AbsoluteFile], gitCmd: => String) = { def gitMessage = s"[git $gitCmd] ($baseDir)" @@ -199,12 +203,12 @@ object ScalafmtPlugin extends AutoPlugin { } if (filterMode == FilterMode.diffDirty) - getFromFiles(gitOps.status(), "status") + getFromFiles(gitOps.status(absDirs: _*), "status") else if (filterMode.startsWith(FilterMode.diffRefPrefix)) { val branch = filterMode.substring(FilterMode.diffRefPrefix.length) - getFromFiles(gitOps.diff(branch), s"diff $branch") + getFromFiles(gitOps.diff(branch, absDirs: _*), s"diff $branch") } else if (filterMode != FilterMode.none && scalafmtSession.isGitOnly) - getFromFiles(gitOps.lsTree(), "ls-files") + getFromFiles(gitOps.lsTree(absDirs: _*), "ls-files") else { log.debug("considering all files (no git)") _ => true @@ -244,8 +248,8 @@ object ScalafmtPlugin extends AutoPlugin { res } - def formatTrackedSources(sources: Seq[File]): Unit = { - val filteredSources = filterFiles(sources) + def formatTrackedSources(sources: Seq[File], dirs: Seq[File]): Unit = { + val filteredSources = filterFiles(sources, dirs) trackSourcesAndConfig(cacheStoreFactory, filteredSources) { (outDiff, configChanged, prev) => val filesToFormat: Seq[File] = @@ -261,8 +265,8 @@ object ScalafmtPlugin extends AutoPlugin { } } - def formatSources(sources: Seq[File]): Unit = - formatFilteredSources(filterFiles(sources)) + def formatSources(sources: Seq[File], dirs: Seq[File]): Unit = + formatFilteredSources(filterFiles(sources, dirs)) private def formatFilteredSources(sources: Seq[File]): Unit = { if (sources.nonEmpty) @@ -274,8 +278,8 @@ object ScalafmtPlugin extends AutoPlugin { if (cnt > 0) log.info(s"Reformatted $cnt Scala sources") } - def checkTrackedSources(sources: Seq[File]): Unit = { - val filteredSources = filterFiles(sources) + def checkTrackedSources(sources: Seq[File], dirs: Seq[File]): Unit = { + val filteredSources = filterFiles(sources, dirs) val result = trackSourcesAndConfig(cacheStoreFactory, filteredSources) { (outDiff, configChanged, prev) => val filesToCheck: Seq[File] = @@ -300,8 +304,8 @@ object ScalafmtPlugin extends AutoPlugin { throwOnFailure(result) } - def checkSources(sources: Seq[File]): Unit = - throwOnFailure(checkFilteredSources(filterFiles(sources))) + def checkSources(sources: Seq[File], dirs: Seq[File]): Unit = + throwOnFailure(checkFilteredSources(filterFiles(sources, dirs))) private def checkFilteredSources(sources: Seq[File]): ScalafmtAnalysis = { if (sources.nonEmpty) { @@ -393,57 +397,69 @@ object ScalafmtPlugin extends AutoPlugin { } } - private def scalafmtTask(sources: Seq[File], session: FormatSession) = + private def scalafmtTask( + sources: Seq[File], + dirs: Seq[File], + session: FormatSession + ) = Def.task { - session.formatTrackedSources(sources) + session.formatTrackedSources(sources, dirs) } tag (ScalafmtTagPack: _*) - private def scalafmtCheckTask(sources: Seq[File], session: FormatSession) = + private def scalafmtCheckTask( + sources: Seq[File], + dirs: Seq[File], + session: FormatSession + ) = Def.task { - session.checkTrackedSources(sources) + session.checkTrackedSources(sources, dirs) } tag (ScalafmtTagPack: _*) private def getScalafmtSourcesTask( - f: (Seq[File], FormatSession) => InitTask + f: (Seq[File], Seq[File], FormatSession) => InitTask ) = Def.taskDyn[Unit] { val sources = (unmanagedSources in scalafmt).?.value.getOrElse(Seq.empty) - getScalafmtTask(f)(sources, scalaConfig.value) + val dirs = (unmanagedSourceDirectories in scalafmt).?.value.getOrElse(Nil) + getScalafmtTask(f)(sources, dirs, scalaConfig.value) } private def scalafmtSbtTask( sources: Seq[File], + dirs: Seq[File], session: FormatSession ) = Def.task { - session.formatSources(sources) + session.formatSources(sources, dirs) } tag (ScalafmtTagPack: _*) private def scalafmtSbtCheckTask( sources: Seq[File], + dirs: Seq[File], session: FormatSession ) = Def.task { - session.checkSources(sources) + session.checkSources(sources, dirs) } tag (ScalafmtTagPack: _*) private def getScalafmtSbtTasks( - func: (Seq[File], FormatSession) => InitTask + func: (Seq[File], Seq[File], FormatSession) => InitTask ) = Def.taskDyn { joinScalafmtTasks(func)( - (sbtSources.value, sbtConfig.value), - (metabuildSources.value, scalaConfig.value) + (sbtSources.value, Nil, sbtConfig.value), + (metabuildSources.value, Nil, scalaConfig.value) ) } private def joinScalafmtTasks( - func: (Seq[File], FormatSession) => InitTask - )(tuples: (Seq[File], Path)*) = { - val tasks = tuples - .map { case (files, config) => getScalafmtTask(func)(files, config) } + func: (Seq[File], Seq[File], FormatSession) => InitTask + )(tuples: (Seq[File], Seq[File], Path)*) = { + val tasks = tuples.map { case (files, dirs, config) => + getScalafmtTask(func)(files, dirs, config) + } Def.sequential(tasks.tail.toList, tasks.head) } private def getScalafmtTask( - func: (Seq[File], FormatSession) => InitTask - )(files: Seq[File], config: Path) = Def.taskDyn[Unit] { + func: (Seq[File], Seq[File], FormatSession) => InitTask + )(files: Seq[File], dirs: Seq[File], config: Path) = Def.taskDyn[Unit] { if (files.isEmpty) Def.task(Unit) else { val session = new FormatSession( @@ -460,7 +476,7 @@ object ScalafmtPlugin extends AutoPlugin { scalafmtDetailedError.value ) ) - func(files, session) + func(files, dirs, session) } } @@ -505,7 +521,7 @@ object ScalafmtPlugin extends AutoPlugin { scalafmtFailOnErrors.value, scalafmtDetailedError.value ) - ).formatSources(absFiles) + ).formatSources(absFiles, Nil) } ) diff --git a/plugin/src/sbt-test/scalafmt-sbt/sbt/test b/plugin/src/sbt-test/scalafmt-sbt/sbt/test index 3e73192..a305e2f 100644 --- a/plugin/src/sbt-test/scalafmt-sbt/sbt/test +++ b/plugin/src/sbt-test/scalafmt-sbt/sbt/test @@ -202,7 +202,7 @@ $ exec git -C p19 add "jvm/src/main/scala/TestGood.scala" > p19/scalafmtCheck $ copy-file changes/invalid.scala p19/shared/src/main/scala/TestInvalid1.scala $ exec git -C p19 add "shared/src/main/scala/TestInvalid1.scala" -> p19/scalafmtCheck +-> p19/scalafmtCheck $ copy-file changes/target/managed.scala project/target/managed.scala $ copy-file changes/x/Something.scala project/x/Something.scala