Skip to content

Commit

Permalink
F#: Fix code generation for enum style discriminated unions and for d…
Browse files Browse the repository at this point in the history
…iscriminated unions with 4 of more cases. (#9095)
  • Loading branch information
gfix authored Jul 31, 2024
1 parent e34a9b6 commit 4ff9f62
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 30 deletions.
42 changes: 16 additions & 26 deletions src/Orleans.CodeGenerator/SyntaxGeneration/FSharpUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,13 @@ public FSharpUnionCaseTypeDescription(Compilation compilation, INamedTypeSymbol

private static IEnumerable<IMemberDescription> GetUnionCaseDataMembers(LibraryTypes libraryTypes, INamedTypeSymbol symbol)
{
List<IPropertySymbol> dataMembers = new();
foreach (var property in symbol.GetDeclaredInstanceMembers<IPropertySymbol>())
List<IFieldSymbol> dataMembers = new();
foreach (var field in symbol.GetDeclaredInstanceMembers<IFieldSymbol>())
{
if (!property.Name.StartsWith("Item", System.StringComparison.Ordinal))
if (field.Name.StartsWith("item", System.StringComparison.Ordinal) || field.Name.Equals("_tag", System.StringComparison.Ordinal))
{
continue;
dataMembers.Add(field);
}

dataMembers.Add(property);
}

dataMembers.Sort(FSharpUnionCasePropertyNameComparer.Default);
Expand All @@ -134,11 +132,11 @@ private static IEnumerable<IMemberDescription> GetUnionCaseDataMembers(LibraryTy
}
}

private class FSharpUnionCasePropertyNameComparer : IComparer<IPropertySymbol>
private class FSharpUnionCasePropertyNameComparer : IComparer<IFieldSymbol>
{
public static FSharpUnionCasePropertyNameComparer Default { get; } = new FSharpUnionCasePropertyNameComparer();

public int Compare(IPropertySymbol x, IPropertySymbol y)
public int Compare(IFieldSymbol x, IFieldSymbol y)
{
var xName = x.Name;
var yName = y.Name;
Expand All @@ -159,30 +157,28 @@ public int Compare(IPropertySymbol x, IPropertySymbol y)
private class FSharpUnionCaseFieldDescription : IMemberDescription, ISerializableMember
{
private readonly LibraryTypes _libraryTypes;
private readonly IPropertySymbol _property;
private readonly IFieldSymbol _field;

public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IPropertySymbol property, uint ordinal)
public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IFieldSymbol field, uint ordinal)
{
_libraryTypes = libraryTypes;
FieldId = ordinal;
_property = property;
_field = field;
}

public uint FieldId { get; }

public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || _property.HasAnyAttribute(_libraryTypes.ImmutableAttributes);
public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || _field.HasAnyAttribute(_libraryTypes.ImmutableAttributes);

public bool IsValueType => Type.IsValueType;

public IMemberDescription Member => this;

public ITypeSymbol Type => _property.Type;
public ITypeSymbol Type => _field.Type;

public INamedTypeSymbol ContainingType => _property.ContainingType;

public ISymbol Symbol => _property;
public INamedTypeSymbol ContainingType => _field.ContainingType;

public string FieldName => _property.Name.ToLowerInvariant();
public ISymbol Symbol => _field;

/// <summary>
/// Gets the name of the setter field.
Expand All @@ -193,15 +189,9 @@ public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IPropertySymbo
/// Gets syntax representing the type of this field.
/// </summary>
public TypeSyntax TypeSyntax => Type.TypeKind == TypeKind.Dynamic
? PredefinedType(Token(SyntaxKind.ObjectKeyword))
? PredefinedType(Token(SyntaxKind.ObjectKeyword))
: GetTypeSyntax(Type);

/// <summary>
/// Gets the <see cref="Property"/> which this field is the backing property for, or
/// <see langword="null" /> if this is not the backing field of an auto-property.
/// </summary>
private IPropertySymbol Property => _property;

public string AssemblyName => Type.ContainingAssembly.ToDisplayName();
public string TypeName => Type.ToDisplayName();
public string TypeNameIdentifier => Type.GetValidIdentifier();
Expand All @@ -215,7 +205,7 @@ public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IPropertySymbo
/// </summary>
/// <param name="instance">The instance of the containing type.</param>
/// <returns>Syntax for retrieving the value of this field.</returns>
public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(Property.Name);
public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_field.Name);

/// <summary>
/// Returns syntax for setting the value of this field.
Expand All @@ -239,7 +229,7 @@ public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax va
public FieldAccessorDescription GetGetterFieldDescription() => null;

public FieldAccessorDescription GetSetterFieldDescription()
=> SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, FieldName, SetterFieldName, _libraryTypes, true);
=> SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, _field.Name, SetterFieldName, _libraryTypes, true);
}
}

Expand Down
37 changes: 35 additions & 2 deletions test/Grains/TestFSharp/Types.fs
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
namespace UnitTests.FSharpTypes

open System.Runtime.CompilerServices
open Orleans

[<Immutable; GenerateSerializer>]
type EnumStyleDU =
| Case1
| Case2
| Case3

[<Immutable; GenerateSerializer>]
type MixCaseDU =
| Case1
| Case2 of string

[<Immutable; GenerateSerializer>]
type RecursiveDU =
| Case1
| Case2 of RecursiveDU

[<Immutable; GenerateSerializer>]
type GenericDU<'T> =
| Case1 of 'T
| Case2

