From d865ac19f9661e0ed0b55ec60b56180d1c9084d7 Mon Sep 17 00:00:00 2001 From: Mario Bucev Date: Thu, 15 Feb 2024 18:09:33 +0100 Subject: [PATCH] Add support for exported methods --- .../invalid/ExportedMethods1.scala | 20 + .../invalid/ExportedMethods2.scala | 11 + .../invalid/ExportedMethods3.scala | 11 + .../invalid/ExportedMethodsExt.scala | 87 ++++ .../valid/ExportedMethods.scala | 384 ++++++++++++++++++ .../valid/ExportedMethodsExt.scala | 81 ++++ .../frontends/dotc/CodeExtraction.scala | 89 +++- .../frontends/dotc/DottyCompiler.scala | 9 +- .../frontends/dotc/StainlessExtraction.scala | 5 +- .../frontends/dotc/StainlessPlugin.scala | 6 +- .../stainless/frontends/dotc/Utils.scala | 65 +++ 11 files changed, 762 insertions(+), 6 deletions(-) create mode 100644 frontends/benchmarks/dotty-specific/invalid/ExportedMethods1.scala create mode 100644 frontends/benchmarks/dotty-specific/invalid/ExportedMethods2.scala create mode 100644 frontends/benchmarks/dotty-specific/invalid/ExportedMethods3.scala create mode 100644 frontends/benchmarks/dotty-specific/invalid/ExportedMethodsExt.scala create mode 100644 frontends/benchmarks/dotty-specific/valid/ExportedMethods.scala create mode 100644 frontends/benchmarks/dotty-specific/valid/ExportedMethodsExt.scala create mode 100644 frontends/dotty/src/main/scala/stainless/frontends/dotc/Utils.scala diff --git a/frontends/benchmarks/dotty-specific/invalid/ExportedMethods1.scala b/frontends/benchmarks/dotty-specific/invalid/ExportedMethods1.scala new file mode 100644 index 0000000000..42150aeca2 --- /dev/null +++ b/frontends/benchmarks/dotty-specific/invalid/ExportedMethods1.scala @@ -0,0 +1,20 @@ +object ExportedMethods1 { + case class Counter(var x: BigInt) { + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt): Unit = { + add(y) // invalid + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/dotty-specific/invalid/ExportedMethods2.scala b/frontends/benchmarks/dotty-specific/invalid/ExportedMethods2.scala new file mode 100644 index 0000000000..cabeaf8df5 --- /dev/null +++ b/frontends/benchmarks/dotty-specific/invalid/ExportedMethods2.scala @@ -0,0 +1,11 @@ +object ExportedMethods2 { + import ExportedMethodsExt.SimpleCounter.* + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt): Unit = { + add(y) // invalid + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/dotty-specific/invalid/ExportedMethods3.scala b/frontends/benchmarks/dotty-specific/invalid/ExportedMethods3.scala new file mode 100644 index 0000000000..cdad66e335 --- /dev/null +++ b/frontends/benchmarks/dotty-specific/invalid/ExportedMethods3.scala @@ -0,0 +1,11 @@ +object ExportedMethods3 { + import ExportedMethodsExt.CounterWithInvariant.* + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt): Unit = { + x = y + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/dotty-specific/invalid/ExportedMethodsExt.scala b/frontends/benchmarks/dotty-specific/invalid/ExportedMethodsExt.scala new file mode 100644 index 0000000000..09594e14d2 --- /dev/null +++ b/frontends/benchmarks/dotty-specific/invalid/ExportedMethodsExt.scala @@ -0,0 +1,87 @@ +object ExportedMethodsExt { + + // This object is used for the other ExportedMethods + // but we need to have at least one invalid VC to pass the "invalid" test suite + def dummyInvalid(x: BigInt): Unit = { + assert(x == 0) + } + + object SimpleCounter { + case class Counter(var x: BigInt) { + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + } + } + + object CounterWithInvariant { + case class Counter(var x: BigInt) { + require(x >= 0) + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + } + } + + object AbstractCounter { + abstract case class Counter() { + var x: BigInt + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + } + } + + object AbstractBaseAndCounter { + abstract case class Counter() { + var x: BigInt + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + abstract case class Base() { + val cnt: Counter + export cnt.* + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/dotty-specific/valid/ExportedMethods.scala b/frontends/benchmarks/dotty-specific/valid/ExportedMethods.scala new file mode 100644 index 0000000000..78c1f97c16 --- /dev/null +++ b/frontends/benchmarks/dotty-specific/valid/ExportedMethods.scala @@ -0,0 +1,384 @@ +object ExportedMethods { + object Local { + object SimpleCounter { + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + case class Counter(var x: BigInt) { + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + + def useCounterFromBase(y: BigInt): Unit = { + require(y >= 0) + x += y + val abc = x + add(y) + } + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + object CounterWithInvariant { + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + case class Counter(var x: BigInt) { + require(x >= 0) + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + + def useCounterFromBase(y: BigInt): Unit = { + require(y >= 0) + x += y + val abc = x + add(y) + } + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + + object AbstractCounter { + + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + abstract case class Counter() { + var x: BigInt + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + + def useCounterFromBase(y: BigInt): Unit = { + require(y >= 0) + x += y + val abc = x + add(y) + } + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + + object AbstractBaseAndCounter { + + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + abstract case class Counter() { + var x: BigInt + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + abstract case class Base() { + val cnt: Counter + export cnt.* + + def useCounterFromBase(y: BigInt): Unit = { + require(y >= 0) + x += y + val abc = x + add(y) + } + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + } + + object Ext { + object SimpleCounter { + import ExportedMethodsExt.SimpleCounter.* + + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + + object CounterWithInvariant { + import ExportedMethodsExt.CounterWithInvariant.* + + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + + object AbstractCounter { + import ExportedMethodsExt.AbstractCounter.* + + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + + object AbstractBaseAndCounter { + import ExportedMethodsExt.AbstractBaseAndCounter.* + + def accessExported(w: Wrapper, y: BigInt): Unit = { + require(y >= 0) + w.add(y) + val abc = w.x + w.x = y + } + + def useCounterFromBaseOutside(b: Base, y: BigInt): Unit = { + require(y >= 0) + b.x += y + val abc = b.x + b.add(y) + } + + case class Wrapper(base: Base) { + export base.* + + def addWith(y: BigInt, z: BigInt): Unit = { + require(y >= 0) + require(z >= 0) + x = z + val abc = x + add(y) + add(z) + } + + def parametricAddWith[T](y: BigInt, t: T): Unit = { + require(y >= 0) + parametricAdd(y, t) + } + } + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/dotty-specific/valid/ExportedMethodsExt.scala b/frontends/benchmarks/dotty-specific/valid/ExportedMethodsExt.scala new file mode 100644 index 0000000000..37106f179b --- /dev/null +++ b/frontends/benchmarks/dotty-specific/valid/ExportedMethodsExt.scala @@ -0,0 +1,81 @@ +object ExportedMethodsExt { + + object SimpleCounter { + case class Counter(var x: BigInt) { + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + } + } + + object CounterWithInvariant { + case class Counter(var x: BigInt) { + require(x >= 0) + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + } + } + + object AbstractCounter { + abstract case class Counter() { + var x: BigInt + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + case class Base(cnt: Counter) { + export cnt.* + } + } + + object AbstractBaseAndCounter { + abstract case class Counter() { + var x: BigInt + + def add(y: BigInt): Unit = { + require(y >= 0) + x += y + } + + def parametricAdd[T](y: BigInt, t: T): Unit = { + require(y >= 0) + x += y + } + } + + abstract case class Base() { + val cnt: Counter + export cnt.* + } + } +} \ No newline at end of file diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala index 5ec52cdad5..c663277ddc 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/CodeExtraction.scala @@ -17,15 +17,20 @@ import core.Types._ import core.Flags._ import core.Constants._ import core.NameKinds +import core.NameOps._ import util.{NoSourcePosition, SourcePosition} import stainless.ast.SymbolIdentifier import extraction.xlang.{trees => xt} +import Utils._ import scala.collection.mutable.{Map => MutableMap} import scala.collection.immutable.ListMap import scala.language.implicitConversions -class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using override val dottyCtx: DottyContext) +class CodeExtraction(inoxCtx: inox.Context, + symbolMapping: SymbolMapping, + exportedSymsMapping: ExportedSymbolsMapping) + (using override val dottyCtx: DottyContext) extends ASTExtractors { import AuxiliaryExtractors._ @@ -374,6 +379,9 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using case t if (t.symbol is Synthetic) && !canExtractSynthetic(t.symbol) => // ignore + case ExFunctionDef(fsym, _, _, _, _) if fsym is Exported => + // ignore + // Normal function case dd @ ExFunctionDef(fsym, tparams, vparams, tpt, rhs) => val fd0 = extractFunction(fsym, dd, tparams, vparams, rhs) @@ -395,6 +403,9 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using case t @ ExNonCtorMutableFieldDef(_, _, _) => outOfSubsetError(t, "Mutable fields in static containers such as objects are not supported") + case Export(_, _) => + // ignore + case other => reporter.warning(other.sourcePos, s"Stainless does not support the following tree in static containers:\n$other") } @@ -553,6 +564,9 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using if hasExternFields && (isCopyMethod(fsym) || isDefaultGetter(fsym)) => () // ignore + case ExFunctionDef(fsym, _, _, _, _) if fsym is Exported => + // ignore + // Normal methods case dd @ ExFunctionDef(fsym, tparams, vparams, _, rhs) => methods :+= extractFunction(fsym, dd, tparams, vparams, rhs)(using tpCtx) @@ -571,6 +585,9 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using case d if d.symbol is Synthetic => // ignore + case Export(_, _) => + // ignore + case other => reporter.warning(other.sourcePos, s"In class $id, Stainless does not support:\n$other") } @@ -789,7 +806,6 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using .copy(isExtern = dctx.isExtern || (flags contains xt.Extern)) lazy val retType = extractType(tree.tpt)(using nctx) - val (finalBody, returnType) = if (isAbstract) { (xt.NoTree(retType).setPos(sym.sourcePos), retType) } else { @@ -1246,6 +1262,61 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using private def extractTree(tr: tpd.Tree)(using dctx: DefContext): xt.Expr = (tr match { case SingletonTypeTree(tree) => extractTree(tree) + case ExExportedSymbol(path, recv0, tps, args) => + def ownerType(sym: Symbol): xt.ClassType | xt.LocalClassType = { + stripAnnotationsExceptStrictBV(extractType(sym.owner.typeRef)(using dctx.setResolveTypes(true), tr.sourcePos)) match { + case ct: (xt.ClassType | xt.LocalClassType) => ct + case _ => outOfSubsetError(tr, s"Stainless does not support use of exported symbol in this context:\n${tr.show}") + } + } + def mkSelection(recv: xt.Expr, sym: Symbol): xt.Expr = { + // Selection across exported symbol that works whether the class + // is abstract (method invocation) or concrete (field selection). + // Inspired by `extractCall` + val ct = ownerType(sym) + val isCtorField = (sym is CaseAccessor) || (sym is ParamAccessor) + val isNonCtorField = sym.isField && !isCtorField + assert(isCtorField || isNonCtorField) + if (isCtorField) { + // Class is concrete, so this is a simple field selection + classSelector(ct, recv, getIdentifier(sym)).setPos(tr.sourcePos) + } else { + // Class is abstract, so we must issue a method call *using* `getFieldAccessorIdentifier` + methodInvocation(ct, recv, getFieldAccessorIdentifier(sym), Seq.empty, Seq.empty).setPos(tr.sourcePos) + } + } + + val last = path.last + val isCtorField = (last is CaseAccessor) || (last is ParamAccessor) + val isNonCtorField = last.isField && !isCtorField + val recRecv0 = extractTree(recv0) + if (isCtorField) { + // Class is concrete, however assignment of fields is done through the setter (e.g. `myField_=`), + // therefore, we must use the underlying symbol in such case. + assert(tps.isEmpty && args.size <= 1) + val isSetter = last.name.isSetterName + val newPath = { + if (isSetter) path.init + else path.init :+ last.underlyingSymbol + } + val recv = newPath.foldLeft(recRecv0)(mkSelection) + if (isSetter) xt.FieldAssignment(recv, getIdentifier(last.underlyingSymbol), extractTree(args.head)).setPos(tr.sourcePos) + else recv + } else { + // Either a normal method, an abstract class field selection or an abstract class field assignment. + // The latter two are distinguished by `isNonCtorField` being true. + // If so, we use `getFieldAccessorIdentifier` to get the right Stainless symbol. + val ct = ownerType(last) + val recv = path.init.foldLeft(recRecv0)(mkSelection) + val iden = { + if (isNonCtorField) getFieldAccessorIdentifier(last) + else getIdentifier(last) + } + val recArgs = extractArgs(last, args) + val recTps = tps.map(extractType) + methodInvocation(ct, recv, iden, recTps, recArgs).setPos(tr.sourcePos) + } + case ExLambda(vparams, rhs) => val vds = vparams map (vd => xt.ValDef( FreshIdentifier(vd.symbol.name.toString), @@ -2530,4 +2601,18 @@ class CodeExtraction(inoxCtx: inox.Context, symbolMapping: SymbolMapping)(using } traverser.traverse(tree) } + + object ExExportedSymbol { + def unapply(tr: tpd.Tree): Option[(Seq[Symbol], tpd.Tree, Seq[tpd.Tree], Seq[tpd.Tree])] = { + val sym = tr.symbol + exportedSymsMapping.get(sym) match { + case Some(path) => + tr match { + case ExCall(Some(recv), _, tps, args) => Some((path, recv, tps, args)) + case _ => outOfSubsetError(tr, s"Stainless does not support use of exported symbol in this context:\n${tr.show}") + } + case None => None + } + } + } } diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala index 006c0bbb01..0bf52ed603 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala @@ -14,6 +14,7 @@ import core.Phases._ import transform._ import typer._ import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend} +import Utils._ import java.io.File import java.net.URL @@ -40,11 +41,17 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler { // Note: this must not be instantiated within `run`, because we need the underlying `symbolMapping` in `StainlessExtraction` // to be shared across multiple compilation unit. private val extraction = new StainlessExtraction(ctx) + private var exportedSymsMapping: ExportedSymbolsMapping = ExportedSymbolsMapping.empty // This method id called for every compilation unit, and in the same thread. override def run(using dottyCtx: DottyContext): Unit = - extraction.extractUnit.foreach(extracted => + extraction.extractUnit(exportedSymsMapping).foreach(extracted => callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs)) + + override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = { + exportedSymsMapping = exportedSymbolsMapping(ctx, this.start, units) + super.runOn(units) + } } // Pick all phases until `including` (with its group included) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala index 359861872e..9c1984384d 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala @@ -14,17 +14,18 @@ import typer._ import extraction.xlang.{trees => xt} import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend, UnsupportedCodeException} +import Utils._ case class ExtractedUnit(file: String, unit: xt.UnitDef, classes: Seq[xt.ClassDef], functions: Seq[xt.FunDef], typeDefs: Seq[xt.TypeDef]) class StainlessExtraction(val inoxCtx: inox.Context) { private val symbolMapping = new SymbolMapping - def extractUnit(using ctx: DottyContext): Option[ExtractedUnit] = { + def extractUnit(exportedSymsMapping: ExportedSymbolsMapping)(using ctx: DottyContext): Option[ExtractedUnit] = { // Remark: the method `extractUnit` is called for each compilation unit (which corresponds more or less to a Scala file) // Therefore, the symbolMapping instances needs to be shared accross compilation unit. // Since `extractUnit` is called within the same thread, we do not need to synchronize accesses to symbolMapping. - val extraction = new CodeExtraction(inoxCtx, symbolMapping) + val extraction = new CodeExtraction(inoxCtx, symbolMapping, exportedSymsMapping) import extraction._ val unit = ctx.compilationUnit diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessPlugin.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessPlugin.scala index 1da1db6e42..1bc86fe39a 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessPlugin.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessPlugin.scala @@ -9,11 +9,13 @@ import dotc.util._ import Contexts.{Context => DottyContext} import plugins._ import Phases._ +import Symbols._ import transform._ import reporting._ import inox.{Context, DebugSection, utils => InoxPosition} import stainless.frontend import stainless.frontend.{CallBack, Frontend} +import Utils._ object StainlessPlugin { val PluginName = "stainless" @@ -76,11 +78,12 @@ class StainlessPlugin extends StandardPlugin { private var extraction: Option[StainlessExtraction] = None private var callback: Option[CallBack] = None + private var exportedSymsMapping: ExportedSymbolsMapping = ExportedSymbolsMapping.empty // This method id called for every compilation unit, and in the same thread. // It is called within super.runOn. override def run(using DottyContext): Unit = - extraction.get.extractUnit.foreach(extracted => + extraction.get.extractUnit(exportedSymsMapping).foreach(extracted => callback.get(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs)) override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = { @@ -105,6 +108,7 @@ class StainlessPlugin extends StandardPlugin { // Not pretty at all... Oh well... callback = Some(cb) extraction = Some(new StainlessExtraction(inoxCtx)) + exportedSymsMapping = Utils.exportedSymbolsMapping(inoxCtx, this.start, units) cb.beginExtractions() val unitRes = super.runOn(units) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/Utils.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/Utils.scala new file mode 100644 index 0000000000..c44f4d3a64 --- /dev/null +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/Utils.scala @@ -0,0 +1,65 @@ +package stainless +package frontends.dotc + +import dotty.tools.dotc +import dotc._ +import core._ +import Symbols._ +import dotc.util._ +import Contexts.{Context => DottyContext} +import ast.tpd +import Flags._ +import transform._ + +object Utils { + + case class ExportedSymbolsMapping private(private val mapping: Map[Symbol, (Option[Symbol], Symbol)]) { + def add(from: Symbol, recv: Option[Symbol], to: Symbol): ExportedSymbolsMapping = { + val newMapping = mapping + (from -> (recv, to)) + ExportedSymbolsMapping(newMapping) + } + + def get(sym: Symbol): Option[Seq[Symbol]] = { + def loop(sym: Symbol, acc: Seq[Symbol]): Seq[Symbol] = { + mapping.get(sym) match { + case Some((recv, fwd)) => + val newPath = acc ++ recv.toSeq + loop(fwd, newPath) + case None => acc :+ sym + } + } + + if (mapping.contains(sym)) Some(loop(sym, Seq.empty)) + else None + } + } + + object ExportedSymbolsMapping { + def empty: ExportedSymbolsMapping = ExportedSymbolsMapping(Map.empty) + } + + def exportedSymbolsMapping(ctx: inox.Context, start: Int, units: List[CompilationUnit])(using dottyCtx: DottyContext): ExportedSymbolsMapping = { + var mapping = ExportedSymbolsMapping.empty + + class Traverser(override val dottyCtx: DottyContext) extends tpd.TreeTraverser with ASTExtractors { + import StructuralExtractors._ + import ExpressionExtractors._ + + override def traverse(tree: tpd.Tree)(using DottyContext): Unit = { + tree match { + case ExFunctionDef(sym, _, _, _, ExCall(recv, fwd, _, _)) if sym is Exported => + mapping = mapping.add(sym, recv.map(_.symbol), fwd) + case _ => traverseChildren(tree) + } + } + } + + import dotty.tools.dotc.typer.ImportInfo.withRootImports + for (unit <- units) { + val newCtx = dottyCtx.fresh.setPhase(start).setCompilationUnit(unit).withRootImports + val traverser = new Traverser(newCtx) + traverser.traverse(unit.tpdTree) + } + mapping + } +}