From c7605b653ae9a1dfb2c8ae51ddb6dd86cb5f7a4c Mon Sep 17 00:00:00 2001 From: Tom Andersen Date: Fri, 28 Feb 2025 13:11:33 -0500 Subject: [PATCH] Generic Stage and Refactor --- .../firebase/firestore/PipelineTest.java | 59 +++++++-- .../firebase/firestore/FirebaseFirestore.java | 5 +- .../com/google/firebase/firestore/Pipeline.kt | 11 +- .../firebase/firestore/UserDataReader.java | 22 ++++ .../google/firebase/firestore/model/Values.kt | 10 +- .../firebase/firestore/pipeline/Constant.kt | 58 ++++----- .../{accumulators.kt => aggregates.kt} | 37 +++--- .../{expression.kt => expressions.kt} | 78 ++++++++---- .../firebase/firestore/pipeline/stage.kt | 115 +++++++++++++----- 9 files changed, 274 insertions(+), 121 deletions(-) rename firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/{accumulators.kt => aggregates.kt} (50%) rename firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/{expression.kt => expressions.kt} (92%) diff --git a/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/PipelineTest.java b/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/PipelineTest.java index 21968864d48..50df7ef880d 100644 --- a/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/PipelineTest.java +++ b/firebase-firestore/src/androidTest/java/com/google/firebase/firestore/PipelineTest.java @@ -37,14 +37,13 @@ import static com.google.firebase.firestore.pipeline.Function.subtract; import static com.google.firebase.firestore.pipeline.Ordering.ascending; import static com.google.firebase.firestore.testutil.IntegrationTestUtil.waitFor; -import static java.util.Map.entry; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.android.gms.tasks.Task; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.truth.Correspondence; -import com.google.firebase.firestore.pipeline.Accumulator; +import com.google.firebase.firestore.pipeline.AggregateExpr; import com.google.firebase.firestore.pipeline.AggregateStage; import com.google.firebase.firestore.pipeline.Constant; import com.google.firebase.firestore.pipeline.Field; @@ -227,7 +226,7 @@ public void aggregateResultsCountAll() { firestore .pipeline() .collection(randomCol) - .aggregate(Accumulator.countAll().as("count")) + .aggregate(AggregateExpr.countAll().as("count")) .execute(); assertThat(waitFor(execute).getResults()) .comparingElementsUsing(DATA_CORRESPONDENCE) @@ -243,8 +242,8 @@ public void aggregateResultsMany() { .collection(randomCol) .where(Function.eq("genre", "Science Fiction")) .aggregate( - Accumulator.countAll().as("count"), - Accumulator.avg("rating").as("avgRating"), + AggregateExpr.countAll().as("count"), + AggregateExpr.avg("rating").as("avgRating"), Field.of("rating").max().as("maxRating")) .execute(); assertThat(waitFor(execute).getResults()) @@ -261,7 +260,7 @@ public void groupAndAccumulateResults() { .collection(randomCol) .where(lt(Field.of("published"), 1984)) .aggregate( - AggregateStage.withAccumulators(Accumulator.avg("rating").as("avgRating")) + AggregateStage.withAccumulators(AggregateExpr.avg("rating").as("avgRating")) .withGroups("genre")) .where(gt("avgRating", 4.3)) .sort(Field.of("avgRating").descending()) @@ -274,6 +273,28 @@ public void groupAndAccumulateResults() { mapOfEntries(entry("avgRating", 4.4), entry("genre", "Science Fiction"))); } + @Test + public void groupAndAccumulateResultsGeneric() { + Task execute = + firestore + .pipeline() + .collection(randomCol) + .genericStage("where", lt(Field.of("published"), 1984)) + .genericStage( + "aggregate", + ImmutableMap.of("avgRating", AggregateExpr.avg("rating")), + ImmutableMap.of("genre", Field.of("genre"))) + .genericStage("where", gt("avgRating", 4.3)) + .genericStage("sort", Field.of("avgRating").descending()) + .execute(); + assertThat(waitFor(execute).getResults()) + .comparingElementsUsing(DATA_CORRESPONDENCE) + .containsExactly( + mapOfEntries(entry("avgRating", 4.7), entry("genre", "Fantasy")), + mapOfEntries(entry("avgRating", 4.5), entry("genre", "Romance")), + mapOfEntries(entry("avgRating", 4.4), entry("genre", "Science Fiction"))); + } + @Test @Ignore("Not supported yet") public void minAndMaxAccumulations() { @@ -282,7 +303,7 @@ public void minAndMaxAccumulations() { .pipeline() .collection(randomCol) .aggregate( - Accumulator.countAll().as("count"), + AggregateExpr.countAll().as("count"), Field.of("rating").max().as("maxRating"), Field.of("published").min().as("minPublished")) .execute(); @@ -781,6 +802,30 @@ public void testMapGetWithFieldNameIncludingNotation() { entry("nested", null))); } + @Test + public void testListEquals() { + Task execute = + randomCol + .pipeline() + .where(eq("tags", ImmutableList.of("philosophy", "crime", "redemption"))) + .execute(); + assertThat(waitFor(execute).getResults()) + .comparingElementsUsing(ID_CORRESPONDENCE) + .containsExactly("book6"); + } + + @Test + public void testMapEquals() { + Task execute = + randomCol + .pipeline() + .where(eq("awards", ImmutableMap.of("nobel", true, "nebula", false))) + .execute(); + assertThat(waitFor(execute).getResults()) + .comparingElementsUsing(ID_CORRESPONDENCE) + .containsExactly("book3"); + } + static Map.Entry entry(String key, T value) { return new Map.Entry() { private String k = key; diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/FirebaseFirestore.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/FirebaseFirestore.java index 114fc18da95..932be5983f5 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/FirebaseFirestore.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/FirebaseFirestore.java @@ -23,6 +23,7 @@ import androidx.annotation.Keep; import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import androidx.annotation.RestrictTo; import androidx.annotation.VisibleForTesting; import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.TaskCompletionSource; @@ -855,7 +856,9 @@ DatabaseId getDatabaseId() { return databaseId; } - UserDataReader getUserDataReader() { + @NonNull + @RestrictTo(RestrictTo.Scope.LIBRARY) + public UserDataReader getUserDataReader() { return userDataReader; } diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/Pipeline.kt b/firebase-firestore/src/main/java/com/google/firebase/firestore/Pipeline.kt index 6fa92b88b8b..6a3f765fd38 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/Pipeline.kt +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/Pipeline.kt @@ -20,9 +20,9 @@ import com.google.common.collect.FluentIterable import com.google.common.collect.ImmutableList import com.google.firebase.firestore.model.DocumentKey import com.google.firebase.firestore.model.SnapshotVersion -import com.google.firebase.firestore.pipeline.AccumulatorWithAlias import com.google.firebase.firestore.pipeline.AddFieldsStage import com.google.firebase.firestore.pipeline.AggregateStage +import com.google.firebase.firestore.pipeline.AggregateWithAlias import com.google.firebase.firestore.pipeline.BooleanExpr import com.google.firebase.firestore.pipeline.CollectionGroupSource import com.google.firebase.firestore.pipeline.CollectionSource @@ -30,6 +30,8 @@ import com.google.firebase.firestore.pipeline.DatabaseSource import com.google.firebase.firestore.pipeline.DistinctStage import com.google.firebase.firestore.pipeline.DocumentsSource import com.google.firebase.firestore.pipeline.Field +import com.google.firebase.firestore.pipeline.GenericArg +import com.google.firebase.firestore.pipeline.GenericStage import com.google.firebase.firestore.pipeline.LimitStage import com.google.firebase.firestore.pipeline.OffsetStage import com.google.firebase.firestore.pipeline.Ordering @@ -84,9 +86,12 @@ internal constructor( internal fun toPipelineProto(): com.google.firestore.v1.Pipeline = com.google.firestore.v1.Pipeline.newBuilder() - .addAllStages(stages.map(Stage::toProtoStage)) + .addAllStages(stages.map { it.toProtoStage(firestore.userDataReader) }) .build() + fun genericStage(name: String, vararg params: Any) = + append(GenericStage(name, params.map(GenericArg::from))) + fun addFields(vararg fields: Selectable): Pipeline = append(AddFieldsStage(fields)) fun removeFields(vararg fields: Field): Pipeline = append(RemoveFieldsStage(fields)) @@ -118,7 +123,7 @@ internal constructor( fun distinct(vararg groups: Any): Pipeline = append(DistinctStage(groups.map(Selectable::toSelectable).toTypedArray())) - fun aggregate(vararg accumulators: AccumulatorWithAlias): Pipeline = + fun aggregate(vararg accumulators: AggregateWithAlias): Pipeline = append(AggregateStage.withAccumulators(*accumulators)) fun aggregate(aggregateStage: AggregateStage): Pipeline = append(aggregateStage) diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java b/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java index b1462ed9f74..3ce7cfec87c 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java @@ -19,6 +19,7 @@ import androidx.annotation.Nullable; import androidx.annotation.RestrictTo; +import com.google.common.base.Function; import com.google.firebase.firestore.FieldValue.ArrayRemoveFieldValue; import com.google.firebase.firestore.FieldValue.ArrayUnionFieldValue; import com.google.firebase.firestore.FieldValue.DeleteFieldValue; @@ -36,6 +37,7 @@ import com.google.firebase.firestore.model.mutation.FieldMask; import com.google.firebase.firestore.model.mutation.NumericIncrementTransformOperation; import com.google.firebase.firestore.model.mutation.ServerTimestampOperation; +import com.google.firebase.firestore.pipeline.Expr; import com.google.firebase.firestore.util.Assert; import com.google.firebase.firestore.util.CustomClassMapper; import com.google.firebase.firestore.util.Util; @@ -389,6 +391,12 @@ public Value parseScalarValue(Object input, ParseContext context) { return Values.NULL_VALUE; } else if (input.getClass().isArray()) { throw context.createError("Arrays are not supported; use a List instead"); + } else if (input instanceof DocumentReference) { + DocumentReference ref = (DocumentReference) input; + validateDocumentReference(ref, context::createError); + return Values.encodeValue(ref); + } else if (input instanceof Expr) { + throw context.createError("Pipeline expressions are not supported user objects"); } else { try { return Values.encodeAnyValue(input); @@ -398,6 +406,20 @@ public Value parseScalarValue(Object input, ParseContext context) { } } + public void validateDocumentReference( + DocumentReference ref, Function createError) { + DatabaseId otherDb = ref.getFirestore().getDatabaseId(); + if (!otherDb.equals(databaseId)) { + throw createError.apply( + String.format( + "Document reference is for database %s/%s but should be for database %s/%s", + otherDb.getProjectId(), + otherDb.getDatabaseId(), + databaseId.getProjectId(), + databaseId.getDatabaseId())); + } + } + private List parseArrayTransformElements(List elements) { ParseAccumulator accumulator = new ParseAccumulator(UserData.Source.Argument); diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.kt b/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.kt index 3b9701bc4ec..d5ae4064d95 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.kt +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/model/Values.kt @@ -609,9 +609,7 @@ internal object Values { return Value.newBuilder() .setTimestampValue( - com.google.protobuf.Timestamp.newBuilder() - .setSeconds(timestamp.seconds) - .setNanos(truncatedNanoseconds) + Timestamp.newBuilder().setSeconds(timestamp.seconds).setNanos(truncatedNanoseconds) ) .build() } @@ -665,6 +663,11 @@ internal object Values { return Value.newBuilder().setMapValue(MapValue.newBuilder().putAllFields(map)).build() } + @JvmStatic + fun encodeValue(values: Iterable): Value { + return Value.newBuilder().setArrayValue(ArrayValue.newBuilder().addAllValues(values)).build() + } + @JvmStatic fun encodeAnyValue(value: Any?): Value { return when (value) { @@ -676,7 +679,6 @@ internal object Values { is Boolean -> encodeValue(value) is GeoPoint -> encodeValue(value) is Blob -> encodeValue(value) - is DocumentReference -> encodeValue(value) is VectorValue -> encodeValue(value) else -> throw IllegalArgumentException("Unexpected type: $value") } diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/Constant.kt b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/Constant.kt index bc9a39a50d4..e9bcf7aac43 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/Constant.kt +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/Constant.kt @@ -18,76 +18,70 @@ import com.google.firebase.Timestamp import com.google.firebase.firestore.Blob import com.google.firebase.firestore.DocumentReference import com.google.firebase.firestore.GeoPoint +import com.google.firebase.firestore.UserDataReader import com.google.firebase.firestore.VectorValue import com.google.firebase.firestore.model.Values import com.google.firebase.firestore.model.Values.encodeValue import com.google.firestore.v1.Value import java.util.Date -class Constant internal constructor(val value: Value) : Expr() { +abstract class Constant internal constructor() : Expr() { + + private class ValueConstant(val value: Value) : Constant() { + override fun toProto(userDataReader: UserDataReader): Value = value + } companion object { - internal val NULL = Constant(Values.NULL_VALUE) - - fun of(value: Any): Constant { - return when (value) { - is String -> of(value) - is Number -> of(value) - is Date -> of(value) - is Timestamp -> of(value) - is Boolean -> of(value) - is GeoPoint -> of(value) - is Blob -> of(value) - is DocumentReference -> of(value) - is Value -> of(value) - is VectorValue -> of(value) - else -> throw IllegalArgumentException("Unknown type: $value") - } - } + internal val NULL: Constant = ValueConstant(Values.NULL_VALUE) @JvmStatic fun of(value: String): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic fun of(value: Number): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic fun of(value: Date): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic fun of(value: Timestamp): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic fun of(value: Boolean): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic fun of(value: GeoPoint): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic fun of(value: Blob): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic - fun of(value: DocumentReference): Constant { - return Constant(encodeValue(value)) + fun of(ref: DocumentReference): Constant { + return object : Constant() { + override fun toProto(userDataReader: UserDataReader): Value { + userDataReader.validateDocumentReference(ref, ::IllegalArgumentException) + return encodeValue(ref) + } + } } @JvmStatic fun of(value: VectorValue): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } @JvmStatic @@ -97,16 +91,12 @@ class Constant internal constructor(val value: Value) : Expr() { @JvmStatic fun vector(value: DoubleArray): Constant { - return Constant(Values.encodeVectorValue(value)) + return ValueConstant(Values.encodeVectorValue(value)) } @JvmStatic fun vector(value: VectorValue): Constant { - return Constant(encodeValue(value)) + return ValueConstant(encodeValue(value)) } } - - override fun toProto(): Value { - return value - } } diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/accumulators.kt b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/aggregates.kt similarity index 50% rename from firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/accumulators.kt rename to firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/aggregates.kt index 7ea8668b1b2..c363c28be3e 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/accumulators.kt +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/aggregates.kt @@ -14,50 +14,51 @@ package com.google.firebase.firestore.pipeline +import com.google.firebase.firestore.UserDataReader import com.google.firestore.v1.Value -class AccumulatorWithAlias -internal constructor(internal val alias: String, internal val accumulator: Accumulator) +class AggregateWithAlias +internal constructor(internal val alias: String, internal val expr: AggregateExpr) -class Accumulator +class AggregateExpr private constructor(private val name: String, private val params: Array) { private constructor(name: String) : this(name, emptyArray()) private constructor(name: String, expr: Expr) : this(name, arrayOf(expr)) private constructor(name: String, fieldName: String) : this(name, Field.of(fieldName)) companion object { - @JvmStatic fun countAll() = Accumulator("count") + @JvmStatic fun countAll() = AggregateExpr("count") - @JvmStatic fun count(fieldName: String) = Accumulator("count", fieldName) + @JvmStatic fun count(fieldName: String) = AggregateExpr("count", fieldName) - @JvmStatic fun count(expr: Expr) = Accumulator("count", expr) + @JvmStatic fun count(expr: Expr) = AggregateExpr("count", expr) - @JvmStatic fun countIf(condition: BooleanExpr) = Accumulator("countIf", condition) + @JvmStatic fun countIf(condition: BooleanExpr) = AggregateExpr("countIf", condition) - @JvmStatic fun sum(fieldName: String) = Accumulator("sum", fieldName) + @JvmStatic fun sum(fieldName: String) = AggregateExpr("sum", fieldName) - @JvmStatic fun sum(expr: Expr) = Accumulator("sum", expr) + @JvmStatic fun sum(expr: Expr) = AggregateExpr("sum", expr) - @JvmStatic fun avg(fieldName: String) = Accumulator("avg", fieldName) + @JvmStatic fun avg(fieldName: String) = AggregateExpr("avg", fieldName) - @JvmStatic fun avg(expr: Expr) = Accumulator("avg", expr) + @JvmStatic fun avg(expr: Expr) = AggregateExpr("avg", expr) - @JvmStatic fun min(fieldName: String) = Accumulator("min", fieldName) + @JvmStatic fun min(fieldName: String) = AggregateExpr("min", fieldName) - @JvmStatic fun min(expr: Expr) = Accumulator("min", expr) + @JvmStatic fun min(expr: Expr) = AggregateExpr("min", expr) - @JvmStatic fun max(fieldName: String) = Accumulator("max", fieldName) + @JvmStatic fun max(fieldName: String) = AggregateExpr("max", fieldName) - @JvmStatic fun max(expr: Expr) = Accumulator("max", expr) + @JvmStatic fun max(expr: Expr) = AggregateExpr("max", expr) } - fun `as`(alias: String) = AccumulatorWithAlias(alias, this) + fun `as`(alias: String) = AggregateWithAlias(alias, this) - fun toProto(): Value { + internal fun toProto(userDataReader: UserDataReader): Value { val builder = com.google.firestore.v1.Function.newBuilder() builder.setName(name) for (param in params) { - builder.addArgs(param.toProto()) + builder.addArgs(param.toProto(userDataReader)) } return Value.newBuilder().setFunctionValue(builder).build() } diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/expression.kt b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/expressions.kt similarity index 92% rename from firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/expression.kt rename to firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/expressions.kt index afc796742b0..a4a6912fbce 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/expression.kt +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/expressions.kt @@ -14,21 +14,49 @@ package com.google.firebase.firestore.pipeline +import com.google.firebase.Timestamp +import com.google.firebase.firestore.Blob +import com.google.firebase.firestore.DocumentReference import com.google.firebase.firestore.FieldPath +import com.google.firebase.firestore.GeoPoint +import com.google.firebase.firestore.UserDataReader import com.google.firebase.firestore.VectorValue import com.google.firebase.firestore.model.DocumentKey -import com.google.firebase.firestore.model.FieldPath as ModelFieldPath import com.google.firebase.firestore.model.Values.encodeValue -import com.google.firestore.v1.ArrayValue +import com.google.firebase.firestore.pipeline.Constant.Companion.of +import com.google.firebase.firestore.util.CustomClassMapper import com.google.firestore.v1.MapValue import com.google.firestore.v1.Value +import java.util.Date +import kotlin.reflect.KFunction1 +import com.google.firebase.firestore.model.FieldPath as ModelFieldPath -abstract class Expr protected constructor() { +abstract class Expr internal constructor() { internal companion object { - internal fun toExprOrConstant(value: Any): Expr { + internal fun toExprOrConstant(value: Any?): Expr = toExpr(value, ::toExprOrConstant) ?: pojoToExprOrConstant(CustomClassMapper.convertToPlainJavaTypes(value)) + + private fun pojoToExprOrConstant(value: Any?): Expr = toExpr(value, ::pojoToExprOrConstant) ?: throw IllegalArgumentException("Unknown type: $value") + + private fun toExpr(value: Any?, toExpr: KFunction1): Expr? { + if (value == null) return Constant.nullValue() return when (value) { is Expr -> value - else -> Constant.of(value) + is String -> of(value) + is Number -> of(value) + is Date -> of(value) + is Timestamp -> of(value) + is Boolean -> of(value) + is GeoPoint -> of(value) + is Blob -> of(value) + is DocumentReference -> of(value) + is VectorValue -> of(value) + is Map<*, *> -> MapOfExpr(value.entries.associate { + val key = it.key + if (key is String) key to toExpr(it.value) else + throw IllegalArgumentException("Maps with non-string keys are not supported") + }) + is List<*> -> ListOfExprs(value.map(toExpr).toTypedArray()) + else -> null } } @@ -227,13 +255,13 @@ abstract class Expr protected constructor() { fun arrayLength() = Function.arrayLength(this) - fun sum() = Accumulator.sum(this) + fun sum() = AggregateExpr.sum(this) - fun avg() = Accumulator.avg(this) + fun avg() = AggregateExpr.avg(this) - fun min() = Accumulator.min(this) + fun min() = AggregateExpr.min(this) - fun max() = Accumulator.max(this) + fun max() = AggregateExpr.max(this) fun ascending() = Ordering.ascending(this) @@ -263,7 +291,7 @@ abstract class Expr protected constructor() { fun lte(other: Any) = Function.lte(this, other) - internal abstract fun toProto(): Value + internal abstract fun toProto(userDataReader: UserDataReader): Value } abstract class Selectable : Expr() { @@ -284,7 +312,7 @@ abstract class Selectable : Expr() { open class ExprWithAlias internal constructor(private val alias: String, private val expr: Expr) : Selectable() { override fun getAlias() = alias - override fun toProto(): Value = expr.toProto() + override fun toProto(userDataReader: UserDataReader): Value = expr.toProto(userDataReader) } class Field private constructor(private val fieldPath: ModelFieldPath) : @@ -310,20 +338,26 @@ class Field private constructor(private val fieldPath: ModelFieldPath) : override fun getAlias(): String = fieldPath.canonicalString() - override fun toProto() = + override fun toProto(userDataReader: UserDataReader) = toProto() + + internal fun toProto(): Value = Value.newBuilder().setFieldReferenceValue(fieldPath.canonicalString()).build() } -class ListOfExprs(private val expressions: Array) : Expr() { - override fun toProto(): Value { - val builder = ArrayValue.newBuilder() +class MapOfExpr(private val expressions: Map) : Expr() { + override fun toProto(userDataReader: UserDataReader): Value { + val builder = MapValue.newBuilder() for (expr in expressions) { - builder.addValues(expr.toProto()) + builder.putFields(expr.key, expr.value.toProto(userDataReader)) } - return Value.newBuilder().setArrayValue(builder).build() + return Value.newBuilder().setMapValue(builder).build() } } +class ListOfExprs(private val expressions: Array) : Expr() { + override fun toProto(userDataReader: UserDataReader): Value = encodeValue(expressions.map{it.toProto(userDataReader)}) +} + open class Function protected constructor(private val name: String, private val params: Array) : Expr() { private constructor(name: String, param: Expr, vararg params: Any) : this(name, arrayOf(param, *toArrayOfExprOrConstant(params))) @@ -738,11 +772,11 @@ protected constructor(private val name: String, private val params: Array) : fun not() = not(this) - fun countIf(): Accumulator = Accumulator.countIf(this) + fun countIf(): AggregateExpr = AggregateExpr.countIf(this) fun ifThen(then: Expr) = ifThen(this, then) @@ -786,12 +820,12 @@ class Ordering private constructor(private val expr: Expr, private val dir: Dire val DESCENDING = Direction("descending") } } - internal fun toProto(): Value = + internal fun toProto(userDataReader: UserDataReader): Value = Value.newBuilder() .setMapValue( MapValue.newBuilder() .putFields("direction", dir.proto) - .putFields("expression", expr.toProto()) + .putFields("expression", expr.toProto(userDataReader)) ) .build() } diff --git a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/stage.kt b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/stage.kt index 81fc2ef7499..fa8026fb1ad 100644 --- a/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/stage.kt +++ b/firebase-firestore/src/main/java/com/google/firebase/firestore/pipeline/stage.kt @@ -15,6 +15,7 @@ package com.google.firebase.firestore.pipeline import com.google.common.collect.ImmutableMap +import com.google.firebase.firestore.UserDataReader import com.google.firebase.firestore.model.Values.encodeValue import com.google.firebase.firestore.model.Values.encodeVectorValue import com.google.firestore.v1.Pipeline @@ -23,58 +24,101 @@ import com.google.firestore.v1.Value abstract class Stage internal constructor(private val name: String, private val options: Map) { internal constructor(name: String) : this(name, emptyMap()) - internal fun toProtoStage(): Pipeline.Stage { + internal fun toProtoStage(userDataReader: UserDataReader): Pipeline.Stage { val builder = Pipeline.Stage.newBuilder() builder.setName(name) - args().forEach { arg -> builder.addArgs(arg) } + args(userDataReader).forEach { arg -> builder.addArgs(arg) } builder.putAllOptions(options) return builder.build() } - protected abstract fun args(): Sequence + protected abstract fun args(userDataReader: UserDataReader): Sequence +} + +class GenericStage internal constructor(name: String, private val params: List) : + Stage(name) { + override fun args(userDataReader: UserDataReader): Sequence = + params.asSequence().map { it.toProto(userDataReader) } +} + +internal sealed class GenericArg { + companion object { + fun from(arg: Any?): GenericArg = + when (arg) { + is AggregateExpr -> AggregateArg(arg) + is Ordering -> OrderingArg(arg) + is Map<*, *> -> MapArg(arg.asIterable().associate { it.key as String to from(it.value) }) + is List<*> -> ListArg(arg.map(::from)) + else -> ExprArg(Expr.toExprOrConstant(arg)) + } + } + abstract fun toProto(userDataReader: UserDataReader): Value + + data class AggregateArg(val aggregate: AggregateExpr) : GenericArg() { + override fun toProto(userDataReader: UserDataReader) = aggregate.toProto(userDataReader) + } + + data class ExprArg(val expr: Expr) : GenericArg() { + override fun toProto(userDataReader: UserDataReader) = expr.toProto(userDataReader) + } + + data class OrderingArg(val ordering: Ordering) : GenericArg() { + override fun toProto(userDataReader: UserDataReader) = ordering.toProto(userDataReader) + } + + data class MapArg(val args: Map) : GenericArg() { + override fun toProto(userDataReader: UserDataReader) = + encodeValue(args.mapValues { it.value.toProto(userDataReader) }) + } + + data class ListArg(val args: List) : GenericArg() { + override fun toProto(userDataReader: UserDataReader) = + encodeValue(args.map { it.toProto(userDataReader) }) + } } class DatabaseSource : Stage("database") { - override fun args(): Sequence = emptySequence() + override fun args(userDataReader: UserDataReader): Sequence = emptySequence() } class CollectionSource internal constructor(path: String) : Stage("collection") { private val path: String = if (path.startsWith("/")) path else "/" + path - override fun args(): Sequence = + override fun args(userDataReader: UserDataReader): Sequence = sequenceOf(Value.newBuilder().setReferenceValue(path).build()) } class CollectionGroupSource internal constructor(val collectionId: String) : Stage("collection_group") { - override fun args(): Sequence = + override fun args(userDataReader: UserDataReader): Sequence = sequenceOf(Value.newBuilder().setReferenceValue("").build(), encodeValue(collectionId)) } class DocumentsSource internal constructor(private val documents: Array) : Stage("documents") { - override fun args(): Sequence = documents.asSequence().map(::encodeValue) + override fun args(userDataReader: UserDataReader): Sequence = + documents.asSequence().map(::encodeValue) } class AddFieldsStage internal constructor(private val fields: Array) : Stage("add_fields") { - override fun args(): Sequence = - sequenceOf(encodeValue(fields.associate { it.getAlias() to it.toProto() })) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(encodeValue(fields.associate { it.getAlias() to it.toProto(userDataReader) })) } class AggregateStage internal constructor( - private val accumulators: Map, + private val accumulators: Map, private val groups: Map ) : Stage("aggregate") { - private constructor(accumulators: Map) : this(accumulators, emptyMap()) + private constructor(accumulators: Map) : this(accumulators, emptyMap()) companion object { @JvmStatic - fun withAccumulators(vararg accumulators: AccumulatorWithAlias): AggregateStage { + fun withAccumulators(vararg accumulators: AggregateWithAlias): AggregateStage { if (accumulators.isEmpty()) { throw IllegalArgumentException( "Must specify at least one accumulator for aggregate() stage. There is a distinct() stage if only distinct group values are needed." ) } - return AggregateStage(accumulators.associate { it.alias to it.accumulator }) + return AggregateStage(accumulators.associate { it.alias to it.expr }) } } @@ -90,15 +134,16 @@ internal constructor( selectable.map(Selectable::toSelectable).associateBy(Selectable::getAlias) ) - override fun args(): Sequence = + override fun args(userDataReader: UserDataReader): Sequence = sequenceOf( - encodeValue(accumulators.mapValues { entry -> entry.value.toProto() }), - encodeValue(groups.mapValues { entry -> entry.value.toProto() }) + encodeValue(accumulators.mapValues { entry -> entry.value.toProto(userDataReader) }), + encodeValue(groups.mapValues { entry -> entry.value.toProto(userDataReader) }) ) } class WhereStage internal constructor(private val condition: BooleanExpr) : Stage("where") { - override fun args(): Sequence = sequenceOf(condition.toProto()) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(condition.toProto(userDataReader)) } class FindNearestStage @@ -118,8 +163,8 @@ internal constructor( } } - override fun args(): Sequence = - sequenceOf(property.toProto(), encodeVectorValue(vector), distanceMeasure.proto) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(property.toProto(userDataReader), encodeVectorValue(vector), distanceMeasure.proto) } class FindNearestOptions @@ -137,32 +182,36 @@ internal constructor(private val limit: Long?, private val distanceField: Field? } class LimitStage internal constructor(private val limit: Long) : Stage("limit") { - override fun args(): Sequence = sequenceOf(encodeValue(limit)) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(encodeValue(limit)) } class OffsetStage internal constructor(private val offset: Long) : Stage("offset") { - override fun args(): Sequence = sequenceOf(encodeValue(offset)) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(encodeValue(offset)) } class SelectStage internal constructor(private val fields: Array) : Stage("select") { - override fun args(): Sequence = - sequenceOf(encodeValue(fields.associate { it.getAlias() to it.toProto() })) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(encodeValue(fields.associate { it.getAlias() to it.toProto(userDataReader) })) } class SortStage internal constructor(private val orders: Array) : Stage("sort") { - override fun args(): Sequence = orders.asSequence().map(Ordering::toProto) + override fun args(userDataReader: UserDataReader): Sequence = + orders.asSequence().map { it.toProto(userDataReader) } } class DistinctStage internal constructor(private val groups: Array) : Stage("distinct") { - override fun args(): Sequence = - sequenceOf(encodeValue(groups.associate { it.getAlias() to it.toProto() })) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(encodeValue(groups.associate { it.getAlias() to it.toProto(userDataReader) })) } class RemoveFieldsStage internal constructor(private val fields: Array) : Stage("remove_fields") { - override fun args(): Sequence = fields.asSequence().map(Field::toProto) + override fun args(userDataReader: UserDataReader): Sequence = + fields.asSequence().map(Field::toProto) } class ReplaceStage internal constructor(private val field: Selectable, private val mode: Mode) : @@ -175,7 +224,8 @@ class ReplaceStage internal constructor(private val field: Selectable, private v val MERGE_PREFER_PARENT = Mode("merge_prefer_parent") } } - override fun args(): Sequence = sequenceOf(field.toProto(), mode.proto) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(field.toProto(userDataReader), mode.proto) } class SampleStage internal constructor(private val size: Number, private val mode: Mode) : @@ -187,16 +237,17 @@ class SampleStage internal constructor(private val size: Number, private val mod val PERCENT = Mode("percent") } } - override fun args(): Sequence = sequenceOf(encodeValue(size), mode.proto) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(encodeValue(size), mode.proto) } class UnionStage internal constructor(private val other: com.google.firebase.firestore.Pipeline) : Stage("union") { - override fun args(): Sequence = + override fun args(userDataReader: UserDataReader): Sequence = sequenceOf(Value.newBuilder().setPipelineValue(other.toPipelineProto()).build()) } class UnnestStage internal constructor(private val selectable: Selectable) : Stage("unnest") { - override fun args(): Sequence = - sequenceOf(encodeValue(selectable.getAlias()), selectable.toProto()) + override fun args(userDataReader: UserDataReader): Sequence = + sequenceOf(encodeValue(selectable.getAlias()), selectable.toProto(userDataReader)) }