Skip to content

Commit

Permalink
test: add cluster tests
Browse files Browse the repository at this point in the history
Moved testing utils into AbstractAsyncClientTest class.
  • Loading branch information
bevzzz committed Nov 15, 2024
1 parent 754bc3d commit 7e91084
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package io.weaviate.integration.client.async.graphql;

import io.weaviate.client.base.Result;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.client.v1.graphql.query.fields.Field;
import io.weaviate.integration.client.graphql.AbstractClientGraphQLTest;

import java.util.Map;

import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertNotNull;

class AbstractAsyncClientTest extends AbstractClientGraphQLTest {
static Field field(String name) {
return Field.builder().name(name).build();
}

static Field[] fields(String... fieldNames) {
Field[] fields = new Field[fieldNames.length];
for (int i = 0; i < fieldNames.length; i++) {
fields[i] = field(fieldNames[i]);
}
return fields;
}

static Field _additional(String... fieldNames) {
return Field.builder().name("_additional").fields(fields(fieldNames)).build();
}

static Field meta(String... fieldNames) {
return Field.builder().name("meta").fields(fields(fieldNames)).build();
}

/**
* Check that request was processed successfully and no errors are returned. Extract the part of the response body for the specified query type.
*
* @param result Result of a GraphQL query.
* @param queryType "Get", "Explore", or "Aggregate".
* @return "data" portion of the response
*/
@SuppressWarnings("unchecked")
<T> T extractQueryResult(Result<GraphQLResponse> result, String queryType) {
assertNotNull(result, "graphQL request returned null");
assertNull("GraphQL error in the response", result.getError());

GraphQLResponse resp = result.getResult();
assertNotNull(resp, "GraphQL response not returned");

Map<String, Object> data = (Map<String, Object>) resp.getData();
assertNotNull(data, "GraphQL response has no data");

T queryResult = (T) data.get(queryType);
assertNotNull(queryResult, String.format("%s query returned no result", queryType));

return queryResult;
}

<T> T extractClass(Result<GraphQLResponse> result, String queryType, String className) {
Map<String, T> queryResult = extractQueryResult(result, queryType);
return extractClass(queryResult, className);
}

<T> T extractClass(Map<String, T> queryResult, String className) {
T objects = queryResult.get(className);
assertNotNull(objects, String.format("no %ss returned", className.toLowerCase()));
return objects;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.weaviate.client.v1.schema.model.WeaviateClass;
import io.weaviate.integration.client.WeaviateDockerCompose;
import io.weaviate.integration.client.WeaviateTestGenerics;
import io.weaviate.integration.client.graphql.AbstractClientGraphQLTest;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Before;
Expand All @@ -33,17 +32,15 @@
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;

import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;
import static org.junit.jupiter.api.Assertions.*;

public class ClientGraphQLTest extends AbstractClientGraphQLTest {
private String address;
public class ClientGraphQLTest extends AbstractAsyncClientTest {
private final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics();
WeaviateTestGenerics.DocumentPassageSchema passageSchema = new WeaviateTestGenerics.DocumentPassageSchema();

private final WeaviateTestGenerics.DocumentPassageSchema passageSchema = new WeaviateTestGenerics.DocumentPassageSchema();

private String address;
private WeaviateClient syncClient;
private WeaviateAsyncClient client;
private GraphQL gql;
Expand Down Expand Up @@ -794,18 +791,6 @@ public void testGraphQLGetUsingLimitAndOffset() {
assertEquals(1, pizzas.size(), "wrong number of pizzas");
}

private Result<GraphQLResponse> doGet(Consumer<Get> build) {
Get get = gql.get();
build.accept(get);
try {
return get.run()
.get();
} catch (InterruptedException | ExecutionException e) {
fail("graphQL.get(): " + e.getMessage());
return null;
}
}

@Test
public void testGraphQLGetWithGroupBy() {
Field[] hits = new Field[]{ Field.builder()
Expand Down Expand Up @@ -1059,6 +1044,18 @@ public void shouldSupportSearchByUUID() {
.returns(true, Result::getResult);
}

private Result<GraphQLResponse> doGet(Consumer<Get> build) {
Get get = gql.get();
build.accept(get);
try {
return get.run()
.get();
} catch (InterruptedException | ExecutionException e) {
fail("graphQL.get(): " + e.getMessage());
return null;
}
}

private Result<GraphQLResponse> doRaw(Consumer<Raw> build) {
Raw raw = gql.raw();
build.accept(raw);
Expand Down Expand Up @@ -1095,68 +1092,6 @@ private Result<GraphQLResponse> doAggregate(Consumer<Aggregate> build) {
}
}

/**
* Check that request was processed successfully and no errors are returned. Extract the part of the response body for the specified query type.
*
* @param result Result of a GraphQL query.
* @param queryType "Get", "Explore", or "Aggregate".
* @return "data" portion of the response
*/
@SuppressWarnings("unchecked")
private <T> T extractQueryResult(Result<GraphQLResponse> result, String queryType) {
assertNotNull(result, "graphQL request returned null");
assertNull("GraphQL error in the response", result.getError());

GraphQLResponse resp = result.getResult();
assertNotNull(resp, "GraphQL response not returned");

Map<String, Object> data = (Map<String, Object>) resp.getData();
assertNotNull(data, "GraphQL response has no data");

T queryResult = (T) data.get(queryType);
assertNotNull(queryResult, String.format("%s query returned no result", queryType));

return queryResult;
}

private <T> T extractClass(Result<GraphQLResponse> result, String queryType, String className) {
Map<String, T> queryResult = extractQueryResult(result, queryType);
return extractClass(queryResult, className);
}

private <T> T extractClass(Map<String, T> queryResult, String className) {
T objects = queryResult.get(className);
assertNotNull(objects, String.format("no %ss returned", className.toLowerCase()));
return objects;
}

private static Field field(String name) {
return Field.builder()
.name(name)
.build();
}

private static Field[] fields(String... fieldNames) {
Field[] fields = new Field[fieldNames.length];
for (int i = 0; i < fieldNames.length; i++) {
fields[i] = field(fieldNames[i]);
}
return fields;
}

private static Field _additional(String... fieldNames) {
return Field.builder()
.name("_additional")
.fields(fields(fieldNames))
.build();
}

private static Field meta(String... fieldNames) {
return Field.builder()
.name("meta")
.fields(fields(fieldNames))
.build();
}

private static WhereArgument whereText(String property, String operator, String... valueText) {
return WhereArgument.builder()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package io.weaviate.integration.client.async.graphql;

import com.google.gson.internal.LinkedTreeMap;
import com.jparams.junit4.JParamsTestRunner;
import com.jparams.junit4.data.DataMethod;
import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.async.WeaviateAsyncClient;
import io.weaviate.client.v1.async.graphql.GraphQL;
import io.weaviate.client.v1.async.graphql.api.Get;
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import io.weaviate.client.v1.graphql.model.GraphQLResponse;
import io.weaviate.integration.client.WeaviateDockerComposeCluster;
import io.weaviate.integration.client.WeaviateTestGenerics;
import org.junit.After;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;

import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

@RunWith(JParamsTestRunner.class)
public class ClusterGraphQLTest extends AbstractAsyncClientTest {
private static final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics();
private String address;

private WeaviateClient syncClient;
private WeaviateAsyncClient client;
private GraphQL gql;

@ClassRule
public static WeaviateDockerComposeCluster compose = new WeaviateDockerComposeCluster();

@Before
public void before() {
address = compose.getHttpHost0Address();

syncClient = new WeaviateClient(new Config("http", address));
testGenerics.createReplicatedTestSchemaAndData(syncClient);

client = syncClient.async();
gql = client.graphQL();
}

public static Object[][] provideConsistencyLevels() {
return new Object[][]{ { ConsistencyLevel.ALL }, { ConsistencyLevel.QUORUM }, { ConsistencyLevel.ONE } };
}

@After
public void after() {
testGenerics.cleanupWeaviate(syncClient);
client.close();
}

@DataMethod(source = ClusterGraphQLTest.class, method = "provideConsistencyLevels")
@Test
public void testGraphQLGetUsingConsistencyLevel(String consistency) {
Result<GraphQLResponse> result = doGet(get -> get.withClassName("Pizza").withConsistencyLevel(consistency)
.withFields(field("name"), _additional("isConsistent")));

List<LinkedTreeMap<String, LinkedTreeMap<String, Boolean>>> pizzas = extractClass(result, "Get", "Pizza");
for (LinkedTreeMap<String, LinkedTreeMap<String, Boolean>> pizza : pizzas) {
assertTrue("not consistent with ConsistencyLevel=" + consistency, pizza.get("_additional").get("isConsistent"));
}
}

private Result<GraphQLResponse> doGet(Consumer<Get> build) {
Get get = gql.get();
build.accept(get);
try {
return get.run().get();
} catch (InterruptedException | ExecutionException e) {
fail("graphQL.get(): " + e.getMessage());
return null;
}
}
}

0 comments on commit 7e91084

Please sign in to comment.