diff --git a/src/libraries/System.Text.Json/gen/Reflection/TypeWrapper.cs b/src/libraries/System.Text.Json/gen/Reflection/TypeWrapper.cs index 2d06d9ec50cd9..c084c6f614823 100644 --- a/src/libraries/System.Text.Json/gen/Reflection/TypeWrapper.cs +++ b/src/libraries/System.Text.Json/gen/Reflection/TypeWrapper.cs @@ -396,12 +396,6 @@ public override PropertyInfo[] GetProperties(BindingFlags bindingAttr) { if (item is IPropertySymbol propertySymbol) { - // Skip auto-generated properties on records. - if (_typeSymbol.IsRecord && propertySymbol.DeclaringSyntaxReferences.Length == 0) - { - continue; - } - // Skip if: if ( // we want a static property and this is not static diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/CompilationHelper.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/CompilationHelper.cs index 8e3105e119238..485a79cd333cd 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/CompilationHelper.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/CompilationHelper.cs @@ -336,6 +336,59 @@ public partial class MyJsonContext : JsonSerializerContext return CreateCompilation(source); } + public static Compilation CreateReferencedLibRecordCompilation() + { + string source = @" + using System.Text.Json.Serialization; + + namespace ReferencedAssembly + { + public record LibRecord(int Id) + { + public string Address1 { get; set; } + public string Address2 { get; set; } + public string City { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public string Name { get; set; } + [JsonInclude] + public string PhoneNumber; + [JsonInclude] + public string Country; + } + } +"; + + return CreateCompilation(source); + } + + public static Compilation CreateReferencedSimpleLibRecordCompilation() + { + string source = @" + using System.Text.Json.Serialization; + + namespace ReferencedAssembly + { + public record LibRecord + { + public int Id { get; set; } + public string Address1 { get; set; } + public string Address2 { get; set; } + public string City { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public string Name { get; set; } + [JsonInclude] + public string PhoneNumber; + [JsonInclude] + public string Country; + } + } +"; + + return CreateCompilation(source); + } + internal static void CheckDiagnosticMessages( DiagnosticSeverity level, ImmutableArray diagnostics, diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs index 2cb749bb40bc6..80494bca5f29f 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Unit.Tests/JsonSourceGeneratorTests.cs @@ -449,22 +449,216 @@ public void UsePrivates() CheckFieldsPropertiesMethods(myType, expectedFieldNames, expectedPropertyNames, expectedMethodNames); } + [Fact] + public void Record() + { + // Compile the referenced assembly first. + Compilation referencedCompilation = CompilationHelper.CreateReferencedLibRecordCompilation(); + + // Emit the image of the referenced assembly. + byte[] referencedImage = CompilationHelper.CreateAssemblyImage(referencedCompilation); + + string source = @" + using System.Text.Json.Serialization; + + namespace HelloWorld + { + [JsonSerializable(typeof(AppRecord))] + internal partial class JsonContext : JsonSerializerContext + { + } + + public record AppRecord(int Id) + { + public string Address1 { get; set; } + public string Address2 { get; set; } + public string City { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public string Name { get; set; } + [JsonInclude] + public string PhoneNumber; + [JsonInclude] + public string Country; + } + }"; + + MetadataReference[] additionalReferences = { MetadataReference.CreateFromImage(referencedImage) }; + + Compilation compilation = CompilationHelper.CreateCompilation(source); + + JsonSourceGenerator generator = new JsonSourceGenerator(); + + Compilation newCompilation = CompilationHelper.RunGenerators(compilation, out ImmutableArray generatorDiags, generator); + + // Make sure compilation was successful. + CheckCompilationDiagnosticsErrors(generatorDiags); + CheckCompilationDiagnosticsErrors(newCompilation.GetDiagnostics()); + + Dictionary types = generator.GetSerializableTypes(); + + // Check base functionality of found types. + Assert.Equal(1, types.Count); + Type recordType = types["HelloWorld.AppRecord"]; + Assert.Equal("HelloWorld.AppRecord", recordType.FullName); + + // Check for received fields, properties and methods for NotMyType. + string[] expectedFieldsNames = { "Country", "PhoneNumber" }; + string[] expectedPropertyNames = { "Address1", "Address2", "City", "Id", "Name", "PostalCode", "State" }; + CheckFieldsPropertiesMethods(recordType, expectedFieldsNames, expectedPropertyNames); + + Assert.Equal(1, recordType.GetConstructors().Length); + } + + [Fact] + public void RecordInExternalAssembly() + { + // Compile the referenced assembly first. + Compilation referencedCompilation = CompilationHelper.CreateReferencedLibRecordCompilation(); + + // Emit the image of the referenced assembly. + byte[] referencedImage = CompilationHelper.CreateAssemblyImage(referencedCompilation); + + string source = @" + using System.Text.Json.Serialization; + using ReferencedAssembly; + + namespace HelloWorld + { + [JsonSerializable(typeof(LibRecord))] + internal partial class JsonContext : JsonSerializerContext + { + } + }"; + + MetadataReference[] additionalReferences = { MetadataReference.CreateFromImage(referencedImage) }; + + Compilation compilation = CompilationHelper.CreateCompilation(source, additionalReferences); + + JsonSourceGenerator generator = new JsonSourceGenerator(); + + Compilation newCompilation = CompilationHelper.RunGenerators(compilation, out ImmutableArray generatorDiags, generator); + + // Make sure compilation was successful. + CheckCompilationDiagnosticsErrors(generatorDiags); + CheckCompilationDiagnosticsErrors(newCompilation.GetDiagnostics()); + + Dictionary types = generator.GetSerializableTypes(); + + Assert.Equal(1, types.Count); + Type recordType = types["ReferencedAssembly.LibRecord"]; + Assert.Equal("ReferencedAssembly.LibRecord", recordType.FullName); + + string[] expectedFieldsNames = { "Country", "PhoneNumber" }; + string[] expectedPropertyNames = { "Address1", "Address2", "City", "Id", "Name", "PostalCode", "State" }; + CheckFieldsPropertiesMethods(recordType, expectedFieldsNames, expectedPropertyNames); + + Assert.Equal(1, recordType.GetConstructors().Length); + } + + [Fact] + public void RecordDerivedFromRecordInExternalAssembly() + { + // Compile the referenced assembly first. + Compilation referencedCompilation = CompilationHelper.CreateReferencedSimpleLibRecordCompilation(); + + // Emit the image of the referenced assembly. + byte[] referencedImage = CompilationHelper.CreateAssemblyImage(referencedCompilation); + + string source = @" + using System.Text.Json.Serialization; + using ReferencedAssembly; + + namespace HelloWorld + { + [JsonSerializable(typeof(AppRecord))] + internal partial class JsonContext : JsonSerializerContext + { + } + + internal record AppRecord : LibRecord + { + public string ExtraData { get; set; } + } + }"; + + MetadataReference[] additionalReferences = { MetadataReference.CreateFromImage(referencedImage) }; + + Compilation compilation = CompilationHelper.CreateCompilation(source, additionalReferences); + + JsonSourceGenerator generator = new JsonSourceGenerator(); + + Compilation newCompilation = CompilationHelper.RunGenerators(compilation, out ImmutableArray generatorDiags, generator); + + // Make sure compilation was successful. + CheckCompilationDiagnosticsErrors(generatorDiags); + CheckCompilationDiagnosticsErrors(newCompilation.GetDiagnostics()); + + Dictionary types = generator.GetSerializableTypes(); + + Assert.Equal(1, types.Count); + Type recordType = types["HelloWorld.AppRecord"]; + Assert.Equal("HelloWorld.AppRecord", recordType.FullName); + + string[] expectedFieldsNames = { "Country", "PhoneNumber" }; + string[] expectedPropertyNames = { "Address1", "Address2", "City", "ExtraData", "Id", "Name", "PostalCode", "State" }; + CheckFieldsPropertiesMethods(recordType, expectedFieldsNames, expectedPropertyNames, inspectBaseTypes: true); + + Assert.Equal(1, recordType.GetConstructors().Length); + } + private void CheckCompilationDiagnosticsErrors(ImmutableArray diagnostics) { Assert.Empty(diagnostics.Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); } - private void CheckFieldsPropertiesMethods(Type type, string[] expectedFields, string[] expectedProperties, string[] expectedMethods) + private void CheckFieldsPropertiesMethods( + Type type, + string[] expectedFields, + string[] expectedProperties, + string[] expectedMethods = null, + bool inspectBaseTypes = false) { BindingFlags bindingFlags = BindingFlags.Public | BindingFlags.Instance; - string[] receivedFields = type.GetFields(bindingFlags).Select(field => field.Name).OrderBy(s => s).ToArray(); - string[] receivedProperties = type.GetProperties(bindingFlags).Select(property => property.Name).OrderBy(s => s).ToArray(); + string[] receivedFields; + string[] receivedProperties; + + if (!inspectBaseTypes) + { + receivedFields = type.GetFields(bindingFlags).Select(field => field.Name).OrderBy(s => s).ToArray(); + receivedProperties = type.GetProperties(bindingFlags).Select(property => property.Name).OrderBy(s => s).ToArray(); + } + else + { + List fields = new List(); + List props = new List(); + + Type currentType = type; + while (currentType != null) + { + fields.AddRange(currentType.GetFields(bindingFlags).Select(property => property.Name).OrderBy(s => s).ToArray()); + props.AddRange(currentType.GetProperties(bindingFlags).Select(property => property.Name).OrderBy(s => s).ToArray()); + currentType = currentType.BaseType; + } + + receivedFields = fields.ToArray(); + receivedProperties = props.ToArray(); + } + string[] receivedMethods = type.GetMethods().Select(method => method.Name).OrderBy(s => s).ToArray(); + Array.Sort(receivedFields); + Array.Sort(receivedProperties); + Array.Sort(receivedMethods); + Assert.Equal(expectedFields, receivedFields); Assert.Equal(expectedProperties, receivedProperties); - Assert.Equal(expectedMethods, receivedMethods); + + if (expectedMethods != null) + { + Assert.Equal(expectedMethods, receivedMethods); + } } // TODO: add test guarding against (de)serializing static classes.