Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F#: Fixing code generation for enum style discriminated unions and for discriminated unions with 4 of more cases #9095

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 16 additions & 26 deletions src/Orleans.CodeGenerator/SyntaxGeneration/FSharpUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,13 @@ public FSharpUnionCaseTypeDescription(Compilation compilation, INamedTypeSymbol

private static IEnumerable<IMemberDescription> GetUnionCaseDataMembers(LibraryTypes libraryTypes, INamedTypeSymbol symbol)
{
List<IPropertySymbol> dataMembers = new();
foreach (var property in symbol.GetDeclaredInstanceMembers<IPropertySymbol>())
List<IFieldSymbol> dataMembers = new();
foreach (var field in symbol.GetDeclaredInstanceMembers<IFieldSymbol>())
{
if (!property.Name.StartsWith("Item", System.StringComparison.Ordinal))
if (field.Name.StartsWith("item", System.StringComparison.Ordinal) || field.Name.Equals("_tag", System.StringComparison.Ordinal))
{
continue;
dataMembers.Add(field);
}

dataMembers.Add(property);
}

dataMembers.Sort(FSharpUnionCasePropertyNameComparer.Default);
Expand All @@ -134,11 +132,11 @@ private static IEnumerable<IMemberDescription> GetUnionCaseDataMembers(LibraryTy
}
}

private class FSharpUnionCasePropertyNameComparer : IComparer<IPropertySymbol>
private class FSharpUnionCasePropertyNameComparer : IComparer<IFieldSymbol>
{
public static FSharpUnionCasePropertyNameComparer Default { get; } = new FSharpUnionCasePropertyNameComparer();

public int Compare(IPropertySymbol x, IPropertySymbol y)
public int Compare(IFieldSymbol x, IFieldSymbol y)
{
var xName = x.Name;
var yName = y.Name;
Expand All @@ -159,30 +157,28 @@ public int Compare(IPropertySymbol x, IPropertySymbol y)
private class FSharpUnionCaseFieldDescription : IMemberDescription, ISerializableMember
{
private readonly LibraryTypes _libraryTypes;
private readonly IPropertySymbol _property;
private readonly IFieldSymbol _field;

public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IPropertySymbol property, uint ordinal)
public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IFieldSymbol field, uint ordinal)
{
_libraryTypes = libraryTypes;
FieldId = ordinal;
_property = property;
_field = field;
}

public uint FieldId { get; }

public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || _property.HasAnyAttribute(_libraryTypes.ImmutableAttributes);
public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || _field.HasAnyAttribute(_libraryTypes.ImmutableAttributes);

public bool IsValueType => Type.IsValueType;

public IMemberDescription Member => this;

public ITypeSymbol Type => _property.Type;
public ITypeSymbol Type => _field.Type;

public INamedTypeSymbol ContainingType => _property.ContainingType;

public ISymbol Symbol => _property;
public INamedTypeSymbol ContainingType => _field.ContainingType;

public string FieldName => _property.Name.ToLowerInvariant();
public ISymbol Symbol => _field;

/// <summary>
/// Gets the name of the setter field.
Expand All @@ -193,15 +189,9 @@ public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IPropertySymbo
/// Gets syntax representing the type of this field.
/// </summary>
public TypeSyntax TypeSyntax => Type.TypeKind == TypeKind.Dynamic
? PredefinedType(Token(SyntaxKind.ObjectKeyword))
? PredefinedType(Token(SyntaxKind.ObjectKeyword))
: GetTypeSyntax(Type);

/// <summary>
/// Gets the <see cref="Property"/> which this field is the backing property for, or
/// <see langword="null" /> if this is not the backing field of an auto-property.
/// </summary>
private IPropertySymbol Property => _property;

public string AssemblyName => Type.ContainingAssembly.ToDisplayName();
public string TypeName => Type.ToDisplayName();
public string TypeNameIdentifier => Type.GetValidIdentifier();
Expand All @@ -215,7 +205,7 @@ public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IPropertySymbo
/// </summary>
/// <param name="instance">The instance of the containing type.</param>
/// <returns>Syntax for retrieving the value of this field.</returns>
public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(Property.Name);
public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_field.Name);

/// <summary>
/// Returns syntax for setting the value of this field.
Expand All @@ -239,7 +229,7 @@ public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax va
public FieldAccessorDescription GetGetterFieldDescription() => null;

public FieldAccessorDescription GetSetterFieldDescription()
=> SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, FieldName, SetterFieldName, _libraryTypes, true);
=> SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, _field.Name, SetterFieldName, _libraryTypes, true);
}
}

Expand Down
37 changes: 35 additions & 2 deletions test/Grains/TestFSharp/Types.fs
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
namespace UnitTests.FSharpTypes

open System.Runtime.CompilerServices
open Orleans

[<Immutable; GenerateSerializer>]
type EnumStyleDU =
| Case1
| Case2
| Case3

[<Immutable; GenerateSerializer>]
type MixCaseDU =
| Case1
| Case2 of string

[<Immutable; GenerateSerializer>]
type RecursiveDU =
| Case1
| Case2 of RecursiveDU

[<Immutable; GenerateSerializer>]
type GenericDU<'T> =
| Case1 of 'T
| Case2

[<Immutable; GenerateSerializer>]
type SingleCaseDU =
| Case1 of int
Expand All @@ -25,6 +47,14 @@ type QuadrupleCaseDU =
| Case3 of char
| Case4 of byte

[<Immutable; GenerateSerializer>]
type QuintupleCaseDU =
| Case1
| Case2 of int
| Case3
| Case4 of byte
| Case5 of string

