Skip to content

Commit

Permalink
Cosmos: Stop nesting results in extra JSON object (#34044)
Browse files Browse the repository at this point in the history
Closes #25527

Co-authored-by: Arthur Vickers <[email protected]>
  • Loading branch information
roji and ajcvickers authored Jul 21, 2024
1 parent 594688b commit d41ba67
Show file tree
Hide file tree
Showing 32 changed files with 1,323 additions and 1,166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ public static bool TryExtractArray(
Limit: null,
Orderings: [],
IsDistinct: false,
UsesSingleValueProjection: true,
Projection: [{ Expression: var a }]
},
}
Expand Down
47 changes: 25 additions & 22 deletions src/EFCore.Cosmos/Query/Internal/CosmosQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,28 @@ protected override Expression VisitScalarSubquery(ScalarSubqueryExpression scala
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitProjection(ProjectionExpression projectionExpression)
=> VisitProjection(projectionExpression, objectProjectionStyle: false);
{
GenerateProjection(projectionExpression, objectProjectionStyle: false);
return projectionExpression;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected virtual Expression VisitProjection(ProjectionExpression projectionExpression, bool objectProjectionStyle)
private void GenerateProjection(ProjectionExpression projectionExpression, bool objectProjectionStyle)
{
// If the SELECT has a single projection with IsValueProjection, prepend the VALUE keyword (without VALUE, Cosmos projects a JSON
// object containing the value).
if (projectionExpression.IsValueProjection)
{
_sqlBuilder.Append("VALUE ");
Visit(projectionExpression.Expression);
return;
}

if (objectProjectionStyle)
{
_sqlBuilder.Append('"').Append(projectionExpression.Alias).Append("\" : ");
Expand All @@ -232,8 +244,6 @@ protected virtual Expression VisitProjection(ProjectionExpression projectionExpr
{
_sqlBuilder.Append(" AS " + projectionExpression.Alias);
}

return projectionExpression;
}

/// <summary>
Expand Down Expand Up @@ -277,34 +287,28 @@ protected override Expression VisitSelect(SelectExpression selectExpression)
_sqlBuilder.Append("DISTINCT ");
}

if (selectExpression.Projection is { Count: > 0 } projection)
if (selectExpression.Projection is { Count: > 0 } projections)
{
// If the SELECT projects a single value out, we just project that with the Cosmos VALUE keyword (without VALUE,
// Cosmos projects a JSON object containing the value).
// TODO: Ideally, just always use VALUE for all single-projection SELECTs - but this like requires shaper changes.
if (selectExpression.UsesSingleValueProjection && projection is [var singleProjection])
{
_sqlBuilder.Append("VALUE ");
Check.DebugAssert(
projections.Count == 1 || !projections.Any(p => p.IsValueProjection),
"Multiple projections with IsValueProjection");

Visit(singleProjection.Expression);
}
// Otherwise, we'll project a JSON object; Cosmos has two syntaxes for doing so:
// If there's only one projection, we simply project it directly (SELECT VALUE c["Id"]); this happens in GenerateProjection().
// Otherwise, we'll project a JSON object wrapping the multiple projections. Cosmos has two syntaxes for doing so:
// 1. Project out a JSON object as a value (SELECT VALUE { 'a': a, 'b': b }), or
// 2. Project a set of properties with optional AS+aliases (SELECT 'a' AS a, 'b' AS b)
// Both methods produce the exact same results; we usually prefer the 1st, but in some cases we use the 2nd.
else if ((projection.Count > 1
// Cosmos does not support "AS Value" projections, specifically for the alias "Value"
|| projection is [{ Alias: string alias }] && alias.Equals("value", StringComparison.OrdinalIgnoreCase))
&& projection.Any(p => !string.IsNullOrEmpty(p.Alias) && p.Alias != p.Name)
&& !projection.Any(p => p.Expression is SqlFunctionExpression)) // Aggregates are not allowed
if (projections.Count > 1
&& projections.Any(p => !string.IsNullOrEmpty(p.Alias) && p.Alias != p.Name)
&& !projections.Any(p => p.Expression is SqlFunctionExpression)) // Aggregates are not allowed
{
_sqlBuilder.AppendLine("VALUE").AppendLine("{").IncrementIndent();
GenerateList(projection, e => VisitProjection(e, objectProjectionStyle: true), joinAction: sql => sql.AppendLine(","));
GenerateList(projections, e => GenerateProjection(e, objectProjectionStyle: true), joinAction: sql => sql.AppendLine(","));
_sqlBuilder.AppendLine().DecrementIndent().Append("}");
}
else
{
GenerateList(projection, e => Visit(e));
GenerateList(projections, e => Visit(e));
}
}
else
Expand Down Expand Up @@ -752,7 +756,6 @@ void VisitContainerExpression(Expression containerExpression)
Limit: null,
Orderings: [],
IsDistinct: false,
UsesSingleValueProjection: true,
Projection.Count: 1
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ public partial class CosmosShapedQueryCompilingExpressionVisitor
{
private sealed class CosmosProjectionBindingRemovingExpressionVisitor(
SelectExpression selectExpression,
ParameterExpression jObjectParameter,
ParameterExpression jTokenParameter,
bool trackQueryResults)
: CosmosProjectionBindingRemovingExpressionVisitorBase(jObjectParameter, trackQueryResults)
: CosmosProjectionBindingRemovingExpressionVisitorBase(jTokenParameter, trackQueryResults)
{
protected override ProjectionExpression GetProjection(ProjectionBindingExpression projectionBindingExpression)
=> selectExpression.Projection[GetProjectionIndex(projectionBindingExpression)];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
public partial class CosmosShapedQueryCompilingExpressionVisitor
{
private abstract class CosmosProjectionBindingRemovingExpressionVisitorBase(
ParameterExpression jObjectParameter,
ParameterExpression jTokenParameter,
bool trackQueryResults)
: ExpressionVisitor
{
Expand Down Expand Up @@ -73,16 +73,40 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)

// Values injected by JObjectInjectingExpressionVisitor
var projectionExpression = ((UnaryExpression)binaryExpression.Right).Operand;
if (projectionExpression is ProjectionBindingExpression projectionBindingExpression)

if (projectionExpression is UnaryExpression
{
NodeType: ExpressionType.Convert,
Operand: UnaryExpression operand
})
{
var projection = GetProjection(projectionBindingExpression);
projectionExpression = projection.Expression;
storeName = projection.Alias;
// Unwrap EntityProjectionExpression when the root entity is not projected
// That is, this is handling the projection of a non-root entity type.
projectionExpression = operand.Operand;
}
else if (projectionExpression is UnaryExpression { NodeType: ExpressionType.Convert } convertExpression)

switch (projectionExpression)
{
// Unwrap EntityProjectionExpression when the root entity is not projected
projectionExpression = ((UnaryExpression)convertExpression.Operand).Operand;
// ProjectionBindingExpression may represent a named token to be obtained from a containing JObject, or
// it may be that the token is not nested in a JObject if the query was generated using the SQL VALUE clause.
case ProjectionBindingExpression projectionBindingExpression:
{
var projection = GetProjection(projectionBindingExpression);
projectionExpression = projection.Expression;
if (!projection.IsValueProjection)
{
storeName = projection.Alias;
}
break;
}

case ObjectArrayAccessExpression e:
storeName = e.PropertyName;
break;

case EntityProjectionExpression e:
storeName = e.PropertyName;
break;
}

Expression innerAccessExpression;
Expand All @@ -91,13 +115,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
case ObjectArrayAccessExpression objectArrayProjectionExpression:
innerAccessExpression = objectArrayProjectionExpression.Object;
_projectionBindings[objectArrayProjectionExpression] = parameterExpression;
storeName ??= objectArrayProjectionExpression.PropertyName;
break;

case EntityProjectionExpression entityProjectionExpression:
var accessExpression = entityProjectionExpression.Object;
_projectionBindings[accessExpression] = parameterExpression;
storeName ??= entityProjectionExpression.PropertyName;

switch (accessExpression)
{
Expand All @@ -107,13 +129,12 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
(innerObjectAccessExpression.Navigation.DeclaringEntityType, innerAccessExpression);
break;
case ObjectReferenceExpression:
innerAccessExpression = jObjectParameter;
innerAccessExpression = jTokenParameter;
break;
default:
throw new InvalidOperationException(
CoreStrings.TranslationFailed(binaryExpression.Print()));
}

break;

default:
Expand Down Expand Up @@ -174,7 +195,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var projection = GetProjection(projectionBindingExpression);

innerExpression = Convert(
CreateReadJTokenExpression(jObjectParameter, projection.Alias),
CreateReadJTokenExpression(jTokenParameter, projection.Alias),
typeof(JObject));
}
else
Expand Down Expand Up @@ -222,9 +243,10 @@ protected override Expression VisitExtension(Expression extensionExpression)
var projection = GetProjection(projectionBindingExpression);

return CreateGetValueExpression(
jObjectParameter,
projection.Alias,
projectionBindingExpression.Type, (projection.Expression as SqlExpression)?.TypeMapping);
jTokenParameter,
projection.IsValueProjection ? null : projection.Alias,
projectionBindingExpression.Type,
(projection.Expression as SqlExpression)?.TypeMapping);
}

case CollectionShaperExpression collectionShaperExpression:
Expand Down Expand Up @@ -584,13 +606,13 @@ private static Expression CreateReadJTokenExpression(Expression jObjectExpressio
=> Call(jObjectExpression, GetItemMethodInfo, Constant(propertyName));

private Expression CreateGetValueExpression(
Expression jObjectExpression,
Expression jTokenExpression,
IProperty property,
Type type)
{
if (property.Name == StoreKeyConvention.JObjectPropertyName)
{
return _projectionBindings[jObjectExpression];
return _projectionBindings[jTokenExpression];
}

var entityType = property.DeclaringType as IEntityType;
Expand All @@ -603,7 +625,7 @@ private Expression CreateGetValueExpression(
{
if (ownership is { IsUnique: false } && property.IsOrdinalKeyProperty())
{
var ordinalExpression = _ordinalParameterBindings[jObjectExpression];
var ordinalExpression = _ordinalParameterBindings[jTokenExpression];
if (ordinalExpression.Type != type)
{
ordinalExpression = Convert(ordinalExpression, type);
Expand All @@ -616,19 +638,19 @@ private Expression CreateGetValueExpression(
if (principalProperty != null)
{
Expression ownerJObjectExpression = null;
if (_ownerMappings.TryGetValue(jObjectExpression, out var ownerInfo))
if (_ownerMappings.TryGetValue(jTokenExpression, out var ownerInfo))
{
Check.DebugAssert(
principalProperty.DeclaringType.IsAssignableFrom(ownerInfo.EntityType),
$"{principalProperty.DeclaringType} is not assignable from {ownerInfo.EntityType}");

ownerJObjectExpression = ownerInfo.JObjectExpression;
}
else if (jObjectExpression is ObjectReferenceExpression objectReferenceExpression)
else if (jTokenExpression is ObjectReferenceExpression objectReferenceExpression)
{
ownerJObjectExpression = objectReferenceExpression;
}
else if (jObjectExpression is ObjectAccessExpression objectAccessExpression)
else if (jTokenExpression is ObjectAccessExpression objectAccessExpression)
{
ownerJObjectExpression = objectAccessExpression.Object;
}
Expand All @@ -653,15 +675,15 @@ private Expression CreateGetValueExpression(
&& !property.IsShadowProperty())
{
var readExpression = CreateGetValueExpression(
jObjectExpression, storeName, type.MakeNullable(), property.GetTypeMapping());
jTokenExpression, storeName, type.MakeNullable(), property.GetTypeMapping());

var nonNullReadExpression = readExpression;
if (nonNullReadExpression.Type != type)
{
nonNullReadExpression = Convert(nonNullReadExpression, type);
}

var ordinalExpression = _ordinalParameterBindings[jObjectExpression];
var ordinalExpression = _ordinalParameterBindings[jTokenExpression];
if (ordinalExpression.Type != type)
{
ordinalExpression = Convert(ordinalExpression, type);
Expand All @@ -674,36 +696,41 @@ private Expression CreateGetValueExpression(
}

return Convert(
CreateGetValueExpression(jObjectExpression, storeName, type.MakeNullable(), property.GetTypeMapping()),
CreateGetValueExpression(jTokenExpression, storeName, type.MakeNullable(), property.GetTypeMapping()),
type);
}

private Expression CreateGetValueExpression(
Expression jObjectExpression,
Expression jTokenExpression,
string storeName,
Type type,
CoreTypeMapping typeMapping = null)
{
Check.DebugAssert(type.IsNullableType(), "Must read nullable type from JObject.");

var innerExpression = jObjectExpression switch
var innerExpression = jTokenExpression switch
{
_ when _projectionBindings.TryGetValue(jObjectExpression, out var innerVariable)
_ when _projectionBindings.TryGetValue(jTokenExpression, out var innerVariable)
=> innerVariable,

ObjectReferenceExpression objectReference
=> CreateGetValueExpression(jObjectParameter, objectReference.Name, typeof(JObject)),
ObjectReferenceExpression
=> jTokenParameter,

ObjectAccessExpression objectAccessExpression
=> CreateGetValueExpression(
objectAccessExpression.Object,
((IAccessExpression)objectAccessExpression.Object).PropertyName,
typeof(JObject)),

_ => jObjectExpression
_ => jTokenExpression
};

var jTokenExpression = CreateReadJTokenExpression(innerExpression, storeName);
jTokenExpression = storeName == null
? innerExpression
: CreateReadJTokenExpression(
innerExpression.Type == typeof(JObject)
? innerExpression
: Convert(innerExpression, typeof(JObject)), storeName);

Expression valueExpression;
var converter = typeMapping?.Converter;
Expand Down Expand Up @@ -774,9 +801,6 @@ private static Expression ConvertJTokenToType(Expression jTokenExpression, Type
ToObjectWithSerializerMethodInfo.MakeGenericMethod(type),
jTokenExpression);

private static T SafeToObject<T>(JToken token)
=> token == null || token.Type == JTokenType.Null ? default : token.ToObject<T>();

private static T SafeToObjectWithSerializer<T>(JToken token)
=> token == null || token.Type == JTokenType.Null ? default : token.ToObject<T>(CosmosClientWrapper.Serializer);
}
Expand Down
Loading

0 comments on commit d41ba67

Please sign in to comment.