Skip to content

Commit

Permalink
[http-client-csharp] fix: generate deserialization switch for nested …
Browse files Browse the repository at this point in the history
…discriminators (#4982)

fixes: #4979
  • Loading branch information
jorgerangel-msft authored Nov 6, 2024
1 parent d1bcbf2 commit aa531cf
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -445,19 +445,19 @@ internal MethodProvider BuildDeserializationMethod()
return new MethodProvider
(
new MethodSignature(methodName, null, signatureModifiers, _model.Type, null, [_jsonElementDeserializationParam, _serializationOptionsParameter]),
_model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Abstract) && _inputModel.DiscriminatedSubtypes.Count > 0 ? BuildAbstractDeserializationMethodBody() : BuildDeserializationMethodBody(),
_inputModel.DiscriminatedSubtypes.Count > 0 ? BuildDiscriminatedModelDeserializationMethodBody() : BuildDeserializationMethodBody(),
this
);
}

private MethodBodyStatement[] BuildAbstractDeserializationMethodBody()
private MethodBodyStatement[] BuildDiscriminatedModelDeserializationMethodBody()
{
var unknownVariant = _model.DerivedModels.First(m => m.IsUnknownDiscriminatorModel);
bool onlyContainsUnknownDerivedModel = _model.DerivedModels.Count == 1;
var discriminator = _model.CanonicalView.Properties.Where(p => p.IsDiscriminator).FirstOrDefault();
var deserializeDiscriminatedModelsConditions = BuildDiscriminatedModelsCondition(
discriminator,
GetAbstractSwitchCases(unknownVariant),
GetDiscriminatorSwitchCases(unknownVariant),
onlyContainsUnknownDerivedModel,
_jsonElementParameterSnippet);

Expand Down Expand Up @@ -488,7 +488,7 @@ private static MethodBodyStatement BuildDiscriminatedModelsCondition(
return MethodBodyStatement.Empty;
}

private SwitchCaseStatement[] GetAbstractSwitchCases(ModelProvider unknownVariant)
private SwitchCaseStatement[] GetDiscriminatorSwitchCases(ModelProvider unknownVariant)
{
SwitchCaseStatement[] cases = new SwitchCaseStatement[_model.DerivedModels.Count - 1];
int index = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,56 @@ public void TestBuildJsonModelCreateMethodProperlyDoesNotCastForUnknown()
"this.JsonModelCreateCore(ref reader, options)",
invocationExpression!.ToDisplayString());
}

// This test validates that a discriminated sub-type with its own discriminator property
// properly generates the deserialization method to deserialize into its' discriminated sub-types
[Test]
public void TestNestedDiscriminatedModelWithOwnDiscriminator()
{
var oakTreeModel = InputFactory.Model(
"oakTree",
discriminatedKind: "oak",
properties:
[
InputFactory.Property("treeType", InputPrimitiveType.String, isRequired: true),
]);
var treeModel = InputFactory.Model(
"tree",
discriminatedKind: "tree",
properties:
[
InputFactory.Property("treeType", InputPrimitiveType.String, isRequired: true, isDiscriminator: true),
],
discriminatedModels: new Dictionary<string, InputModelType>() { { "oak", oakTreeModel } });
var baseModel = InputFactory.Model(
"plant",
properties:
[
InputFactory.Property("foo", InputPrimitiveType.String, isRequired: true, isDiscriminator: true),
],
discriminatedModels: new Dictionary<string, InputModelType>() { { "tree", treeModel } });

MockHelpers.LoadMockPlugin(inputModels: () => [baseModel, treeModel]);
var baseModelProvider = ClientModelPlugin.Instance.OutputLibrary.TypeProviders.OfType<ModelProvider>()
.FirstOrDefault(t => t.Name == "Plant");
var treeModelProvider = ClientModelPlugin.Instance.OutputLibrary.TypeProviders.OfType<ModelProvider>()
.FirstOrDefault(t => t.Name == "Tree");
Assert.IsNotNull(baseModelProvider);
Assert.IsNotNull(treeModelProvider);

// validate the base discriminator deserialization method has the switch statement
var baseDeserializationMethod = baseModelProvider!.SerializationProviders.FirstOrDefault()!.Methods
.FirstOrDefault(m => m.Signature.Name == "DeserializePlant");
Assert.IsTrue(baseDeserializationMethod?.BodyStatements!.ToDisplayString().Contains(
$"if (element.TryGetProperty(\"foo\"u8, out global::System.Text.Json.JsonElement discriminator))"));

var treeModelSerializationProvider = treeModelProvider!.SerializationProviders.FirstOrDefault();
Assert.IsNotNull(treeModelSerializationProvider);

// validate the deserialization methods for the tree model
var writer = new TypeProviderWriter(treeModelSerializationProvider!);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Text.Json;
using Sample;

namespace Sample.Models
{
/// <summary></summary>
public partial class Tree : global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.Tree>
{
internal Tree()
{
}

void global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.Tree>.Write(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
writer.WriteStartObject();
this.JsonModelWriteCore(writer, options);
writer.WriteEndObject();
}

/// <param name="writer"> The JSON writer. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWriter writer, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.Tree>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support writing '{format}' format.");
}
base.JsonModelWriteCore(writer, options);
writer.WritePropertyName("treeType"u8);
writer.WriteStringValue(TreeType);
}

global::Sample.Models.Tree global::System.ClientModel.Primitives.IJsonModel<global::Sample.Models.Tree>.Create(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Tree)this.JsonModelCreateCore(ref reader, options));

