Skip to content

Commit

Permalink
Support indexing into collections in the Rust backend
Browse files Browse the repository at this point in the history
<small>By submitting this pull request, I confirm that my contribution is made under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).</small>
  • Loading branch information
shadaj committed Aug 23, 2023
1 parent fb8b5f2 commit f3d96c4
Show file tree
Hide file tree
Showing 6 changed files with 4,797 additions and 4,571 deletions.
3 changes: 2 additions & 1 deletion Source/DafnyCore/Compilers/Dafny/AST.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ module {:extern "DAST"} DAST {

datatype Ident = Ident(id: string)

datatype Class = Class(name: string, typeParams: seq<Type>, superClasses: seq<Type>, fields: seq<Field>, body: seq<ClassItem>)
datatype Class = Class(name: string, enclosingModule: Ident, typeParams: seq<Type>, superClasses: seq<Type>, fields: seq<Field>, body: seq<ClassItem>)

datatype Trait = Trait(name: string, typeParams: seq<Type>, body: seq<ClassItem>)

Expand Down Expand Up @@ -71,6 +71,7 @@ module {:extern "DAST"} DAST {
BinOp(op: string, left: Expression, right: Expression) |
Select(expr: Expression, field: string, isConstant: bool, onDatatype: bool) |
SelectFn(expr: Expression, field: string, onDatatype: bool, isStatic: bool, arity: nat) |
Index(expr: Expression, idx: Expression) |
TupleSelect(expr: Expression, index: nat) |
Call(on: Expression, name: Ident, typeArgs: seq<Type>, args: seq<Expression>) |
Lambda(params: seq<Formal>, retType: Type, body: seq<Statement>) |
Expand Down
17 changes: 10 additions & 7 deletions Source/DafnyCore/Compilers/Dafny/ASTBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,24 @@ public object Finish() {
interface ClassContainer {
void AddClass(Class item);

public ClassBuilder Class(string name, List<DAST.Type> typeParams, List<DAST.Type> superClasses) {
return new ClassBuilder(this, name, typeParams, superClasses);
public ClassBuilder Class(string name, string enclosingModule, List<DAST.Type> typeParams, List<DAST.Type> superClasses) {
return new ClassBuilder(this, name, enclosingModule, typeParams, superClasses);
}
}

class ClassBuilder : ClassLike {
readonly ClassContainer parent;
readonly string name;
readonly string enclosingModule;
readonly List<DAST.Type> typeParams;
readonly List<DAST.Type> superClasses;
readonly List<DAST.Field> fields = new();
readonly List<DAST.Method> body = new();

public ClassBuilder(ClassContainer parent, string name, List<DAST.Type> typeParams, List<DAST.Type> superClasses) {
public ClassBuilder(ClassContainer parent, string name, string enclosingModule, List<DAST.Type> typeParams, List<DAST.Type> superClasses) {
this.parent = parent;
this.name = name;
this.enclosingModule = enclosingModule;
this.typeParams = typeParams;
this.superClasses = superClasses;
}
Expand All @@ -96,6 +98,7 @@ public void AddField(DAST.Formal item, DAST.Expression defaultValue) {
public object Finish() {
parent.AddClass((Class)Class.create(
Sequence<Rune>.UnicodeFromString(this.name),
Sequence<Rune>.UnicodeFromString(this.enclosingModule),
Sequence<DAST.Type>.FromArray(this.typeParams.ToArray()),
Sequence<DAST.Type>.FromArray(this.superClasses.ToArray()),
Sequence<DAST.Field>.FromArray(this.fields.ToArray()),
Expand Down Expand Up @@ -199,21 +202,21 @@ public object Finish() {
interface DatatypeContainer {
void AddDatatype(Datatype item);

public DatatypeBuilder Datatype(string name, ISequence<Rune> enclosingModule, List<DAST.Type> typeParams, List<DAST.DatatypeCtor> ctors, bool isCo) {
public DatatypeBuilder Datatype(string name, string enclosingModule, List<DAST.Type> typeParams, List<DAST.DatatypeCtor> ctors, bool isCo) {
return new DatatypeBuilder(this, name, enclosingModule, typeParams, ctors, isCo);
}
}

class DatatypeBuilder : ClassLike {
readonly DatatypeContainer parent;
readonly string name;
readonly ISequence<Rune> enclosingModule;
readonly string enclosingModule;
readonly List<DAST.Type> typeParams;
readonly List<DAST.DatatypeCtor> ctors;
readonly bool isCo;
readonly List<DAST.Method> body = new();

public DatatypeBuilder(DatatypeContainer parent, string name, ISequence<Rune> enclosingModule, List<DAST.Type> typeParams, List<DAST.DatatypeCtor> ctors, bool isCo) {
public DatatypeBuilder(DatatypeContainer parent, string name, string enclosingModule, List<DAST.Type> typeParams, List<DAST.DatatypeCtor> ctors, bool isCo) {
this.parent = parent;
this.name = name;
this.typeParams = typeParams;
Expand All @@ -233,7 +236,7 @@ public void AddField(DAST.Formal item, DAST.Expression defaultValue) {
public object Finish() {
parent.AddDatatype((Datatype)Datatype.create(
Sequence<Rune>.UnicodeFromString(this.name),
this.enclosingModule,
Sequence<Rune>.UnicodeFromString(this.enclosingModule),
Sequence<DAST.Type>.FromArray(typeParams.ToArray()),
Sequence<DAST.DatatypeCtor>.FromArray(ctors.ToArray()),
Sequence<DAST.Method>.FromArray(body.ToArray()),
Expand Down
33 changes: 29 additions & 4 deletions Source/DafnyCore/Compilers/Dafny/Compiler-dafny.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ protected override IClassWriter CreateClass(string moduleName, string name, bool
typeParams.Add((DAST.Type)DAST.Type.create_TypeArg(Sequence<Rune>.UnicodeFromString(IdProtect(tp.GetCompileName(Options)))));
}

return new ClassWriter(this, builder.Class(name, typeParams, superClasses.Select(t => GenType(t)).ToList()));
return new ClassWriter(this, builder.Class(name, moduleName, typeParams, superClasses.Select(t => GenType(t)).ToList()));
} else {
throw new InvalidOperationException();
}
Expand Down Expand Up @@ -207,7 +207,7 @@ protected override IClassWriter DeclareDatatype(DatatypeDecl dt, ConcreteSyntaxT

return new ClassWriter(this, builder.Datatype(
dt.GetCompileName(Options),
Sequence<Rune>.UnicodeFromString(dt.EnclosingModuleDefinition.GetCompileName(Options)),
dt.EnclosingModuleDefinition.GetCompileName(Options),
typeParams,
ctors,
dt is CoDatatypeDecl
Expand Down Expand Up @@ -800,7 +800,19 @@ protected override ILvalue IdentLvalue(string var) {
}

protected override ILvalue SeqSelectLvalue(SeqSelectExpr ll, ConcreteSyntaxTree wr, ConcreteSyntaxTree wStmts) {
throw new NotImplementedException();
var sourceBuf = new ExprBuffer(null);
EmitExpr(ll.Seq, false, new BuilderSyntaxTree<ExprContainer>(sourceBuf), wStmts);

var indexBuf = new ExprBuffer(null);
EmitExpr(ll.E0, false, new BuilderSyntaxTree<ExprContainer>(indexBuf), wStmts);

return new ExprLvalue(
(DAST.Expression)DAST.Expression.create_Index(
sourceBuf.Finish(),
indexBuf.Finish()
),
null
);
}

protected override ILvalue MultiSelectLvalue(MultiSelectExpr ll, ConcreteSyntaxTree wr, ConcreteSyntaxTree wStmts) {
Expand Down Expand Up @@ -1370,7 +1382,20 @@ protected override void EmitExprAsNativeInt(Expression expr, bool inLetExprBody,

protected override void EmitIndexCollectionSelect(Expression source, Expression index, bool inLetExprBody,
ConcreteSyntaxTree wr, ConcreteSyntaxTree wStmts) {
throw new NotImplementedException();
var sourceBuf = new ExprBuffer(null);
EmitExpr(source, inLetExprBody, new BuilderSyntaxTree<ExprContainer>(sourceBuf), wStmts);

var indexBuf = new ExprBuffer(null);
EmitExpr(index, inLetExprBody, new BuilderSyntaxTree<ExprContainer>(indexBuf), wStmts);

if (wr is BuilderSyntaxTree<ExprContainer> builder) {
builder.Builder.AddExpr((DAST.Expression)DAST.Expression.create_Index(
sourceBuf.Finish(),
indexBuf.Finish()
));
} else {
throw new InvalidOperationException();
}
}

protected override void EmitIndexCollectionUpdate(Expression source, Expression index, Expression value,
Expand Down
38 changes: 33 additions & 5 deletions Source/DafnyCore/Compilers/Rust/Dafny-compiler-rust.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ module {:extern "DCOMP"} DCOMP {
defaultImpl := defaultImpl + "}\n";

var printImpl := "impl " + constrainedTypeParams + " ::dafny_runtime::DafnyPrint for r#" + c.name + typeParams + " {\n" + "fn fmt_print(&self, __fmt_print_formatter: &mut ::std::fmt::Formatter, _in_seq: bool) -> std::fmt::Result {\n";
printImpl := printImpl + "write!(__fmt_print_formatter, \"r#" + c.name + "(" + (if |c.fields| > 0 then "" else ")") + "\")?;";
printImpl := printImpl + "write!(__fmt_print_formatter, \"" + c.enclosingModule.id + "." + c.name + (if |c.fields| > 0 then "(" else "") + "\")?;";
var i := 0;
while i < |c.fields| {
var field := c.fields[i];
Expand All @@ -144,7 +144,11 @@ module {:extern "DCOMP"} DCOMP {
printImpl := printImpl + "\n::dafny_runtime::DafnyPrint::fmt_print(::std::ops::Deref::deref(&(self.r#" + field.formal.name + ".borrow())), __fmt_print_formatter, false)?;";
i := i + 1;
}
printImpl := printImpl + "\nwrite!(__fmt_print_formatter, \")\")?;\nOk(())\n}\n}\n";

if |c.fields| > 0 {
printImpl := printImpl + "\nwrite!(__fmt_print_formatter, \")\")?;";
}
printImpl := printImpl + "\nOk(())\n}\n}\n";

var ptrPartialEqImpl := "impl " + constrainedTypeParams + " ::std::cmp::PartialEq for r#" + c.name + typeParams + " {\n";
ptrPartialEqImpl := ptrPartialEqImpl + "fn eq(&self, other: &Self) -> bool {\n";
Expand Down Expand Up @@ -1341,12 +1345,13 @@ module {:extern "DCOMP"} DCOMP {
isErased := true;
}
case UnOp(Cardinality, e) => {
var recursiveGen, _, recErased, recIdents := GenExpr(e, params, false);
var recursiveGen, recOwned, recErased, recIdents := GenExpr(e, params, false);
if !recErased {
recursiveGen := "::dafny_runtime::DafnyErasable::erase_owned(" + recursiveGen + ")";
var eraseFn := if recOwned then "erase_owned" else "erase";
recursiveGen := "::dafny_runtime::DafnyErasable::" + eraseFn + "(" + recursiveGen + ")";
}

s := "(" + recursiveGen + ").len()";
s := "::dafny_runtime::BigInt::from((" + recursiveGen + ").len())";
isOwned := true;
readIdents := recIdents;
isErased := true;
Expand Down Expand Up @@ -1432,6 +1437,29 @@ module {:extern "DCOMP"} DCOMP {
isErased := false;
readIdents := recIdents;
}
case Index(on, idx) => {
var onString, onOwned, onErased, recIdents := GenExpr(on, params, false);
if !onErased {
var eraseFn := if onOwned then "erase_owned" else "erase";
onString := "::dafny_runtime::DafnyErasable::" + eraseFn + "(" + onString + ")";
}

var idxString, _, idxErased, recIdentsIdx := GenExpr(idx, params, true);
if !idxErased {
idxString := "::dafny_runtime::DafnyErasable::erase_owned(" + idxString + ")";
}

s := "(" + onString + ")" + "[<usize as ::dafny_runtime::NumCast>::from(" + idxString + ").unwrap()]";
if mustOwn {
s := "(" + s + ").clone()";
isOwned := true;
} else {
isOwned := false;
}

isErased := true;
readIdents := recIdents + recIdentsIdx;
}
case TupleSelect(on, idx) => {
var onString, _, tupErased, recIdents := GenExpr(on, params, false);
s := "(" + onString + ")." + natToString(idx);
Expand Down
Loading

0 comments on commit f3d96c4

Please sign in to comment.