[<Immutable; GenerateSerializer>]
type Record = { [<Id(1u)>] A: SingleCaseDU } with
static member ofInt x = { A = SingleCaseDU.ofInt x }
Expand Down Expand Up @@ -60,8 +90,11 @@ type DiscriminatedUnion =

static member set s = SetFieldCase s
static member emptySet() = SetFieldCase Set.empty
static member nonEmptySet() = Set.ofList [1; 2; 3] |> SetFieldCase
static member nonEmptySet() = Set.ofList [1; 2; 3] |> SetFieldCase

static member map m = MapFieldCase m
static member emptyMap() = MapFieldCase Map.empty
static member nonEmptyMap() = Map.ofList [0, "zero"; 1, "one"] |> MapFieldCase
static member nonEmptyMap() = Map.ofList [0, "zero"; 1, "one"] |> MapFieldCase

[<InternalsVisibleTo("TestFSharpGrainInterfaces")>]
do ()
101 changes: 99 additions & 2 deletions test/Orleans.Serialization.FSharp.Tests/SerializationTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,72 @@ type FSharpSerializationTests(fixture: DefaultClusterFixture) =
Assert.Equal((), roundtripped)
Assert.Equal((), copy)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_EnumStyleDU () =
let case1 = EnumStyleDU.Case1
let case2 = EnumStyleDU.Case2

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1, roundtrippedCase1)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1, copyCase1)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_MixCaseDU () =
let case1 = MixCaseDU.Case1
let case2 = MixCaseDU.Case2 "Case2"

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1, roundtrippedCase1)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1, copyCase1)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_RecursiveDU () =
let case1 = RecursiveDU.Case1
let case2 = RecursiveDU.Case2 (RecursiveDU.Case2 RecursiveDU.Case1)

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1, roundtrippedCase1)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1, copyCase1)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_GenericDU () =
let case1String = GenericDU.Case1 "string"
let case1Int = GenericDU.Case1 99
let case1Case2 = GenericDU.Case1 GenericDU.Case2
let case2 = GenericDU.Case2

let roundtrippedCase1String = cluster.RoundTripSerializationForTesting case1String
let roundtrippedCase1Int = cluster.RoundTripSerializationForTesting case1Int
let roundtrippedCase1Case2 = cluster.RoundTripSerializationForTesting case1Case2
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let copyCase1String = cluster.DeepCopy case1String
let copyCase1Int = cluster.DeepCopy case1Int
let copyCase1Case2 = cluster.DeepCopy case1Case2
let copyCase2 = cluster.DeepCopy case2
Assert.Equal(case1String, roundtrippedCase1String)
Assert.Equal(case1Int, roundtrippedCase1Int)
Assert.Equal(case1Case2, roundtrippedCase1Case2)
Assert.Equal(case2, roundtrippedCase2)
Assert.Equal(case1String, copyCase1String)
Assert.Equal(case1Int, copyCase1Int)
Assert.Equal(case1Case2, copyCase1Case2)
Assert.Equal(case2, copyCase2)

[<Fact; TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_SingleCaseDiscriminatedUnion () =
let du = SingleCaseDU.Case1 1
Expand Down Expand Up @@ -62,7 +128,7 @@ type FSharpSerializationTests(fixture: DefaultClusterFixture) =
Assert.Equal(case2, copyCase2)
Assert.Equal(case3, copyCase3)

[<Fact(Skip = "DUs with 4 or more cases fail when trying to instantiate Case{2-4}-classes via RuntimeHelpers.GetUninitializedObject when deserializing")>]
[<Fact>]
[<TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_QuadrupleCaseDiscriminatedUnion () =
let case1 = QuadrupleCaseDU.Case1 "case 1"
Expand All @@ -86,4 +152,35 @@ type FSharpSerializationTests(fixture: DefaultClusterFixture) =
Assert.Equal(case1, copyCase1);
Assert.Equal(case2, copyCase2);
Assert.Equal(case3, copyCase3);
Assert.Equal(case4, copyCase4);
Assert.Equal(case4, copyCase4)

[<Fact>]
[<TestCategory("BVT"); TestCategory("Serialization")>]
let Serialization_Roundtrip_FSharp_QuintupleCaseDiscriminatedUnion () =
let case1 = QuintupleCaseDU.Case1
let case2 = QuintupleCaseDU.Case2 2
let case3 = QuintupleCaseDU.Case3
let case4 = QuintupleCaseDU.Case4 1uy
let case5 = QuintupleCaseDU.Case5 "case 5"

let roundtrippedCase1 = cluster.RoundTripSerializationForTesting case1
let roundtrippedCase2 = cluster.RoundTripSerializationForTesting case2
let roundtrippedCase3 = cluster.RoundTripSerializationForTesting case3
let roundtrippedCase4 = cluster.RoundTripSerializationForTesting case4
let roundtrippedCase5 = cluster.RoundTripSerializationForTesting case5
let copyCase1 = cluster.DeepCopy case1
let copyCase2 = cluster.DeepCopy case2
let copyCase3 = cluster.DeepCopy case3
let copyCase4 = cluster.DeepCopy case4
let copyCase5 = cluster.DeepCopy case5

Assert.Equal(case1, roundtrippedCase1);
Assert.Equal(case2, roundtrippedCase2);
Assert.Equal(case3, roundtrippedCase3);
Assert.Equal(case4, roundtrippedCase4);
Assert.Equal(case5, roundtrippedCase5);
Assert.Equal(case1, copyCase1);
Assert.Equal(case2, copyCase2);
Assert.Equal(case3, copyCase3);
Assert.Equal(case4, copyCase4)
Assert.Equal(case5, copyCase5)
Loading