diff --git a/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java b/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java index c1e4dd90..0ac83ac5 100644 --- a/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java +++ b/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java @@ -10,6 +10,7 @@ import io.weaviate.client.v1.async.classifications.Classifications; import io.weaviate.client.v1.async.cluster.Cluster; import io.weaviate.client.v1.async.data.Data; +import io.weaviate.client.v1.async.graphql.GraphQL; import io.weaviate.client.v1.async.misc.Misc; import io.weaviate.client.v1.async.schema.Schema; import io.weaviate.client.v1.misc.model.Meta; @@ -60,6 +61,10 @@ public Backup backup() { return new Backup(client, config); } + public GraphQL graphQL() { + return new GraphQL(client, config); + } + private DbVersionProvider initDbVersionProvider() { DbVersionProvider.VersionGetter getter = () -> Optional.ofNullable(this.getMeta()) diff --git a/src/main/java/io/weaviate/client/v1/async/graphql/GraphQL.java b/src/main/java/io/weaviate/client/v1/async/graphql/GraphQL.java new file mode 100644 index 00000000..544115e6 --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/async/graphql/GraphQL.java @@ -0,0 +1,38 @@ +package io.weaviate.client.v1.async.graphql; + +import io.weaviate.client.Config; +import io.weaviate.client.v1.async.graphql.api.Aggregate; +import io.weaviate.client.v1.async.graphql.api.Explore; +import io.weaviate.client.v1.async.graphql.api.Get; +import io.weaviate.client.v1.async.graphql.api.Raw; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; + +public class GraphQL { + private final Config config; + private final CloseableHttpAsyncClient client; + + public GraphQL(CloseableHttpAsyncClient client, Config config) { + this.client = client; + this.config = config; + } + + public Get get() { + return new Get(client, config); + } + + public Raw raw() { + return new Raw(client, config); + } + + public Explore explore() { + return new Explore(client, config); + } + + public Aggregate aggregate() { + return new Aggregate(client, config); + } + + public io.weaviate.client.v1.graphql.GraphQL.Arguments arguments() { + return new io.weaviate.client.v1.graphql.GraphQL.Arguments(); + } +} diff --git a/src/main/java/io/weaviate/client/v1/async/graphql/api/Aggregate.java b/src/main/java/io/weaviate/client/v1/async/graphql/api/Aggregate.java new file mode 100644 index 00000000..4e9889d2 --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/async/graphql/api/Aggregate.java @@ -0,0 +1,126 @@ +package io.weaviate.client.v1.async.graphql.api; + +import io.weaviate.client.Config; +import io.weaviate.client.base.AsyncBaseClient; +import io.weaviate.client.base.AsyncClientResult; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLQuery; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.*; +import io.weaviate.client.v1.graphql.query.builder.AggregateBuilder; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.client.v1.graphql.query.fields.Fields; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; + +import java.util.concurrent.Future; + +public class Aggregate extends AsyncBaseClient implements AsyncClientResult { + private final AggregateBuilder.AggregateBuilderBuilder aggregateBuilder; + + public Aggregate(CloseableHttpAsyncClient client, Config config) { + super(client, config); + aggregateBuilder = AggregateBuilder.builder(); + } + + public Aggregate withClassName(String className) { + aggregateBuilder.className(className); + return this; + } + + public Aggregate withFields(Field... fields) { + aggregateBuilder.fields(Fields.builder() + .fields(fields) + .build()); + return this; + } + + @Deprecated + public Aggregate withWhere(WhereFilter where) { + return withWhere(WhereArgument.builder() + .filter(where) + .build()); + } + + public Aggregate withWhere(WhereArgument where) { + aggregateBuilder.withWhereFilter(where); + return this; + } + + public Aggregate withGroupBy(String propertyName) { + aggregateBuilder.groupByClausePropertyName(propertyName); + return this; + } + + public Aggregate withAsk(AskArgument ask) { + aggregateBuilder.withAskArgument(ask); + return this; + } + + public Aggregate withNearText(NearTextArgument withNearTextFilter) { + aggregateBuilder.withNearTextFilter(withNearTextFilter); + return this; + } + + public Aggregate withNearObject(NearObjectArgument withNearObjectFilter) { + aggregateBuilder.withNearObjectFilter(withNearObjectFilter); + return this; + } + + public Aggregate withNearVector(NearVectorArgument withNearVectorFilter) { + aggregateBuilder.withNearVectorFilter(withNearVectorFilter); + return this; + } + + public Aggregate withNearImage(NearImageArgument nearImage) { + aggregateBuilder.withNearImageFilter(nearImage); + return this; + } + + public Aggregate withNearAudio(NearAudioArgument nearAudio) { + aggregateBuilder.withNearAudioFilter(nearAudio); + return this; + } + + public Aggregate withNearVideo(NearVideoArgument nearVideo) { + aggregateBuilder.withNearVideoFilter(nearVideo); + return this; + } + + public Aggregate withNearDepth(NearDepthArgument nearDepth) { + aggregateBuilder.withNearDepthFilter(nearDepth); + return this; + } + + public Aggregate withNearThermal(NearThermalArgument nearThermal) { + aggregateBuilder.withNearThermalFilter(nearThermal); + return this; + } + + public Aggregate withNearImu(NearImuArgument nearImu) { + aggregateBuilder.withNearImuFilter(nearImu); + return this; + } + + public Aggregate withObjectLimit(Integer objectLimit) { + aggregateBuilder.objectLimit(objectLimit); + return this; + } + + public Aggregate withTenant(String tenant) { + aggregateBuilder.tenant(tenant); + return this; + } + + @Override + public Future> run(FutureCallback> callback) { + String aggregateQuery = aggregateBuilder.build() + .buildQuery(); + GraphQLQuery query = GraphQLQuery.builder() + .query(aggregateQuery) + .build(); + return sendPostRequest("/graphql", query, GraphQLResponse.class, callback); + } + +} diff --git a/src/main/java/io/weaviate/client/v1/async/graphql/api/Explore.java b/src/main/java/io/weaviate/client/v1/async/graphql/api/Explore.java new file mode 100644 index 00000000..58b60e8a --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/async/graphql/api/Explore.java @@ -0,0 +1,99 @@ +package io.weaviate.client.v1.async.graphql.api; + +import io.weaviate.client.Config; +import io.weaviate.client.base.AsyncBaseClient; +import io.weaviate.client.base.AsyncClientResult; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.graphql.model.ExploreFields; +import io.weaviate.client.v1.graphql.model.GraphQLQuery; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.*; +import io.weaviate.client.v1.graphql.query.builder.ExploreBuilder; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; + +import java.util.concurrent.Future; + +public class Explore extends AsyncBaseClient implements AsyncClientResult { + private final ExploreBuilder.ExploreBuilderBuilder exploreBuilder; + + public Explore(CloseableHttpAsyncClient client, Config config) { + super(client, config); + exploreBuilder = ExploreBuilder.builder(); + } + + public Explore withFields(ExploreFields... fields) { + exploreBuilder.fields(fields); + return this; + } + + public Explore withLimit(Integer limit) { + exploreBuilder.limit(limit); + return this; + } + + public Explore withOffset(Integer offset) { + exploreBuilder.offset(offset); + return this; + } + + public Explore withAsk(AskArgument ask) { + exploreBuilder.withAskArgument(ask); + return this; + } + + public Explore withNearText(NearTextArgument nearText) { + exploreBuilder.withNearText(nearText); + return this; + } + + public Explore withNearObject(NearObjectArgument nearObject) { + exploreBuilder.withNearObjectFilter(nearObject); + return this; + } + + public Explore withNearVector(NearVectorArgument nearVector) { + exploreBuilder.withNearVectorFilter(nearVector); + return this; + } + + public Explore withNearImage(NearImageArgument nearImage) { + exploreBuilder.withNearImageFilter(nearImage); + return this; + } + + public Explore withNearAudio(NearAudioArgument nearAudio) { + exploreBuilder.withNearAudioFilter(nearAudio); + return this; + } + + public Explore withNearVideo(NearVideoArgument nearVideo) { + exploreBuilder.withNearVideoFilter(nearVideo); + return this; + } + + public Explore withNearDepth(NearDepthArgument nearDepth) { + exploreBuilder.withNearDepthFilter(nearDepth); + return this; + } + + public Explore withNearThermal(NearThermalArgument nearThermal) { + exploreBuilder.withNearThermalFilter(nearThermal); + return this; + } + + public Explore withNearImu(NearImuArgument nearImu) { + exploreBuilder.withNearImuFilter(nearImu); + return this; + } + + @Override + public Future> run(FutureCallback> callback) { + String exploreQuery = exploreBuilder.build() + .buildQuery(); + GraphQLQuery query = GraphQLQuery.builder() + .query(exploreQuery) + .build(); + return sendPostRequest("/graphql", query, GraphQLResponse.class, callback); + } +} diff --git a/src/main/java/io/weaviate/client/v1/async/graphql/api/Get.java b/src/main/java/io/weaviate/client/v1/async/graphql/api/Get.java new file mode 100644 index 00000000..f7daf26c --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/async/graphql/api/Get.java @@ -0,0 +1,173 @@ +package io.weaviate.client.v1.async.graphql.api; + +import io.weaviate.client.Config; +import io.weaviate.client.base.AsyncBaseClient; +import io.weaviate.client.base.AsyncClientResult; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLQuery; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.*; +import io.weaviate.client.v1.graphql.query.builder.GetBuilder; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.client.v1.graphql.query.fields.Fields; +import io.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; + +import java.util.concurrent.Future; + +public class Get extends AsyncBaseClient implements AsyncClientResult { + private final GetBuilder.GetBuilderBuilder getBuilder; + + public Get(CloseableHttpAsyncClient client, Config config) { + super(client, config); + getBuilder = GetBuilder.builder(); + } + + public Get withClassName(String className) { + getBuilder.className(className); + return this; + } + + public Get withFields(Field... fields) { + getBuilder.fields(Fields.builder() + .fields(fields) + .build()); + return this; + } + + @Deprecated + public Get withWhere(WhereFilter where) { + return withWhere(WhereArgument.builder() + .filter(where) + .build()); + } + + public Get withWhere(WhereArgument where) { + getBuilder.withWhereFilter(where); + return this; + } + + public Get withLimit(Integer limit) { + getBuilder.limit(limit); + return this; + } + + public Get withOffset(Integer offset) { + getBuilder.offset(offset); + return this; + } + + public Get withAfter(String after) { + getBuilder.after(after); + return this; + } + + public Get withBm25(Bm25Argument bm25) { + getBuilder.withBm25Filter(bm25); + return this; + } + + public Get withHybrid(HybridArgument hybrid) { + getBuilder.withHybridFilter(hybrid); + return this; + } + + public Get withAsk(AskArgument ask) { + getBuilder.withAskArgument(ask); + return this; + } + + public Get withNearText(NearTextArgument nearText) { + getBuilder.withNearTextFilter(nearText); + return this; + } + + public Get withNearObject(NearObjectArgument nearObject) { + getBuilder.withNearObjectFilter(nearObject); + return this; + } + + public Get withNearVector(NearVectorArgument nearVector) { + getBuilder.withNearVectorFilter(nearVector); + return this; + } + + public Get withNearImage(NearImageArgument nearImage) { + getBuilder.withNearImageFilter(nearImage); + return this; + } + + public Get withNearAudio(NearAudioArgument nearAudio) { + getBuilder.withNearAudioFilter(nearAudio); + return this; + } + + public Get withNearVideo(NearVideoArgument nearVideo) { + getBuilder.withNearVideoFilter(nearVideo); + return this; + } + + public Get withNearDepth(NearDepthArgument nearDepth) { + getBuilder.withNearDepthFilter(nearDepth); + return this; + } + + public Get withNearThermal(NearThermalArgument nearThermal) { + getBuilder.withNearThermalFilter(nearThermal); + return this; + } + + public Get withNearImu(NearImuArgument nearImu) { + getBuilder.withNearImuFilter(nearImu); + return this; + } + + public Get withGroup(GroupArgument group) { + getBuilder.withGroupArgument(group); + return this; + } + + public Get withSort(SortArgument... sort) { + getBuilder.withSortArguments(SortArguments.builder() + .sort(sort) + .build()); + return this; + } + + public Get withGenerativeSearch(GenerativeSearchBuilder generativeSearch) { + getBuilder.withGenerativeSearch(generativeSearch); + return this; + } + + public Get withConsistencyLevel(String level) { + getBuilder.withConsistencyLevel(level); + return this; + } + + public Get withGroupBy(GroupByArgument groupBy) { + getBuilder.withGroupByArgument(groupBy); + return this; + } + + public Get withTenant(String tenant) { + getBuilder.tenant(tenant); + return this; + } + + public Get withAutocut(Integer autocut) { + getBuilder.autocut(autocut); + return this; + } + + @Override + public Future> run(FutureCallback> callback) { + String getQuery = getBuilder.build() + .buildQuery(); + GraphQLQuery query = GraphQLQuery.builder() + .query(getQuery) + .build(); + return sendPostRequest("/graphql", query, GraphQLResponse.class, callback); + } +} diff --git a/src/main/java/io/weaviate/client/v1/async/graphql/api/Raw.java b/src/main/java/io/weaviate/client/v1/async/graphql/api/Raw.java new file mode 100644 index 00000000..7e943a4b --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/async/graphql/api/Raw.java @@ -0,0 +1,33 @@ +package io.weaviate.client.v1.async.graphql.api; + +import io.weaviate.client.Config; +import io.weaviate.client.base.AsyncBaseClient; +import io.weaviate.client.base.AsyncClientResult; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.graphql.model.GraphQLQuery; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; + +import java.util.concurrent.Future; + +public class Raw extends AsyncBaseClient implements AsyncClientResult { + private String query; + + public Raw(CloseableHttpAsyncClient client, Config config) { + super(client, config); + } + + public Raw withQuery(String query) { + this.query = query; + return this; + } + + @Override + public Future> run(FutureCallback> callback) { + GraphQLQuery query = GraphQLQuery.builder() + .query(this.query) + .build(); + return sendPostRequest("/graphql", query, GraphQLResponse.class, callback); + } +} diff --git a/src/main/java/io/weaviate/client/v1/graphql/GraphQL.java b/src/main/java/io/weaviate/client/v1/graphql/GraphQL.java index b259aa31..7ff814a9 100644 --- a/src/main/java/io/weaviate/client/v1/graphql/GraphQL.java +++ b/src/main/java/io/weaviate/client/v1/graphql/GraphQL.java @@ -27,7 +27,7 @@ public class GraphQL { private Config config; private HttpClient httpClient; - public class Arguments { + public static class Arguments { public NearTextArgument.NearTextArgumentBuilder nearTextArgBuilder() { return NearTextArgument.builder(); } diff --git a/src/test/java/io/weaviate/integration/client/async/graphql/AbstractAsyncClientTest.java b/src/test/java/io/weaviate/integration/client/async/graphql/AbstractAsyncClientTest.java new file mode 100644 index 00000000..c21a0c35 --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/graphql/AbstractAsyncClientTest.java @@ -0,0 +1,102 @@ +package io.weaviate.integration.client.async.graphql; + +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.WhereArgument; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.integration.client.graphql.AbstractClientGraphQLTest; + +import java.util.Date; +import java.util.Map; + +import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class AbstractAsyncClientTest extends AbstractClientGraphQLTest { + static Field field(String name) { + return Field.builder().name(name).build(); + } + + static Field[] fields(String... fieldNames) { + Field[] fields = new Field[fieldNames.length]; + for (int i = 0; i < fieldNames.length; i++) { + fields[i] = field(fieldNames[i]); + } + return fields; + } + + static Field _additional(String... fieldNames) { + return Field.builder().name("_additional").fields(fields(fieldNames)).build(); + } + + static Field meta(String... fieldNames) { + return Field.builder().name("meta").fields(fields(fieldNames)).build(); + } + + + static WhereArgument whereText(String property, String operator, String... valueText) { + return WhereArgument.builder() + .filter(WhereFilter.builder() + .path(property) + .operator(operator) + .valueText(valueText) + .build()) + .build(); + } + + static WhereArgument whereDate(String property, String operator, Date... valueDate) { + return WhereArgument.builder() + .filter(WhereFilter.builder() + .path(property) + .operator(operator) + .valueDate(valueDate) + .build()) + .build(); + } + + static WhereArgument whereNumber(String property, String operator, Double... valueNumber) { + return WhereArgument.builder() + .filter(WhereFilter.builder() + .path(property) + .operator(operator) + .valueNumber(valueNumber) + .build()) + .build(); + } + + /** + * Check that request was processed successfully and no errors are returned. Extract the part of the response body for the specified query type. + * + * @param result Result of a GraphQL query. + * @param queryType "Get", "Explore", or "Aggregate". + * @return "data" portion of the response + */ + @SuppressWarnings("unchecked") + T extractQueryResult(Result result, String queryType) { + assertNotNull(result, "graphQL request returned null"); + assertNull("GraphQL error in the response", result.getError()); + + GraphQLResponse resp = result.getResult(); + assertNotNull(resp, "GraphQL response not returned"); + + Map data = (Map) resp.getData(); + assertNotNull(data, "GraphQL response has no data"); + + T queryResult = (T) data.get(queryType); + assertNotNull(queryResult, String.format("%s query returned no result", queryType)); + + return queryResult; + } + + T extractClass(Result result, String queryType, String className) { + Map queryResult = extractQueryResult(result, queryType); + return extractClass(queryResult, className); + } + + T extractClass(Map queryResult, String className) { + T objects = queryResult.get(className); + assertNotNull(objects, String.format("no %ss returned", className.toLowerCase())); + return objects; + } +} diff --git a/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLMultiTargetSearchTest.java b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLMultiTargetSearchTest.java new file mode 100644 index 00000000..8b3ae714 --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLMultiTargetSearchTest.java @@ -0,0 +1,367 @@ +package io.weaviate.integration.client.async.graphql; + +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.async.graphql.GraphQL; +import io.weaviate.client.v1.async.graphql.api.Get; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.client.v1.data.model.WeaviateObject; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.NearObjectArgument; +import io.weaviate.client.v1.graphql.query.argument.NearTextArgument; +import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument; +import io.weaviate.client.v1.graphql.query.argument.Targets; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.client.v1.misc.model.BQConfig; +import io.weaviate.client.v1.misc.model.PQConfig; +import io.weaviate.client.v1.misc.model.SQConfig; +import io.weaviate.client.v1.misc.model.VectorIndexConfig; +import io.weaviate.client.v1.schema.model.DataType; +import io.weaviate.client.v1.schema.model.Property; +import io.weaviate.client.v1.schema.model.WeaviateClass; +import io.weaviate.integration.client.WeaviateDockerCompose; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.ARRAY; +import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.fail; + +public class ClientGraphQLMultiTargetSearchTest extends AbstractAsyncClientTest { + private String httpHost; + private String grpcHost; + + private WeaviateClient syncClient; + private WeaviateAsyncClient client; + private GraphQL gql; + + @ClassRule + public static WeaviateDockerCompose compose = new WeaviateDockerCompose(); + + @Before + public void before() { + httpHost = compose.getHttpHostAddress(); + + Config config = new Config("http", httpHost); + config.setGRPCSecured(false); + config.setGRPCHost(grpcHost); + syncClient = new WeaviateClient(config); + + client = syncClient.async(); + gql = client.graphQL(); + } + + @After + public void after() { + client.close(); + } + + private final String id1 = "00000000-0000-0000-0000-000000000001"; + private final String id2 = "00000000-0000-0000-0000-000000000002"; + private final String id3 = "00000000-0000-0000-0000-000000000003"; + + private final String titleAndContent = "titleAndContent"; + private final String title1 = "title1"; + private final String title2 = "title2"; + private final String bringYourOwnVector = "bringYourOwnVector"; + private final String bringYourOwnVector2 = "bringYourOwnVector2"; + + @Test + public void shouldPerformMultiTargetSearch() throws InterruptedException { + String className = "MultiTargetSearch"; + setupDB(className); + Field _additional = _additional("id", "distance"); + // nearText + Map weights = new HashMap<>(); + weights.put(titleAndContent, 0.1f); + weights.put(title1, 0.6f); + weights.put(title2, 0.3f); + Targets targets = Targets.builder() + .targetVectors(new String[]{ titleAndContent, title1, title2 }) + .combinationMethod(Targets.CombinationMethod.manualWeights) + .weights(weights) + .build(); + NearTextArgument nearText = gql.arguments().nearTextArgBuilder() + .concepts(new String[]{ "Water black" }) + .targets(targets) + .build(); + Result response = doGet(get -> get + .withClassName(className) + .withNearText(nearText) + .withFields(_additional)); + assertGetContainsIds(response, className, id1, id2, id3); + // nearVector with single vector-per-target + Map vectorPerTarget = new HashMap<>(); + vectorPerTarget.put(bringYourOwnVector, new Float[]{ .99f, .88f, .77f }); + vectorPerTarget.put(bringYourOwnVector2, new Float[]{ .11f, .22f, .33f }); + weights = new HashMap() {{ + this.put(bringYourOwnVector, 0.1f); + this.put(bringYourOwnVector2, 0.6f); + }}; + targets = Targets.builder() + .targetVectors(new String[]{ bringYourOwnVector, bringYourOwnVector2 }) + .combinationMethod(Targets.CombinationMethod.manualWeights) + .weights(weights) + .build(); + final NearVectorArgument nearVector1 = gql.arguments().nearVectorArgBuilder() + .vectorPerTarget(vectorPerTarget) + .targets(targets).build(); + response = doGet(get -> get + .withClassName(className) + .withNearVector(nearVector1) + .withFields(_additional)); + assertNull("check error in response:", response.getError()); + assertGetContainsIds(response, className, id2, id3); + // nearVector with multiple vector-per-target + Map vectorsPerTarget = new HashMap<>(); + vectorsPerTarget.put(bringYourOwnVector, new Float[][]{ new Float[]{ .99f, .88f, .77f }, new Float[]{ .99f, .88f, .77f } }); + vectorsPerTarget.put(bringYourOwnVector2, new Float[][]{ new Float[]{ .11f, .22f, .33f } }); + Map weightsMulti = new HashMap<>(); + weightsMulti.put(bringYourOwnVector, new Float[]{ 0.5f, 0.5f }); + weightsMulti.put(bringYourOwnVector2, new Float[]{ 0.6f }); + targets = Targets.builder() + .targetVectors(new String[]{ bringYourOwnVector, bringYourOwnVector2 }) + .combinationMethod(Targets.CombinationMethod.manualWeights) + .weightsMulti(weightsMulti) + .build(); + final NearVectorArgument nearVector2 = gql.arguments().nearVectorArgBuilder() + .vectorsPerTarget(vectorsPerTarget) + .targets(targets).build(); + response = doGet(get -> get + .withClassName(className) + .withNearVector(nearVector2) + .withFields(_additional)); + assertNull("check error in response:", response.getError()); + assertGetContainsIds(response, className, id2, id3); + // nearObject + targets = Targets.builder() + .targetVectors(new String[]{ bringYourOwnVector, bringYourOwnVector2, titleAndContent, title1, title2 }) + .combinationMethod(Targets.CombinationMethod.average) + .build(); + NearObjectArgument nearObject = gql.arguments().nearObjectArgBuilder() + .id(id3).targets(targets).build(); + response = doGet(get -> get + .withClassName(className) + .withNearObject(nearObject) + .withFields(_additional)); + assertGetContainsIds(response, className, id2, id3); + } + + private void setupDB(String className) { + // clean + Result delete = syncClient.schema().allDeleter().run(); + assertThat(delete).isNotNull() + .returns(false, Result::hasErrors) + .returns(true, Result::getResult); + // create class + List properties = Arrays.asList( + Property.builder() + .name("title") + .dataType(Collections.singletonList(DataType.TEXT)) + .build(), + Property.builder() + .name("content") + .dataType(Collections.singletonList(DataType.TEXT)) + .build(), + Property.builder() + .name("title1") + .dataType(Collections.singletonList(DataType.TEXT)) + .build(), + Property.builder() + .name("title2") + .dataType(Collections.singletonList(DataType.TEXT)) + .build() + ); + Map vectorConfig = new HashMap<>(); + vectorConfig.put(titleAndContent, getTitleAndContentVectorConfig()); + vectorConfig.put(title1, getTitle1VectorConfig()); + vectorConfig.put(title2, getTitle2VectorConfig()); + vectorConfig.put(bringYourOwnVector, getBringYourOwnVectorVectorConfig()); + vectorConfig.put(bringYourOwnVector2, getBringYourOwnVectorVectorConfig2()); + Result createResult = syncClient.schema().classCreator() + .withClass(WeaviateClass.builder() + .className(className) + .properties(properties) + .vectorConfig(vectorConfig) + .build() + ) + .run(); + assertThat(createResult).isNotNull() + .returns(false, Result::hasErrors) + .returns(true, Result::getResult); + // add data + // obj1 + Map props1 = new HashMap<>(); + props1.put("title", "The Lord of the Rings"); + props1.put("content", "A great fantasy novel"); + props1.put("title1", "J.R.R. Tolkien The Lord of the Rings"); + props1.put("title2", "Rings"); + Float[] vector1a = new Float[]{ 0.77f, 0.88f, 0.77f }; + Map vectors1 = new HashMap<>(); + vectors1.put("bringYourOwnVector", vector1a); + // don't add vector for bringYourOwnVector2 + // obj2 + Map props2 = new HashMap<>(); + props2.put("title", "Black Oceans"); + props2.put("content", "A great science fiction book"); + props2.put("title1", "Jacek Dukaj Black Oceans"); + props2.put("title2", "Water"); + Float[] vector2a = new Float[]{ 0.11f, 0.22f, 0.33f }; + Float[] vector2b = new Float[]{ 0.11f, 0.11f, 0.11f }; + Map vectors2 = new HashMap<>(); + vectors2.put("bringYourOwnVector", vector2a); + vectors2.put("bringYourOwnVector2", vector2b); + // obj2 + Map props3 = new HashMap<>(); + props3.put("title", "Into the Water"); + props3.put("content", "New York Times bestseller and global phenomenon The Girl on the Train returns with Into the Water"); + props3.put("title1", "Paula Hawkins Into the Water"); + props3.put("title2", "Water go into it"); + Float[] vector3a = new Float[]{ 0.99f, 0.88f, 0.77f }; + Float[] vector3b = new Float[]{ 0.99f, 0.88f, 0.77f }; + Map vectors3 = new HashMap<>(); + vectors3.put("bringYourOwnVector", vector3a); + vectors3.put("bringYourOwnVector2", vector3b); + + WeaviateObject obj1 = createObject(id1, className, props1, vectors1); + WeaviateObject obj2 = createObject(id2, className, props2, vectors2); + WeaviateObject obj3 = createObject(id3, className, props3, vectors3); + + Result result = syncClient.batch().objectsBatcher() + .withObjects(obj1, obj2, obj3) + .run(); + + assertThat(result).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asInstanceOf(ARRAY) + .hasSize(3); + } + + private WeaviateClass.VectorConfig getTitleAndContentVectorConfig() { + Map titleAndContent = new HashMap<>(); + Map text2vecContextionarySettings = new HashMap<>(); + text2vecContextionarySettings.put("vectorizeClassName", false); + text2vecContextionarySettings.put("properties", new String[]{ "title", "content" }); + titleAndContent.put("text2vec-contextionary", text2vecContextionarySettings); + return getHNSWSQVectorConfig(titleAndContent); + } + + private WeaviateClass.VectorConfig getTitle1VectorConfig() { + Map titleAndContent = new HashMap<>(); + Map text2vecContextionarySettings = new HashMap<>(); + text2vecContextionarySettings.put("vectorizeClassName", false); + text2vecContextionarySettings.put("properties", new String[]{ "title1" }); + titleAndContent.put("text2vec-contextionary", text2vecContextionarySettings); + return getHNSWPQVectorConfig(titleAndContent); + } + + private WeaviateClass.VectorConfig getTitle2VectorConfig() { + Map titleAndContent = new HashMap<>(); + Map text2vecContextionarySettings = new HashMap<>(); + text2vecContextionarySettings.put("vectorizeClassName", false); + text2vecContextionarySettings.put("properties", new String[]{ "title2" }); + titleAndContent.put("text2vec-contextionary", text2vecContextionarySettings); + return getHNSWVectorConfig(titleAndContent); + } + + private WeaviateClass.VectorConfig getBringYourOwnVectorVectorConfig() { + Map byov = new HashMap<>(); + byov.put("none", new Object()); + return getFlatBQVectorConfig(byov); + } + + private WeaviateClass.VectorConfig getBringYourOwnVectorVectorConfig2() { + Map byov = new HashMap<>(); + byov.put("none", new Object()); + return getFlatVectorConfig(byov); + } + + private WeaviateClass.VectorConfig getFlatBQVectorConfig(Map vectorizerConfig) { + return WeaviateClass.VectorConfig.builder() + .vectorIndexType("flat") + .vectorizer(vectorizerConfig) + .vectorIndexConfig(VectorIndexConfig.builder() + .bq(BQConfig.builder().enabled(true).build()) + .build()) + .build(); + } + + private WeaviateClass.VectorConfig getFlatVectorConfig(Map vectorizerConfig) { + return WeaviateClass.VectorConfig.builder() + .vectorIndexType("flat") + .vectorizer(vectorizerConfig) + .build(); + } + + private WeaviateClass.VectorConfig getHNSWVectorConfig(Map vectorizerConfig) { + return WeaviateClass.VectorConfig.builder() + .vectorIndexType("hnsw") + .vectorizer(vectorizerConfig) + .build(); + } + + private WeaviateClass.VectorConfig getHNSWPQVectorConfig(Map vectorizerConfig) { + return WeaviateClass.VectorConfig.builder() + .vectorIndexType("hnsw") + .vectorizer(vectorizerConfig) + .vectorIndexConfig(VectorIndexConfig.builder() + .pq(PQConfig.builder().enabled(true).build()) + .build()) + .build(); + } + + private WeaviateClass.VectorConfig getHNSWSQVectorConfig(Map vectorizerConfig) { + return WeaviateClass.VectorConfig.builder() + .vectorIndexType("hnsw") + .vectorizer(vectorizerConfig) + .vectorIndexConfig(VectorIndexConfig.builder() + .sq(SQConfig.builder().enabled(true).build()) + .build()) + .build(); + } + + private WeaviateObject createObject(String id, String className, Map props, Map vectors) { + WeaviateObject.WeaviateObjectBuilder obj = WeaviateObject.builder() + .id(id) + .className(className) + .properties(props); + if (vectors != null) { + obj = obj.vectors(vectors); + } + return obj.build(); + } + + @SuppressWarnings("unchecked") + private void assertGetContainsIds(Result response, String className, String... expectedIds) { + assertThat(response).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull() + .extracting(GraphQLResponse::getData).isInstanceOf(Map.class) + .extracting(data -> ((Map) data).get("Get")).isInstanceOf(Map.class) + .extracting(get -> ((Map) get).get(className)).isInstanceOf(List.class).asList() + .hasSize(expectedIds.length) + .extracting(obj -> ((Map) obj).get("_additional")) + .extracting(add -> ((Map) add).get("id")) + .containsExactlyInAnyOrder((Object[]) expectedIds); + } + + private Result doGet(Consumer build) { + Get get = gql.get(); + build.accept(get); + try { + return get.run().get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.get(): " + e.getMessage()); + return null; + } + } +} diff --git a/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLMultiTenancyTest.java b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLMultiTenancyTest.java new file mode 100644 index 00000000..1e5aed31 --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLMultiTenancyTest.java @@ -0,0 +1,269 @@ +package io.weaviate.integration.client.async.graphql; + +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.async.graphql.GraphQL; +import io.weaviate.client.v1.async.graphql.api.Aggregate; +import io.weaviate.client.v1.async.graphql.api.Get; +import io.weaviate.client.v1.filters.Operator; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.integration.client.WeaviateDockerCompose; +import io.weaviate.integration.client.WeaviateTestGenerics; +import org.assertj.core.api.AbstractObjectAssert; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; + +import static io.weaviate.integration.client.WeaviateTestGenerics.TENANT_1; +import static io.weaviate.integration.client.WeaviateTestGenerics.TENANT_2; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +public class ClientGraphQLMultiTenancyTest extends AbstractAsyncClientTest { + private static final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); + private String address; + + private WeaviateClient syncClient; + private WeaviateAsyncClient client; + private GraphQL gql; + + @ClassRule + public static WeaviateDockerCompose compose = new WeaviateDockerCompose(); + + @Before + public void before() { + address = compose.getHttpHostAddress(); + + syncClient = new WeaviateClient(new Config("http", address)); + + testGenerics.createSchemaPizzaForTenants(syncClient); + testGenerics.createTenantsPizza(syncClient, TENANT_1, TENANT_2); + testGenerics.createDataPizzaQuattroFormaggiForTenants(syncClient, TENANT_1.getName()); + testGenerics.createDataPizzaFruttiDiMareForTenants(syncClient, TENANT_1.getName()); + testGenerics.createDataPizzaHawaiiForTenants(syncClient, TENANT_2.getName()); + testGenerics.createDataPizzaDoenerForTenants(syncClient, TENANT_2.getName()); + + client = syncClient.async(); + gql = client.graphQL(); + } + + @After + public void after() { + testGenerics.cleanupWeaviate(syncClient); + client.close(); + } + + @Test + public void shouldGetAllDataForTenant() { + Map expectedIdsByTenant = new HashMap<>(); + expectedIdsByTenant.put(TENANT_1.getName(), new String[]{ + WeaviateTestGenerics.PIZZA_QUATTRO_FORMAGGI_ID, + WeaviateTestGenerics.PIZZA_FRUTTI_DI_MARE_ID, + }); + expectedIdsByTenant.put(TENANT_2.getName(), new String[]{ + WeaviateTestGenerics.PIZZA_HAWAII_ID, + WeaviateTestGenerics.PIZZA_DOENER_ID, + }); + + expectedIdsByTenant.forEach((tenant, expectedIds) -> { + Result response = doGet(get -> get + .withTenant(tenant) + .withClassName("Pizza") + .withFields(_additional("id"))); + + assertGetContainsIds(response, "Pizza", expectedIds); + }); + } + + @Test + public void shouldGetLimitedDataForTenant() { + Map expectedIdsByTenant = new HashMap<>(); + expectedIdsByTenant.put(TENANT_1.getName(), new String[]{ + WeaviateTestGenerics.PIZZA_QUATTRO_FORMAGGI_ID, + }); + expectedIdsByTenant.put(TENANT_2.getName(), new String[]{ + WeaviateTestGenerics.PIZZA_HAWAII_ID, + }); + + expectedIdsByTenant.forEach((tenant, expectedIds) -> { + Result response = doGet(get -> get + .withTenant(tenant) + .withClassName("Pizza") + .withLimit(1) + .withFields(_additional("id"))); + + assertGetContainsIds(response, "Pizza", expectedIds); + }); + } + + @Test + public void shouldGetFilteredDataForTenant() { + Map expectedIdsByTenant = new HashMap<>(); + expectedIdsByTenant.put(TENANT_1.getName(), new String[]{ + WeaviateTestGenerics.PIZZA_FRUTTI_DI_MARE_ID, + }); + expectedIdsByTenant.put(TENANT_2.getName(), new String[]{ + }); + + expectedIdsByTenant.forEach((tenant, expectedIds) -> { + Result response = doGet(get -> get + .withTenant(tenant) + .withClassName("Pizza") + .withWhere(whereNumber("price", Operator.GreaterThan, 2.0d)) + .withFields(_additional("id"))); + + assertGetContainsIds(response, "Pizza", expectedIds); + }); + } + + @Test + public void shouldAggregateAllDataForTenant() { + Map> expectedAggValuesByTenant = new HashMap<>(); + expectedAggValuesByTenant.put(TENANT_1.getName(), new HashMap() {{ + put("count", 2.0); + put("maximum", 2.5); + put("minimum", 1.4); + put("median", 1.95); + put("mean", 1.95); + put("mode", 1.4); + put("sum", 3.9); + }}); + expectedAggValuesByTenant.put(TENANT_2.getName(), new HashMap() {{ + put("count", 2.0); + put("maximum", 1.2); + put("minimum", 1.1); + put("median", 1.15); + put("mean", 1.15); + put("mode", 1.1); + put("sum", 2.3); + }}); + + expectedAggValuesByTenant.forEach((tenant, expectedAggValues) -> { + Result response = doAggregate(aggregate -> aggregate + .withTenant(tenant) + .withClassName("Pizza") + .withFields(Field.builder() + .name("price") + .fields(fields( + "count", + "maximum", + "minimum", + "median", + "mean", + "mode", + "sum" + )).build())); + + assertAggregateNumFieldHasValues(response, "Pizza", "price", expectedAggValues); + }); + } + + @Test + public void shouldAggregateFilteredDataForTenant() { + Map> expectedAggValuesByTenant = new HashMap<>(); + expectedAggValuesByTenant.put(TENANT_1.getName(), new HashMap() {{ + put("count", 1.0); + put("maximum", 2.5); + put("minimum", 2.5); + put("median", 2.5); + put("mean", 2.5); + put("mode", 2.5); + put("sum", 2.5); + }}); + expectedAggValuesByTenant.put(TENANT_2.getName(), new HashMap() {{ + put("count", 0.0); + put("maximum", null); + put("minimum", null); + put("median", null); + put("mean", null); + put("mode", null); + put("sum", null); + }}); + + expectedAggValuesByTenant.forEach((tenant, expectedAggValues) -> { + Result response = doAggregate(aggregate -> aggregate + .withTenant(tenant) + .withClassName("Pizza") + .withWhere(whereNumber("price", Operator.GreaterThan, 2.0d)) + .withFields(Field.builder() + .name("price") + .fields(fields( + "count", + "maximum", + "minimum", + "median", + "mean", + "mode", + "sum" + )).build())); + + assertAggregateNumFieldHasValues(response, "Pizza", "price", expectedAggValues); + }); + } + + @SuppressWarnings("unchecked") + private void assertGetContainsIds(Result response, String className, String... expectedIds) { + assertThat(response).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull() + .extracting(GraphQLResponse::getData).isInstanceOf(Map.class) + .extracting(data -> ((Map) data).get("Get")).isInstanceOf(Map.class) + .extracting(get -> ((Map) get).get(className)).isInstanceOf(List.class).asList() + .hasSize(expectedIds.length) + .extracting(obj -> ((Map) obj).get("_additional")) + .extracting(add -> ((Map) add).get("id")) + .containsExactlyInAnyOrder((Object[]) expectedIds); + } + + @SuppressWarnings("unchecked") + private void assertAggregateNumFieldHasValues( + Result response, String className, String fieldName, + Map expectedAggValues + ) { + AbstractObjectAssert aggregate = assertThat(response).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull() + .extracting(GraphQLResponse::getData).isInstanceOf(Map.class) + .extracting(data -> ((Map) data).get("Aggregate")).isInstanceOf(Map.class) + .extracting(agg -> ((Map) agg).get(className)).isInstanceOf(List.class).asList() + .hasSize(1) + .first() + .extracting(obj -> ((Map) obj).get(fieldName)).isInstanceOf(Map.class); + + expectedAggValues.forEach((name, value) -> aggregate.returns(value, map -> ((Map) map).get(name))); + } + + private Result doGet(Consumer build) { + Get get = gql.get(); + build.accept(get); + try { + return get.run() + .get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.get(): " + e.getMessage()); + return null; + } + } + + private Result doAggregate(Consumer build) { + Aggregate aggregate = gql.aggregate(); + build.accept(aggregate); + try { + return aggregate.run() + .get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.aggregate(): " + e.getMessage()); + return null; + } + } +} diff --git a/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLTest.java b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLTest.java new file mode 100644 index 00000000..1346ab15 --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLTest.java @@ -0,0 +1,1147 @@ +package io.weaviate.integration.client.async.graphql; + +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.async.graphql.GraphQL; +import io.weaviate.client.v1.async.graphql.api.Aggregate; +import io.weaviate.client.v1.async.graphql.api.Explore; +import io.weaviate.client.v1.async.graphql.api.Get; +import io.weaviate.client.v1.async.graphql.api.Raw; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.client.v1.data.model.WeaviateObject; +import io.weaviate.client.v1.filters.Operator; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.ExploreFields; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.query.argument.*; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.client.v1.schema.model.DataType; +import io.weaviate.client.v1.schema.model.Property; +import io.weaviate.client.v1.schema.model.WeaviateClass; +import io.weaviate.integration.client.WeaviateDockerCompose; +import io.weaviate.integration.client.WeaviateTestGenerics; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; + +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeTrue; +import static org.junit.jupiter.api.Assertions.*; + +public class ClientGraphQLTest extends AbstractAsyncClientTest { + private final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); + private final WeaviateTestGenerics.DocumentPassageSchema passageSchema = new WeaviateTestGenerics.DocumentPassageSchema(); + + private String address; + private WeaviateClient syncClient; + private WeaviateAsyncClient client; + private GraphQL gql; + + @ClassRule + public static WeaviateDockerCompose compose = new WeaviateDockerCompose(); + + @Before + public void before() { + address = compose.getHttpHostAddress(); + + syncClient = new WeaviateClient(new Config("http", address)); + testGenerics.createTestSchemaAndData(syncClient); + + client = syncClient.async(); + gql = client.graphQL(); + } + + @After + public void after() { + testGenerics.cleanupWeaviate(syncClient); + client.close(); + } + + @Test + public void testGraphQLGet() { + Result result = doGet(get -> get.withClassName("Pizza") + .withFields(field("name"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(4, pizzas.size(), "wrong number of pizzas returned"); + } + + @Test + public void testGraphQLRaw() { + String query = "{Get{Pizza{_additional{id}}}}"; + + Result result = doRaw(raw -> raw.withQuery(query)); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(4, pizzas.size(), "wrong number of pizzas returned"); + } + + @Test + public void testGraphQLGetWithNearObjectAndCertainty() { + String newObjID = "6baed48e-2afe-4be4-a09d-b00a955d962b"; + NearObjectArgument nearObjectArgument = gql.arguments() + .nearObjectArgBuilder() + .id(newObjID) + .certainty(0.99f) + .build(); + + WeaviateObject soupWithID = WeaviateObject.builder() + .className("Soup") + .id(newObjID) + .properties(new HashMap() { + { + put("name", "JustSoup"); + put("description", "soup with id"); + } + }) + .build(); + + // Insert additional test data + Result insert = syncClient.batch() + .objectsBatcher() + .withObjects(soupWithID) + .run(); + assumeTrue("all test objects inserted successfully", insert.getResult().length == 1); + + Result result = doGet(get -> get.withClassName("Soup") + .withNearObject(nearObjectArgument) + .withFields(field("name"), _additional("certainty"))); + + List soups = extractClass(result, "Get", "Soup"); + assertEquals(1, soups.size(), "wrong number of soups"); + } + + @Test + public void testGraphQLGetWithNearObjectAndDistance() { + String newObjID = "6baed48e-2afe-4be4-a09d-b00a955d962b"; + NearObjectArgument nearObjectArgument = gql.arguments() + .nearObjectArgBuilder() + .id(newObjID) + .distance(0.1f) + .build(); + + WeaviateObject soupWithID = WeaviateObject.builder() + .className("Soup") + .id(newObjID) + .properties(new HashMap() { + { + put("name", "JustSoup"); + put("description", "soup with id"); + } + }) + .build(); + + // Insert additional test data + syncClient.batch() + .objectsBatcher() + .withObjects(soupWithID) + .run(); + + Result result = doGet(get -> get.withClassName("Soup") + .withNearObject(nearObjectArgument) + .withFields(field("name"), _additional("distance"))); + + List soups = extractClass(result, "Get", "Soup"); + assertEquals(1, soups.size(), "wrong number of soups"); + } + + @Test + @SuppressWarnings("unchecked") + public void testBm25() { + Bm25Argument bm25 = gql.arguments() + .bm25ArgBuilder() + .query("innovation") + .properties(new String[]{ "description" }) + .build(); + + Result result = doGet(get -> get.withClassName("Pizza") + .withBm25(bm25) + .withFields(field("description"), _additional("id", "distance"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(1, pizzas.size(), "wrong number of pizzas"); + + Map pizza = (Map) pizzas.get(0); + assertTrue(((String) pizza.get("description")).contains("innovation"), "wrong Pizza description"); + } + + @Test + public void testHybrid() { + HybridArgument hybrid = gql.arguments() + .hybridArgBuilder() + .query("some say revolution") + .alpha(0.8f) + .build(); + + Result result = doGet(get -> get.withClassName("Pizza") + .withHybrid(hybrid) + .withFields(field("description"), _additional("id"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertFalse(pizzas.isEmpty(), "didn't get any pizzas"); + } + + @Test + public void testGraphQLGetWithNearTextAndCertainty() { + NearTextMoveParameters moveAway = NearTextMoveParameters.builder() + .concepts(new String[]{ "Universally" }) + .force(0.8f) + .build(); + NearTextArgument nearText = gql.arguments() + .nearTextArgBuilder() + .concepts(new String[]{ "some say revolution" }) + .moveAwayFrom(moveAway) + .certainty(0.8f) + .build(); + + Result result = doGet(get -> get.withClassName("Pizza") + .withNearText(nearText) + .withFields(field("name"), _additional("certainty"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(1, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetWithNearTextAndDistance() { + NearTextMoveParameters moveAway = NearTextMoveParameters.builder() + .concepts(new String[]{ "Universally" }) + .force(0.8f) + .build(); + NearTextArgument nearText = gql.arguments() + .nearTextArgBuilder() + .concepts(new String[]{ "some say revolution" }) + .moveAwayFrom(moveAway) + .distance(0.4f) + .build(); + + Result result = doGet(get -> get.withClassName("Pizza") + .withNearText(nearText) + .withFields(field("name"), _additional("distance"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(1, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetWithNearTextAndMoveParamsAndCertainty() { + String newObjID1 = "6baed48e-2afe-4be4-a09d-b00a955d962b"; + String newObjID2 = "6baed48e-2afe-4be4-a09d-b00a955d962a"; + WeaviateObject pizzaWithID = WeaviateObject.builder() + .className("Pizza") + .id(newObjID1) + .properties(new HashMap() { + { + put("name", "JustPizza1"); + put("description", "Universally pizza with id"); + } + }) + .build(); + WeaviateObject pizzaWithID2 = WeaviateObject.builder() + .className("Pizza") + .id(newObjID2) + .properties(new HashMap() { + { + put("name", "JustPizza2"); + put("description", "Universally pizza with some other id"); + } + }) + .build(); + + NearTextMoveParameters moveAway = NearTextMoveParameters.builder() + .objects(new NearTextMoveParameters.ObjectMove[]{ NearTextMoveParameters.ObjectMove.builder() + .id(newObjID1).build() + }) + .force(0.9f) + .build(); + NearTextMoveParameters moveTo = NearTextMoveParameters.builder() + .objects(new NearTextMoveParameters.ObjectMove[]{ NearTextMoveParameters.ObjectMove.builder() + .id(newObjID2).build() + }) + .force(0.9f) + .build(); + NearTextArgument nearText = gql.arguments() + .nearTextArgBuilder() + .concepts(new String[]{ "Universally pizza with id" }) + .moveAwayFrom(moveAway) + .moveTo(moveTo) + .certainty(0.4f) + .build(); + + // Insert additional test data + Result insert = syncClient.batch() + .objectsBatcher() + .withObjects(pizzaWithID, pizzaWithID2) + .run(); + assumeTrue("all test objects inserted successfully", insert.getResult().length == 2); + + Result result = doGet(get -> get.withClassName("Pizza") + .withNearText(nearText) + .withFields(field("name"), _additional("certainty"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(6, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetWithNearTextAndMoveParamsAndDistance() { + String newObjID1 = "6baed48e-2afe-4be4-a09d-b00a955d962b"; + String newObjID2 = "6baed48e-2afe-4be4-a09d-b00a955d962a"; + WeaviateObject pizzaWithID = WeaviateObject.builder() + .className("Pizza") + .id(newObjID1) + .properties(new HashMap() { + { + put("name", "JustPizza1"); + put("description", "Universally pizza with id"); + } + }) + .build(); + WeaviateObject pizzaWithID2 = WeaviateObject.builder() + .className("Pizza") + .id(newObjID2) + .properties(new HashMap() { + { + put("name", "JustPizza2"); + put("description", "Universally pizza with some other id"); + } + }) + .build(); + + NearTextMoveParameters moveAway = NearTextMoveParameters.builder() + .objects(new NearTextMoveParameters.ObjectMove[]{ NearTextMoveParameters.ObjectMove.builder() + .id(newObjID1).build() + }) + .force(0.9f) + .build(); + NearTextMoveParameters moveTo = NearTextMoveParameters.builder() + .objects(new NearTextMoveParameters.ObjectMove[]{ NearTextMoveParameters.ObjectMove.builder() + .id(newObjID2).build() + }) + .force(0.9f) + .build(); + NearTextArgument nearText = gql.arguments() + .nearTextArgBuilder() + .concepts(new String[]{ "Universally pizza with id" }) + .moveAwayFrom(moveAway) + .moveTo(moveTo) + .distance(0.6f) + .build(); + + // Insert additional test data + Result insert = syncClient.batch() + .objectsBatcher() + .withObjects(pizzaWithID, pizzaWithID2) + .run(); + assumeTrue("all test objects inserted successfully", insert.getResult().length == 2); + + Result result = doGet(get -> get.withClassName("Pizza") + .withNearText(nearText) + .withFields(field("name"), _additional("distance"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(6, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetWithNearTextAndLimitAndCertainty() { + NearTextArgument nearText = gql.arguments() + .nearTextArgBuilder() + .concepts(new String[]{ "some say revolution" }) + .certainty(0.8f) + .build(); + + Result result = doGet(get -> get.withClassName("Pizza") + .withNearText(nearText) + .withLimit(1) + .withFields(field("name"), _additional("certainty"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(1, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetWithNearTextAndLimitAndDistance() { + NearTextArgument nearText = gql.arguments() + .nearTextArgBuilder() + .concepts(new String[]{ "some say revolution" }) + .distance(0.4f) + .build(); + + Result result = doGet(get -> get.withClassName("Pizza") + .withNearText(nearText) + .withLimit(1) + .withFields(field("name"), _additional("distance"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(1, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetWithWhereByFieldTokenizedProperty() { + Field name = field("name"); + WhereArgument whereFullString = whereText("name", Operator.Equal, "Frutti di Mare"); + WhereArgument wherePartString = whereText("name", Operator.Equal, "Frutti"); + WhereArgument whereFullText = whereText("description", Operator.Equal, "Universally accepted to be the best pizza ever created."); + WhereArgument wherePartText = whereText("description", Operator.Equal, "Universally"); + // when + Result resultFullString = doGet(get -> get.withWhere(whereFullString) + .withClassName("Pizza") + .withFields(name)); + Result resultPartString = doGet(get -> get.withWhere(wherePartString) + .withClassName("Pizza") + .withFields(name)); + Result resultFullText = doGet(get -> get.withWhere(whereFullText) + .withClassName("Pizza") + .withFields(name)); + Result resultPartText = doGet(get -> get.withWhere(wherePartText) + .withClassName("Pizza") + .withFields(name)); + // then + assertWhereResultSize(1, resultFullString, "Pizza"); + assertWhereResultSize(0, resultPartString, "Pizza"); + assertWhereResultSize(1, resultFullText, "Pizza"); + assertWhereResultSize(1, resultPartText, "Pizza"); + } + + @Test + public void shouldSupportDeprecatedValueString() { + WhereArgument whereString = whereText("name", Operator.Equal, "Frutti di Mare"); + + Result result = doGet(get -> get.withWhere(whereString) + .withClassName("Pizza") + .withFields(field("name"))); + + assertWhereResultSize(1, result, "Pizza"); + } + + @Test + public void testGraphQLGetWithWhereByDate() { + Calendar cal = Calendar.getInstance(); + cal.set(2022, Calendar.FEBRUARY, 1, 0, 0, 0); + WhereArgument whereDate = whereDate("bestBefore", Operator.GreaterThan, cal.getTime()); + + Result resultDate = doGet(get -> get.withWhere(whereDate) + .withClassName("Pizza") + .withFields(field("name"))); + + List> result = extractClass(resultDate, "Get", "Pizza"); + Assertions.assertThat(result) + .hasSize(3) + .extracting(el -> (String) el.get("name")) + .contains("Frutti di Mare", "Hawaii", "Doener"); + } + + @Test + public void testGraphQLExploreWithCertainty() { + ExploreFields[] fields = new ExploreFields[]{ ExploreFields.CERTAINTY, ExploreFields.BEACON, ExploreFields.CLASS_NAME }; + String[] concepts = new String[]{ "pineapple slices", "ham" }; + NearTextMoveParameters moveTo = gql.arguments() + .nearTextMoveParameterBuilder() + .concepts(new String[]{ "Pizza" }) + .force(0.3f) + .build(); + NearTextMoveParameters moveAwayFrom = gql.arguments() + .nearTextMoveParameterBuilder() + .concepts(new String[]{ "toast", "bread" }) + .force(0.4f) + .build(); + NearTextArgument withNearText = gql.arguments() + .nearTextArgBuilder() + .concepts(concepts) + .certainty(0.40f) + .moveTo(moveTo) + .moveAwayFrom(moveAwayFrom) + .build(); + + Result result = doExplore(explore -> explore.withFields(fields) + .withNearText(withNearText)); + + List got = extractQueryResult(result, "Explore"); + assertEquals(6, got.size()); + } + + @Test + public void testGraphQLExploreWithDistance() { + ExploreFields[] fields = new ExploreFields[]{ ExploreFields.CERTAINTY, ExploreFields.BEACON, ExploreFields.CLASS_NAME }; + String[] concepts = new String[]{ "pineapple slices", "ham" }; + NearTextMoveParameters moveTo = gql.arguments() + .nearTextMoveParameterBuilder() + .concepts(new String[]{ "Pizza" }) + .force(0.3f) + .build(); + NearTextMoveParameters moveAwayFrom = gql.arguments() + .nearTextMoveParameterBuilder() + .concepts(new String[]{ "toast", "bread" }) + .force(0.4f) + .build(); + NearTextArgument withNearText = gql.arguments() + .nearTextArgBuilder() + .concepts(concepts) + .distance(0.80f) + .moveTo(moveTo) + .moveAwayFrom(moveAwayFrom) + .build(); + + Result result = doExplore(explore -> explore.withFields(fields) + .withNearText(withNearText)); + + List got = extractQueryResult(result, "Explore"); + assertEquals(6, got.size()); + } + + @Test + public void testGraphQLAggregate() { + Result result = doAggregate(aggregate -> aggregate.withFields(meta("count")) + .withClassName("Pizza")); + + assertAggregateMetaCount(result, "Pizza", 1, 4.0d); + } + + @Test + public void testGraphQLAggregateWithWhereFilter() { + String newObjID = "6baed48e-2afe-4be4-a09d-b00a955d96ee"; + WeaviateObject pizzaWithID = WeaviateObject.builder() + .className("Pizza") + .id(newObjID) + .properties(new HashMap() { + { + put("name", "JustPizza"); + put("description", "pizza with id"); + } + }) + .build(); + + // Insert additional test data + Result insert = syncClient.batch() + .objectsBatcher() + .withObjects(pizzaWithID) + .run(); + assumeTrue("all test objects inserted successfully", insert.getResult().length == 1); + + Result result = doAggregate(aggregate -> aggregate.withFields(meta("count")) + .withClassName("Pizza") + .withWhere(whereText("id", Operator.Equal, newObjID))); + + assertAggregateMetaCount(result, "Pizza", 1, 1.0d); + } + + @Test + public void testGraphQLAggregateWithGroupedByAndWhere() { + // given + String newObjID = "6baed48e-2afe-4be4-a09d-b00a955d96ee"; + WeaviateObject pizzaWithID = WeaviateObject.builder() + .className("Pizza") + .id(newObjID) + .properties(new HashMap() { + { + put("name", "JustPizza"); + put("description", "pizza with id"); + } + }) + .build(); + + // Insert additional test objects + Result insert = syncClient.batch() + .objectsBatcher() + .withObjects(pizzaWithID) + .run(); + assumeTrue("all test objects inserted successfully", insert.getResult().length == 1); + + Result result = doAggregate(aggregate -> aggregate.withFields(meta("count")) + .withClassName("Pizza") + .withGroupBy("name") + .withWhere(whereText("id", Operator.Equal, newObjID))); + + assertAggregateMetaCount(result, "Pizza", 1, 1.0d); + } + + @Test + public void testGraphQLAggregateWithGroupedBy() { + String newObjID = "6baed48e-2afe-4be4-a09d-b00a955d96ee"; + WeaviateObject pizzaWithID = WeaviateObject.builder() + .className("Pizza") + .id(newObjID) + .properties(new HashMap() { + { + put("name", "JustPizza"); + put("description", "pizza with id"); + } + }) + .build(); + + // Insert additional test data + Result insert = syncClient.batch() + .objectsBatcher() + .withObjects(pizzaWithID) + .run(); + assumeTrue("all test objects inserted successfully", insert.getResult().length == 1); + + Result result = doAggregate(aggregate -> aggregate.withClassName("Pizza") + .withFields(meta("count")) + .withGroupBy("name")); + + assertAggregateMetaCount(result, "Pizza", 5, 1.0d); + } + + @Test + public void testGraphQLAggregateWithNearVector() { + Result getVector = doGet(get -> get.withClassName("Pizza") + .withFields(_additional("vector"))); + Float[] vector = extractVector(getVector, "Get", "Pizza"); + NearVectorArgument nearVector = NearVectorArgument.builder() + .certainty(0.7f) + .vector(vector) + .build(); + + Result result = doAggregate(aggregate -> aggregate.withClassName("Pizza") + .withFields(meta("count")) + .withNearVector(nearVector)); + + assertAggregateMetaCount(result, "Pizza", 1, 4.0d); + } + + @Test + public void testGraphQLAggregateWithNearObjectAndCertainty() { + Result getId = doGet(get -> get.withClassName("Pizza") + .withFields(_additional("id"))); + String id = extractAdditional(getId, "Get", "Pizza", "id"); + + // when + NearObjectArgument nearObject = NearObjectArgument.builder() + .certainty(0.7f) + .id(id) + .build(); + Result result = doAggregate(aggregate -> aggregate.withClassName("Pizza") + .withFields(meta("count")) + .withNearObject(nearObject)); + + assertAggregateMetaCount(result, "Pizza", 1, 4.0d); + } + + @Test + public void testGraphQLAggregateWithNearObjectAndDistance() { + Result getId = doGet(get -> get.withClassName("Pizza") + .withFields(_additional("id"))); + String id = extractAdditional(getId, "Get", "Pizza", "id"); + + NearObjectArgument nearObject = NearObjectArgument.builder() + .distance(0.3f) + .id(id) + .build(); + Result result = doAggregate(aggregate -> aggregate.withFields(meta("count")) + .withClassName("Pizza") + .withNearObject(nearObject)); + + assertAggregateMetaCount(result, "Pizza", 1, 4.0d); + } + + @Test + public void testGraphQLAggregateWithNearTextAndCertainty() { + NearTextArgument nearText = NearTextArgument.builder() + .certainty(0.7f) + .concepts(new String[]{ "pizza" }) + .build(); + + Result result = doAggregate(aggregate -> aggregate.withClassName("Pizza") + .withFields(meta("count")) + .withNearText(nearText)); + + assertAggregateMetaCount(result, "Pizza", 1, 4.0d); + } + + @Test + public void testGraphQLAggregateWithNearTextAndDistance() { + NearTextArgument nearText = NearTextArgument.builder() + .distance(0.6f) + .concepts(new String[]{ "pizza" }) + .build(); + + Result result = doAggregate(aggregate -> aggregate.withClassName("Pizza") + .withFields(meta("count")) + .withNearText(nearText)); + + assertAggregateMetaCount(result, "Pizza", 1, 4.0d); + } + + @Test + public void testGraphQLAggregateWithObjectLimitAndCertainty() { + int limit = 1; + NearTextArgument nearText = NearTextArgument.builder() + .certainty(0.7f) + .concepts(new String[]{ "pizza" }) + .build(); + + Result result = doAggregate(aggregate -> aggregate.withClassName("Pizza") + .withFields(meta("count")) + .withNearText(nearText) + .withObjectLimit(limit)); + + assertAggregateMetaCount(result, "Pizza", 1, (double) limit); + } + + @Test + public void testGraphQLAggregateWithObjectLimitAndDistance() { + int limit = 1; + NearTextArgument nearText = NearTextArgument.builder() + .distance(0.3f) + .concepts(new String[]{ "pizza" }) + .build(); + + Result result = doAggregate(aggregate -> aggregate.withClassName("Pizza") + .withFields(meta("count")) + .withNearText(nearText) + .withObjectLimit(limit)); + + assertAggregateMetaCount(result, "Pizza", 1, (double) limit); + } + + @Test + public void testGraphQLGetWithGroup() { + GroupArgument group = gql.arguments() + .groupArgBuilder() + .type(GroupType.merge) + .force(1.0f) + .build(); + + Result result = doGet(get -> get.withClassName("Soup") + .withFields(field("name")) + .withGroup(group) + .withLimit(7)); + + List soups = extractClass(result, "Get", "Soup"); + assertEquals(1, soups.size(), "wrong number of soups"); + } + + @Test + public void testGraphQLGetWithSort() { + SortArgument byNameDesc = sort(SortOrder.desc, "name"); + String[] expectedByNameDesc = new String[]{ "Quattro Formaggi", "Hawaii", "Frutti di Mare", "Doener" }; + + SortArgument byPriceAsc = sort(SortOrder.asc, "price"); + String[] expectedByPriceAsc = new String[]{ "Hawaii", "Doener", "Quattro Formaggi", "Frutti di Mare" }; + + Field name = field("name"); + + Result resultByNameDesc = doGet(get -> get.withClassName("Pizza") + .withSort(byNameDesc) + .withFields(name)); + Result resultByDescriptionAsc = doGet(get -> get.withClassName("Pizza") + .withSort(byPriceAsc) + .withFields(name)); + Result resultByNameDescByPriceAsc = doGet(get -> get.withClassName("Pizza") + .withSort(byNameDesc, byPriceAsc) + .withFields(name)); + + assertObjectNamesEqual(resultByNameDesc, "Get", "Pizza", expectedByNameDesc); + assertObjectNamesEqual(resultByDescriptionAsc, "Get", "Pizza", expectedByPriceAsc); + assertObjectNamesEqual(resultByNameDescByPriceAsc, "Get", "Pizza", expectedByNameDesc); + } + + @Test + public void testGraphQLGetWithTimestampFilters() { + Field additional = _additional("id", "creationTimeUnix", "lastUpdateTimeUnix"); + Result expected = doGet(get -> get.withClassName("Pizza") + .withFields(additional)); + + String expectedCreateTime = extractAdditional(expected, "Get", "Pizza", "creationTimeUnix"); + String expectedUpdateTime = extractAdditional(expected, "Get", "Pizza", "lastUpdateTimeUnix"); + + Result createTimeResult = doGet(get -> get.withClassName("Pizza") + .withWhere(whereText("_creationTimeUnix", Operator.Equal, expectedCreateTime)) + .withFields(additional)); + Result updateTimeResult = doGet(get -> get.withClassName("Pizza") + .withWhere(whereText("_lastUpdateTimeUnix", Operator.Equal, expectedCreateTime)) + .withFields(additional)); + + String resultCreateTime = extractAdditional(createTimeResult, "Get", "Pizza", "creationTimeUnix"); + assertEquals(expectedCreateTime, resultCreateTime); + + String resultUpdateTime = extractAdditional(updateTimeResult, "Get", "Pizza", "lastUpdateTimeUnix"); + assertEquals(expectedUpdateTime, resultUpdateTime); + } + + @Test + public void testGraphQLGetUsingCursorAPI() { + Result result = doGet(get -> get.withClassName("Pizza") + .withAfter("00000000-0000-0000-0000-000000000000") + .withLimit(10) + .withFields(field("name"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(3, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetUsingLimitAndOffset() { + Result result = doGet(get -> get.withClassName("Pizza") + .withOffset(3) + .withLimit(4) + .withFields(field("name"))); + + List pizzas = extractClass(result, "Get", "Pizza"); + assertEquals(1, pizzas.size(), "wrong number of pizzas"); + } + + @Test + public void testGraphQLGetWithGroupBy() { + Field[] hits = new Field[]{ Field.builder() + .name("ofDocument") + .fields(new Field[]{ Field.builder() + .name("... on Document") + .fields(new Field[]{ Field.builder() + .name("_additional{id}").build() }).build() + }).build(), Field.builder() + .name("_additional{id distance}").build(), + }; + + Field group = Field.builder() + .name("group") + .fields(new Field[]{ Field.builder() + .name("id").build(), Field.builder() + .name("groupedBy") + .fields(new Field[]{ Field.builder() + .name("value").build(), Field.builder() + .name("path").build(), + }).build(), Field.builder() + .name("count").build(), Field.builder() + .name("maxDistance").build(), Field.builder() + .name("minDistance").build(), Field.builder() + .name("hits") + .fields(hits).build(), + }) + .build(); + + Field _additional = Field.builder() + .name("_additional") + .fields(new Field[]{ group }) + .build(); + Field ofDocument = Field.builder() + .name("ofDocument{__typename}") + .build(); // Property that we group by + + GroupByArgument groupBy = client.graphQL() + .arguments() + .groupByArgBuilder() + .path(new String[]{ "ofDocument" }) + .groups(3) + .objectsPerGroup(10) + .build(); + NearObjectArgument nearObject = client.graphQL() + .arguments() + .nearObjectArgBuilder() + .id("00000000-0000-0000-0000-000000000001") + .build(); + + passageSchema.createAndInsertData(syncClient); + + try { + Result result = doGet(get -> get.withClassName(passageSchema.PASSAGE) + .withNearObject(nearObject) + .withGroupBy(groupBy) + .withFields(ofDocument, _additional)); + + List> passages = extractClass(result, "Get", passageSchema.PASSAGE); + assertEquals(3, passages.size(), "wrong number of passages"); + + // This part of assertions is almost verbatim from package io.weaviate.integration.client.graphql.ClientGraphQLTest + // because it involves a lot of inner classes that we don't won't to redefine here. + List groups = getGroups(passages); + Assertions.assertThat(groups) + .isNotNull() + .hasSize(3); + for (int i = 0; i < 3; i++) { + Assertions.assertThat(groups.get(i).minDistance) + .isEqualTo(groups.get(i) + .getHits() + .get(0) + .get_additional() + .getDistance()); + Assertions.assertThat(groups.get(i).maxDistance) + .isEqualTo(groups.get(i) + .getHits() + .get(groups.get(i) + .getHits() + .size() - 1) + .get_additional() + .getDistance()); + } + checkGroupElements(expectedHitsA, groups.get(0) + .getHits()); + checkGroupElements(expectedHitsB, groups.get(1) + .getHits()); + } finally { + passageSchema.cleanupWeaviate(syncClient); + } + } + + @Test + public void testGraphQLGetWithGroupByWithHybrid() { + Field[] hits = new Field[]{ Field.builder() + .name("content").build(), Field.builder() + .name("_additional{id distance}").build(), + }; + Field group = Field.builder() + .name("group") + .fields(new Field[]{ Field.builder() + .name("id").build(), Field.builder() + .name("groupedBy") + .fields(new Field[]{ Field.builder() + .name("value").build(), Field.builder() + .name("path").build(), + }).build(), Field.builder() + .name("count").build(), Field.builder() + .name("maxDistance").build(), Field.builder() + .name("minDistance").build(), Field.builder() + .name("hits") + .fields(hits).build(), + }) + .build(); + Field _additional = Field.builder() + .name("_additional") + .fields(new Field[]{ group }) + .build(); + Field content = Field.builder() + .name("content") + .build(); // Property that we group by + GroupByArgument groupBy = client.graphQL() + .arguments() + .groupByArgBuilder() + .path(new String[]{ "content" }) + .groups(3) + .objectsPerGroup(10) + .build(); + + NearTextArgument nearText = NearTextArgument.builder() + .concepts(new String[]{ "Passage content 2" }) + .build(); + HybridArgument hybrid = HybridArgument.builder() + .searches(HybridArgument.Searches.builder() + .nearText(nearText) + .build()) + .query("Passage content 2") + .alpha(0.9f) + .build(); + + passageSchema.createAndInsertData(syncClient); + + try { + Result groupByResult = doGet(get -> get.withClassName(passageSchema.PASSAGE) + .withHybrid(hybrid) + .withGroupBy(groupBy) + .withFields(content, _additional)); + + List> result = extractClass(groupByResult, "Get", passageSchema.PASSAGE); + Assertions.assertThat(result) + .isNotNull() + .hasSize(3); + List groups = getGroups(result); + Assertions.assertThat(groups) + .isNotNull() + .hasSize(3); + for (int i = 0; i < 3; i++) { + if (i == 0) { + Assertions.assertThat(groups.get(i).groupedBy.value) + .isEqualTo("Passage content 2"); + } + Assertions.assertThat(groups.get(i).minDistance) + .isEqualTo(groups.get(i) + .getHits() + .get(0) + .get_additional() + .getDistance()); + Assertions.assertThat(groups.get(i).maxDistance) + .isEqualTo(groups.get(i) + .getHits() + .get(groups.get(i) + .getHits() + .size() - 1) + .get_additional() + .getDistance()); + } + } finally { + passageSchema.cleanupWeaviate(syncClient); + } + } + + @Test + public void shouldSupportSearchByUUID() { + String className = "ClassUUID"; + WeaviateClass clazz = WeaviateClass.builder() + .className(className) + .description("class with uuid properties") + .properties(Arrays.asList( + Property.builder() + .dataType(Collections.singletonList(DataType.UUID)) + .name("uuidProp") + .build(), Property.builder() + .dataType(Collections.singletonList(DataType.UUID_ARRAY)) + .name("uuidArrayProp") + .build() + )) + .build(); + + String id = "abefd256-8574-442b-9293-9205193737ee"; + Map properties = new HashMap<>(); + properties.put("uuidProp", "7aaa79d3-a564-45db-8fa8-c49e20b8a39a"); + properties.put("uuidArrayProp", new String[]{ "f70512a3-26cb-4ae4-9369-204555917f15", "9e516f40-fd54-4083-a476-f4675b2b5f92" + }); + + Result createStatus = syncClient.schema() + .classCreator() + .withClass(clazz) + .run(); + Assertions.assertThat(createStatus) + .isNotNull() + .returns(false, Result::hasErrors) + .returns(true, Result::getResult); + + Result objectStatus = syncClient.data() + .creator() + .withClassName(className) + .withID(id) + .withProperties(properties) + .run(); + Assertions.assertThat(objectStatus) + .isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult) + .isNotNull(); + + Field fieldId = _additional("id"); + WhereArgument whereUuid = whereText("uuidProp", Operator.Equal, "7aaa79d3-a564-45db-8fa8-c49e20b8a39a"); + WhereArgument whereUuidArray1 = whereText("uuidArrayProp", Operator.Equal, "f70512a3-26cb-4ae4-9369-204555917f15"); + WhereArgument whereUuidArray2 = whereText("uuidArrayProp", Operator.Equal, "9e516f40-fd54-4083-a476-f4675b2b5f92"); + + Result resultUuid = doGet(get -> get.withWhere(whereUuid) + .withClassName(className) + .withFields(fieldId)); + Result resultUuidArray1 = doGet(get -> get.withWhere(whereUuidArray1) + .withClassName(className) + .withFields(fieldId)); + Result resultUuidArray2 = doGet(get -> get.withWhere(whereUuidArray2) + .withClassName(className) + .withFields(fieldId)); + + assertIds(className, resultUuid, new String[]{ id }); + assertIds(className, resultUuidArray1, new String[]{ id }); + assertIds(className, resultUuidArray2, new String[]{ id }); + + Result deleteStatus = syncClient.schema() + .allDeleter() + .run(); + Assertions.assertThat(deleteStatus) + .isNotNull() + .returns(false, Result::hasErrors) + .returns(true, Result::getResult); + } + + private Result doGet(Consumer build) { + Get get = gql.get(); + build.accept(get); + try { + return get.run() + .get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.get(): " + e.getMessage()); + return null; + } + } + + private Result doRaw(Consumer build) { + Raw raw = gql.raw(); + build.accept(raw); + try { + return raw.run() + .get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.raw(): " + e.getMessage()); + return null; + } + } + + private Result doExplore(Consumer build) { + Explore explore = gql.explore(); + build.accept(explore); + try { + return explore.run() + .get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.explore(): " + e.getMessage()); + return null; + } + } + + private Result doAggregate(Consumer build) { + Aggregate aggregate = gql.aggregate(); + build.accept(aggregate); + try { + return aggregate.run() + .get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.aggregate(): " + e.getMessage()); + return null; + } + } + + private SortArgument sort(SortOrder ord, String... properties) { + return gql.arguments() + .sortArgBuilder() + .path(properties) + .order(ord) + .build(); + } + + private void assertWhereResultSize(int expectedSize, Result result, String className) { + List getClass = extractClass(result, "Get", className); + assertEquals(expectedSize, getClass.size()); + } + + + @SuppressWarnings("unchecked") + private T extractAdditional(Result result, String queryType, String className, String fieldName) { + List objects = extractClass(result, queryType, className); + + Map> firstObject = (Map>) objects.get(0); + Map additional = firstObject.get("_additional"); + + return (T) additional.get(fieldName); + } + + private Float[] extractVector(Result result, String queryType, String className) { + ArrayList vector = extractAdditional(result, queryType, className, "vector"); + Float[] out = new Float[vector.size()]; + for (int i = 0; i < vector.size(); i++) { + out[i] = vector.get(i) + .floatValue(); + } + return out; + } + + @SuppressWarnings("unchecked") + private void assertAggregateMetaCount(Result result, String className, int wantObjects, Double wantCount) { + List objects = extractClass(result, "Aggregate", className); + + assertEquals(wantObjects, objects.size(), "wrong number of objects"); + Map> firstObject = (Map>) objects.get(0); + Map meta = firstObject.get("meta"); + assertEquals(wantCount, meta.get("count"), "wrong meta:count"); + } + + private void assertObjectNamesEqual(Result result, String queryType, String className, String[] want) { + List> objects = extractClass(result, queryType, className); + assertEquals(want.length, objects.size()); + for (int i = 0; i < want.length; i++) { + assertEquals(want[i], objects.get(i) + .get("name"), String.format("%s[%d] has wrong name", className.toLowerCase(), i)); + } + } +} diff --git a/src/test/java/io/weaviate/integration/client/async/graphql/ClusterGraphQLTest.java b/src/test/java/io/weaviate/integration/client/async/graphql/ClusterGraphQLTest.java new file mode 100644 index 00000000..56d42cb3 --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/graphql/ClusterGraphQLTest.java @@ -0,0 +1,84 @@ +package io.weaviate.integration.client.async.graphql; + +import com.google.gson.internal.LinkedTreeMap; +import com.jparams.junit4.JParamsTestRunner; +import com.jparams.junit4.data.DataMethod; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.async.graphql.GraphQL; +import io.weaviate.client.v1.async.graphql.api.Get; +import io.weaviate.client.v1.data.replication.model.ConsistencyLevel; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.integration.client.WeaviateDockerComposeCluster; +import io.weaviate.integration.client.WeaviateTestGenerics; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; + +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +@RunWith(JParamsTestRunner.class) +public class ClusterGraphQLTest extends AbstractAsyncClientTest { + private static final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); + private String address; + + private WeaviateClient syncClient; + private WeaviateAsyncClient client; + private GraphQL gql; + + @ClassRule + public static WeaviateDockerComposeCluster compose = new WeaviateDockerComposeCluster(); + + @Before + public void before() { + address = compose.getHttpHost0Address(); + + syncClient = new WeaviateClient(new Config("http", address)); + testGenerics.createReplicatedTestSchemaAndData(syncClient); + + client = syncClient.async(); + gql = client.graphQL(); + } + + @After + public void after() { + testGenerics.cleanupWeaviate(syncClient); + client.close(); + } + + public static Object[][] provideConsistencyLevels() { + return new Object[][]{ { ConsistencyLevel.ALL }, { ConsistencyLevel.QUORUM }, { ConsistencyLevel.ONE } }; + } + + @DataMethod(source = ClusterGraphQLTest.class, method = "provideConsistencyLevels") + @Test + public void testGraphQLGetUsingConsistencyLevel(String consistency) { + Result result = doGet(get -> get.withClassName("Pizza").withConsistencyLevel(consistency) + .withFields(field("name"), _additional("isConsistent"))); + + List>> pizzas = extractClass(result, "Get", "Pizza"); + for (LinkedTreeMap> pizza : pizzas) { + assertTrue("not consistent with ConsistencyLevel=" + consistency, pizza.get("_additional").get("isConsistent")); + } + } + + private Result doGet(Consumer build) { + Get get = gql.get(); + build.accept(get); + try { + return get.run().get(); + } catch (InterruptedException | ExecutionException e) { + fail("graphQL.get(): " + e.getMessage()); + return null; + } + } +} diff --git a/src/test/java/io/weaviate/integration/client/graphql/AbstractClientGraphQLTest.java b/src/test/java/io/weaviate/integration/client/graphql/AbstractClientGraphQLTest.java new file mode 100644 index 00000000..acdea371 --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/graphql/AbstractClientGraphQLTest.java @@ -0,0 +1,133 @@ +package io.weaviate.integration.client.graphql; + +import io.weaviate.client.base.Result; +import io.weaviate.client.base.Serializer; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.integration.client.WeaviateTestGenerics; +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.assertj.core.api.Assertions; + +import java.util.*; +import java.util.stream.Collectors; + +/** AbstractClientGraphQLTest has fixtures and assertion utils used for both sync and async tests. */ +public abstract class AbstractClientGraphQLTest { + protected static final WeaviateTestGenerics.DocumentPassageSchema testData = new WeaviateTestGenerics.DocumentPassageSchema(); + + @Getter + @AllArgsConstructor + protected static class AdditionalOfDocument { + String id; + } + + @Getter + protected static class Additional { + Group group; + } + + @Getter + protected static class AdditionalGroupByAdditional { + Additional _additional; + } + + @Getter + @AllArgsConstructor + protected static class AdditionalGroupHit { + String id; + Float distance; + } + + + @Getter + @AllArgsConstructor + protected static class GroupHitOfDocument { + AdditionalOfDocument _additional; + } + + @Getter + @AllArgsConstructor + protected static class GroupHit { + AdditionalGroupHit _additional; + List ofDocument; + } + + @Getter + @AllArgsConstructor + protected static class GroupedBy { + public String value; + public String[] path; + } + + @Getter + @AllArgsConstructor + protected static class Group { + public String id; + public GroupedBy groupedBy; + public Integer count; + public Float maxDistance; + public Float minDistance; + public List hits; + } + + protected static final List ofDocumentA = Collections.singletonList( + new GroupHitOfDocument(new AdditionalOfDocument(testData.DOCUMENT_IDS[0])) + ); + protected static final List ofDocumentB = Collections.singletonList( + new GroupHitOfDocument(new AdditionalOfDocument(testData.DOCUMENT_IDS[1])) + ); + + protected static final List expectedHitsA = new ArrayList() { + { + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[0], 4.172325e-7f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[8], 0.0023148656f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[6], 0.0023562312f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[7], 0.0025092363f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[5], 0.002709806f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[9], 0.002762556f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[4], 0.0028533936f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[3], 0.0033442378f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[2], 0.004181564f), ofDocumentA)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[1], 0.0057129264f), ofDocumentA)); + } + }; + + protected static final List expectedHitsB = new ArrayList() { + { + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[10], 0.0025351048f), ofDocumentB)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[12], 0.00288558f), ofDocumentB)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[11], 0.0033002496f), ofDocumentB)); + this.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[13], 0.004168868f), ofDocumentB)); + } + }; + + protected void assertIds(String className, Result gqlResult, String[] expectedIds) { + Assertions.assertThat(gqlResult).isNotNull().returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull() + .extracting(GraphQLResponse::getData).isInstanceOf(Map.class) + .extracting(data -> ((Map) data).get("Get")).isInstanceOf(Map.class) + .extracting(get -> ((Map) get).get(className)).isInstanceOf(List.class).asList().hasSize(expectedIds.length); + + List> results = (List>) ((Map) (((Map) + (gqlResult.getResult().getData())).get("Get"))).get(className); + String[] resultIds = results.stream().map(m -> m.get("_additional")).map(a -> ((Map) a).get("id")).toArray(String[]::new); + Assertions.assertThat(resultIds).containsExactlyInAnyOrder(expectedIds); + } + + protected List getGroups(List> result) { + Serializer serializer = new Serializer(); + String jsonString = serializer.toJsonString(result); + AdditionalGroupByAdditional[] response = serializer.toObject(jsonString, AdditionalGroupByAdditional[].class); + Assertions.assertThat(response).isNotNull().hasSize(3); + return Arrays.stream(response).map(AdditionalGroupByAdditional::get_additional).map(Additional::getGroup).collect(Collectors.toList()); + } + + protected void checkGroupElements(List expected, List actual) { + Assertions.assertThat(expected).hasSameSizeAs(actual); + for (int i = 0; i < actual.size(); i++) { + Assertions.assertThat(actual.get(i).get_additional().getId()).isEqualTo(expected.get(i).get_additional().getId()); + Assertions.assertThat(actual.get(i).getOfDocument().get(0).get_additional().getId()) + .isEqualTo(expected.get(i).getOfDocument().get(0).get_additional().getId()); + } + } +} diff --git a/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java index 10527f91..e9693f5f 100644 --- a/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java +++ b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTest.java @@ -62,9 +62,10 @@ import org.junit.Ignore; import org.junit.Test; -public class ClientGraphQLTest { +public class ClientGraphQLTest extends AbstractClientGraphQLTest { private String address; private String openAIApiKey; + private static final WeaviateTestGenerics.DocumentPassageSchema testData = new WeaviateTestGenerics.DocumentPassageSchema(); @ClassRule public static WeaviateDockerCompose compose = new WeaviateDockerCompose(); @@ -110,7 +111,6 @@ public void testRawGraphQL() { Config config = new Config("http", address); WeaviateClient client = new WeaviateClient(config); WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); - Field name = Field.builder().name("name").build(); // when testGenerics.createTestSchemaAndData(client); Result result = client.graphQL().raw().withQuery("{Get{Pizza{_additional{id}}}}").run(); @@ -1491,89 +1491,12 @@ private void expectPizzaNamesOrder(Result result, String[] expe } } - @Getter - @AllArgsConstructor - private static class AdditionalGroupHit { - String id; - Float distance; - } - - @Getter - @AllArgsConstructor - private static class AdditionalOfDocument { - String id; - } - - @Getter - @AllArgsConstructor - private static class GroupHitOfDocument { - AdditionalOfDocument _additional; - } - - @Getter - @AllArgsConstructor - private static class GroupHit { - AdditionalGroupHit _additional; - List ofDocument; - } - - @Getter - @AllArgsConstructor - private static class GroupedBy { - String value; - String[] path; - } - - @Getter - @AllArgsConstructor - private static class Group { - String id; - GroupedBy groupedBy; - Integer count; - Float maxDistance; - Float minDistance; - List hits; - } - - @Getter - private static class Additional { - Group group; - } - - @Getter - private static class AdditionalGroupByAdditional { - Additional _additional; - } - @Test public void testGraphQLGetWithGroupBy() { // given Config config = new Config("http", address); WeaviateClient client = new WeaviateClient(config); - WeaviateTestGenerics.DocumentPassageSchema testData = new WeaviateTestGenerics.DocumentPassageSchema(); - - List ofDocumentA = Collections.singletonList( - new GroupHitOfDocument(new AdditionalOfDocument(testData.DOCUMENT_IDS[0])) - ); - List ofDocumentB = Collections.singletonList( - new GroupHitOfDocument(new AdditionalOfDocument(testData.DOCUMENT_IDS[1])) - ); - List expectedHits1 = new ArrayList<>(); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[0], 4.172325e-7f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[8], 0.0023148656f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[6], 0.0023562312f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[7], 0.0025092363f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[5], 0.002709806f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[9], 0.002762556f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[4], 0.0028533936f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[3], 0.0033442378f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[2], 0.004181564f), ofDocumentA)); - expectedHits1.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[1], 0.0057129264f), ofDocumentA)); - List expectedHits2 = new ArrayList<>(); - expectedHits2.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[10], 0.0025351048f), ofDocumentB)); - expectedHits2.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[12], 0.00288558f), ofDocumentB)); - expectedHits2.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[11], 0.0033002496f), ofDocumentB)); - expectedHits2.add(new GroupHit(new AdditionalGroupHit(testData.PASSAGE_IDS[13], 0.004168868f), ofDocumentB)); + // hits Field[] hits = new Field[]{ Field.builder() @@ -1628,8 +1551,8 @@ public void testGraphQLGetWithGroupBy() { assertThat(groups.get(i).minDistance).isEqualTo(groups.get(i).getHits().get(0).get_additional().getDistance()); assertThat(groups.get(i).maxDistance).isEqualTo(groups.get(i).getHits().get(groups.get(i).getHits().size() - 1).get_additional().getDistance()); } - checkGroupElements(expectedHits1, groups.get(0).getHits()); - checkGroupElements(expectedHits2, groups.get(1).getHits()); + checkGroupElements(expectedHitsA, groups.get(0).getHits()); + checkGroupElements(expectedHitsB, groups.get(1).getHits()); } @Test @@ -1637,7 +1560,7 @@ public void testGraphQLGetWithGroupByWithHybrid() { // given Config config = new Config("http", address); WeaviateClient client = new WeaviateClient(config); - WeaviateTestGenerics.DocumentPassageSchema testData = new WeaviateTestGenerics.DocumentPassageSchema(); + // hits Field[] hits = new Field[]{ Field.builder().name("content").build(), @@ -1696,25 +1619,6 @@ public void testGraphQLGetWithGroupByWithHybrid() { } } - private void checkGroupElements(List expected, List actual) { - assertThat(expected).hasSameSizeAs(actual); - for (int i = 0; i < actual.size(); i++) { - assertThat(actual.get(i).get_additional().getId()).isEqualTo(expected.get(i).get_additional().getId()); - assertThat(actual.get(i).getOfDocument().get(0).get_additional().getId()).isEqualTo(expected.get(i).getOfDocument().get(0).get_additional().getId()); - } - } - - private List getGroups(List> result) { - Serializer serializer = new Serializer(); - String jsonString = serializer.toJsonString(result); - AdditionalGroupByAdditional[] response = serializer.toObject(jsonString, AdditionalGroupByAdditional[].class); - assertThat(response).isNotNull().hasSize(3); - return Arrays.stream(response) - .map(AdditionalGroupByAdditional::get_additional) - .map(Additional::getGroup) - .collect(Collectors.toList()); - } - private void assertPizzaName(String name, List pizzas, int position) { assertTrue(pizzas.get(position) instanceof Map); Map pizza = (Map) pizzas.get(position); @@ -2251,7 +2155,7 @@ public void shouldSupportSearchWithContains() { new String[]{id1, id2, id3}); } - private void assertIds(String className, Result gqlResult, String[] expectedIds) { + protected void assertIds(String className, Result gqlResult, String[] expectedIds) { assertThat(gqlResult).isNotNull() .returns(false, Result::hasErrors) .extracting(Result::getResult).isNotNull()