/// <param name="reader"> The JSON reader. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override global::Sample.Models.Plant JsonModelCreateCore(ref global::System.Text.Json.Utf8JsonReader reader, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.Tree>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support reading '{format}' format.");
}
using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.ParseValue(ref reader);
return global::Sample.Models.Tree.DeserializeTree(document.RootElement, options);
}

internal static global::Sample.Models.Tree DeserializeTree(global::System.Text.Json.JsonElement element, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
if ((element.ValueKind == global::System.Text.Json.JsonValueKind.Null))
{
return null;
}
if (element.TryGetProperty("treeType"u8, out global::System.Text.Json.JsonElement discriminator))
{
switch (discriminator.GetString())
{
case "oak":
return global::Sample.Models.OakTree.DeserializeOakTree(element, options);
}
}
return global::Sample.Models.UnknownTree.DeserializeUnknownTree(element, options);
}

global::System.BinaryData global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.Tree>.Write(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => this.PersistableModelWriteCore(options);

/// <param name="options"> The client options for reading and writing models. </param>
protected override global::System.BinaryData PersistableModelWriteCore(global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.Tree>)this).GetFormatFromOptions(options) : options.Format;
switch (format)
{
case "J":
return global::System.ClientModel.Primitives.ModelReaderWriter.Write(this, options);
default:
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support writing '{options.Format}' format.");
}
}

global::Sample.Models.Tree global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.Tree>.Create(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => ((global::Sample.Models.Tree)this.PersistableModelCreateCore(data, options));

/// <param name="data"> The data to parse. </param>
/// <param name="options"> The client options for reading and writing models. </param>
protected override global::Sample.Models.Plant PersistableModelCreateCore(global::System.BinaryData data, global::System.ClientModel.Primitives.ModelReaderWriterOptions options)
{
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.Tree>)this).GetFormatFromOptions(options) : options.Format;
switch (format)
{
case "J":
using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(data))
{
return global::Sample.Models.Tree.DeserializeTree(document.RootElement, options);
}
default:
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.Tree)} does not support reading '{options.Format}' format.");
}
}

string global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.Tree>.GetFormatFromOptions(global::System.ClientModel.Primitives.ModelReaderWriterOptions options) => "J";

/// <param name="tree"> The <see cref="global::Sample.Models.Tree"/> to serialize into <see cref="global::System.ClientModel.BinaryContent"/>. </param>
public static implicit operator BinaryContent(global::Sample.Models.Tree tree)
{
if ((tree == null))
{
return null;
}
return global::System.ClientModel.BinaryContent.Create(tree, global::Sample.ModelSerializationExtensions.WireOptions);
}

/// <param name="result"> The <see cref="global::System.ClientModel.ClientResult"/> to deserialize the <see cref="global::Sample.Models.Tree"/> from. </param>
public static explicit operator Tree(global::System.ClientModel.ClientResult result)
{
using global::System.ClientModel.Primitives.PipelineResponse response = result.GetRawResponse();
using global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(response.Content);
return global::Sample.Models.Tree.DeserializeTree(document.RootElement, global::Sample.ModelSerializationExtensions.WireOptions);
}
}
}

0 comments on commit aa531cf

Please sign in to comment.