From 6133dd831c2201067cb4382d9c563f83785790cb Mon Sep 17 00:00:00 2001 From: Jad Hamza Date: Tue, 12 Oct 2021 15:53:12 +0200 Subject: [PATCH] Fix some GenC export issues and add cCode.pack annotation --- core/src/main/scala/stainless/genc/CAST.scala | 2 +- .../stainless/genc/CASTDependencies.scala | 2 +- .../main/scala/stainless/genc/CPrinter.scala | 40 ++++++++++---- .../scala/stainless/genc/ir/ClassLifter.scala | 2 +- .../src/main/scala/stainless/genc/ir/IR.scala | 2 +- .../scala/stainless/genc/ir/Transformer.scala | 2 +- .../main/scala/stainless/genc/package.scala | 3 +- .../stainless/genc/phases/ExtraOps.scala | 5 ++ .../stainless/genc/phases/IR2CPhase.scala | 6 +-- .../stainless/genc/phases/Scala2IRPhase.scala | 54 ++++++++++++++----- .../library/stainless/annotation/cCode.scala | 4 ++ 11 files changed, 90 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/stainless/genc/CAST.scala b/core/src/main/scala/stainless/genc/CAST.scala index 4ffba36811..4f274746e2 100644 --- a/core/src/main/scala/stainless/genc/CAST.scala +++ b/core/src/main/scala/stainless/genc/CAST.scala @@ -93,7 +93,7 @@ object CAST { // C Abstract Syntax Tree case class FunType(ret: Type, params: Seq[Type]) extends Type - case class Struct(id: Id, fields: Seq[Var], isExported: Boolean) extends DataType { + case class Struct(id: Id, fields: Seq[Var], isExported: Boolean, isPacked: Boolean) extends DataType { require(fields.nonEmpty, s"Fields of struct $id should be non empty") } diff --git a/core/src/main/scala/stainless/genc/CASTDependencies.scala b/core/src/main/scala/stainless/genc/CASTDependencies.scala index 8fb36d9695..aab5fe242e 100644 --- a/core/src/main/scala/stainless/genc/CASTDependencies.scala +++ b/core/src/main/scala/stainless/genc/CASTDependencies.scala @@ -23,7 +23,7 @@ class CASTTraverser(implicit ctx: inox.Context) { case FixedArrayType(base, _) => Seq(base) - case Struct(id, fields, _) => + case Struct(id, fields, _, _) => id +: fields case Fun(id, returnType, params, Left(block), _, _) => diff --git a/core/src/main/scala/stainless/genc/CPrinter.scala b/core/src/main/scala/stainless/genc/CPrinter.scala index b850ef3eb7..ac0d777242 100644 --- a/core/src/main/scala/stainless/genc/CPrinter.scala +++ b/core/src/main/scala/stainless/genc/CPrinter.scala @@ -63,7 +63,7 @@ class CPrinter( sep = "\n\n") } |${nary( - decls.filter(!_._2.contains(External)).map { case (decl, modes) => + decls.filter(decl => !decl._2.contains(External) && !decl._2.contains(Export)).map { case (decl, modes) => modes.foldLeft(TTree(decl) : WrapperTree) { case (acc, Static) => StaticStorage(acc) case (acc, Volatile) => VolatileStorage(acc) @@ -135,6 +135,18 @@ class CPrinter( closing = "\n\n", sep = "\n\n") } + |${nary( + decls.filter(decl => decl._2.contains(Export)).map { case (decl, modes) => + modes.foldLeft(TTree(decl) : WrapperTree) { + case (acc, Static) => StaticStorage(acc) + case (acc, Volatile) => VolatileStorage(acc) + case (acc, _) => acc + } + }, + opening = separator("global variables"), + closing = ";\n\n", + sep = ";\n") + } |${nary( functions.filter(_.isExported) map FunDecl, opening = separator("function declarations"), @@ -185,7 +197,7 @@ class CPrinter( case Pointer(base) => c"$base*" - case Struct(id, _, _) => c"$id" + case Struct(id, _, _, _) => c"$id" case Union(id, _, _) => c"$id" @@ -319,14 +331,22 @@ class CPrinter( | ${nary(literals, sep = ",\n")} |} $id;""" - case DataTypeDecl(t: DataType) => - val kind = t match { - case _: Struct => "struct" - case _: Union => "union" - } - c"""|typedef $kind { - | ${nary(t.fields, sep = ";\n", closing = ";")} - |} ${t.id};""" + case DataTypeDecl(u: Union) => + c"""|typedef union { + | ${nary(u.fields, sep = ";\n", closing = ";")} + |} ${u.id};""" + + case DataTypeDecl(s: Struct) if s.isPacked => + c"""|#pragma pack(1) + |typedef struct { + | ${nary(s.fields, sep = ";\n", closing = ";")} + |} ${s.id}; + |#pragma pack()""" + + case DataTypeDecl(s: Struct) => + c"""|typedef struct { + | ${nary(s.fields, sep = ";\n", closing = ";")} + |} ${s.id};""" case FieldInit(id, value) => c".$id = $value" } diff --git a/core/src/main/scala/stainless/genc/ir/ClassLifter.scala b/core/src/main/scala/stainless/genc/ir/ClassLifter.scala index fcd805dbd2..7d7d4ab14c 100644 --- a/core/src/main/scala/stainless/genc/ir/ClassLifter.scala +++ b/core/src/main/scala/stainless/genc/ir/ClassLifter.scala @@ -70,7 +70,7 @@ final class ClassLifter(val ctx: inox.Context) extends Transformer(NIR, LIR) { vd } - val cd = to.ClassDef(id, parent, fields, isAbstract, cd0.isExported) + val cd = to.ClassDef(id, parent, fields, isAbstract, cd0.isExported, cd0.isPacked) // Actually register the classes/arrays now that we have the corresponding ClassDef valFieldsToRegister foreach { case (id, ct) => diff --git a/core/src/main/scala/stainless/genc/ir/IR.scala b/core/src/main/scala/stainless/genc/ir/IR.scala index bac471540c..d3592cc316 100644 --- a/core/src/main/scala/stainless/genc/ir/IR.scala +++ b/core/src/main/scala/stainless/genc/ir/IR.scala @@ -84,7 +84,7 @@ private[genc] sealed trait IR { ir => def toVal = FunVal(this) } - case class ClassDef(id: Id, parent: Option[ClassDef], fields: Seq[ValDef], isAbstract: Boolean, isExported: Boolean) extends Def { + case class ClassDef(id: Id, parent: Option[ClassDef], fields: Seq[ValDef], isAbstract: Boolean, isExported: Boolean, isPacked: Boolean) extends Def { require( // Parent must be abstract if any (parent forall { _.isAbstract }) && diff --git a/core/src/main/scala/stainless/genc/ir/Transformer.scala b/core/src/main/scala/stainless/genc/ir/Transformer.scala index 5333cedc89..74ce2957b3 100644 --- a/core/src/main/scala/stainless/genc/ir/Transformer.scala +++ b/core/src/main/scala/stainless/genc/ir/Transformer.scala @@ -103,7 +103,7 @@ abstract class Transformer[From <: IR, To <: IR](final val from: From, final val } protected def recImpl(cd: ClassDef, parent: Option[to.ClassDef])(implicit env: Env): to.ClassDef = - to.ClassDef(cd.id, parent, cd.fields map rec, cd.isAbstract, cd.isExported) + to.ClassDef(cd.id, parent, cd.fields map rec, cd.isAbstract, cd.isExported, cd.isPacked) protected def rec(vd: ValDef)(implicit env: Env): to.ValDef = to.ValDef(vd.id, rec(vd.typ), vd.isVar) diff --git a/core/src/main/scala/stainless/genc/package.scala b/core/src/main/scala/stainless/genc/package.scala index 5ac47e5d9a..4def3b68dc 100644 --- a/core/src/main/scala/stainless/genc/package.scala +++ b/core/src/main/scala/stainless/genc/package.scala @@ -25,6 +25,7 @@ package object genc { // declaration mode for global variables sealed abstract class DeclarationMode case object Static extends DeclarationMode // static annotation - case object Volatile extends DeclarationMode // static annotation + case object Volatile extends DeclarationMode // volatile annotation case object External extends DeclarationMode // no declaration in the produced code + case object Export extends DeclarationMode // print in header file } diff --git a/core/src/main/scala/stainless/genc/phases/ExtraOps.scala b/core/src/main/scala/stainless/genc/phases/ExtraOps.scala index d9a98476ab..40e5e2b939 100644 --- a/core/src/main/scala/stainless/genc/phases/ExtraOps.scala +++ b/core/src/main/scala/stainless/genc/phases/ExtraOps.scala @@ -21,6 +21,7 @@ private[genc] object ExtraOps { def isManuallyDefined = hasAnnotation(manualDefAnnotation) def isExtern = fa.flags contains Extern def isDropped = hasAnnotation("cCode.drop") || fa.flags.contains(Ghost) + def isVal: Boolean = fa.isInstanceOf[Outer] && fa.asInstanceOf[Outer].fd.isVal def extAnnotations: Map[String, Seq[Any]] = fa.flags.collect { case Annotation(s, args) => s -> args @@ -50,6 +51,9 @@ private[genc] object ExtraOps { def isDropped = hasAnnotation("cCode.drop") || fd.flags.contains(Ghost) def isExported = hasAnnotation("cCode.export") def isManuallyDefined = hasAnnotation(manualDefAnnotation) + def isVal = + (fd.flags.exists(_.name == "accessor") || fd.flags.exists { case IsField(_) => true case _ => false }) && + fd.tparams.isEmpty && fd.params.isEmpty def extAnnotations: Map[String, Seq[Any]] = fd.flags.collect { case Annotation(s, args) => s -> args @@ -78,6 +82,7 @@ private[genc] object ExtraOps { def isManuallyTyped = hasAnnotation(manualTypeAnnotation) def isDropped = hasAnnotation(droppedAnnotation) def isExported = hasAnnotation("cCode.export") + def isPacked = hasAnnotation("cCode.pack") def isGlobal = cd.flags.exists(_.name.startsWith("cCode.global")) def isGlobalDefault = cd.flags.exists(_.name == "cCode.global") def isGlobalUninitialized = cd.flags.exists(_.name == "cCode.globalUninitialized") diff --git a/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala b/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala index 9dddc75b27..33789da841 100644 --- a/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala +++ b/core/src/main/scala/stainless/genc/phases/IR2CPhase.scala @@ -519,7 +519,7 @@ private class IR2CImpl()(implicit val ctx: inox.Context) { val unionType = getUnionFor(top) val union = C.Var(TaggedUnion.value, unionType) - C.Struct(rec(top.id), tag :: union :: Nil, top.isExported) + C.Struct(rec(top.id), tag :: union :: Nil, top.isExported, top.isPacked) } private def buildStructForCaseClass(cd: ClassDef): C.Struct = { @@ -532,7 +532,7 @@ private class IR2CImpl()(implicit val ctx: inox.Context) { Seq(C.Var(C.Id("extra"), C.Primitive(Int8Type))) } else cd.fields.map(rec(_)) - C.Struct(rec(cd.id), fields, cd.isExported) + C.Struct(rec(cd.id), fields, cd.isExported, cd.isPacked) } private object TaggedUnion { @@ -547,7 +547,7 @@ private class IR2CImpl()(implicit val ctx: inox.Context) { val data = C.Var(Array.data, C.Pointer(base)) val id = C.Id(repId(arrayType)) - val array = C.Struct(id, data :: length :: Nil, false) + val array = C.Struct(id, data :: length :: Nil, false, false) // This needs to get registered as a datatype as well register(array) diff --git a/core/src/main/scala/stainless/genc/phases/Scala2IRPhase.scala b/core/src/main/scala/stainless/genc/phases/Scala2IRPhase.scala index b30897fc45..8351bf7946 100644 --- a/core/src/main/scala/stainless/genc/phases/Scala2IRPhase.scala +++ b/core/src/main/scala/stainless/genc/phases/Scala2IRPhase.scala @@ -81,8 +81,12 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms: checkGlobalUsage() // Start the transformation from `@cCode.export` (and `@cCode.global`) functions and classes - for (fd <- symbols.functions.values if fd.isExported) - rec(Outer(fd), Seq())(Map.empty, Env(Map.empty, Map.empty, fd.isExported)) + for (fd <- symbols.functions.values if fd.isExported) { + if (fd.isVal) + registerVal(fd) + else + rec(Outer(fd), Seq())(Map.empty, Env(Map.empty, Map.empty, fd.isExported)) + } for (cd <- symbols.classes.values if cd.isExported || cd.isGlobal) rec(cd.typed.toType)(Map.empty) @@ -92,7 +96,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms: * Caches * ****************************************************************************************************/ - var declResults = new scala.collection.mutable.ListBuffer[(CIR.Decl, Seq[DeclarationMode])]() + val declResults = new scala.collection.mutable.ListBuffer[(CIR.Decl, Seq[DeclarationMode])]() // For functions, we associate each TypedFunDef to a CIR.FunDef for each "type context" (TypeMapping). // This is very important for (non-generic) functions nested in a generic function because for N @@ -124,6 +128,24 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms: private def convertVarInfoToArg(vi: VarInfo)(implicit tm: TypeMapping) = CIR.ValDef(rec(vi.vd.id), rec(vi.typ), vi.isVar) private def convertVarInfoToParam(vi: VarInfo)(implicit tm: TypeMapping) = CIR.Binding(convertVarInfoToArg(vi)) + + + val registered = MutableSet[FunDef]() + def registerVal(fd: FunDef): Unit = { + if (!registered(fd)) { + registered += fd + val newId = rec(fd.id, withUnique = !fd.isExported && !fd.isDropped) + val newType = rec(fd.returnType)(Map.empty) + val exporting = if (fd.isExported) Seq(Export) else Seq() + if (fd.isDropped) { + declResults += ((CIR.Decl(CIR.ValDef(newId, newType, false), None), exporting :+ External)) + } else { + val newBody = rec(fd.fullBody)(Env(Map.empty, Map.empty, fd.isExported), Map.empty) + declResults += ((CIR.Decl(CIR.ValDef(newId, newType, false), Some(newBody)), exporting)) + } + } + } + // Extract the ValDef from the known one private def buildBinding(vd: ValDef)(implicit env: Env, tm: TypeMapping): CIR.Binding = { val typ = instantiateType(vd.tpe, tm) @@ -241,7 +263,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms: val types = bases map rec val fields = types.zipWithIndex map { case (typ, i) => CIR.ValDef("_" + (i+1), typ, isVar = false) } val id = "Tuple" + buildIdPostfix(bases) - CIR.ClassDef(id, None, fields, isAbstract = false, isExported = false) + CIR.ClassDef(id, None, fields, isAbstract = false, isExported = false, isPacked = false) case _ => reporter.fatalError(typ.getPos, s"Unexpected ${typ.getClass} instead of TupleType") } @@ -489,7 +511,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms: val impl = fa.getManualDefinition CIR.FunBodyManual(impl.includes, impl.code) } else if (fa.isDropped) { - CIR.FunDropped(fa.flags.exists(_.name == "accessor") || fa.flags.exists { case IsField(_) => true case _ => false }) + CIR.FunDropped(fa.isVal) } else { // Build the new environment from context and parameters val ctxKeys: Seq[(ValDef, Type)] = ctxDBAbs map { c => c.vd -> instantiateType(c.typ, tm1) } @@ -588,7 +610,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms: CIR.ValDef(rec(vd.id, withUnique = mangling), rec(typ), vd.flags.contains(IsVar)) }) - val clazz = CIR.ClassDef(id, parent, fields, cd.isAbstract, cd.isExported) + val clazz = CIR.ClassDef(id, parent, fields, cd.isAbstract, cd.isExported, cd.isPacked) val newAcc = acc + (ct -> clazz) if (cd.isGlobal) { assert(parent.isEmpty, "Classes annotated with `@cCode.global` cannot have parents") @@ -706,15 +728,21 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms: // We don't have to traverse the nested function now because we already have their contexts rec(body)(env.copy(lfds = env.lfds ++ lfds.map(lfd => lfd.id -> lfd)), tm0) + case FunctionInvocation(id, Seq(), Seq()) if syms.getFunction(id).isVal => + val fd = syms.getFunction(id) + registerVal(fd) + CIR.Binding(CIR.ValDef(rec(id, !fd.isExported && !fd.isDropped), rec(fd.returnType), false)) + case fi @ FunctionInvocation(id, tps, args) => val fd = syms.getFunction(id) - if (fd.isExported && fd.hasPrecondition) { - reporter.warning(fi.getPos, - s"Exported functions (${fd.id.asString}) generate C assertions for requires, " + - "so invoking them from within Stainless is not recommended as Stainless already checks " + - "that the requires are respected" - ) - } + // FIXME: requires do not generate assertions at the moment + // if (fd.isExported && fd.hasPrecondition) { + // reporter.warning(fi.getPos, + // s"Exported functions (${fd.id.asString}) generate C assertions for requires, " + + // "so invoking them from within Stainless is not recommended as Stainless already checks " + + // "that the requires are respected" + // ) + // } val tfd = fd.typed(tps) val fun = rec(Outer(fd), tps)(tm0, env.copy(inExported = fd.isExported)) implicit val tm1 = tm0 ++ tfd.tpSubst diff --git a/frontends/library/stainless/annotation/cCode.scala b/frontends/library/stainless/annotation/cCode.scala index b13071b5e6..38ca6598b7 100644 --- a/frontends/library/stainless/annotation/cCode.scala +++ b/frontends/library/stainless/annotation/cCode.scala @@ -57,6 +57,10 @@ object cCode { @ignore class export extends StaticAnnotation + /* Make sure struct is "packed" when compiled to C (no padding between fields of structs in memory) */ + @ignore + class pack extends StaticAnnotation + /* * Allows the user to define a type (e.g. case class) as a typeDef to an * existing type with an optional include file.