[<Immutable; GenerateSerializer>]
type SingleCaseDU =
| Case1 of int
Expand All @@ -25,6 +47,14 @@ type QuadrupleCaseDU =
| Case3 of char
| Case4 of byte

[<Immutable; GenerateSerializer>]
type QuintupleCaseDU =
| Case1
| Case2 of int
| Case3
| Case4 of byte
| Case5 of string

[<Immutable; GenerateSerializer>]
type Record = { [<Id(1u)>] A: SingleCaseDU } with
static member ofInt x = { A = SingleCaseDU.ofInt x }
Expand Down Expand Up @@ -60,8 +90,11 @@ type DiscriminatedUnion =

static member set s = SetFieldCase s
static member emptySet() = SetFieldCase Set.empty
static member nonEmptySet() = Set.ofList [1; 2; 3] |> SetFieldCase
static member nonEmptySet() = Set.ofList [1; 2; 3] |> SetFieldCase

static member map m = MapFieldCase m
static member emptyMap() = MapFieldCase Map.empty
static member nonEmptyMap() = Map.ofList [0, "zero"; 1, "one"] |> MapFieldCase
static member nonEmptyMap() = Map.ofList [0, "zero"; 1, "one"] |> MapFieldCase

[<InternalsVisibleTo("TestFSharpGrainInterfaces")>]
do ()
101 changes: 99 additions & 2 deletions test/Orleans.Serialization.FSharp.Tests/SerializationTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,72 @@ type FSharpSerializationTests(fixture: DefaultClusterFixture) =
Assert.Equal((), roundtripped)
Assert.Equal((), copy)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_EnumStyleDU () =
let case1 = EnumStyleDU.Case1
let case2 = EnumStyleDU.Case2

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1, roundtrippedCase1)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1, copyCase1)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_MixCaseDU () =
let case1 = MixCaseDU.Case1
let case2 = MixCaseDU.Case2 "Case2"

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1, roundtrippedCase1)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1, copyCase1)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_RecursiveDU () =
let case1 = RecursiveDU.Case1
let case2 = RecursiveDU.Case2 (RecursiveDU.Case2 RecursiveDU.Case1)

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1, roundtrippedCase1)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1, copyCase1)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_GenericDU () =
let case1String = GenericDU.Case1 "string"
let case1Int = GenericDU.Case1 99
let case1Case2 = GenericDU.Case1 GenericDU.Case2
let case2 = GenericDU.Case2

let roundtrippedCase1String = cluster.RoundTripSerializationForTesting case1String
let roundtrippedCase1Int = cluster.RoundTripSerializationForTesting case1Int
let roundtrippedCase1Case2 = cluster.RoundTripSerializationForTesting case1Case2
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1String = cluster.DeepCopy case1String
let copyCase1Int = cluster.DeepCopy case1Int
let copyCase1Case2 = cluster.DeepCopy case1Case2
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1String, roundtrippedCase1String)
Assert.Equal(case1Int, roundtrippedCase1Int)
Assert.Equal(case1Case2, roundtrippedCase1Case2)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1String, copyCase1String)
Assert.Equal(case1Int, copyCase1Int)
Assert.Equal(case1Case2, copyCase1Case2)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_SingleCaseDiscriminatedUnion () =
let du = SingleCaseDU.Case1 1
Expand Down Expand Up @@ -62,7 +128,7 @@ type FSharpSerializationTests(fixture: DefaultClusterFixture) =
Assert.Equal(case2, copyCase2)
Assert.Equal(case3, copyCase3)

[<Fact(Skip = "DUs with 4 or more cases fail when trying to instantiate Case{2-4}-classes via RuntimeHelpers.GetUninitializedObject when deserializing")>]
[<Fact>]
[<TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_QuadrupleCaseDiscriminatedUnion () =
let case1 = QuadrupleCaseDU.Case1 "case 1"
Expand All @@ -86,4 +152,35 @@ type FSharpSerializationTests(fixture: DefaultClusterFixture) =
Assert.Equal(case1, copyCase1);
Assert.Equal(case2, copyCase2);
Assert.Equal(case3, copyCase3);
Assert.Equal(case4, copyCase4);
Assert.Equal(case4, copyCase4)

[<Fact>]
[<TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_QuintupleCaseDiscriminatedUnion () =
let case1 = QuintupleCaseDU.Case1
let case2 = QuintupleCaseDU.Case2 2
let case3 = QuintupleCaseDU.Case3
let case4 = QuintupleCaseDU.Case4 1uy
let case5 = QuintupleCaseDU.Case5 "case 5"

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let roundtrippedCase3 = cluster.RoundTripSerializationForTesting case3
let roundtrippedCase4 = cluster.RoundTripSerializationForTesting case4
let roundtrippedCase5 = cluster.RoundTripSerializationForTesting case5
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
let copyCase3 = cluster.DeepCopy case3
let copyCase4 = cluster.DeepCopy case4
let copyCase5 = cluster.DeepCopy case5

Assert.Equal(case1, roundtrippedCase1);
Assert.Equal(case2, roundtrippedCase2);
Assert.Equal(case3, roundtrippedCase3);
Assert.Equal(case4, roundtrippedCase4);
Assert.Equal(case5, roundtrippedCase5);
Assert.Equal(case1, copyCase1);
Assert.Equal(case2, copyCase2);
Assert.Equal(case3, copyCase3);
Assert.Equal(case4, copyCase4)
Assert.Equal(case5, copyCase5)

0 comments on commit 4ff9f62

Please sign in to comment.