Skip to content

Commit

Permalink
Generic Stage and Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersen committed Feb 28, 2025
1 parent 63cd54e commit c7605b6
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@
import static com.google.firebase.firestore.pipeline.Function.subtract;
import static com.google.firebase.firestore.pipeline.Ordering.ascending;
import static com.google.firebase.firestore.testutil.IntegrationTestUtil.waitFor;
import static java.util.Map.entry;

import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.android.gms.tasks.Task;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.truth.Correspondence;
import com.google.firebase.firestore.pipeline.Accumulator;
import com.google.firebase.firestore.pipeline.AggregateExpr;
import com.google.firebase.firestore.pipeline.AggregateStage;
import com.google.firebase.firestore.pipeline.Constant;
import com.google.firebase.firestore.pipeline.Field;
Expand Down Expand Up @@ -227,7 +226,7 @@ public void aggregateResultsCountAll() {
firestore
.pipeline()
.collection(randomCol)
.aggregate(Accumulator.countAll().as("count"))
.aggregate(AggregateExpr.countAll().as("count"))
.execute();
assertThat(waitFor(execute).getResults())
.comparingElementsUsing(DATA_CORRESPONDENCE)
Expand All @@ -243,8 +242,8 @@ public void aggregateResultsMany() {
.collection(randomCol)
.where(Function.eq("genre", "Science Fiction"))
.aggregate(
Accumulator.countAll().as("count"),
Accumulator.avg("rating").as("avgRating"),
AggregateExpr.countAll().as("count"),
AggregateExpr.avg("rating").as("avgRating"),
Field.of("rating").max().as("maxRating"))
.execute();
assertThat(waitFor(execute).getResults())
Expand All @@ -261,7 +260,7 @@ public void groupAndAccumulateResults() {
.collection(randomCol)
.where(lt(Field.of("published"), 1984))
.aggregate(
AggregateStage.withAccumulators(Accumulator.avg("rating").as("avgRating"))
AggregateStage.withAccumulators(AggregateExpr.avg("rating").as("avgRating"))
.withGroups("genre"))
.where(gt("avgRating", 4.3))
.sort(Field.of("avgRating").descending())
Expand All @@ -274,6 +273,28 @@ public void groupAndAccumulateResults() {
mapOfEntries(entry("avgRating", 4.4), entry("genre", "Science Fiction")));
}

@Test
public void groupAndAccumulateResultsGeneric() {
Task<PipelineSnapshot> execute =
firestore
.pipeline()
.collection(randomCol)
.genericStage("where", lt(Field.of("published"), 1984))
.genericStage(
"aggregate",
ImmutableMap.of("avgRating", AggregateExpr.avg("rating")),
ImmutableMap.of("genre", Field.of("genre")))
.genericStage("where", gt("avgRating", 4.3))
.genericStage("sort", Field.of("avgRating").descending())
.execute();
assertThat(waitFor(execute).getResults())
.comparingElementsUsing(DATA_CORRESPONDENCE)
.containsExactly(
mapOfEntries(entry("avgRating", 4.7), entry("genre", "Fantasy")),
mapOfEntries(entry("avgRating", 4.5), entry("genre", "Romance")),
mapOfEntries(entry("avgRating", 4.4), entry("genre", "Science Fiction")));
}

@Test
@Ignore("Not supported yet")
public void minAndMaxAccumulations() {
Expand All @@ -282,7 +303,7 @@ public void minAndMaxAccumulations() {
.pipeline()
.collection(randomCol)
.aggregate(
Accumulator.countAll().as("count"),
AggregateExpr.countAll().as("count"),
Field.of("rating").max().as("maxRating"),
Field.of("published").min().as("minPublished"))
.execute();
Expand Down Expand Up @@ -781,6 +802,30 @@ public void testMapGetWithFieldNameIncludingNotation() {
entry("nested", null)));
}

@Test
public void testListEquals() {
Task<PipelineSnapshot> execute =
randomCol
.pipeline()
.where(eq("tags", ImmutableList.of("philosophy", "crime", "redemption")))
.execute();
assertThat(waitFor(execute).getResults())
.comparingElementsUsing(ID_CORRESPONDENCE)
.containsExactly("book6");
}

@Test
public void testMapEquals() {
Task<PipelineSnapshot> execute =
randomCol
.pipeline()
.where(eq("awards", ImmutableMap.of("nobel", true, "nebula", false)))
.execute();
assertThat(waitFor(execute).getResults())
.comparingElementsUsing(ID_CORRESPONDENCE)
.containsExactly("book3");
}

static <T> Map.Entry<String, T> entry(String key, T value) {
return new Map.Entry<String, T>() {
private String k = key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import androidx.annotation.Keep;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RestrictTo;
import androidx.annotation.VisibleForTesting;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.TaskCompletionSource;
Expand Down Expand Up @@ -855,7 +856,9 @@ DatabaseId getDatabaseId() {
return databaseId;
}

UserDataReader getUserDataReader() {
@NonNull
@RestrictTo(RestrictTo.Scope.LIBRARY)
public UserDataReader getUserDataReader() {
return userDataReader;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ import com.google.common.collect.FluentIterable
import com.google.common.collect.ImmutableList
import com.google.firebase.firestore.model.DocumentKey
import com.google.firebase.firestore.model.SnapshotVersion
import com.google.firebase.firestore.pipeline.AccumulatorWithAlias
import com.google.firebase.firestore.pipeline.AddFieldsStage
import com.google.firebase.firestore.pipeline.AggregateStage
import com.google.firebase.firestore.pipeline.AggregateWithAlias
import com.google.firebase.firestore.pipeline.BooleanExpr
import com.google.firebase.firestore.pipeline.CollectionGroupSource
import com.google.firebase.firestore.pipeline.CollectionSource
import com.google.firebase.firestore.pipeline.DatabaseSource
import com.google.firebase.firestore.pipeline.DistinctStage
import com.google.firebase.firestore.pipeline.DocumentsSource
import com.google.firebase.firestore.pipeline.Field
import com.google.firebase.firestore.pipeline.GenericArg
import com.google.firebase.firestore.pipeline.GenericStage
import com.google.firebase.firestore.pipeline.LimitStage
import com.google.firebase.firestore.pipeline.OffsetStage
import com.google.firebase.firestore.pipeline.Ordering
Expand Down Expand Up @@ -84,9 +86,12 @@ internal constructor(

internal fun toPipelineProto(): com.google.firestore.v1.Pipeline =
com.google.firestore.v1.Pipeline.newBuilder()
.addAllStages(stages.map(Stage::toProtoStage))
.addAllStages(stages.map { it.toProtoStage(firestore.userDataReader) })
.build()

fun genericStage(name: String, vararg params: Any) =
append(GenericStage(name, params.map(GenericArg::from)))

fun addFields(vararg fields: Selectable): Pipeline = append(AddFieldsStage(fields))

fun removeFields(vararg fields: Field): Pipeline = append(RemoveFieldsStage(fields))
Expand Down Expand Up @@ -118,7 +123,7 @@ internal constructor(
fun distinct(vararg groups: Any): Pipeline =
append(DistinctStage(groups.map(Selectable::toSelectable).toTypedArray()))

fun aggregate(vararg accumulators: AccumulatorWithAlias): Pipeline =
fun aggregate(vararg accumulators: AggregateWithAlias): Pipeline =
append(AggregateStage.withAccumulators(*accumulators))

fun aggregate(aggregateStage: AggregateStage): Pipeline = append(aggregateStage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import androidx.annotation.Nullable;
import androidx.annotation.RestrictTo;
import com.google.common.base.Function;
import com.google.firebase.firestore.FieldValue.ArrayRemoveFieldValue;
import com.google.firebase.firestore.FieldValue.ArrayUnionFieldValue;
import com.google.firebase.firestore.FieldValue.DeleteFieldValue;
Expand All @@ -36,6 +37,7 @@
import com.google.firebase.firestore.model.mutation.FieldMask;
import com.google.firebase.firestore.model.mutation.NumericIncrementTransformOperation;
import com.google.firebase.firestore.model.mutation.ServerTimestampOperation;
import com.google.firebase.firestore.pipeline.Expr;
import com.google.firebase.firestore.util.Assert;
import com.google.firebase.firestore.util.CustomClassMapper;
import com.google.firebase.firestore.util.Util;
Expand Down Expand Up @@ -389,6 +391,12 @@ public Value parseScalarValue(Object input, ParseContext context) {
return Values.NULL_VALUE;
} else if (input.getClass().isArray()) {
throw context.createError("Arrays are not supported; use a List instead");
} else if (input instanceof DocumentReference) {
DocumentReference ref = (DocumentReference) input;
validateDocumentReference(ref, context::createError);
return Values.encodeValue(ref);
} else if (input instanceof Expr) {
throw context.createError("Pipeline expressions are not supported user objects");
} else {
try {
return Values.encodeAnyValue(input);
Expand All @@ -398,6 +406,20 @@ public Value parseScalarValue(Object input, ParseContext context) {
}
}

public void validateDocumentReference(
DocumentReference ref, Function<String, RuntimeException> createError) {
DatabaseId otherDb = ref.getFirestore().getDatabaseId();
if (!otherDb.equals(databaseId)) {
throw createError.apply(
String.format(
"Document reference is for database %s/%s but should be for database %s/%s",
otherDb.getProjectId(),
otherDb.getDatabaseId(),
databaseId.getProjectId(),
databaseId.getDatabaseId()));
}
}

private List<Value> parseArrayTransformElements(List<Object> elements) {
ParseAccumulator accumulator = new ParseAccumulator(UserData.Source.Argument);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,7 @@ internal object Values {

return Value.newBuilder()
.setTimestampValue(
com.google.protobuf.Timestamp.newBuilder()
.setSeconds(timestamp.seconds)
.setNanos(truncatedNanoseconds)
Timestamp.newBuilder().setSeconds(timestamp.seconds).setNanos(truncatedNanoseconds)
)
.build()
}
Expand Down Expand Up @@ -665,6 +663,11 @@ internal object Values {
return Value.newBuilder().setMapValue(MapValue.newBuilder().putAllFields(map)).build()
}

@JvmStatic
fun encodeValue(values: Iterable<Value>): Value {
return Value.newBuilder().setArrayValue(ArrayValue.newBuilder().addAllValues(values)).build()
}

@JvmStatic
fun encodeAnyValue(value: Any?): Value {
return when (value) {
Expand All @@ -676,7 +679,6 @@ internal object Values {
is Boolean -> encodeValue(value)
is GeoPoint -> encodeValue(value)
is Blob -> encodeValue(value)
is DocumentReference -> encodeValue(value)
is VectorValue -> encodeValue(value)
else -> throw IllegalArgumentException("Unexpected type: $value")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,76 +18,70 @@ import com.google.firebase.Timestamp
import com.google.firebase.firestore.Blob
import com.google.firebase.firestore.DocumentReference
import com.google.firebase.firestore.GeoPoint
import com.google.firebase.firestore.UserDataReader
import com.google.firebase.firestore.VectorValue
import com.google.firebase.firestore.model.Values
import com.google.firebase.firestore.model.Values.encodeValue
import com.google.firestore.v1.Value
import java.util.Date

class Constant internal constructor(val value: Value) : Expr() {
abstract class Constant internal constructor() : Expr() {

private class ValueConstant(val value: Value) : Constant() {
override fun toProto(userDataReader: UserDataReader): Value = value
}

companion object {
internal val NULL = Constant(Values.NULL_VALUE)

fun of(value: Any): Constant {
return when (value) {
is String -> of(value)
is Number -> of(value)
is Date -> of(value)
is Timestamp -> of(value)
is Boolean -> of(value)
is GeoPoint -> of(value)
is Blob -> of(value)
is DocumentReference -> of(value)
is Value -> of(value)
is VectorValue -> of(value)
else -> throw IllegalArgumentException("Unknown type: $value")
}
}
internal val NULL: Constant = ValueConstant(Values.NULL_VALUE)

@JvmStatic
fun of(value: String): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
fun of(value: Number): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
fun of(value: Date): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
fun of(value: Timestamp): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
fun of(value: Boolean): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
fun of(value: GeoPoint): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
fun of(value: Blob): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
fun of(value: DocumentReference): Constant {
return Constant(encodeValue(value))
fun of(ref: DocumentReference): Constant {
return object : Constant() {
override fun toProto(userDataReader: UserDataReader): Value {
userDataReader.validateDocumentReference(ref, ::IllegalArgumentException)
return encodeValue(ref)
}
}
}

@JvmStatic
fun of(value: VectorValue): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}

@JvmStatic
Expand All @@ -97,16 +91,12 @@ class Constant internal constructor(val value: Value) : Expr() {

@JvmStatic
fun vector(value: DoubleArray): Constant {
return Constant(Values.encodeVectorValue(value))
return ValueConstant(Values.encodeVectorValue(value))
}

@JvmStatic
fun vector(value: VectorValue): Constant {
return Constant(encodeValue(value))
return ValueConstant(encodeValue(value))
}
}

override fun toProto(): Value {
return value
}
}
Loading

0 comments on commit c7605b6

Please sign in to comment.