From aa531cfafced44988828865e39ffa8f1adc9943f Mon Sep 17 00:00:00 2001 From: Jorge Rangel <102122018+jorgerangel-msft@users.noreply.github.com> Date: Wed, 6 Nov 2024 11:54:57 -0600 Subject: [PATCH] [http-client-csharp] fix: generate deserialization switch for nested discriminators (#4982) fixes: https://github.com/microsoft/typespec/issues/4979 --- .../MrwSerializationTypeDefinition.cs | 8 +- .../DiscriminatorTests.cs | 51 +++++++ ...dDiscriminatedModelWithOwnDiscriminator.cs | 127 ++++++++++++++++++ 3 files changed, 182 insertions(+), 4 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/DiscriminatorTests/TestNestedDiscriminatedModelWithOwnDiscriminator.cs diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 55a5f5c06a..c4f19fa92d 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -445,19 +445,19 @@ internal MethodProvider BuildDeserializationMethod() return new MethodProvider ( new MethodSignature(methodName, null, signatureModifiers, _model.Type, null, [_jsonElementDeserializationParam, _serializationOptionsParameter]), - _model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Abstract) && _inputModel.DiscriminatedSubtypes.Count > 0 ? BuildAbstractDeserializationMethodBody() : BuildDeserializationMethodBody(), + _inputModel.DiscriminatedSubtypes.Count > 0 ? BuildDiscriminatedModelDeserializationMethodBody() : BuildDeserializationMethodBody(), this ); } - private MethodBodyStatement[] BuildAbstractDeserializationMethodBody() + private MethodBodyStatement[] BuildDiscriminatedModelDeserializationMethodBody() { var unknownVariant = _model.DerivedModels.First(m => m.IsUnknownDiscriminatorModel); bool onlyContainsUnknownDerivedModel = _model.DerivedModels.Count == 1; var discriminator = _model.CanonicalView.Properties.Where(p => p.IsDiscriminator).FirstOrDefault(); var deserializeDiscriminatedModelsConditions = BuildDiscriminatedModelsCondition( discriminator, - GetAbstractSwitchCases(unknownVariant), + GetDiscriminatorSwitchCases(unknownVariant), onlyContainsUnknownDerivedModel, _jsonElementParameterSnippet); @@ -488,7 +488,7 @@ private static MethodBodyStatement BuildDiscriminatedModelsCondition( return MethodBodyStatement.Empty; } - private SwitchCaseStatement[] GetAbstractSwitchCases(ModelProvider unknownVariant) + private SwitchCaseStatement[] GetDiscriminatorSwitchCases(ModelProvider unknownVariant) { SwitchCaseStatement[] cases = new SwitchCaseStatement[_model.DerivedModels.Count - 1]; int index = 0; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/DiscriminatorTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/DiscriminatorTests.cs index 19d1da448b..c137f22f75 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/DiscriminatorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/DiscriminatorTests.cs @@ -254,5 +254,56 @@ public void TestBuildJsonModelCreateMethodProperlyDoesNotCastForUnknown() "this.JsonModelCreateCore(ref reader, options)", invocationExpression!.ToDisplayString()); } + + // This test validates that a discriminated sub-type with its own discriminator property + // properly generates the deserialization method to deserialize into its' discriminated sub-types + [Test] + public void TestNestedDiscriminatedModelWithOwnDiscriminator() + { + var oakTreeModel = InputFactory.Model( + "oakTree", + discriminatedKind: "oak", + properties: + [ + InputFactory.Property("treeType", InputPrimitiveType.String, isRequired: true), + ]); + var treeModel = InputFactory.Model( + "tree", + discriminatedKind: "tree", + properties: + [ + InputFactory.Property("treeType", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + ], + discriminatedModels: new Dictionary() { { "oak", oakTreeModel } }); + var baseModel = InputFactory.Model( + "plant", + properties: + [ + InputFactory.Property("foo", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + ], + discriminatedModels: new Dictionary() { { "tree", treeModel } }); + + MockHelpers.LoadMockPlugin(inputModels: () => [baseModel, treeModel]); + var baseModelProvider = ClientModelPlugin.Instance.OutputLibrary.TypeProviders.OfType() + .FirstOrDefault(t => t.Name == "Plant"); + var treeModelProvider = ClientModelPlugin.Instance.OutputLibrary.TypeProviders.OfType() + .FirstOrDefault(t => t.Name == "Tree"); + Assert.IsNotNull(baseModelProvider); + Assert.IsNotNull(treeModelProvider); + + // validate the base discriminator deserialization method has the switch statement + var baseDeserializationMethod = baseModelProvider!.SerializationProviders.FirstOrDefault()!.Methods + .FirstOrDefault(m => m.Signature.Name == "DeserializePlant"); + Assert.IsTrue(baseDeserializationMethod?.BodyStatements!.ToDisplayString().Contains( + $"if (element.TryGetProperty(\"foo\"u8, out global::System.Text.Json.JsonElement discriminator))")); + + var treeModelSerializationProvider = treeModelProvider!.SerializationProviders.FirstOrDefault(); + Assert.IsNotNull(treeModelSerializationProvider); + + // validate the deserialization methods for the tree model + var writer = new TypeProviderWriter(treeModelSerializationProvider!); + var file = writer.Write(); + Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content); + } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/DiscriminatorTests/TestNestedDiscriminatedModelWithOwnDiscriminator.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/DiscriminatorTests/TestNestedDiscriminatedModelWithOwnDiscriminator.cs new file mode 100644 index 0000000000..988cb9f580 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/test/Providers/MrwSerializationTypeDefinitions/TestData/DiscriminatorTests/TestNestedDiscriminatedModelWithOwnDiscriminator.cs @@ -0,0 +1,127 @@ +// + +#nullable disable + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; +using Sample; + +namespace Sample.Models +{ + /// + public partial class Tree : global::System.ClientModel.Primitives.IJsonModel + { + internal Tree() + { + } + + void global::System.ClientModel.Primitives.IJsonModel.Write(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + writer.WriteStartObject(); + this.JsonModelWriteCore(writer, options); + writer.WriteEndObject(); + } + + /// The JSON writer. + /// The client options for reading and writing models. + protected override void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if ((format != "J")) + { + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support writing '{format}' format."); + } + base.JsonModelWriteCore(writer, options); + writer.WritePropertyName("treeType"u8); + writer.WriteStringValue(TreeType); + } + + global::Sample.Models.Tree global::System.ClientModel.Primitives.IJsonModel.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Tree)this.JsonModelCreateCore(ref reader, options)); + + /// The JSON reader. + /// The client options for reading and writing models. + protected override global::Sample.Models.Plant JsonModelCreateCore(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + if ((format != "J")) + { + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support reading '{format}' format."); + } + using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.ParseValue(ref reader); + return global::Sample.Models.Tree.DeserializeTree(document.RootElement, options); + } + + internal static global::Sample.Models.Tree DeserializeTree(global::System.Text.Json.JsonElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + if ((element.ValueKind == global::System.Text.Json.JsonValueKind.Null)) + { + return null; + } + if (element.TryGetProperty("treeType"u8, out global::System.Text.Json.JsonElement discriminator)) + { + switch (discriminator.GetString()) + { + case "oak": + return global::Sample.Models.OakTree.DeserializeOakTree(element, options); + } + } + return global::Sample.Models.UnknownTree.DeserializeUnknownTree(element, options); + } + + global::System.BinaryData global::System.ClientModel.Primitives.IPersistableModel.Write(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelWriteCore(options); + + /// The client options for reading and writing models. + protected override global::System.BinaryData PersistableModelWriteCore(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + switch (format) + { + case "J": + return global::System.ClientModel.Primitives.ModelReaderWriter.Write(this, options); + default: + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support writing '{options.Format}' format."); + } + } + + global::Sample.Models.Tree global::System.ClientModel.Primitives.IPersistableModel.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Tree)this.PersistableModelCreateCore(data, options)); + + /// The data to parse. + /// The client options for reading and writing models. + protected override global::Sample.Models.Plant PersistableModelCreateCore(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) + { + string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel)this).GetFormatFromOptions(options) : options.Format; + switch (format) + { + case "J": + using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(data)) + { + return global::Sample.Models.Tree.DeserializeTree(document.RootElement, options); + } + default: + throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support reading '{options.Format}' format."); + } + } + + string global::System.ClientModel.Primitives.IPersistableModel.GetFormatFromOptions(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => "J"; + + /// The to serialize into . + public static implicit operator BinaryContent(global::Sample.Models.Tree tree) + { + if ((tree == null)) + { + return null; + } + return global::System.ClientModel.BinaryContent.Create(tree, global::Sample.ModelSerializationExtensions.WireOptions); + } + + /// The to deserialize the from. + public static explicit operator Tree(global::System.ClientModel.ClientResult result) + { + using global::System.ClientModel.Primitives.PipelineResponse response = result.GetRawResponse(); + using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(response.Content); + return global::Sample.Models.Tree.DeserializeTree(document.RootElement, global::Sample.ModelSerializationExtensions.WireOptions); + } + } +}