From 8c22bfa162fabbc0088bf63a2e235be30f0b1664 Mon Sep 17 00:00:00 2001 From: Alexandre Archambault Date: Tue, 12 Jan 2021 16:52:24 +0100 Subject: [PATCH] Add Scala 3 stuff (3.0.0-M1 for now) --- .../ammonite/compiler/AmmonitePhase.scala | 264 +++++++++ .../scala-3/ammonite/compiler/Compiler.scala | 470 ++++++++++++++++ .../ammonite/compiler/CompilerBuilder.scala | 56 ++ .../compiler/CompilerExtensions.scala | 41 ++ .../compiler/CompilerLifecycleManager.scala | 154 ++++++ .../ammonite/compiler/DottyParser.scala | 73 +++ .../ammonite/compiler/Extensions.scala | 9 + .../scala-3/ammonite/compiler/Parsers.scala | 311 +++++++++++ .../ammonite/compiler/Preprocessor.scala | 311 +++++++++++ .../compiler/SyntaxHighlighting.scala | 129 +++++ .../ammonite/compiler/tools/desugar.scala | 3 + .../ammonite/compiler/tools/source.scala | 9 + .../dotty/ammonite/compiler/Completion.scala | 507 ++++++++++++++++++ .../compiler/DirectoryClassPath.scala | 32 ++ .../compiler/WhiteListClassPath.scala | 52 ++ build.sc | 210 ++++++-- 16 files changed, 2588 insertions(+), 43 deletions(-) create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/AmmonitePhase.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/Compiler.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/CompilerBuilder.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/CompilerExtensions.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/CompilerLifecycleManager.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/DottyParser.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/Extensions.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/Parsers.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/Preprocessor.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/SyntaxHighlighting.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/tools/desugar.scala create mode 100644 amm/compiler/src/main/scala-3/ammonite/compiler/tools/source.scala create mode 100644 amm/compiler/src/main/scala-3/dotty/ammonite/compiler/Completion.scala create mode 100644 amm/compiler/src/main/scala-3/dotty/ammonite/compiler/DirectoryClassPath.scala create mode 100644 amm/compiler/src/main/scala-3/dotty/ammonite/compiler/WhiteListClassPath.scala diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/AmmonitePhase.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/AmmonitePhase.scala new file mode 100644 index 000000000..333645aa7 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/AmmonitePhase.scala @@ -0,0 +1,264 @@ +package ammonite.compiler + +import ammonite.util.{ImportData, Imports, Name => AmmName, Printer, Util} + +import dotty.tools.dotc +import dotty.tools.dotc.core.StdNames.nme +import dotc.ast.Trees._ +import dotc.ast.{tpd, untpd} +import dotc.core.Flags +import dotc.core.Contexts._ +import dotc.core.Names.Name +import dotc.core.Phases.Phase +import dotc.core.Symbols.{NoSymbol, Symbol, newSymbol} +import dotc.core.Types.{TermRef, Type, TypeTraverser} + +import scala.collection.mutable + +class AmmonitePhase( + userCodeNestingLevel: => Int, + needsUsedEarlierDefinitions: => Boolean +) extends Phase: + import tpd._ + + def phaseName: String = "ammonite" + + private var myImports = new mutable.ListBuffer[(Boolean, String, String, Seq[AmmName])] + private var usedEarlierDefinitions0 = new mutable.ListBuffer[String] + + def importData: Seq[ImportData] = + val grouped = myImports + .toList + .distinct + .groupBy { case (a, b, c, d) => (b, c, d) } + .mapValues(_.map(_._1)) + + val open = for { + ((fromName, toName, importString), items) <- grouped + if !CompilerUtil.ignoredNames(fromName) + } yield { + val importType = items match{ + case Seq(true) => ImportData.Type + case Seq(false) => ImportData.Term + case Seq(_, _) => ImportData.TermType + } + + ImportData(AmmName(fromName), AmmName(toName), importString, importType) + } + + open.toVector.sortBy(x => Util.encodeScalaSourcePath(x.prefix)) + + def usedEarlierDefinitions: Seq[String] = + usedEarlierDefinitions0.toList.distinct + + private def saneSym(name: Name, sym: Symbol)(using Context): Boolean = + !name.decode.toString.contains('$') && + sym.exists && + // !sym.is(Flags.Synthetic) && + !scala.util.Try(sym.is(Flags.Private)).toOption.getOrElse(true) && + !scala.util.Try(sym.is(Flags.Protected)).toOption.getOrElse(true) && + // sym.is(Flags.Public) && + !CompilerUtil.ignoredSyms(sym.toString) && + !CompilerUtil.ignoredNames(name.decode.toString) + + private def saneSym(sym: Symbol)(using Context): Boolean = + saneSym(sym.name, sym) + + private def processTree(t: tpd.Tree)(using Context): Unit = { + val sym = t.symbol + val name = t match { + case t: tpd.ValDef => t.name + case _ => sym.name + } + if (saneSym(name, sym)) { + val name = sym.name.decode.toString + myImports.addOne((sym.isType, name, name, Nil)) + } + } + + private def processImport(i: tpd.Import)(using Context): Unit = { + val expr = i.expr + val selectors = i.selectors + + // Most of that logic was adapted from AmmonitePlugin, the Scala 2 counterpart + // of this file. + + val prefix = + val (_ :: nameListTail, symbolHead :: _) = { + def rec(expr: tpd.Tree): List[(Name, Symbol)] = { + expr match { + case s @ tpd.Select(lhs, _) => (s.symbol.name -> s.symbol) :: rec(lhs) + case i @ tpd.Ident(name) => List(name -> i.symbol) + case t @ tpd.This(pkg) => List(pkg.name -> t.symbol) + } + } + rec(expr).reverse.unzip + } + + val headFullPath = symbolHead.fullName.decode.toString.split('.') + .map(n => if (n.endsWith("$")) n.stripSuffix("$") else n) // meh + // prefix package imports with `_root_` to try and stop random + // variables from interfering with them. If someone defines a value + // called `_root_`, this will still break, but that's their problem + val rootPrefix = if(symbolHead.denot.is(Flags.Package)) Seq("_root_") else Nil + val tailPath = nameListTail.map(_.decode.toString) + + (rootPrefix ++ headFullPath ++ tailPath).map(AmmName(_)) + + def isMask(sel: untpd.ImportSelector) = sel.name != nme.WILDCARD && sel.rename == nme.WILDCARD + + val renameMap = + + /** + * A map of each name importable from `expr`, to a `Seq[Boolean]` + * containing a `true` if there's a type-symbol you can import, `false` + * if there's a non-type symbol and both if there are both type and + * non-type symbols that are importable for that name + */ + val importableIsTypes = + expr.tpe + .allMembers + .map(_.symbol) + .filter(saneSym(_)) + .groupBy(_.name.decode.toString) + .mapValues(_.map(_.isType).toVector) + + val renamings = for{ + t @ untpd.ImportSelector(name, renameTree, _) <- selectors + if !isMask(t) + // getOrElse just in case... + isType <- importableIsTypes.getOrElse(name.name.decode.toString, Nil) + Ident(rename) <- Option(renameTree) + } yield ((isType, rename.decode.toString), name.name.decode.toString) + + renamings.toMap + + + def isUnimportableUnlessRenamed(sym: Symbol): Boolean = + sym eq NoSymbol + + @scala.annotation.tailrec + def transformImport(selectors: List[untpd.ImportSelector], sym: Symbol): List[Symbol] = + selectors match { + case Nil => Nil + case sel :: Nil if sel.isWildcard => + if (isUnimportableUnlessRenamed(sym)) Nil + else List(sym) + case (sel @ untpd.ImportSelector(from, to, _)) :: _ + if from.name == (if (from.isTerm) sym.name.toTermName else sym.name.toTypeName) => + if (isMask(sel)) Nil + else List( + newSymbol(sym.owner, sel.rename, sym.flags, sym.info, sym.privateWithin, sym.coord) + ) + case _ :: rest => transformImport(rest, sym) + } + + val symNames = + for { + sym <- expr.tpe.allMembers.map(_.symbol).flatMap(transformImport(selectors, _)) + if saneSym(sym) + } yield (sym.isType, sym.name.decode.toString) + + val syms = for { + // For some reason `info.allImportedSymbols` does not show imported + // type aliases when they are imported directly e.g. + // + // import scala.reflect.macros.Context + // + // As opposed to via import scala.reflect.macros._. + // Thus we need to combine allImportedSymbols with the renameMap + (isType, sym) <- (symNames.toList ++ renameMap.keys).distinct + } yield (isType, renameMap.getOrElse((isType, sym), sym), sym, prefix) + + myImports ++= syms + } + + private def updateUsedEarlierDefinitions( + wrapperSym: Symbol, + stats: List[tpd.Tree] + )(using Context): Unit = { + /* + * We list the variables from the first wrapper + * used from the user code. + * + * E.g. if, after wrapping, the code looks like + * ``` + * class cmd2 { + * + * val cmd0 = ??? + * val cmd1 = ??? + * + * import cmd0.{ + * n + * } + * + * class Helper { + * // user-typed code + * val n0 = n + 1 + * } + * } + * ``` + * this would process the tree of `val n0 = n + 1`, find `n` as a tree like + * `cmd2.this.cmd0.n`, and put `cmd0` in `uses`. + */ + + val typeTraverser: TypeTraverser = new TypeTraverser { + def traverse(tpe: Type) = tpe match { + case tr: TermRef if tr.prefix.typeSymbol == wrapperSym => + tr.designator match { + case n: Name => usedEarlierDefinitions0 += n.decode.toString + case s: Symbol => usedEarlierDefinitions0 += s.name.decode.toString + case _ => // can this happen? + } + case _ => + traverseChildren(tpe) + } + } + + val traverser: TreeTraverser = new TreeTraverser { + def traverse(tree: Tree)(using Context) = tree match { + case tpd.Select(node, name) if node.symbol == wrapperSym => + usedEarlierDefinitions0 += name.decode.toString + case tt @ tpd.TypeTree() => + typeTraverser.traverse(tt.tpe) + case _ => + traverseChildren(tree) + } + } + + for (tree <- stats) + traverser.traverse(tree) + } + + private def unpkg(tree: tpd.Tree): List[tpd.Tree] = + tree match { + case PackageDef(_, elems) => elems.flatMap(unpkg) + case _ => List(tree) + } + + def run(using Context): Unit = + val elems = unpkg(ctx.compilationUnit.tpdTree) + def mainStats(trees: List[tpd.Tree]): List[tpd.Tree] = + trees + .reverseIterator + .collectFirst { + case TypeDef(name, rhs0: Template) => rhs0.body + } + .getOrElse(Nil) + + val rootStats = mainStats(elems) + val stats = (1 until userCodeNestingLevel) + .foldLeft(rootStats)((trees, _) => mainStats(trees)) + + if (needsUsedEarlierDefinitions) { + val wrapperSym = elems.last.symbol + updateUsedEarlierDefinitions(wrapperSym, stats) + } + + stats.foreach { + case i: Import => processImport(i) + case t: tpd.DefDef => processTree(t) + case t: tpd.ValDef => processTree(t) + case t: tpd.TypeDef => processTree(t) + case _ => + } diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/Compiler.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/Compiler.scala new file mode 100644 index 000000000..53ed6544f --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/Compiler.scala @@ -0,0 +1,470 @@ +package ammonite.compiler + +import java.net.URL +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} +import java.io.OutputStream + +import ammonite.compiler.iface.{ + Compiler => ICompiler, + CompilerBuilder => ICompilerBuilder, + CompilerLifecycleManager => ICompilerLifecycleManager, + Preprocessor => IPreprocessor, + _ +} +import ammonite.util.{ImportData, Imports, Printer} +import ammonite.util.Util.newLine + +import dotty.tools.dotc +import dotc.{CompilationUnit, Compiler => DottyCompiler, Run} +import dotc.ast.{tpd, untpd} +import dotc.ast.Positioned +import dotc.classpath +import dotc.config.{CompilerCommand, JavaPlatform} +import dotc.core.Contexts._ +import dotc.core.{Flags, MacroClassLoader, Mode} +import dotc.core.Comments.{ContextDoc, ContextDocstrings} +import dotc.core.Phases.{Phase, unfusedPhases} +import dotc.core.Symbols.{defn, Symbol} +import dotc.fromtasty.TastyFileUtil +import dotc.interactive.Completion +import dotc.report +import dotc.reporting +import dotc.transform.{PostTyper, Staging} +import dotc.typer.FrontEnd +import dotc.util.{Property, SourceFile, SourcePosition} +import dotc.util.Spans.Span +import dotty.tools.io.{AbstractFile, ClassPath, ClassRepresentation, File, VirtualDirectory} +import dotty.tools.repl.CollectTopLevelImports + +class Compiler( + dynamicClassPath: AbstractFile, + initialClassPath: Seq[URL], + classPath: Seq[URL], + macroClassLoader: ClassLoader, + whiteList: Set[Seq[String]], + dependencyCompleteOpt: => Option[String => (Int, Seq[String])] = None, + contextInit: FreshContext => Unit = _ => (), + settings: Seq[String] = Nil, + reporter: Option[ICompilerBuilder.Message => Unit] = None +) extends ICompiler: + self => + + import Compiler.{enumerateVdFiles, files} + + private val outputDir = new VirtualDirectory("(memory)") + + private def initCtx: Context = + val base: ContextBase = + new ContextBase: + override protected def newPlatform(using Context) = + new JavaPlatform: + private var classPath0: ClassPath = null + override def classPath(using Context) = + if (classPath0 == null) + classPath0 = classpath.AggregateClassPath(Seq( + asDottyClassPath(initialClassPath, whiteListed = true), + asDottyClassPath(self.classPath), + classpath.ClassPathFactory.newClassPath(dynamicClassPath) + )) + classPath0 + base.initialCtx + + private def sourcesRequired = false + + private lazy val MacroClassLoaderKey = + val cls = macroClassLoader.loadClass("dotty.tools.dotc.core.MacroClassLoader$") + val fld = cls.getDeclaredField("MacroClassLoaderKey") + fld.setAccessible(true) + fld.get(null).asInstanceOf[Property.Key[ClassLoader]] + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/dotc/Driver.scala/#L67-L81 + private def setup(args: Array[String], rootCtx: Context): (List[String], Context) = + given ictx as FreshContext = rootCtx.fresh + val summary = CompilerCommand.distill(args) + ictx.setSettings(summary.sstate) + ictx.setProperty(MacroClassLoaderKey, macroClassLoader) + Positioned.init + + if !ictx.settings.YdropComments.value || ictx.mode.is(Mode.ReadComments) then + ictx.setProperty(ContextDoc, new ContextDocstrings) + val fileNames = CompilerCommand.checkUsage(summary, sourcesRequired) + contextInit(ictx) + (fileNames, ictx) + + private def asDottyClassPath( + cp: Seq[URL], + whiteListed: Boolean = false + )(using Context): ClassPath = + val (dirs, jars) = cp.partition { url => + url.getProtocol == "file" && Files.isDirectory(Paths.get(url.toURI)) + } + + val dirsCp = dirs.map(u => classpath.ClassPathFactory.newClassPath(AbstractFile.getURL(u))) + val jarsCp = jars + .filter(ammonite.util.Classpath.canBeOpenedAsJar) + .map(u => classpath.ZipAndJarClassPathFactory.create(AbstractFile.getURL(u))) + + if (whiteListed) new dotty.ammonite.compiler.WhiteListClasspath(dirsCp ++ jarsCp, whiteList) + else classpath.AggregateClassPath(dirsCp ++ jarsCp) + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplDriver.scala/#L67-L73 + /** Create a fresh and initialized context with IDE mode enabled */ + lazy val initialCtx = + val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions | Mode.Interactive | Mode.ReadComments) + rootCtx.setSetting(rootCtx.settings.YcookComments, true) + // FIXME Disabled for the tests to pass + rootCtx.setSetting(rootCtx.settings.color, "never") + // FIXME We lose possible custom openStream implementations on the URLs of initialClassPath and + // classPath + val initialClassPath0 = initialClassPath + // .filter(!_.toURI.toASCIIString.contains("fansi_2.13")) + // .filter(!_.toURI.toASCIIString.contains("pprint_2.13")) + rootCtx.setSetting(rootCtx.settings.outputDir, outputDir) + + val (_, ictx) = setup(settings.toArray, rootCtx) + ictx.base.initialize()(using ictx) + ictx + + private var userCodeNestingLevel = -1 + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplCompiler.scala/#L34-L39 + val compiler = + new DottyCompiler: + override protected def frontendPhases: List[List[Phase]] = List( + List(new FrontEnd), // List(new REPLFrontEnd), + List(new AmmonitePhase(userCodeNestingLevel, userCodeNestingLevel == 2)), + List(new Staging), + List(new PostTyper) + ) + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/Rendering.scala/#L97-L103 + /** Formats errors using the `messageRenderer` */ + private def formatError(dia: reporting.Diagnostic)(implicit ctx: Context): reporting.Diagnostic = + new reporting.Diagnostic( + Compiler.messageRenderer.messageAndPos( + dia.msg, + dia.pos, + Compiler.messageRenderer.diagnosticLevel(dia) + ), + dia.pos, + dia.level + ) + + def compile( + src: Array[Byte], + printer: Printer, + importsLen: Int, + userCodeNestingLevel: Int, + fileName: String + ): Option[ICompiler.Output] = + val sourceFile = SourceFile.virtual(fileName, new String(src, StandardCharsets.UTF_8)) + // println(s"Compiling\n${new String(src, StandardCharsets.UTF_8)}\n") + + self.userCodeNestingLevel = userCodeNestingLevel + + val reporter0 = reporter match { + case None => + Compiler.newStoreReporter() + case Some(rep) => + val simpleReporter = new dotc.interfaces.SimpleReporter { + def report(diag: dotc.interfaces.Diagnostic) = { + val severity = diag.level match { + case dotc.interfaces.Diagnostic.ERROR => "ERROR" + case dotc.interfaces.Diagnostic.WARNING => "WARNING" + case dotc.interfaces.Diagnostic.INFO => "INFO" + case _ => "INFO" // should not happen + } + val pos = Some(diag.position).filter(_.isPresent).map(_.get) + val start = pos.fold(0)(_.start) + val end = pos.fold(new String(src, "UTF-8").length)(_.end) + val msg = ICompilerBuilder.Message(severity, start, end, diag.message) + rep(msg) + } + } + reporting.Reporter.fromSimpleReporter(simpleReporter) + } + val run = new Run(compiler, initialCtx.fresh.setReporter(reporter0)) + implicit val ctx: Context = run.runContext.withSource(sourceFile) + + val unit = + new CompilationUnit(ctx.source): + // as done in + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplCompillationUnit.scala/#L8 + override def isSuspendable: Boolean = false + ctx + .run + .compileUnits(unit :: Nil) + + val result = + if (ctx.reporter.hasErrors) Left(reporter.fold(ctx.reporter.removeBufferedMessages)(_ => Nil)) + else Right(unit) + + result match { + case Left(errors) => + errors + .map(formatError) + .map(_.msg.toString) + .foreach(printer.error) + None + case Right(unit) => + val newImports = unfusedPhases.collectFirst { + case p: AmmonitePhase => p.importData + }.getOrElse(Seq.empty[ImportData]) + val usedEarlierDefinitions = unfusedPhases.collectFirst { + case p: AmmonitePhase => p.usedEarlierDefinitions + }.getOrElse(Seq.empty[String]) + val fileCount = enumerateVdFiles(outputDir).length + val classes = files(outputDir).toArray + Compiler.addToClasspath(classes, dynamicClassPath) + outputDir.clear() + val output = ICompiler.Output( + classes.toVector, + Imports(newImports), + Some(usedEarlierDefinitions) + ) + Some(output) + } + + def objCompiler = compiler + + def preprocessor(fileName: String, markGeneratedSections: Boolean): IPreprocessor = + new Preprocessor( + initialCtx.fresh.withSource(SourceFile.virtual(fileName, "")), + markGeneratedSections: Boolean + ) + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplCompiler.scala/#L224-L286 + def tryTypeCheck( + src: Array[Byte], + fileName: String + ) = + val sourceFile = SourceFile.virtual(fileName, new String(src, StandardCharsets.UTF_8)) + + val reporter0 = Compiler.newStoreReporter() + val run = new Run( + compiler, + initCtx.fresh + .addMode(Mode.ReadPositions | Mode.Interactive) + .setReporter(reporter0) + .setSetting(initialCtx.settings.YstopAfter, List("typer")) + ) + implicit val ctx: Context = run.runContext.withSource(sourceFile) + + val unit = + new CompilationUnit(ctx.source): + override def isSuspendable: Boolean = false + ctx + .run + .compileUnits(unit :: Nil, ctx) + + (unit.tpdTree, ctx) + + def complete( + offset: Int, + previousImports: String, + snippet: String + ): (Int, Seq[String], Seq[String]) = { + + val prefix = previousImports + newLine + + "object AutocompleteWrapper{ val expr: _root_.scala.Unit = {" + newLine + val suffix = newLine + "()}}" + val allCode = prefix + snippet + suffix + val index = offset + prefix.length + + + // Originally based on + // https://github.com/lampepfl/dotty/blob/3.0.0-M1/ + // compiler/src/dotty/tools/repl/ReplDriver.scala/#L179-L191 + + val (tree, ctx0) = tryTypeCheck(allCode.getBytes("UTF-8"), "") + val ctx = ctx0.fresh + val file = SourceFile.virtual("", allCode, maybeIncomplete = true) + val unit = CompilationUnit(file)(using ctx) + unit.tpdTree = { + given Context = ctx + import tpd._ + tree match { + case PackageDef(_, p) => + p.collectFirst { + case TypeDef(_, tmpl: Template) => + tmpl.body + .collectFirst { case dd: ValDef if dd.name.show == "expr" => dd } + .getOrElse(???) + }.getOrElse(???) + case _ => ??? + } + } + val ctx1 = ctx.fresh.setCompilationUnit(unit) + val srcPos = SourcePosition(file, Span(index)) + val (start, completions) = dotty.ammonite.compiler.Completion.completions( + srcPos, + dependencyCompleteOpt + )(using ctx1) + + val blacklistedPackages = Set("shaded") + + def deepCompletion(name: String): List[String] = { + given Context = ctx1 + def rec(t: Symbol): Seq[Symbol] = { + if (blacklistedPackages(t.name.toString)) + Nil + else { + val children = + if (t.is(Flags.Package) || t.is(Flags.PackageVal) || t.is(Flags.PackageClass)) + t.denot.info.allMembers.map(_.symbol).filter(_ != t).flatMap(rec) + else Nil + + t +: children.toSeq + } + } + + for { + member <- defn.RootClass.denot.info.allMembers.map(_.symbol).toList + sym <- rec(member) + // Scala 2 comment: sketchy name munging because I don't know how to do this properly + // Note lack of back-quoting support. + strippedName = sym.name.toString.stripPrefix("package$").stripSuffix("$") + if strippedName.startsWith(name) + (pref, _) = sym.fullName.toString.splitAt(sym.fullName.toString.lastIndexOf('.') + 1) + out = pref + strippedName + if out != "" + } yield out + } + + def blacklisted(s: Symbol) = { + given Context = ctx1 + val blacklist = Set( + "scala.Predef.any2stringadd.+", + "scala.Any.##", + "java.lang.Object.##", + "scala.", + "scala.", + "scala.", + "scala.", + "scala.Predef.StringFormat.formatted", + "scala.Predef.Ensuring.ensuring", + "scala.Predef.ArrowAssoc.->", + "scala.Predef.ArrowAssoc.→", + "java.lang.Object.synchronized", + "java.lang.Object.ne", + "java.lang.Object.eq", + "java.lang.Object.wait", + "java.lang.Object.notifyAll", + "java.lang.Object.notify", + "java.lang.Object.clone", + "java.lang.Object.finalize" + ) + + blacklist(s.showFullName) || + s.isOneOf(Flags.GivenOrImplicit) || + // Cache objects, which you should probably never need to + // access directly, and apart from that have annoyingly long names + "cache[a-f0-9]{32}".r.findPrefixMatchOf(s.name.decode.toString).isDefined || + // s.isDeprecated || + s.name.decode.toString == "" || + s.name.decode.toString.contains('$') + } + + val filteredCompletions = completions.filter { c => + c.symbols.isEmpty || c.symbols.exists(!blacklisted(_)) + } + val signatures = { + given Context = ctx1 + for { + c <- filteredCompletions + s <- c.symbols + isMethod = s.denot.is(Flags.Method) + if isMethod + } yield s"def ${s.name}${s.denot.info.widenTermRefExpr.show}" + } + (start - prefix.length, filteredCompletions.map(_.label.replace(".package$.", ".")), signatures) + } + +object Compiler: + + /** Create empty outer store reporter */ + def newStoreReporter(): reporting.StoreReporter = + new reporting.StoreReporter(null) + with reporting.UniqueMessagePositions with reporting.HideNonSensicalMessages + + private def enumerateVdFiles(d: VirtualDirectory): Iterator[AbstractFile] = + val (subs, files) = d.iterator.partition(_.isDirectory) + files ++ subs.map(_.asInstanceOf[VirtualDirectory]).flatMap(enumerateVdFiles) + + private def files(d: VirtualDirectory): Iterator[(String, Array[Byte])] = + for (x <- enumerateVdFiles(d) if x.name.endsWith(".class") || x.name.endsWith(".tasty")) yield { + val segments = x.path.split("/").toList.tail + (x.path.stripPrefix("(memory)/"), x.toByteArray) + } + + private def writeDeep( + d: AbstractFile, + path: List[String] + ): OutputStream = path match { + case head :: Nil => d.fileNamed(path.head).output + case head :: rest => + writeDeep( + d.subdirectoryNamed(head), //.asInstanceOf[VirtualDirectory], + rest + ) + // We should never write to an empty path, and one of the above cases + // should catch this and return before getting here + case Nil => ??? + } + + def addToClasspath(classFiles: Traversable[(String, Array[Byte])], + dynamicClasspath: AbstractFile): Unit = { + + for((name, bytes) <- classFiles){ + val output = writeDeep(dynamicClasspath, name.split('/').toList) + output.write(bytes) + output.close() + } + + } + + /** A `MessageRenderer` for the REPL without file positions */ + private[compiler] val messageRenderer = + new reporting.MessageRendering: + override def sourceLines( + pos: SourcePosition, + diagnosticLevel: String + )(using Context): (List[String], List[String], Int) = { + val (srcBefore, srcAfter, offset) = super.sourceLines(pos, diagnosticLevel) + val updatedSrcBefore = srcBefore.map { line => + val chars = line.toCharArray + var i = 0 + var updated = false + while (i < chars.length) { + if (chars(i) == '|') + i = chars.length + else if (chars(i).isDigit) { + chars(i) = ' ' + updated = true + } + i += 1 + } + if (updated) new String(chars) + else line + } + (updatedSrcBefore, srcAfter, offset) + } + // TODO Add this back for scripts + override def posStr( + pos: SourcePosition, + diagnosticLevel: String, + message: reporting.Message + )(using Context): String = + "" diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerBuilder.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerBuilder.scala new file mode 100644 index 000000000..d4cd8d4d4 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerBuilder.scala @@ -0,0 +1,56 @@ +package ammonite.compiler + +import java.net.URL +import java.nio.file.Path + +import ammonite.compiler.iface.{ + Compiler => ICompiler, + CompilerBuilder => ICompilerBuilder, + CompilerLifecycleManager => ICompilerLifecycleManager, + _ +} +import ammonite.util.Frame +import dotty.tools.io.AbstractFile + +object CompilerBuilder extends ICompilerBuilder: + + def create( + initialClassPath: Seq[URL], + classPath: Seq[URL], + dynamicClassPath: Seq[(String, Array[Byte])], + evalClassLoader: ClassLoader, + pluginClassLoader: ClassLoader, + reporter: Option[ICompilerBuilder.Message => Unit], + settings: Seq[String], + classPathWhiteList: Set[Seq[String]], + lineNumberModifier: Boolean + ): ICompiler = { + val tempDir = AbstractFile.getDirectory(os.temp.dir().toNIO) + Compiler.addToClasspath(dynamicClassPath, tempDir) + new Compiler( + tempDir, + initialClassPath, + classPath, + evalClassLoader, + classPathWhiteList, + settings = settings, + reporter = reporter + ) + } + + def scalaVersion = dotty.tools.dotc.config.Properties.versionNumberString + + def newManager( + rtCacheDir: Option[Path], + headFrame: => Frame, + dependencyCompleter: => Option[String => (Int, Seq[String])], + whiteList: Set[Seq[String]], + initialClassLoader: ClassLoader + ): ICompilerLifecycleManager = + new CompilerLifecycleManager( + rtCacheDir, + headFrame, + dependencyCompleter, + whiteList, + initialClassLoader + ) diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerExtensions.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerExtensions.scala new file mode 100644 index 000000000..1c980b427 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerExtensions.scala @@ -0,0 +1,41 @@ +package ammonite.compiler + +import ammonite.interp.api.InterpAPI +import ammonite.repl.api.ReplAPI + +object CompilerExtensions { + + implicit class CompilerInterpAPIExtensions(private val api: InterpAPI) extends AnyVal { + + private def compilerManager = api._compilerManager.asInstanceOf[CompilerLifecycleManager] + + /** + * Configures the current compiler, or if the compiler hasn't been initialized + * yet, registers the configuration callback and applies it to the compiler + * when it ends up being initialized later + */ + def configureCompiler(c: dotty.tools.dotc.Compiler => Unit): Unit = + compilerManager.configureCompiler(c) + + /** + * Pre-configures the next compiler context. Useful for tuning options that are + * used during parsing. + */ + def preConfigureCompiler(c: dotty.tools.dotc.core.Contexts.FreshContext => Unit): Unit = + compilerManager.preConfigureCompiler(c) + } + + implicit class CompilerReplAPIExtensions(private val api: ReplAPI) extends AnyVal { + + private def compilerManager = api._compilerManager.asInstanceOf[CompilerLifecycleManager] + + def initialContext: dotty.tools.dotc.core.Contexts.Context = + compilerManager.compiler.initialCtx + /** + * Access the compiler to do crazy things if you really want to! + */ + def compiler: dotty.tools.dotc.Compiler = + compilerManager.compiler.compiler + } + +} diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerLifecycleManager.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerLifecycleManager.scala new file mode 100644 index 000000000..96b1e87c5 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/CompilerLifecycleManager.scala @@ -0,0 +1,154 @@ +package ammonite.compiler + +import ammonite.util.Util._ +import ammonite.util.{Classpath, ImportTree, Printer} +import ammonite.util.Util.ClassFiles + +import java.nio.file.Path + +import dotty.tools.dotc +import dotc.{Compiler => DottyCompiler} +import dotc.core.Contexts.FreshContext +import dotty.tools.io.AbstractFile +import scala.collection.mutable + + +/** + * Wraps up the `Compiler` and `Pressy`, ensuring that they get properly + * initialized before use. Mostly deals with ensuring the object lifecycles + * are properly dealt with; `Compiler` and `Pressy` are the ones which deal + * with the compiler's nasty APIs + * + * Exposes a simple API where you can just call methods like `compilerClass` + * `configureCompiler` any-how and not worry about ensuring the necessary + * compiler objects are initialized, or worry about initializing them more + * than necessary + */ +class CompilerLifecycleManager( + rtCacheDir: Option[Path], + headFrame: => ammonite.util.Frame, + dependencyCompleteOpt: => Option[String => (Int, Seq[String])], + classPathWhitelist: Set[Seq[String]], + initialClassLoader: ClassLoader +) extends ammonite.compiler.iface.CompilerLifecycleManager { + + def scalaVersion = dotc.config.Properties.versionNumberString + + def forceInit(): Unit = init(force = true) + def init(): Unit = init(force = false) + + + private[this] object Internal{ + val dynamicClasspath = AbstractFile.getDirectory(os.temp.dir().toNIO) + var compiler: ammonite.compiler.Compiler = null + val onCompilerInit = mutable.Buffer.empty[DottyCompiler => Unit] + val onSettingsInit = mutable.Buffer.empty[FreshContext => Unit] // TODO Pass a SettingsState too + var preConfiguredSettingsChanged: Boolean = false + var compilationCount = 0 + var (lastFrame, lastFrameVersion) = (headFrame, headFrame.version) + } + + + import Internal._ + + + // Public to expose it in the REPL so people can poke at it at runtime + // Not for use within Ammonite! Use one of the other methods to ensure + // that `Internal.compiler` is properly initialized before use. + def compiler: ammonite.compiler.Compiler = Internal.compiler + def compilationCount = Internal.compilationCount + + // def pressy: Pressy = Internal.pressy + + def preprocess(fileName: String) = synchronized{ + init() + compiler.preprocessor(fileName) + } + + + // We lazily force the compiler to be re-initialized by setting the + // compilerStale flag. Otherwise, if we re-initialized the compiler eagerly, + // we end up sometimes re-initializing it multiple times unnecessarily before + // it gets even used once. Empirically, this cuts down the number of compiler + // re-initializations by about 2/3, each of which costs about 30ms and + // probably creates a pile of garbage + + def init(force: Boolean = false) = synchronized{ + if (compiler == null || + (headFrame ne lastFrame) || + headFrame.version != lastFrameVersion || + Internal.preConfiguredSettingsChanged || + force) { + + lastFrame = headFrame + lastFrameVersion = headFrame.version + + val initialClassPath = Classpath.classpath(initialClassLoader, rtCacheDir) + val headFrameClassPath = + Classpath.classpath(headFrame.classloader, rtCacheDir) + + Internal.compiler = new Compiler( + Internal.dynamicClasspath, + initialClassPath, + headFrameClassPath, + headFrame.classloader, + classPathWhitelist, + dependencyCompleteOpt = dependencyCompleteOpt, + contextInit = c => onSettingsInit.foreach(_(c)) + ) + onCompilerInit.foreach(_(compiler.compiler)) + + Internal.preConfiguredSettingsChanged = false + } + } + + def complete( + offset: Int, + previousImports: String, + snippet: String + ): (Int, Seq[String], Seq[String]) = synchronized{ + init() + Internal.compiler.complete(offset, previousImports, snippet) + } + + def compileClass( + processed: ammonite.compiler.iface.Preprocessor.Output, + printer: Printer, + fileName: String + ): Option[ammonite.compiler.iface.Compiler.Output] = synchronized{ + // Enforce the invariant that every piece of code Ammonite ever compiles, + // gets run within the `ammonite` package. It's further namespaced into + // things like `ammonite.$file` or `ammonite.$sess`, but it has to be + // within `ammonite` + assert(processed.code.trim.startsWith("package ammonite")) + + init() + val compiled = compiler.compile( + processed.code.getBytes(scala.util.Properties.sourceEncoding), + printer, + processed.prefixCharLength, + processed.userCodeNestingLevel, + fileName + ) + Internal.compilationCount += 1 + compiled + } + + def configureCompiler(callback: DottyCompiler => Unit) = synchronized{ + onCompilerInit.append(callback) + if (compiler != null){ + callback(compiler.compiler) + } + } + + def preConfigureCompiler(callback: FreshContext => Unit) = + synchronized { + onSettingsInit.append(callback) + preConfiguredSettingsChanged = true + } + + def addToClasspath(classFiles: ClassFiles): Unit = synchronized { + Compiler.addToClasspath(classFiles, dynamicClasspath) + } + def shutdownPressy() = () // N/A in Scala 3 +} diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/DottyParser.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/DottyParser.scala new file mode 100644 index 000000000..d00dd3e61 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/DottyParser.scala @@ -0,0 +1,73 @@ +package ammonite.compiler + +import dotty.tools.dotc +import dotc.ast.untpd +import dotc.core.Contexts.Context +import dotc.core.Flags +import dotc.core.StdNames.nme +import dotc.parsing.Parsers.{Location, Parser} +import dotc.parsing.Tokens +import dotc.reporting.IllegalStartOfStatement +import dotc.util.SourceFile + +import scala.collection.mutable + +class DottyParser(source: SourceFile)(using Context) extends Parser(source) { + + // From + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/dotc/parsing/Parsers.scala/#L67-L71 + extension (buf: mutable.ListBuffer[untpd.Tree]): + def +++=(x: untpd.Tree) = x match { + case x: untpd.Thicket => buf ++= x.trees + case x => buf += x + } + + private val oursLocalModifierTokens = Tokens.localModifierTokens + Tokens.PRIVATE + + override def localDef( + start: Int, + implicitMods: untpd.Modifiers = untpd.EmptyModifiers + ): untpd.Tree = { + var mods = defAnnotsMods(oursLocalModifierTokens) + for (imod <- implicitMods.mods) mods = addMod(mods, imod) + if (mods.is(Flags.Final)) + // A final modifier means the local definition is "class-like". + // FIXME: Deal with modifiers separately + tmplDef(start, mods) + else + defOrDcl(start, mods) + } + + // Adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/dotc/parsing/Parsers.scala/#L3882-L3904 + // Unlike it, we accept private modifiers for top-level definitions. + override def blockStatSeq(): List[untpd.Tree] = checkNoEscapingPlaceholders { + val stats = new mutable.ListBuffer[untpd.Tree] + var exitOnError = false + while (!isStatSeqEnd && in.token != Tokens.CASE && !exitOnError) { + setLastStatOffset() + if (in.token == Tokens.IMPORT) + stats ++= importClause(Tokens.IMPORT, mkImport()) + else if (isExprIntro) + stats += expr(Location.InBlock) + else if in.token == Tokens.IMPLICIT && !in.inModifierPosition() then + stats += closure( + in.offset, + Location.InBlock, + modifiers(scala.collection.immutable.BitSet(Tokens.IMPLICIT)) + ) + else if isIdent(nme.extension) && followingIsExtension() then + stats += extension() + else if isDefIntro(oursLocalModifierTokens, excludedSoftModifiers = Set(nme.`opaque`)) then + stats +++= localDef(in.offset) + else if (!isStatSep && (in.token != Tokens.CASE)) { + exitOnError = mustStartStat + syntaxErrorOrIncomplete(IllegalStartOfStatement(isModifier)) + } + acceptStatSepUnlessAtEnd(stats, Tokens.CASE) + } + stats.toList + } +} diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/Extensions.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/Extensions.scala new file mode 100644 index 000000000..423384521 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/Extensions.scala @@ -0,0 +1,9 @@ +package ammonite.compiler + +import ammonite.interp.api.InterpAPI + +object Extensions { + + implicit class CompilerInterAPIExtensions(private val self: InterpAPI) extends AnyVal + +} diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/Parsers.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/Parsers.scala new file mode 100644 index 000000000..47a504343 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/Parsers.scala @@ -0,0 +1,311 @@ +package ammonite.compiler + +import java.util.Map + +import ammonite.compiler.iface.{Compiler => _, Parser => IParser, _} +import ammonite.util.ImportTree +import ammonite.util.Util.CodeSource + +import dotty.tools.dotc +import dotc.ast.untpd +import dotc.CompilationUnit +import dotc.core.Contexts.{ctx, Context, ContextBase} +import dotc.parsing.Tokens +import dotc.util.SourceFile + +import scala.collection.mutable + +class Parsers extends IParser { + + // FIXME Get via Compiler? + private lazy val initCtx: Context = + (new ContextBase).initialCtx + + // From + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ParseResult.scala/#L115-L120 + private def parseStats(using Context): List[untpd.Tree] = { + val parser = new DottyParser(ctx.source) + val stats = parser.blockStatSeq() + parser.accept(Tokens.EOF) + stats + } + + // Adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ParseResult.scala/#L163-L184 + /** Check if the input is incomplete. + * + * This can be used in order to check if a newline can be inserted without + * having to evaluate the expression. + */ + private def isComplete(sourceCode: String)(using Context): Boolean = + val reporter = Compiler.newStoreReporter() + val source = SourceFile.virtual("", sourceCode, maybeIncomplete = true) + val unit = CompilationUnit(source, mustExist = false) + val localCtx = ctx.fresh + .setCompilationUnit(unit) + .setReporter(reporter) + var needsMore = false + reporter.withIncompleteHandler((_, _) => needsMore = true) { + parseStats(using localCtx) + } + reporter.hasErrors || !needsMore + + private val BlockPat = """(?s)^\s*\{(.*)\}\s*$""".r + + def split( + code: String, + ignoreIncomplete: Boolean = true, + fileName: String = "(console)" + ): Option[Either[String, Seq[String]]] = + doSplit(code, ignoreIncomplete, fileName) + .map(_.map(_.map(_._2))) + + private def doSplit( + code: String, + ignoreIncomplete: Boolean, + fileName: String + ): Option[Either[String, Seq[(Int, String)]]] = { + val code0 = code match { + case BlockPat(wrapped) => wrapped + case _ => code + } + + given Context = initCtx + val reporter = Compiler.newStoreReporter() + val source = SourceFile.virtual("", code0, maybeIncomplete = true) + val unit = CompilationUnit(source, mustExist = false) + val localCtx = ctx.fresh + .setCompilationUnit(unit) + .setReporter(reporter) + var needsMore = false + val stats = reporter.withIncompleteHandler((_, _) => needsMore = true) { + parseStats(using localCtx) + } + + val nl = System.lineSeparator + def errors = reporter + .removeBufferedMessages + .map { e => + val maybeMsg = scala.util.Try { + Compiler.messageRenderer.messageAndPos( + e.msg, + e.pos, + Compiler.messageRenderer.diagnosticLevel(e) + ) + } + Compiler.messageRenderer.stripColor(maybeMsg.getOrElse("???")) + } + .mkString(nl) + + if (reporter.hasErrors) + Some(Left(s"$fileName$nl$errors")) + else if (needsMore) + None + else { + val startIndices = stats.toArray.map(_.startPos(using localCtx).point) + def startEndIndices = startIndices.iterator + .zip(startIndices.iterator.drop(1) ++ Iterator(code0.length)) + val stmts = startEndIndices.map { + case (start, end) => + code0.substring(start, end) + }.toVector + val statsAndStmts = stats.zip(stmts).zip(startIndices).iterator + + val stmts0 = new mutable.ListBuffer[(Int, String)] + var current = Option.empty[(untpd.Tree, String, Int)] + while (statsAndStmts.hasNext) { + val next = statsAndStmts.next() + val ((nextStat, nextStmt), nextIdx) = next + (current, nextStat) match { + case (Some((_: untpd.Import, stmt, idx)), _: untpd.Import) + if stmt.startsWith("import ") && !nextStmt.startsWith("import ") => + current = Some((nextStat, stmt + nextStmt, idx)) + case _ => + current.foreach { case (_, stmt, idx) => stmts0.+=((idx, stmt)) } + current = Some((nextStat, nextStmt, nextIdx)) + } + } + current.foreach { case (_, stmt, idx) => stmts0.+=((idx, stmt)) } + + Some(Right(stmts0.toList)) + } + } + + private def importExprs(i: untpd.Import): Seq[String] = { + def exprs(t: untpd.Tree): List[String] = + t match { + case untpd.Ident(name) => name.decode.toString :: Nil + case untpd.Select(qual, name) => name.decode.toString :: exprs(qual) + case _ => Nil // ??? + } + exprs(i.expr).reverse + } + + def importHooks(statement: String): (String, Seq[ImportTree]) = { + + given Context = initCtx + val reporter = Compiler.newStoreReporter() + val source = SourceFile.virtual("", statement, maybeIncomplete = true) + val unit = CompilationUnit(source, mustExist = false) + val localCtx = ctx.fresh + .setCompilationUnit(unit) + .setReporter(reporter) + var needsMore = false + val stats = reporter.withIncompleteHandler((_, _) => needsMore = true) { + parseStats(using localCtx) + } + + if (reporter.hasErrors || needsMore) + (statement, Nil) + else { + var updatedStatement = statement + var importTrees = Array.newBuilder[ImportTree] + stats.foreach { + case i: untpd.Import => + val exprs = importExprs(i) + if (exprs.headOption.exists(_.startsWith("$"))) { + val start = i.startPos.point + val end = { + var initialEnd = i.endPos.point + // kind of meh + // In statements like 'import $file.foo.{a, b}', endPos points at 'b' rather than '}', + // so we work around that here. + if (updatedStatement.iterator.drop(start).take(initialEnd - start).contains('{')) { + while (updatedStatement.length > initialEnd && + updatedStatement.charAt(initialEnd).isWhitespace) + initialEnd = initialEnd + 1 + if (updatedStatement.length > initialEnd && + updatedStatement.charAt(initialEnd) == '}') + initialEnd = initialEnd + 1 + } + initialEnd + } + val selectors = i.selectors.map { sel => + val from = sel.name.decode.toString + val to = sel.rename.decode.toString + from -> Some(to).filter(_ != from) + } + val updatedImport = updatedStatement.substring(start, end).takeWhile(_ != '.') + ".$" + updatedStatement = updatedStatement.patch( + start, + updatedImport + (" ") * (end - start - updatedImport.length), + end - start + ) + + val prefixLen = if (updatedStatement.startsWith("import ")) "import ".length else 0 + importTrees += ImportTree( + exprs, + Some(selectors).filter(_.nonEmpty), + start + prefixLen, end + ) + } + case _ => + } + (updatedStatement, importTrees.result) + } + } + + def parseImportHooksWithIndices( + source: CodeSource, + statements: Seq[(Int, String)] + ): (Seq[String], Seq[ImportTree]) = { + + val (updatedStatements, trees) = statements.map { + case (startIdx, stmt) => + val (hookedStmts, parsedTrees) = importHooks(stmt) + + val updatedParsedTrees = parsedTrees.map { importTree => + importTree.copy( + start = startIdx + importTree.start, + end = startIdx + importTree.end + ) + } + + (hookedStmts, updatedParsedTrees) + }.unzip + + (updatedStatements, trees.flatten) + } + + private val scriptSplitPattern = "(?m)^\\s*@[\\s\\n\\r]+".r + + def splitScript( + rawCode: String, + fileName: String + ): Either[String, IndexedSeq[(String, Seq[String])]] = + scriptBlocksWithStartIndices(rawCode, fileName) + .left.map(_.getMessage) + .map(_.map(b => (b.ncomment, b.codeWithStartIndices.map(_._2))).toVector) + + def scriptBlocksWithStartIndices( + rawCode: String, + fileName: String + ): Either[IParser.ScriptSplittingError, Seq[IParser.ScriptBlock]] = { + + val bounds = { + def allBounds = Iterator(0) ++ scriptSplitPattern.findAllMatchIn(rawCode).flatMap { m => + Iterator(m.start, m.end) + } ++ Iterator(rawCode.length) + allBounds + .grouped(2) + .map { case Seq(start, end) => (start, end) } + .toVector + } + + val blocks = bounds.zipWithIndex.map { + case ((start, end), idx) => + val blockCode = rawCode.substring(start, end) + doSplit(blockCode, false, fileName) match { + case None => + Right( + IParser.ScriptBlock( + start, + "", + Seq((start, blockCode)) + ) + ) + case Some(Left(err)) => Left(err) + case Some(Right(stmts)) => + Right( + IParser.ScriptBlock( + start, + blockCode.take(stmts.headOption.fold(0)(_._1)), + stmts.map { case (idx, stmt) => (idx + start, stmt) } + ) + ) + } + } + + val errors = blocks.collect { case Left(err) => err } + if (errors.isEmpty) + Right(blocks.collect { case Right(elem) => elem }) + else + Left(new IParser.ScriptSplittingError(errors.mkString(System.lineSeparator))) + } + + def defaultHighlight(buffer: Vector[Char], + comment: fansi.Attrs, + `type`: fansi.Attrs, + literal: fansi.Attrs, + keyword: fansi.Attrs, + reset: fansi.Attrs): Vector[Char] = { + val valDef = reset + val annotation = reset + new SyntaxHighlighting( + reset, + comment, + keyword, + valDef, + literal, + `type`, + annotation, + ).highlight(buffer.mkString)(using initCtx).toVector + } + + def isObjDef(code: String): Boolean = + false // TODO +} + +object Parsers extends Parsers diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/Preprocessor.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/Preprocessor.scala new file mode 100644 index 000000000..d869b5b96 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/Preprocessor.scala @@ -0,0 +1,311 @@ +package ammonite.compiler + +import java.util.function.{Function => JFunction} + +import ammonite.compiler.iface.{Compiler => _, Parser => _, Preprocessor => IPreprocessor, _} +import ammonite.util.{Imports, Name, Res} +import ammonite.util.Util.CodeSource +import pprint.Util + +import dotty.tools.dotc +import dotc.ast.untpd +import dotc.core.Contexts._ +import dotc.core.{Flags, Names} +import dotc.parsing.Parsers.Parser +import dotc.parsing.Tokens +import dotc.util.SourceFile + +class Preprocessor( + ctx: Context, + markGeneratedSections: Boolean +) extends IPreprocessor { + + // FIXME Quite some duplication with DefaultProcessor for Scala 2.x + + private case class Expanded(code: String, printer: Seq[String]) + + private def parse(source: String): Either[Seq[String], List[untpd.Tree]] = { + val reporter = Compiler.newStoreReporter() + val sourceFile = SourceFile.virtual("foo", source) + val parseCtx = ctx.fresh.setReporter(reporter).withSource(sourceFile) + val parser = new DottyParser(sourceFile)(using parseCtx) + val stats = parser.blockStatSeq() + parser.accept(Tokens.EOF) + if (reporter.hasErrors) { + val errorsStr = reporter + .allErrors + // .map(rendering.formatError) + .map(e => scala.util.Try(e.msg.toString).toOption.getOrElse("???")) + Left(errorsStr) + } else + Right(stats) + } + + def transform( + stmts: Seq[String], + resultIndex: String, + leadingSpaces: String, + codeSource: CodeSource, + indexedWrapper: Name, + imports: Imports, + printerTemplate: String => String, + extraCode: String, + skipEmpty: Boolean, + markScript: Boolean, + codeWrapper: CodeWrapper + ): Res[IPreprocessor.Output] = { + + // println(s"transformOrNull(${stmts.toSeq})") + + // All code Ammonite compiles must be rooted in some package within + // the `ammonite` top-level package + assert(codeSource.pkgName.head == Name("ammonite")) + + expandStatements(stmts, resultIndex, skipEmpty).map { + case Expanded(code, printer) => + val (wrappedCode, importsLength, userCodeNestingLevel) = wrapCode( + codeSource, indexedWrapper, leadingSpaces + code, + printerTemplate(printer.mkString(", ")), + imports, extraCode, markScript, codeWrapper + ) + IPreprocessor.Output(wrappedCode, importsLength, userCodeNestingLevel) + } + } + + private def expandStatements( + stmts: Seq[String], + wrapperIndex: String, + skipEmpty: Boolean + ): Res[Expanded] = + stmts match{ + // In the REPL, we do not process empty inputs at all, to avoid + // unnecessarily incrementing the command counter + // + // But in scripts, we process empty inputs and create an empty object, + // to ensure that when the time comes to cache/load the class it exists + case Nil if skipEmpty => Res.Skip + case postSplit => + Res(complete(stmts.mkString(""), wrapperIndex, postSplit)) + + } + + private def wrapCode( + codeSource: CodeSource, + indexedWrapperName: Name, + code: String, + printCode: String, + imports: Imports, + extraCode: String, + markScript: Boolean, + codeWrapper: CodeWrapper + ) = { + + //we need to normalize topWrapper and bottomWrapper in order to ensure + //the snippets always use the platform-specific newLine + val extraCode0 = + if (markScript) extraCode + "/**/" + else extraCode + val (topWrapper, bottomWrapper, userCodeNestingLevel) = + codeWrapper(code, codeSource, imports, printCode, indexedWrapperName, extraCode0) + val (topWrapper0, bottomWrapper0) = + if (markScript) (topWrapper + "/**/ /**/" + bottomWrapper) + else (topWrapper, bottomWrapper) + val importsLen = topWrapper0.length + + (topWrapper0 + code + bottomWrapper0, importsLen, userCodeNestingLevel) + } + + // Large parts of the logic below is adapted from DefaultProcessor, + // the Scala 2 counterpart of this file. + + private def isPrivate(tree: untpd.Tree): Boolean = + tree match { + case m: untpd.MemberDef => m.mods.is(Flags.Private) + case _ => false + } + + private def Processor(cond: PartialFunction[(String, String, untpd.Tree), Expanded]) = + (code: String, name: String, tree: untpd.Tree) => cond.lift(name, code, tree) + + private def pprintSignature(ident: String, customMsg: Option[String]): String = + val customCode = customMsg.fold("_root_.scala.None")(x => s"""_root_.scala.Some("$x")""") + s""" + _root_.ammonite + .repl + .ReplBridge + .value + .Internal + .print($ident, ${Util.literalize(ident)}, $customCode) + """ + private def definedStr(definitionLabel: String, name: String) = + s""" + _root_.ammonite + .repl + .ReplBridge + .value + .Internal + .printDef("$definitionLabel", ${Util.literalize(name)}) + """ + private def pprint(ident: String) = pprintSignature(ident, None) + + /** + * Processors for declarations which all have the same shape + */ + private def DefProc(definitionLabel: String)(cond: PartialFunction[untpd.Tree, Names.Name]) = + (code: String, name: String, tree: untpd.Tree) => + cond.lift(tree).map{ name => + val printer = + if (isPrivate(tree)) Nil + else Seq(definedStr(definitionLabel, Name.backtickWrap(name.decode.toString))) + Expanded( + code, + printer + ) + } + + private val ObjectDef = DefProc("object"){case m: untpd.ModuleDef => m.name} + private val ClassDef = DefProc("class"){ + case m: untpd.TypeDef if m.isClassDef && !m.mods.flags.is(Flags.Trait) => + m.name + } + private val TraitDef = DefProc("trait"){ + case m: untpd.TypeDef if m.isClassDef && m.mods.flags.is(Flags.Trait) => + m.name + } + private val DefDef = DefProc("function"){ case m: untpd.DefDef => m.name } + private val TypeDef = DefProc("type"){ case m: untpd.TypeDef => m.name } + + private val VarDef = Processor { case (name, code, t: untpd.ValDef) => + Expanded( + //Only wrap rhs in function if it is not a function + //Wrapping functions causes type inference errors. + code, + // Try to leave out all synthetics; we don't actually have proper + // synthetic flags right now, because we're dumb-parsing it and not putting + // it through a full compilation + if (isPrivate(t) || t.name.decode.toString.contains("$")) Nil + else if (!t.mods.flags.is(Flags.Lazy)) Seq(pprint(Name.backtickWrap(t.name.decode.toString))) + else Seq(pprintSignature(Name.backtickWrap(t.name.decode.toString), Some(""))) + ) + } + + private val PatDef = Processor { case (name, code, t: untpd.PatDef) => + val isLazy = t.mods.flags.is(Flags.Lazy) + val printers = + if (isPrivate(t)) Nil + else + t.pats + .flatMap { + case untpd.Tuple(trees) => trees + case elem => List(elem) + } + .flatMap { + case untpd.Ident(name) => + val decoded = name.decode.toString + if (decoded.contains("$")) Nil + else if (isLazy) Seq(pprintSignature(Name.backtickWrap(decoded), Some(""))) + else Seq(pprint(Name.backtickWrap(decoded))) + case _ => Nil // can this happen? + } + Expanded(code, printers) + } + + private val Import = Processor { + case (name, code, tree: untpd.Import) => + val Array(keyword, body) = code.split(" ", 2) + val tq = "\"\"\"" + Expanded(code, Seq( + s""" + _root_.ammonite + .repl + .ReplBridge + .value + .Internal + .printImport(${Util.literalize(body)}) + """ + )) + } + + private val Expr = Processor { + //Expressions are lifted to anon function applications so they will be JITed + case (name, code, tree) => + val expandedCode = + if (markGeneratedSections) + s"/**/val $name = /**/$code" + else + s"val $name = $code" + Expanded( + expandedCode, + if (isPrivate(tree)) Nil else Seq(pprint(name)) + ) + } + + private val decls = Seq[(String, String, untpd.Tree) => Option[Expanded]]( + ObjectDef, ClassDef, TraitDef, DefDef, TypeDef, VarDef, PatDef, Import, Expr + ) + + private def complete( + code: String, + resultIndex: String, + postSplit: Seq[String] + ): Either[String, Expanded] = { + val reParsed = postSplit.map(p => (parse(p), p)) + val errors = reParsed.collect{case (Left(e), _) => e }.flatten + if (errors.length != 0) Left(errors.mkString(System.lineSeparator())) + else { + val allDecls = for { + ((Right(trees), code), i) <- reParsed.zipWithIndex if trees.nonEmpty + } yield { + // Suffix the name of the result variable with the index of + // the tree if there is more than one statement in this command + val suffix = if (reParsed.length > 1) "_" + i else "" + def handleTree(t: untpd.Tree) = { + // println(s"handleTree($t)") + val it = decls.iterator.flatMap(_.apply(code, "res" + resultIndex + suffix, t)) + if (it.hasNext) + it.next() + else { + sys.error(s"Don't know how to handle ${t.getClass}: $t") + } + } + trees match { + case Seq(tree) => handleTree(tree) + + // This handles the multi-import case `import a.b, c.d` + case trees if trees.forall(_.isInstanceOf[untpd.Import]) => handleTree(trees(0)) + + // AFAIK this can only happen for pattern-matching multi-assignment, + // which for some reason parse into a list of statements. In such a + // scenario, aggregate all their printers, but only output the code once + case trees => + val printers = for { + tree <- trees + if tree.isInstanceOf[untpd.ValDef] + Expanded(_, printers) = handleTree(tree) + printer <- printers + } yield printer + + Expanded(code, printers) + } + } + + val expanded = allDecls match{ + case Seq(first, rest@_*) => + val allDeclsWithComments = Expanded(first.code, first.printer) +: rest + allDeclsWithComments.reduce { (a, b) => + Expanded( + // We do not need to separate the code with our own semi-colons + // or newlines, as each expanded code snippet itself comes with + // it's own trailing newline/semicolons as a result of the + // initial split + a.code + b.code, + a.printer ++ b.printer + ) + } + case Nil => Expanded("", Nil) + } + + Right(expanded) + } + } +} diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/SyntaxHighlighting.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/SyntaxHighlighting.scala new file mode 100644 index 000000000..4ee9ef053 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/SyntaxHighlighting.scala @@ -0,0 +1,129 @@ +package ammonite.compiler + +// Originally adapted from +// https://github.com/lampepfl/dotty/blob/3.0.0-M3/ +// compiler/src/dotty/tools/dotc/printing/SyntaxHighlighting.scala + +import dotty.tools.dotc +import dotc.CompilationUnit +import dotc.ast.untpd +import dotc.core.Contexts._ +import dotc.core.StdNames._ +import dotc.parsing.Parsers.Parser +import dotc.parsing.Scanners.Scanner +import dotc.parsing.Tokens._ +import dotc.reporting.Reporter +import dotc.util.Spans.Span +import dotc.util.SourceFile + +import java.util.Arrays + +/** This object provides functions for syntax highlighting in the REPL */ +class SyntaxHighlighting( + noAttrs: fansi.Attrs, + commentAttrs: fansi.Attrs, + keywordAttrs: fansi.Attrs, + valDefAttrs: fansi.Attrs, + literalAttrs: fansi.Attrs, + typeAttrs: fansi.Attrs, + annotationAttrs: fansi.Attrs, +) { + + def highlight(in: String)(using Context): String = { + def freshCtx = ctx.fresh.setReporter(Reporter.NoReporter) + if (in.isEmpty || ctx.settings.color.value == "never") in + else { + val source = SourceFile.virtual("", in) + + given Context = freshCtx + .setCompilationUnit(CompilationUnit(source, mustExist = false)(using freshCtx)) + + val colors = Array.fill(in.length)(0L) + + def highlightRange(from: Int, to: Int, attr: fansi.Attrs) = + Arrays.fill(colors, from, to, attr.applyMask) + + def highlightPosition(span: Span, attr: fansi.Attrs) = + if (span.exists && span.start >= 0 && span.end <= in.length) + highlightRange(span.start, span.end, attr) + + val scanner = new Scanner(source) + while (scanner.token != EOF) { + val start = scanner.offset + val token = scanner.token + val name = scanner.name + val isSoftModifier = scanner.isSoftModifierInModifierPosition + scanner.nextToken() + val end = scanner.lastOffset + + // Branch order is important. For example, + // `true` is at the same time a keyword and a literal + token match { + case _ if literalTokens.contains(token) => + highlightRange(start, end, literalAttrs) + + case STRINGPART => + // String interpolation parts include `$` but + // we don't highlight it, hence the `-1` + highlightRange(start, end - 1, literalAttrs) + + case _ if alphaKeywords.contains(token) || isSoftModifier => + highlightRange(start, end, keywordAttrs) + + case IDENTIFIER if name == nme.??? => + highlightRange(start, end, fansi.Color.Red) + + case _ => + } + } + + for (span <- scanner.commentSpans) + highlightPosition(span, commentAttrs) + + object TreeHighlighter extends untpd.UntypedTreeTraverser { + import untpd._ + + def ignored(tree: NameTree) = { + val name = tree.name.toTermName + // trees named and have weird positions + name == nme.ERROR || name == nme.CONSTRUCTOR + } + + def highlightAnnotations(tree: MemberDef): Unit = + for (annotation <- tree.mods.annotations) + highlightPosition(annotation.span, annotationAttrs) + + def highlight(trees: List[Tree])(using Context): Unit = + trees.foreach(traverse) + + def traverse(tree: Tree)(using Context): Unit = { + tree match { + case tree: NameTree if ignored(tree) => + () + case tree: ValOrDefDef => + highlightAnnotations(tree) + highlightPosition(tree.nameSpan, valDefAttrs) + case tree: MemberDef /* ModuleDef | TypeDef */ => + highlightAnnotations(tree) + highlightPosition(tree.nameSpan, typeAttrs) + case tree: Ident if tree.isType => + highlightPosition(tree.span, typeAttrs) + case _: TypTree => + highlightPosition(tree.span, typeAttrs) + case _ => + } + traverseChildren(tree) + } + } + + val parser = new DottyParser(source) + val trees = parser.blockStatSeq() + TreeHighlighter.highlight(trees) + + // if (colorAt.last != NoColor) + // highlighted.append(NoColor) + + fansi.Str.fromArrays(in.toCharArray, colors).render + } + } +} diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/tools/desugar.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/tools/desugar.scala new file mode 100644 index 000000000..8a27985b7 --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/tools/desugar.scala @@ -0,0 +1,3 @@ +package ammonite.compiler.tools + +object desugar diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/tools/source.scala b/amm/compiler/src/main/scala-3/ammonite/compiler/tools/source.scala new file mode 100644 index 000000000..b8e4bf9ca --- /dev/null +++ b/amm/compiler/src/main/scala-3/ammonite/compiler/tools/source.scala @@ -0,0 +1,9 @@ +package ammonite.compiler.tools + +import ammonite.util.Util.Location + +object source{ + + def load(f: => Any): Location = ??? + +} diff --git a/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/Completion.scala b/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/Completion.scala new file mode 100644 index 000000000..fb42b39b0 --- /dev/null +++ b/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/Completion.scala @@ -0,0 +1,507 @@ +package dotty.ammonite.compiler + +// Originally adapted from +// https://github.com/lampepfl/dotty/blob/3.0.0-M1/ +// compiler/src/dotty/tools/dotc/interactive/Completion.scala +// Then tweaked for deep completion, import $ivy completion, … + +import java.nio.charset.Charset + +import dotty.tools.dotc +import dotc.ast.Trees._ +import dotc.ast.untpd +import dotc.config.Printers.interactiv +import dotc.core.Contexts._ +import dotc.core.CheckRealizable +import dotc.core.Decorators._ +import dotc.core.Denotations.SingleDenotation +import dotc.core.Flags._ +import dotc.core.Names.{Name, TermName, termName} +import dotc.core.NameKinds.SimpleNameKind +import dotc.core.NameOps._ +import dotc.core.Symbols.{NoSymbol, Symbol, defn} +import dotc.core.Scopes +import dotc.core.StdNames.{nme, tpnme} +import dotc.core.TypeError +import dotc.core.Types.{NameFilter, NamedType, NoType, Type} +import dotc.interactive._ +import dotc.util.{NameTransformer, NoSourcePosition, SourcePosition} + +import scala.collection.mutable +import scala.internal.Chars.{isOperatorPart, isScalaLetter} + +/** + * One of the results of a completion query. + * + * @param label The label of this completion result, or the text that this completion result + * should insert in the scope where the completion request happened. + * @param description The description of this completion result: the fully qualified name for + * types, or the type for terms. + * @param symbols The symbols that are matched by this completion result. + */ +case class Completion(name: Name, description: String, symbols: List[Symbol]) { + def label: String = { + + // adapted from + // https://github.com/scala/scala/blob/decbd53f1bde4600c8ff860f30a79f028a8e431d/ + // src/reflect/scala/reflect/internal/Printers.scala#L573-L584 + val bslash = '\\' + val isDot = (x: Char) => x == '.' + val brackets = List('[',']','(',')','{','}') + + def quotedName(name: Name): String = { + val s = name.decode + val term = name.toTermName + if (nme.keywords(term) && term != nme.USCOREkw) s"`$s`" + else s.toString + } + + val decName = name.decode.toString + def addBackquotes(s: String) = { + val hasSpecialChar = decName.exists { ch => + brackets.contains(ch) || ch.isWhitespace || isDot(ch) + } + def isOperatorLike = (name.isOperatorName || decName.exists(isOperatorPart)) && + decName.exists(isScalaLetter) && + !decName.contains(bslash) + if (hasSpecialChar || isOperatorLike) s"`$s`" + else s + } + + if (name == nme.CONSTRUCTOR) "this" + else addBackquotes(quotedName(name)) + } +} + +object Completion { + + import dotc.ast.tpd._ + + /** Get possible completions from tree at `pos` + * + * @return offset and list of symbols for possible completions + */ + def completions( + pos: SourcePosition, + dependencyCompleteOpt: Option[String => (Int, Seq[String])] + )(using Context): (Int, List[Completion]) = { + val path = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span) + computeCompletions(pos, path, dependencyCompleteOpt)(using Interactive.contextOfPath(path)) + } + + /** + * Inspect `path` to determine what kinds of symbols should be considered. + * + * If the path starts with: + * - a `RefTree`, then accept symbols of the same kind as its name; + * - a renaming import, and the cursor is on the renamee, accept both terms and types; + * - an import, accept both terms and types; + * + * Otherwise, provide no completion suggestion. + */ + private def completionMode(path: List[Tree], pos: SourcePosition): Mode = + path match { + case (ref: RefTree) :: _ => + if (ref.name.isTermName) Mode.Term + else if (ref.name.isTypeName) Mode.Type + else Mode.None + + case (sel: untpd.ImportSelector) :: _ => + if sel.imported.span.contains(pos.span) then Mode.Import + else Mode.None // Can't help completing the renaming + + case Import(_, _) :: _ => + Mode.Import + + case _ => + Mode.None + } + + /** + * Inspect `path` to determine the completion prefix. Only symbols whose name start with the + * returned prefix should be considered. + */ + private def completionPrefix(path: List[untpd.Tree], pos: SourcePosition): String = + path match { + case (sel: untpd.ImportSelector) :: _ => + completionPrefix(sel.imported :: Nil, pos) + + case Import(expr, selectors) :: _ => + selectors.find(_.span.contains(pos.span)).map { selector => + completionPrefix(selector :: Nil, pos) + }.getOrElse("") + + case (ref: untpd.RefTree) :: _ => + if (ref.name == nme.ERROR) "" + else ref.name.toString.take(pos.span.point - ref.span.point) + + case _ => + "" + } + + /** Inspect `path` to determine the offset where the completion result should be inserted. */ + private def completionOffset(path: List[Tree]): Int = + path match { + case (ref: RefTree) :: _ => ref.span.point + case _ => 0 + } + + /** Create a new `CompletionBuffer` for completing at `pos`. */ + private def completionBuffer(path: List[Tree], pos: SourcePosition): CompletionBuffer = { + val mode = completionMode(path, pos) + val prefix = completionPrefix(path, pos) + new CompletionBuffer(mode, prefix, pos) + } + + private def computeCompletions( + pos: SourcePosition, + path: List[Tree], + dependencyCompleteOpt: Option[String => (Int, Seq[String])] + )(using Context): (Int, List[Completion]) = { + + val offset = completionOffset(path) + val buffer = completionBuffer(path, pos) + + var extra = List.empty[Completion] + + if (buffer.mode != Mode.None) + path match { + case Select(qual, _) :: _ => buffer.addMemberCompletions(qual) + case Import(Ident(name), _) :: _ if name.decode.toString == "$ivy" => + dependencyCompleteOpt match { + case None => 0 -> Seq.empty[(String, Option[String])] + case Some(complete) => + val input = buffer.prefix + val (pos, completions) = complete(input) + val input0 = input.take(pos) + extra = completions.distinct.toList + .map(s => Completion(termName(input0 + s), "", Nil)) + } + case Import(expr, _) :: _ => buffer.addMemberCompletions(expr) + // (Dotty comment) TODO: distinguish given from plain imports + case (_: untpd.ImportSelector) :: Import(expr, _) :: _ => buffer.addMemberCompletions(expr) + case _ => + buffer.addScopeCompletions + // Too slow for now + // if (buffer.getCompletions.isEmpty) + // buffer.addDeepCompletions + } + + val completionList = extra ++ buffer.getCompletions + + interactiv.println(i"""completion with pos = $pos, + | offset = ${offset}, + | prefix = ${buffer.prefix}, + | term = ${buffer.mode.is(Mode.Term)}, + | type = ${buffer.mode.is(Mode.Type)} + | results = $completionList%, %""") + (pos.span.start - buffer.prefix.length, completionList) + } + + private class CompletionBuffer(val mode: Mode, val prefix: String, pos: SourcePosition) { + + private val completions = new RenameAwareScope + + /** + * Return the list of symbols that should be included in completion results. + * + * If several symbols share the same name, the type symbols appear before term symbols inside + * the same `Completion`. + */ + def getCompletions(using Context): List[Completion] = { + val nameToSymbols = completions.mappings.toList + nameToSymbols.map { case (name, symbols) => + val typesFirst = symbols.sortWith((s1, s2) => s1.isType && !s2.isType) + val desc = description(typesFirst) + // kind of meh, not sure how to make that more reliable in Scala 3 + Completion(name, desc, typesFirst) + } + } + + /** + * A description for completion result that represents `symbols`. + * + * If `symbols` contains a single symbol, show its full name in case it's a type, or its type if + * it's a term. + * + * When there are multiple symbols, show their kinds. + */ + private def description(symbols: List[Symbol])(using Context): String = + symbols match { + case sym :: Nil => + if (sym.isType) sym.showFullName + else sym.info.widenTermRefExpr.show + + case sym :: _ => + symbols.map(ctx.printer.kindString).mkString("", " and ", s" ${sym.name.show}") + + case Nil => + "" + } + + /** + * Add symbols that are currently in scope to `info`: the members of the current class and the + * symbols that have been imported. + */ + def addScopeCompletions(using Context): Unit = { + if (ctx.owner.isClass) { + addAccessibleMembers(ctx.owner.thisType) + ctx.owner.asClass.classInfo.selfInfo match { + case selfSym: Symbol => add(selfSym, selfSym.name) + case _ => + } + } + else if (ctx.scope != null) ctx.scope.foreach(s => add(s, s.name)) + + addImportCompletions + + var outer = ctx.outer + while ((outer.owner `eq` ctx.owner) && (outer.scope `eq` ctx.scope)) { + addImportCompletions(using outer) + outer = outer.outer + } + if (outer `ne` NoContext) addScopeCompletions(using outer) + } + + /** + * Find all the members of `qual` and add the ones that pass the include filters to `info`. + * + * If `info.mode` is `Import`, the members added via implicit conversion on `qual` are not + * considered. + */ + def addMemberCompletions(qual: Tree)(using Context): Unit = + if (!qual.tpe.widenDealias.isNothing) { + addAccessibleMembers(qual.tpe) + if (!mode.is(Mode.Import) && !qual.tpe.isNullType) + // Implicit conversions do not kick in when importing + // and for `NullClass` they produce unapplicable completions (for unclear reasons) + implicitConversionTargets(qual)(using ctx.fresh.setExploreTyperState()) + .foreach(addAccessibleMembers) + } + + /** + * If `sym` exists, no symbol with the same name is already included, and it satisfies the + * inclusion filter, then add it to the completions. + */ + private def add(sym: Symbol, nameInScope: Name, deep: Boolean = false)(using Context) = + if (sym.exists && + (deep || completionsFilter(NoType, nameInScope)) && + !completions.lookup(nameInScope).exists && + include(sym, nameInScope, deep)) + completions.enter(sym, nameInScope) + + /** Lookup members `name` from `site`, and try to add them to the completion list. */ + private def addMember(site: Type, name: Name, nameInScope: Name)(using Context) = + if (!completions.lookup(nameInScope).exists) + for (alt <- site.member(name).alternatives) add(alt.symbol, nameInScope) + + /** Include in completion sets only symbols that + * 1. start with given name prefix, and + * 2. is not absent (info is not NoType) + * 3. are not a primary constructor, + * 4. have an existing source symbol, + * 5. are the module class in case of packages, + * 6. are mutable accessors, to exclude setters for `var`, + * 7. symbol is not a package object + * 8. symbol is not an artifact of the compiler + * 9. have same term/type kind as name prefix given so far + */ + private def include( + sym: Symbol, + nameInScope: Name, + deep: Boolean = false + )(using Context): Boolean = + (deep || nameInScope.startsWith(prefix)) && + !sym.isAbsent() && + !sym.isPrimaryConstructor && + sym.sourceSymbol.exists && + (!sym.is(Package) || sym.is(ModuleClass)) && + !sym.isAllOf(Mutable | Accessor) && + !sym.isPackageObject && + !sym.is(Artifact) && + ( + (mode.is(Mode.Term) && sym.isTerm) + || (mode.is(Mode.Type) && (sym.isType || sym.isStableMember)) + ) + + /** + * Find all the members of `site` that are accessible and which should be included in `info`. + * + * @param site The type to inspect. + * @return The members of `site` that are accessible and pass the include filter of `info`. + */ + private def accessibleMembers(site: Type)(using Context): Seq[Symbol] = site match { + case site: NamedType if site.symbol.is(Package) => + extension (tpe: Type) + def accessibleSymbols = tpe + .decls + .toList + .filter(sym => sym.isAccessibleFrom(site, superAccess = false)) + + val packageDecls = site.accessibleSymbols + val packageObjectsDecls = packageDecls + .filter(_.isPackageObject) + .flatMap(_.thisType.accessibleSymbols) + packageDecls ++ packageObjectsDecls + case _ => + def appendMemberSyms(name: Name, buf: mutable.Buffer[SingleDenotation]): Unit = + try buf ++= site.member(name).alternatives + catch { case ex: TypeError => } + site.memberDenots(completionsFilter, appendMemberSyms).collect { + case mbr if include(mbr.symbol, mbr.symbol.name) => + mbr.accessibleFrom(site, superAccess = true).symbol + case _ => NoSymbol + }.filter(_.exists) + } + + /** Add all the accessible members of `site` in `info`. */ + private def addAccessibleMembers(site: Type)(using Context): Unit = + for (mbr <- accessibleMembers(site)) addMember(site, mbr.name, mbr.name) + + /** + * Add in `info` the symbols that are imported by `ctx.importInfo`. If this is a wildcard + * import, all the accessible members of the import's `site` are included. + */ + private def addImportCompletions(using Context): Unit = { + val imp = ctx.importInfo + if (imp != null) { + def addImport(name: TermName, nameInScope: TermName) = { + addMember(imp.site, name, nameInScope) + addMember(imp.site, name.toTypeName, nameInScope.toTypeName) + } + imp.reverseMapping.foreachBinding { (nameInScope, original) => + if (original != nameInScope || !imp.excluded.contains(original)) + addImport(original, nameInScope) + } + if (imp.isWildcardImport) + for (mbr <- accessibleMembers(imp.site) if !imp.excluded.contains(mbr.name.toTermName)) + addMember(imp.site, mbr.name, mbr.name) + } + } + private def blacklisted(s: Symbol)(using Context) = { + val blacklist = Set( + "scala.Predef.any2stringadd.+", + "scala.Any.##", + "java.lang.Object.##", + "scala.", + "scala.", + "scala.", + "scala.", + "scala.Predef.StringFormat.formatted", + "scala.Predef.Ensuring.ensuring", + "scala.Predef.ArrowAssoc.->", + "scala.Predef.ArrowAssoc.→", + "java.lang.Object.synchronized", + "java.lang.Object.ne", + "java.lang.Object.eq", + "java.lang.Object.wait", + "java.lang.Object.notifyAll", + "java.lang.Object.notify", + "java.lang.Object.clone", + "java.lang.Object.finalize" + ) + + blacklist(s.showFullName) || + s.isOneOf(GivenOrImplicit) || + // Cache objects, which you should probably never need to + // access directly, and apart from that have annoyingly long names + "cache[a-f0-9]{32}".r.findPrefixMatchOf(s.name.decode.toString).isDefined || + // s.isDeprecated || + s.name.decode.toString == "" || + s.name.decode.toString.contains('$') + } + def addDeepCompletions(using Context): Unit = { + + val blacklistedPackages = Set("shaded") + + def allMembers(s: Symbol) = + try s.info.allMembers + catch { + case _: dotc.core.TypeError => Nil + } + def rec(t: Symbol): Seq[Symbol] = { + if (blacklistedPackages(t.name.toString)) + Nil + else { + val fullName = t.fullName.toString + val children = + if (t.is(Package) || t.is(PackageVal) || t.is(PackageClass)) { + allMembers(t).map(_.symbol).filter(!blacklisted(_)).filter(_ != t).flatMap(rec) + } else Nil + + t +: children.toSeq + } + } + + for { + member <- allMembers(defn.RootClass).map(_.symbol).filter(!blacklisted(_)).toList + sym <- rec(member) + if sym.name.toString.startsWith(prefix) + } add(sym, sym.fullName, deep = true) + } + + /** + * Given `qual` of type T, finds all the types S such that there exists an implicit conversion + * from T to S. + * + * @param qual The argument to which the implicit conversion should be applied. + * @return The set of types that `qual` can be converted to. + */ + private def implicitConversionTargets(qual: Tree)(using Context): Set[Type] = { + val typer = ctx.typer + val conversions = new typer.ImplicitSearch(defn.AnyType, qual, pos.span).allImplicits + val targets = conversions.map(_.widen.finalResultType) + interactiv.println(i"implicit conversion targets considered: ${targets.toList}%, %") + targets + } + + /** Filter for names that should appear when looking for completions. */ + private object completionsFilter extends NameFilter { + def apply(pre: Type, name: Name)(using Context): Boolean = + !name.isConstructorName && name.toTermName.info.kind == SimpleNameKind + def isStable = true + } + } + + /** + * The completion mode: defines what kinds of symbols should be included in the completion + * results. + */ + private class Mode(val bits: Int) extends AnyVal { + def is(other: Mode): Boolean = (bits & other.bits) == other.bits + def |(other: Mode): Mode = new Mode(bits | other.bits) + } + private object Mode { + /** No symbol should be included */ + val None: Mode = new Mode(0) + + /** Term symbols are allowed */ + val Term: Mode = new Mode(1) + + /** Type and stable term symbols are allowed */ + val Type: Mode = new Mode(2) + + /** Both term and type symbols are allowed */ + val Import: Mode = new Mode(4) | Term | Type + } + + /** A scope that tracks renames of the entered symbols. + * Useful for providing completions for renamed symbols + * in the REPL and the IDE. + */ + private class RenameAwareScope extends Scopes.MutableScope { + private val nameToSymbols: mutable.Map[TermName, List[Symbol]] = mutable.Map.empty + + /** Enter the symbol `sym` in this scope, recording a potential renaming. */ + def enter[T <: Symbol](sym: T, name: Name)(using Context): T = { + val termName = name.stripModuleClassSuffix.toTermName + nameToSymbols += termName -> (sym :: nameToSymbols.getOrElse(termName, Nil)) + newScopeEntry(name, sym) + sym + } + + /** Get the names that are known in this scope, along with the list of symbols they refer to. */ + def mappings: Map[TermName, List[Symbol]] = nameToSymbols.toMap + } +} + diff --git a/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/DirectoryClassPath.scala b/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/DirectoryClassPath.scala new file mode 100644 index 000000000..f7029fa6b --- /dev/null +++ b/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/DirectoryClassPath.scala @@ -0,0 +1,32 @@ +package dotty.ammonite.compiler + +import java.io.{File => JFile} +import dotty.tools.dotc.classpath +import dotty.tools.io.{AbstractFile, PlainFile, ClassPath, ClassRepresentation, EfficientClassPath} +import classpath.FileUtils._ + +case class DirectoryClassPath(dir: JFile) + extends classpath.JFileDirectoryLookup[classpath.ClassFileEntryImpl] + with classpath.NoSourcePaths { + override def findClass(className: String): Option[ClassRepresentation] = + findClassFile(className) map classpath.ClassFileEntryImpl + + def findClassFile(className: String): Option[AbstractFile] = { + val relativePath = classpath.FileUtils.dirPath(className) + val classFile = new JFile(dir, relativePath + ".class") + if (classFile.exists) { + val wrappedClassFile = new dotty.tools.io.File(classFile.toPath) + val abstractClassFile = new PlainFile(wrappedClassFile) + Some(abstractClassFile) + } + else None + } + + protected def createFileEntry(file: AbstractFile): classpath.ClassFileEntryImpl = + classpath.ClassFileEntryImpl(file) + protected def isMatchingFile(f: JFile): Boolean = + f.isClass + + private[dotty] def classes(inPackage: classpath.PackageName): Seq[classpath.ClassFileEntry] = + files(inPackage) +} diff --git a/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/WhiteListClassPath.scala b/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/WhiteListClassPath.scala new file mode 100644 index 000000000..0be6c02d5 --- /dev/null +++ b/amm/compiler/src/main/scala-3/dotty/ammonite/compiler/WhiteListClassPath.scala @@ -0,0 +1,52 @@ +package dotty.ammonite.compiler + +import ammonite.util.Util + +import dotty.tools.dotc.classpath.{ClassPathEntries, PackageName} +import dotty.tools.io.ClassPath + +class WhiteListClasspath(aggregates: Seq[ClassPath], whitelist: Set[Seq[String]]) + extends dotty.tools.dotc.classpath.AggregateClassPath(aggregates) { + override def findClassFile(name: String) = { + val tokens = name.split('.') + if (Util.lookupWhiteList(whitelist, tokens.init ++ Seq(tokens.last + ".class"))) { + super.findClassFile(name) + } + else None + } + /*override def findClass(name: String) = { + val tokens = name.split('.') + if (Util.lookupWhiteList(whitelist, tokens.init ++ Seq(tokens.last + ".class"))) { + super.findClass(name) + } + else None + } + override def packages(inPackage: PackageName) = + super.packages(inPackage) + + override def classes(inPackage: PackageName) = + super.classes(inPackage).filter { t => + Util.lookupWhiteList(whitelist, inPackage.dottedString.split('.') ++ Seq(t.name + ".class")) + } + + override def sources(inPackage: PackageName) = + super.sources(inPackage).filter { t => + Util.lookupWhiteList(whitelist, inPackage.dottedString.split('.') ++ Seq(t.name + ".class")) + } + + override def hasPackage(pkg: PackageName): Boolean = + super.hasPackage(pkg)*/ + + override def list(inPackage: PackageName) = { + val superList = super.list(inPackage) + ClassPathEntries( + superList.packages.filter{ p => Util.lookupWhiteList(whitelist, p.name.split('.')) }, + superList.classesAndSources.filter{ t => + Util.lookupWhiteList(whitelist, inPackage.dottedString.split('.') ++ Seq(t.name + ".class")) + } + ) + } + + override def toString: String = + s"WhiteListClasspath($aggregates, ${whitelist.size} white-listed elements)" +} \ No newline at end of file diff --git a/build.sc b/build.sc index 70809c978..78b878e68 100644 --- a/build.sc +++ b/build.sc @@ -20,15 +20,37 @@ val commitsSinceTaggedVersion = { .toInt } +// When this version is used as cross scala version, either +// cross2_3Version or actual3Version is picked as actual Scala version. +// Modules picking one or the other can depend on each other, thanks to +// the dotty compatibility. +// Beware that this requires both versions to have the same tasty format +// version. For example, 2.13.4 and 3.0.0-M1 do, while 2.13.4 and 3.0.0-M{2,3} +// don't. +val special3Version = "3" +val actual3Version = "3.0.0-M1" +val cross2_3Version = "2.13.4" + + +// Same as https://github.com/lihaoyi/mill/blob/0.9.3/scalalib/src/Dep.scala/#L55, +// except we also fix the suffix when scalaVersion is like 2.13.x, +// so that downstream Scala 3 project still pick the _2.13 dependency. +def withDottyCompat(dep: Dep, scalaVersion: String): Dep = + dep.cross match { + case cross: CrossVersion.Binary if scalaVersion.startsWith("3.") || scalaVersion.startsWith("2.13.") => + val compatSuffix = "_2.13" + dep.copy(cross = CrossVersion.Constant(value = compatSuffix, platformed = dep.cross.platformed)) + case _ => dep + } -val binCrossScalaVersions = Seq("2.12.13", "2.13.4") +val binCrossScalaVersions = Seq("2.12.13", "2.13.4", cross2_3Version, special3Version).distinct def isScala2_12_10OrLater(sv: String): Boolean = { (sv.startsWith("2.12.") && sv.stripPrefix("2.12.").length > 1) || (sv.startsWith("2.13.") && sv != "2.13.0") } val fullCrossScalaVersions = Seq( "2.12.1", "2.12.2", "2.12.3", "2.12.4", "2.12.6", "2.12.7", "2.12.8", "2.12.9", "2.12.10", "2.12.11", "2.12.12", "2.12.13", - "2.13.0", "2.13.1", "2.13.2", "2.13.3", "2.13.4" -) + "2.13.0", "2.13.1", "2.13.2", "2.13.3", "2.13.4", cross2_3Version, special3Version +).distinct val latestAssemblies = binCrossScalaVersions.map(amm(_).assembly) @@ -48,14 +70,78 @@ val (buildVersion, unstable) = scala.util.Try( val bspVersion = "2.0.0-M6" -trait AmmInternalModule extends mill.scalalib.CrossSbtModule{ +// Adapted from https://github.com/lihaoyi/mill/blob/0.9.3/scalalib/src/MiscModule.scala/#L80-L100 +// Compared to the original code, we added the custom Resolver, +// and ensure `scalaVersion()` rather than `crossScalaVersion` is used +// when computing paths, as the former is always a valid Scala version, +// while the latter can be a sentinel, like "3" (see special3Version above). +trait CrossSbtModule extends mill.scalalib.SbtModule with mill.scalalib.CrossModuleBase { outer => + + // Adapted from https://github.com/lihaoyi/mill/blob/0.9.3/scalalib/src/MiscModule.scala/#L20-L36 + // Compared to it, we just accept crossScalaVersion with no '.' in it, + // like "3" (see special3Version above). + import mill.define.Cross.Resolver + override implicit def crossSbtModuleResolver: Resolver[CrossModuleBase] = new Resolver[CrossModuleBase]{ + def resolve[V <: CrossModuleBase](c: Cross[V]): V = { + crossScalaVersion.split('.') + .inits + .takeWhile(_.length >= 1) + .flatMap( prefix => + c.items.map(_._2).find(_.crossScalaVersion.split('.').startsWith(prefix)) + ) + .collectFirst{case x => x} + .getOrElse( + throw new Exception( + s"Unable to find compatible cross version between $crossScalaVersion and "+ + c.items.map(_._2.crossScalaVersion).mkString(",") + ) + ) + } + } + + override def sources = T.sources { + super.sources() ++ + mill.scalalib.CrossModuleBase.scalaVersionPaths( + scalaVersion(), + s => millSourcePath / 'src / 'main / s"scala-$s" + ) + + } + trait Tests extends super.Tests { + override def millSourcePath = outer.millSourcePath + override def sources = T.sources { + super.sources() ++ + mill.scalalib.CrossModuleBase.scalaVersionPaths( + scalaVersion(), + s => millSourcePath / 'src / 'test / s"scala-$s" + ) + } + } +} + +trait AmmInternalModule extends CrossSbtModule{ def artifactName = "ammonite-" + millOuterCtx.segments.parts.mkString("-").stripPrefix("amm-") def testFramework = "utest.runner.Framework" - def scalacOptions = Seq("-P:acyclic:force") - def compileIvyDeps = Agg(ivy"com.lihaoyi::acyclic:0.2.0") - def scalacPluginIvyDeps = Agg(ivy"com.lihaoyi::acyclic:0.2.0") + def isScala2 = T { scalaVersion().startsWith("2.") } + def scalacOptions = T { + val acyclicOptions = + if (isScala2()) Seq("-P:acyclic:force") + else Nil + val tastyReaderOptions = + if (scalaVersion() == cross2_3Version) Seq("-Ytasty-reader") + else Nil + acyclicOptions ++ tastyReaderOptions + } + def compileIvyDeps = T { + if (isScala2()) Agg(ivy"com.lihaoyi::acyclic:0.2.0") + else Agg[Dep]() + } + def scalacPluginIvyDeps = T { + if (isScala2()) Agg(ivy"com.lihaoyi::acyclic:0.2.0") + else Agg[Dep]() + } trait Tests extends super.Tests{ - def ivyDeps = Agg(ivy"com.lihaoyi::utest:0.7.3") + def ivyDeps = Agg(withDottyCompat(ivy"com.lihaoyi::utest:0.7.3", scalaVersion())) def testFrameworks = Seq("utest.runner.Framework") def forkArgs = Seq("-Xmx2g", "-Dfile.encoding=UTF8") } @@ -72,23 +158,40 @@ trait AmmInternalModule extends mill.scalalib.CrossSbtModule{ } else Nil - val extraDir2 = PathRef( - if (isScala2_12_10OrLater(sv)) millSourcePath / "src" / "main" / "scala-2.12.10-2.13.1+" - else millSourcePath / "src" / "main" / "scala-not-2.12.10-2.13.1+" - ) + val extraDir2 = + if (isScala2()) + Seq(PathRef( + if (isScala2_12_10OrLater(sv)) millSourcePath / "src" / "main" / "scala-2.12.10-2.13.1+" + else millSourcePath / "src" / "main" / "scala-not-2.12.10-2.13.1+" + )) + else Nil val extraDir3 = - if (sv.startsWith("2.13.") && sv != "2.13.0") - PathRef(millSourcePath / "src" / "main" / "scala-2.13.1+") - else if (sv.startsWith("2.12.") && sv.stripPrefix("2.12.").toInt >= 13) - PathRef(millSourcePath / "src" / "main" / "scala-2.12.13+") - else - PathRef(millSourcePath / "src" / "main" / "scala-not-2.12.13+-2.13.1+") - - super.sources() ++ extraDir ++ Seq(extraDir2, extraDir3) + if (isScala2()) { + val dir = + if (sv.startsWith("2.13.") && sv != "2.13.0") + millSourcePath / "src" / "main" / "scala-2.13.1+" + else if (sv.startsWith("2.12.") && sv.stripPrefix("2.12.").toInt >= 13) + millSourcePath / "src" / "main" / "scala-2.12.13+" + else + millSourcePath / "src" / "main" / "scala-not-2.12.13+-2.13.1+" + Seq(PathRef(dir)) + } else Nil + val extraDir4 = + if (sv.startsWith("2.13.") || sv.startsWith("3.")) + Seq(PathRef(millSourcePath / "src" / "main" / "scala-2.13-or-3")) + else Nil + + super.sources() ++ extraDir ++ extraDir2 ++ extraDir3 ++ extraDir4 } def externalSources = T{ resolveDeps(allIvyDeps, sources = true)() } + def supports3: Boolean = false + def scalaVersion = T{ + if (crossScalaVersion == special3Version) + (if (supports3) actual3Version else cross2_3Version) + else crossScalaVersion + } } trait AmmModule extends AmmInternalModule with PublishModule{ def publishVersion = buildVersion @@ -142,8 +245,8 @@ trait AmmDependenciesResourceFileModule extends JavaModule{ object ops extends Cross[OpsModule](binCrossScalaVersions:_*) class OpsModule(val crossScalaVersion: String) extends AmmModule{ def ivyDeps = Agg( - ivy"com.lihaoyi::os-lib:0.7.1", - ivy"org.scala-lang.modules::scala-collection-compat:2.1.2" + withDottyCompat(ivy"com.lihaoyi::os-lib:0.7.1", scalaVersion()), + withDottyCompat(ivy"org.scala-lang.modules::scala-collection-compat:2.3.1", scalaVersion()) ) def scalacOptions = super.scalacOptions().filter(!_.contains("acyclic")) object test extends Tests @@ -156,7 +259,7 @@ class TerminalModule(val crossScalaVersion: String) extends AmmModule{ ivy"com.lihaoyi::fansi:0.2.8" ) def compileIvyDeps = Agg( - ivy"org.scala-lang:scala-reflect:$crossScalaVersion" + ivy"org.scala-lang:scala-reflect:${scalaVersion()}" ) object test extends Tests @@ -167,11 +270,12 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ class UtilModule(val crossScalaVersion: String) extends AmmModule{ def moduleDeps = Seq(ops()) def ivyDeps = Agg( - ivy"com.lihaoyi::pprint:0.6.0", - ivy"com.lihaoyi::fansi:0.2.9", + withDottyCompat(ivy"com.lihaoyi::pprint:0.6.0", scalaVersion()), + withDottyCompat(ivy"com.lihaoyi::fansi:0.2.9", scalaVersion()), + withDottyCompat(ivy"org.scala-lang.modules::scala-collection-compat:2.3.1", scalaVersion()) ) def compileIvyDeps = Agg( - ivy"org.scala-lang:scala-reflect:$crossScalaVersion" + ivy"org.scala-lang:scala-reflect:${scalaVersion()}" ) } @@ -182,7 +286,7 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ def ivyDeps = Agg( ivy"com.lihaoyi::upickle:1.2.0", ivy"com.lihaoyi::requests:0.6.5", - ivy"com.lihaoyi::mainargs:0.1.4", + withDottyCompat(ivy"com.lihaoyi::mainargs:0.1.4", scalaVersion()) ) } @@ -200,16 +304,33 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ } } class CompilerModule(val crossScalaVersion: String) extends AmmModule{ - def moduleDeps = Seq(amm.compiler.interface(), amm.util(), amm.repl.api()) + def supports3 = true + def moduleDeps = + if (crossScalaVersion == special3Version) + Seq( + amm.compiler.interface(cross2_3Version), + amm.util(cross2_3Version), + amm.repl.api(cross2_3Version) + ) + else + Seq(amm.compiler.interface(), amm.util(), amm.repl.api()) def crossFullScalaVersion = true def ivyDeps = T { - Agg( - ivy"org.scala-lang:scala-compiler:${scalaVersion()}", - ivy"com.lihaoyi::scalaparse:2.3.0", - ivy"org.scala-lang.modules::scala-xml:2.0.0-M3", - ivy"org.javassist:javassist:3.21.0-GA", - ivy"com.github.javaparser:javaparser-core:3.2.5" - ) + val scalaSpecificDeps = + if (isScala2()) + Agg( + ivy"org.scala-lang:scala-compiler:${scalaVersion()}", + ivy"com.lihaoyi::scalaparse:2.3.0", + ivy"org.scala-lang.modules::scala-xml:2.0.0-M3" + ) + else + Agg[Dep]( + ivy"org.scala-lang:scala3-compiler_${scalaVersion()}:${scalaVersion()}" + ) + scalaSpecificDeps ++ Agg( + ivy"org.javassist:javassist:3.21.0-GA", + ivy"com.github.javaparser:javaparser-core:3.2.5" + ) } def exposedClassPath = T{ @@ -227,7 +348,7 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ def crossFullScalaVersion = true def dependencyResourceFileName = "amm-interp-api-dependencies.txt" def ivyDeps = Agg( - ivy"org.scala-lang:scala-reflect:$crossScalaVersion", + ivy"org.scala-lang:scala-reflect:${scalaVersion()}", ivy"io.get-coursier:interface:0.0.21" ) } @@ -238,7 +359,7 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ def ivyDeps = Agg( ivy"ch.epfl.scala:bsp4j:$bspVersion", ivy"org.scalameta::trees:4.4.6", - ivy"org.scala-lang:scala-reflect:$crossScalaVersion", + ivy"org.scala-lang:scala-reflect:${scalaVersion()}", ivy"org.scala-lang.modules::scala-xml:1.2.0" ) } @@ -261,7 +382,8 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ interp.api() ) def ivyDeps = Agg( - ivy"com.lihaoyi::mainargs:0.1.4" + withDottyCompat(ivy"com.lihaoyi::mainargs:0.1.4", scalaVersion()), + withDottyCompat(ivy"com.lihaoyi::pprint:0.6.0", scalaVersion()) ) def generatedSources = T{ @@ -294,7 +416,7 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ object test extends Tests with AmmDependenciesResourceFileModule { def crossScalaVersion = ReplModule.this.crossScalaVersion - def scalaVersion = ReplModule.this.crossScalaVersion + def scalaVersion = ReplModule.this.scalaVersion def dependencyResourceFileName = "amm-test-dependencies.txt" def moduleDeps = super.moduleDeps ++ Seq(amm.compiler()) @@ -318,7 +440,7 @@ object amm extends Cross[MainModule](fullCrossScalaVersions:_*){ resolveDeps(ivyDeps, sources = true)()).distinct } def ivyDeps = super.ivyDeps() ++ amm.compiler().ivyDeps() ++ Agg( - ivy"org.scalaz::scalaz-core:7.2.27" + withDottyCompat(ivy"org.scalaz::scalaz-core:7.2.27", scalaVersion()) ) } } @@ -501,7 +623,7 @@ object integration extends Cross[IntegrationModule](fullCrossScalaVersions:_*) class IntegrationModule(val crossScalaVersion: String) extends AmmInternalModule{ def moduleDeps = Seq(ops(), amm()) def ivyDeps = T{ - if (crossScalaVersion.startsWith("2.13.")) + if (scalaVersion().startsWith("2.13.")) Agg(ivy"com.lihaoyi::cask:0.6.0") else Agg.empty @@ -528,7 +650,7 @@ class SshdModule(val crossScalaVersion: String) extends AmmModule{ // slf4j-nop makes sshd server use logger that writes into the void ivy"org.slf4j:slf4j-nop:1.7.12", ivy"com.jcraft:jsch:0.1.54", - ivy"org.scalacheck::scalacheck:1.14.0" + withDottyCompat(ivy"org.scalacheck::scalacheck:1.14.0", scalaVersion()) ) } } @@ -629,7 +751,9 @@ def publishExecutable() = { println("MASTER COMMIT: Publishing Executable for Scala " + version) //Prepare executable - val scalaBinaryVersion = version.take(version.lastIndexOf(".")) + val scalaBinaryVersion = + (if (version == special3Version) actual3Version else version) + .take(version.lastIndexOf(".")) upload( jar.path, latestTaggedVersion,