Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some GenC export issues and add cCode.pack annotation #1190

Merged
merged 1 commit into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/genc/CAST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object CAST { // C Abstract Syntax Tree

case class FunType(ret: Type, params: Seq[Type]) extends Type

case class Struct(id: Id, fields: Seq[Var], isExported: Boolean) extends DataType {
case class Struct(id: Id, fields: Seq[Var], isExported: Boolean, isPacked: Boolean) extends DataType {
require(fields.nonEmpty, s"Fields of struct $id should be non empty")
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/genc/CASTDependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class CASTTraverser(implicit ctx: inox.Context) {
case FixedArrayType(base, _) =>
Seq(base)

case Struct(id, fields, _) =>
case Struct(id, fields, _, _) =>
id +: fields

case Fun(id, returnType, params, Left(block), _, _) =>
Expand Down
40 changes: 30 additions & 10 deletions core/src/main/scala/stainless/genc/CPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CPrinter(
sep = "\n\n")
}
|${nary(
decls.filter(!_._2.contains(External)).map { case (decl, modes) =>
decls.filter(decl => !decl._2.contains(External) && !decl._2.contains(Export)).map { case (decl, modes) =>
modes.foldLeft(TTree(decl) : WrapperTree) {
case (acc, Static) => StaticStorage(acc)
case (acc, Volatile) => VolatileStorage(acc)
Expand Down Expand Up @@ -135,6 +135,18 @@ class CPrinter(
closing = "\n\n",
sep = "\n\n")
}
|${nary(
decls.filter(decl => decl._2.contains(Export)).map { case (decl, modes) =>
modes.foldLeft(TTree(decl) : WrapperTree) {
case (acc, Static) => StaticStorage(acc)
case (acc, Volatile) => VolatileStorage(acc)
case (acc, _) => acc
}
},
opening = separator("global variables"),
closing = ";\n\n",
sep = ";\n")
}
|${nary(
functions.filter(_.isExported) map FunDecl,
opening = separator("function declarations"),
Expand Down Expand Up @@ -185,7 +197,7 @@ class CPrinter(

case Pointer(base) => c"$base*"

case Struct(id, _, _) => c"$id"
case Struct(id, _, _, _) => c"$id"

case Union(id, _, _) => c"$id"

Expand Down Expand Up @@ -319,14 +331,22 @@ class CPrinter(
| ${nary(literals, sep = ",\n")}
|} $id;"""

case DataTypeDecl(t: DataType) =>
val kind = t match {
case _: Struct => "struct"
case _: Union => "union"
}
c"""|typedef $kind {
| ${nary(t.fields, sep = ";\n", closing = ";")}
|} ${t.id};"""
case DataTypeDecl(u: Union) =>
c"""|typedef union {
| ${nary(u.fields, sep = ";\n", closing = ";")}
|} ${u.id};"""

case DataTypeDecl(s: Struct) if s.isPacked =>
c"""|#pragma pack(1)
|typedef struct {
| ${nary(s.fields, sep = ";\n", closing = ";")}
|} ${s.id};
|#pragma pack()"""

case DataTypeDecl(s: Struct) =>
c"""|typedef struct {
| ${nary(s.fields, sep = ";\n", closing = ";")}
|} ${s.id};"""

case FieldInit(id, value) => c".$id = $value"
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/genc/ir/ClassLifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ final class ClassLifter(val ctx: inox.Context) extends Transformer(NIR, LIR) {
vd
}

val cd = to.ClassDef(id, parent, fields, isAbstract, cd0.isExported)
val cd = to.ClassDef(id, parent, fields, isAbstract, cd0.isExported, cd0.isPacked)

// Actually register the classes/arrays now that we have the corresponding ClassDef
valFieldsToRegister foreach { case (id, ct) =>
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/genc/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private[genc] sealed trait IR { ir =>
def toVal = FunVal(this)
}

case class ClassDef(id: Id, parent: Option[ClassDef], fields: Seq[ValDef], isAbstract: Boolean, isExported: Boolean) extends Def {
case class ClassDef(id: Id, parent: Option[ClassDef], fields: Seq[ValDef], isAbstract: Boolean, isExported: Boolean, isPacked: Boolean) extends Def {
require(
// Parent must be abstract if any
(parent forall { _.isAbstract }) &&
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/genc/ir/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ abstract class Transformer[From <: IR, To <: IR](final val from: From, final val
}

protected def recImpl(cd: ClassDef, parent: Option[to.ClassDef])(implicit env: Env): to.ClassDef =
to.ClassDef(cd.id, parent, cd.fields map rec, cd.isAbstract, cd.isExported)
to.ClassDef(cd.id, parent, cd.fields map rec, cd.isAbstract, cd.isExported, cd.isPacked)

protected def rec(vd: ValDef)(implicit env: Env): to.ValDef = to.ValDef(vd.id, rec(vd.typ), vd.isVar)

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/stainless/genc/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ package object genc {
// declaration mode for global variables
sealed abstract class DeclarationMode
case object Static extends DeclarationMode // static annotation
case object Volatile extends DeclarationMode // static annotation
case object Volatile extends DeclarationMode // volatile annotation
case object External extends DeclarationMode // no declaration in the produced code
case object Export extends DeclarationMode // print in header file
}
5 changes: 5 additions & 0 deletions core/src/main/scala/stainless/genc/phases/ExtraOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ private[genc] object ExtraOps {
def isManuallyDefined = hasAnnotation(manualDefAnnotation)
def isExtern = fa.flags contains Extern
def isDropped = hasAnnotation("cCode.drop") || fa.flags.contains(Ghost)
def isVal: Boolean = fa.isInstanceOf[Outer] && fa.asInstanceOf[Outer].fd.isVal

def extAnnotations: Map[String, Seq[Any]] = fa.flags.collect {
case Annotation(s, args) => s -> args
Expand Down Expand Up @@ -50,6 +51,9 @@ private[genc] object ExtraOps {
def isDropped = hasAnnotation("cCode.drop") || fd.flags.contains(Ghost)
def isExported = hasAnnotation("cCode.export")
def isManuallyDefined = hasAnnotation(manualDefAnnotation)
def isVal =
(fd.flags.exists(_.name == "accessor") || fd.flags.exists { case IsField(_) => true case _ => false }) &&
fd.tparams.isEmpty && fd.params.isEmpty

def extAnnotations: Map[String, Seq[Any]] = fd.flags.collect {
case Annotation(s, args) => s -> args
Expand Down Expand Up @@ -78,6 +82,7 @@ private[genc] object ExtraOps {
def isManuallyTyped = hasAnnotation(manualTypeAnnotation)
def isDropped = hasAnnotation(droppedAnnotation)
def isExported = hasAnnotation("cCode.export")
def isPacked = hasAnnotation("cCode.pack")
def isGlobal = cd.flags.exists(_.name.startsWith("cCode.global"))
def isGlobalDefault = cd.flags.exists(_.name == "cCode.global")
def isGlobalUninitialized = cd.flags.exists(_.name == "cCode.globalUninitialized")
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/stainless/genc/phases/IR2CPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ private class IR2CImpl()(implicit val ctx: inox.Context) {
val unionType = getUnionFor(top)
val union = C.Var(TaggedUnion.value, unionType)

C.Struct(rec(top.id), tag :: union :: Nil, top.isExported)
C.Struct(rec(top.id), tag :: union :: Nil, top.isExported, top.isPacked)
}

private def buildStructForCaseClass(cd: ClassDef): C.Struct = {
Expand All @@ -532,7 +532,7 @@ private class IR2CImpl()(implicit val ctx: inox.Context) {
Seq(C.Var(C.Id("extra"), C.Primitive(Int8Type)))
} else cd.fields.map(rec(_))

C.Struct(rec(cd.id), fields, cd.isExported)
C.Struct(rec(cd.id), fields, cd.isExported, cd.isPacked)
}

private object TaggedUnion {
Expand All @@ -547,7 +547,7 @@ private class IR2CImpl()(implicit val ctx: inox.Context) {
val data = C.Var(Array.data, C.Pointer(base))
val id = C.Id(repId(arrayType))

val array = C.Struct(id, data :: length :: Nil, false)
val array = C.Struct(id, data :: length :: Nil, false, false)

// This needs to get registered as a datatype as well
register(array)
Expand Down
54 changes: 41 additions & 13 deletions core/src/main/scala/stainless/genc/phases/Scala2IRPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,12 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms:
checkGlobalUsage()

// Start the transformation from `@cCode.export` (and `@cCode.global`) functions and classes
for (fd <- symbols.functions.values if fd.isExported)
rec(Outer(fd), Seq())(Map.empty, Env(Map.empty, Map.empty, fd.isExported))
for (fd <- symbols.functions.values if fd.isExported) {
if (fd.isVal)
registerVal(fd)
else
rec(Outer(fd), Seq())(Map.empty, Env(Map.empty, Map.empty, fd.isExported))
}

for (cd <- symbols.classes.values if cd.isExported || cd.isGlobal)
rec(cd.typed.toType)(Map.empty)
Expand All @@ -92,7 +96,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms:
* Caches *
****************************************************************************************************/

var declResults = new scala.collection.mutable.ListBuffer[(CIR.Decl, Seq[DeclarationMode])]()
val declResults = new scala.collection.mutable.ListBuffer[(CIR.Decl, Seq[DeclarationMode])]()

// For functions, we associate each TypedFunDef to a CIR.FunDef for each "type context" (TypeMapping).
// This is very important for (non-generic) functions nested in a generic function because for N
Expand Down Expand Up @@ -124,6 +128,24 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms:
private def convertVarInfoToArg(vi: VarInfo)(implicit tm: TypeMapping) = CIR.ValDef(rec(vi.vd.id), rec(vi.typ), vi.isVar)
private def convertVarInfoToParam(vi: VarInfo)(implicit tm: TypeMapping) = CIR.Binding(convertVarInfoToArg(vi))



val registered = MutableSet[FunDef]()
def registerVal(fd: FunDef): Unit = {
if (!registered(fd)) {
registered += fd
val newId = rec(fd.id, withUnique = !fd.isExported && !fd.isDropped)
val newType = rec(fd.returnType)(Map.empty)
val exporting = if (fd.isExported) Seq(Export) else Seq()
if (fd.isDropped) {
declResults += ((CIR.Decl(CIR.ValDef(newId, newType, false), None), exporting :+ External))
} else {
val newBody = rec(fd.fullBody)(Env(Map.empty, Map.empty, fd.isExported), Map.empty)
declResults += ((CIR.Decl(CIR.ValDef(newId, newType, false), Some(newBody)), exporting))
}
}
}

// Extract the ValDef from the known one
private def buildBinding(vd: ValDef)(implicit env: Env, tm: TypeMapping): CIR.Binding = {
val typ = instantiateType(vd.tpe, tm)
Expand Down Expand Up @@ -241,7 +263,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms:
val types = bases map rec
val fields = types.zipWithIndex map { case (typ, i) => CIR.ValDef("_" + (i+1), typ, isVar = false) }
val id = "Tuple" + buildIdPostfix(bases)
CIR.ClassDef(id, None, fields, isAbstract = false, isExported = false)
CIR.ClassDef(id, None, fields, isAbstract = false, isExported = false, isPacked = false)

case _ => reporter.fatalError(typ.getPos, s"Unexpected ${typ.getClass} instead of TupleType")
}
Expand Down Expand Up @@ -489,7 +511,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms:
val impl = fa.getManualDefinition
CIR.FunBodyManual(impl.includes, impl.code)
} else if (fa.isDropped) {
CIR.FunDropped(fa.flags.exists(_.name == "accessor") || fa.flags.exists { case IsField(_) => true case _ => false })
CIR.FunDropped(fa.isVal)
} else {
// Build the new environment from context and parameters
val ctxKeys: Seq[(ValDef, Type)] = ctxDBAbs map { c => c.vd -> instantiateType(c.typ, tm1) }
Expand Down Expand Up @@ -588,7 +610,7 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms:
CIR.ValDef(rec(vd.id, withUnique = mangling), rec(typ), vd.flags.contains(IsVar))
})

val clazz = CIR.ClassDef(id, parent, fields, cd.isAbstract, cd.isExported)
val clazz = CIR.ClassDef(id, parent, fields, cd.isAbstract, cd.isExported, cd.isPacked)
val newAcc = acc + (ct -> clazz)
if (cd.isGlobal) {
assert(parent.isEmpty, "Classes annotated with `@cCode.global` cannot have parents")
Expand Down Expand Up @@ -706,15 +728,21 @@ private class S2IRImpl(val context: inox.Context, val ctxDB: FunCtxDB, val syms:
// We don't have to traverse the nested function now because we already have their contexts
rec(body)(env.copy(lfds = env.lfds ++ lfds.map(lfd => lfd.id -> lfd)), tm0)

case FunctionInvocation(id, Seq(), Seq()) if syms.getFunction(id).isVal =>
val fd = syms.getFunction(id)
registerVal(fd)
CIR.Binding(CIR.ValDef(rec(id, !fd.isExported && !fd.isDropped), rec(fd.returnType), false))

case fi @ FunctionInvocation(id, tps, args) =>
val fd = syms.getFunction(id)
if (fd.isExported && fd.hasPrecondition) {
reporter.warning(fi.getPos,
s"Exported functions (${fd.id.asString}) generate C assertions for requires, " +
"so invoking them from within Stainless is not recommended as Stainless already checks " +
"that the requires are respected"
)
}
// FIXME: requires do not generate assertions at the moment
// if (fd.isExported && fd.hasPrecondition) {
// reporter.warning(fi.getPos,
// s"Exported functions (${fd.id.asString}) generate C assertions for requires, " +
// "so invoking them from within Stainless is not recommended as Stainless already checks " +
// "that the requires are respected"
// )
// }
val tfd = fd.typed(tps)
val fun = rec(Outer(fd), tps)(tm0, env.copy(inExported = fd.isExported))
implicit val tm1 = tm0 ++ tfd.tpSubst
Expand Down
4 changes: 4 additions & 0 deletions frontends/library/stainless/annotation/cCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ object cCode {
@ignore
class export extends StaticAnnotation

/* Make sure struct is "packed" when compiled to C (no padding between fields of structs in memory) */
@ignore
class pack extends StaticAnnotation

/*
* Allows the user to define a type (e.g. case class) as a typeDef to an
* existing type with an optional include file.
Expand Down