Skip to content

Commit

Permalink
Support tail recursion and nullable references in the Rust backend
Browse files Browse the repository at this point in the history
<small>By submitting this pull request, I confirm that my contribution is made under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).</small>
  • Loading branch information
shadaj committed Aug 23, 2023
1 parent f3d96c4 commit 9439dab
Show file tree
Hide file tree
Showing 8 changed files with 8,081 additions and 6,948 deletions.
7 changes: 6 additions & 1 deletion Source/DafnyCore/AST/Types/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2321,13 +2321,18 @@ public static UserDefinedType FromTopLevelDeclWithAllBooleanTypeParameters(TopLe
/// Return the upcast of "receiverType" that has base type "member.EnclosingClass".
/// Assumes that "receiverType" normalizes to a UserDefinedFunction with a .ResolveClass that is a subtype
/// of "member.EnclosingClass".
/// Preserves non-null-ness of "receiverType" if it is a non-null reference.
/// Otherwise:
/// Return "receiverType" (expanded).
/// </summary>
public static Type UpcastToMemberEnclosingType(Type receiverType, MemberDecl/*?*/ member) {
Contract.Requires(receiverType != null);
if (member != null && member.EnclosingClass != null && !(member.EnclosingClass is ValuetypeDecl)) {
return receiverType.AsParentType(member.EnclosingClass);
if (receiverType.IsNonNullRefType) {
return CreateNonNullType(receiverType.AsParentType(member.EnclosingClass));
} else {
return receiverType.AsParentType(member.EnclosingClass);
}
}
return receiverType.NormalizeExpandKeepConstraints();
}
Expand Down
3 changes: 3 additions & 0 deletions Source/DafnyCore/Compilers/Dafny/AST.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module {:extern "DAST"} DAST {

datatype Type =
Path(seq<Ident>, typeArgs: seq<Type>, resolved: ResolvedType) |
Nullable(Type) |
Tuple(seq<Type>) |
Array(element: Type) |
Seq(element: Type) |
Expand Down Expand Up @@ -49,6 +50,8 @@ module {:extern "DAST"} DAST {
Call(on: Expression, name: string, typeArgs: seq<Type>, args: seq<Expression>, outs: Optional<seq<Ident>>) |
Return(expr: Expression) |
EarlyReturn() |
TailRecursive(body: seq<Statement>) |
JumpTailCallStart() |
Halt() |
Print(Expression)

Expand Down
35 changes: 35 additions & 0 deletions Source/DafnyCore/Compilers/Dafny/ASTBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,12 @@ public WhileBuilder While() {
return ret;
}

public TailRecursiveBuilder TailRecursive() {
var ret = new TailRecursiveBuilder();
AddBuildable(ret);
return ret;
}

public CallStmtBuilder Call() {
var ret = new CallStmtBuilder();
AddBuildable(ret);
Expand Down Expand Up @@ -640,6 +646,35 @@ public DAST.Statement Build() {
}
}

class TailRecursiveBuilder : StatementContainer, BuildableStatement {
readonly List<object> body = new();

public TailRecursiveBuilder() { }

public void AddStatement(DAST.Statement item) {
body.Add(item);
}

public void AddBuildable(BuildableStatement item) {
body.Add(item);
}

public List<object> ForkList() {
var ret = new List<object>();
this.body.Add(ret);
return ret;
}

public DAST.Statement Build() {
List<DAST.Statement> builtStatements = new();
StatementContainer.RecursivelyBuild(body, builtStatements);

return (DAST.Statement)DAST.Statement.create_TailRecursive(
Sequence<DAST.Statement>.FromArray(builtStatements.ToArray())
);
}
}

class CallStmtBuilder : ExprContainer, BuildableStatement {
object on = null;
string name = null;
Expand Down
65 changes: 55 additions & 10 deletions Source/DafnyCore/Compilers/Dafny/Compiler-dafny.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,29 @@ private static string MangleName(string name) {
}

protected override ConcreteSyntaxTree EmitCoercionIfNecessary(Type from, Type to, IToken tok, ConcreteSyntaxTree wr) {
if (from.AsSubsetType == null && to.AsSubsetType != null) {
if (from != null && to != null && from.IsNonNullRefType != to.IsNonNullRefType) {
if (wr is BuilderSyntaxTree<ExprContainer> stmt) {
var nullConvert = stmt.Builder.Convert(GenType(from), GenType(to));

if (from is UserDefinedType fromUdt && fromUdt.ResolvedClass is NonNullTypeDecl fromNonNull) {
from = fromNonNull.RhsWithArgument(fromUdt.TypeArgs);
}

if (to is UserDefinedType toUdt && toUdt.ResolvedClass is NonNullTypeDecl toNonNull) {
to = toNonNull.RhsWithArgument(toUdt.TypeArgs);
}

return EmitCoercionIfNecessary(from, to, tok, new BuilderSyntaxTree<ExprContainer>(nullConvert));
} else {
return base.EmitCoercionIfNecessary(from, to, tok, wr);
}
} else if (from != null && to != null && from.AsSubsetType == null && to.AsSubsetType != null) {
if (wr is BuilderSyntaxTree<ExprContainer> stmt) {
return new BuilderSyntaxTree<ExprContainer>(stmt.Builder.Convert(GenType(from), GenType(to)));
} else {
return base.EmitCoercionIfNecessary(from, to, tok, wr);
}
} else if (from.AsSubsetType != null && to.AsSubsetType == null) {
} else if (from != null && to != null && from.AsSubsetType != null && to.AsSubsetType == null) {
if (wr is BuilderSyntaxTree<ExprContainer> stmt) {
return new BuilderSyntaxTree<ExprContainer>(stmt.Builder.Convert(GenType(from), GenType(to)));
} else {
Expand Down Expand Up @@ -515,12 +531,27 @@ protected override string TypeDescriptor(Type type, ConcreteSyntaxTree wr, IToke
return type.ToString();
}

protected override ConcreteSyntaxTree EmitMethodReturns(Method m, ConcreteSyntaxTree wr) {
var beforeReturnBlock = wr.Fork();
EmitReturn(m.Outs, wr);
return beforeReturnBlock;
}

protected override ConcreteSyntaxTree EmitTailCallStructure(MemberDecl member, ConcreteSyntaxTree wr) {
throw new NotImplementedException();
if (wr is BuilderSyntaxTree<StatementContainer> stmtContainer) {
var recBuilder = stmtContainer.Builder.TailRecursive();
return new BuilderSyntaxTree<StatementContainer>(recBuilder);
} else {
throw new InvalidOperationException();
}
}

protected override void EmitJumpToTailCallStart(ConcreteSyntaxTree wr) {
throw new NotImplementedException();
if (wr is BuilderSyntaxTree<StatementContainer> stmtContainer) {
stmtContainer.Builder.AddStatement((DAST.Statement)DAST.Statement.create_JumpTailCallStart());
} else {
throw new InvalidOperationException();
}
}

internal override string TypeName(Type type, ConcreteSyntaxTree wr, IToken tok, MemberDecl/*?*/ member = null) {
Expand Down Expand Up @@ -701,8 +732,12 @@ protected override void EmitCallReturnOuts(List<string> outTmps, ConcreteSyntaxT

protected override void TrCallStmt(CallStmt s, string receiverReplacement, ConcreteSyntaxTree wr, ConcreteSyntaxTree wrStmts, ConcreteSyntaxTree wrStmtsAfterCall) {
if (wr is BuilderSyntaxTree<StatementContainer> stmtContainer) {
var callBuilder = stmtContainer.Builder.Call();
base.TrCallStmt(s, receiverReplacement, new BuilderSyntaxTree<ExprContainer>(callBuilder), wrStmts, wrStmtsAfterCall);
if (s.Method == enclosingMethod && enclosingMethod.IsTailRecursive) {
base.TrCallStmt(s, receiverReplacement, wr, wrStmts, wrStmtsAfterCall);
} else {
var callBuilder = stmtContainer.Builder.Call();
base.TrCallStmt(s, receiverReplacement, new BuilderSyntaxTree<ExprContainer>(callBuilder), wrStmts, wrStmtsAfterCall);
}
} else {
throw new InvalidOperationException("Cannot call statement in this context: " + currentBuilder);
}
Expand Down Expand Up @@ -795,6 +830,10 @@ public ConcreteSyntaxTree EmitWrite(ConcreteSyntaxTree wr) {
}
}

protected override void EmitAssignment(string lhs, Type/*?*/ lhsType, string rhs, Type/*?*/ rhsType, ConcreteSyntaxTree wr) {
throw new InvalidOperationException("Cannot use stringified version of assignment");
}

protected override ILvalue IdentLvalue(string var) {
return new BuilderLvalue(var);
}
Expand Down Expand Up @@ -1152,11 +1191,11 @@ private ISequence<ISequence<Rune>> PathFromTopLevel(TopLevelDecl topLevel) {
private DAST.Type TypeNameASTFromTopLevel(TopLevelDecl topLevel, List<Type> typeArgs) {
var path = PathFromTopLevel(topLevel);

// TODO(shadaj): do something with nullable references
// bool nonNull = false;
bool nonNull = true;
if (topLevel is NonNullTypeDecl non) {
// nonNull = true;
topLevel = non.Rhs.AsTopLevelTypeWithMembers;
} else if (topLevel is ClassLikeDecl) {
nonNull = false;
}

ResolvedType resolvedType;
Expand All @@ -1175,11 +1214,17 @@ private DAST.Type TypeNameASTFromTopLevel(TopLevelDecl topLevel, List<Type> type
throw new InvalidOperationException(topLevel.GetType().ToString());
}

return (DAST.Type)DAST.Type.create_Path(
DAST.Type baseType = (DAST.Type)DAST.Type.create_Path(
path,
Sequence<DAST.Type>.FromArray(typeArgs.Select(m => GenType(m)).ToArray()),
resolvedType
);

if (nonNull) {
return baseType;
} else {
return (DAST.Type)DAST.Type.create_Nullable(baseType);
}
}

public override ConcreteSyntaxTree Expr(Expression expr, bool inLetExprBody, ConcreteSyntaxTree wStmts) {
Expand Down
Loading

0 comments on commit 9439dab

Please sign in to comment.