diff --git a/src/main/java/io/weaviate/client/WeaviateClient.java b/src/main/java/io/weaviate/client/WeaviateClient.java index aea5b1bf..10be7f45 100644 --- a/src/main/java/io/weaviate/client/WeaviateClient.java +++ b/src/main/java/io/weaviate/client/WeaviateClient.java @@ -46,7 +46,7 @@ public WeaviateClient(Config config, HttpClient httpClient, AccessTokenProvider } public WeaviateAsyncClient async() { - return new WeaviateAsyncClient(config); + return new WeaviateAsyncClient(config, tokenProvider); } public Misc misc() { diff --git a/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java b/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java index 0ac83ac5..4d9fdf6b 100644 --- a/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java +++ b/src/main/java/io/weaviate/client/v1/async/WeaviateAsyncClient.java @@ -5,6 +5,7 @@ import io.weaviate.client.base.http.async.AsyncHttpClient; import io.weaviate.client.base.util.DbVersionProvider; import io.weaviate.client.base.util.DbVersionSupport; +import io.weaviate.client.base.util.GrpcVersionSupport; import io.weaviate.client.v1.async.backup.Backup; import io.weaviate.client.v1.async.batch.Batch; import io.weaviate.client.v1.async.classifications.Classifications; @@ -13,6 +14,7 @@ import io.weaviate.client.v1.async.graphql.GraphQL; import io.weaviate.client.v1.async.misc.Misc; import io.weaviate.client.v1.async.schema.Schema; +import io.weaviate.client.v1.auth.provider.AccessTokenProvider; import io.weaviate.client.v1.misc.model.Meta; import java.util.Optional; import java.util.concurrent.ExecutionException; @@ -23,14 +25,19 @@ public class WeaviateAsyncClient implements AutoCloseable { private final Config config; private final CloseableHttpAsyncClient client; private final DbVersionSupport dbVersionSupport; + private final GrpcVersionSupport grpcVersionSupport; + private final AccessTokenProvider tokenProvider; - public WeaviateAsyncClient(Config config) { + public WeaviateAsyncClient(Config config, AccessTokenProvider tokenProvider) { this.config = config; this.client = AsyncHttpClient.create(config); // auto start the client this.start(); // init the db version provider and get the version info - this.dbVersionSupport = new DbVersionSupport(initDbVersionProvider()); + DbVersionProvider dbVersionProvider = initDbVersionProvider(); + this.dbVersionSupport = new DbVersionSupport(dbVersionProvider); + this.grpcVersionSupport = new GrpcVersionSupport(dbVersionProvider); + this.tokenProvider = tokenProvider; } public Misc misc() { @@ -38,7 +45,7 @@ public Misc misc() { } public Schema schema() { - return new Schema(client, config); + return new Schema(client, config, dbVersionSupport); } public Data data() { @@ -46,7 +53,7 @@ public Data data() { } public Batch batch() { - return new Batch(client, config, dbVersionSupport, data()); + return new Batch(client, config, dbVersionSupport, grpcVersionSupport, tokenProvider, data()); } public Cluster cluster() { diff --git a/src/main/java/io/weaviate/client/v1/async/batch/Batch.java b/src/main/java/io/weaviate/client/v1/async/batch/Batch.java index e0af3798..b2b12ecd 100644 --- a/src/main/java/io/weaviate/client/v1/async/batch/Batch.java +++ b/src/main/java/io/weaviate/client/v1/async/batch/Batch.java @@ -3,9 +3,11 @@ import io.weaviate.client.Config; import io.weaviate.client.base.util.BeaconPath; import io.weaviate.client.base.util.DbVersionSupport; +import io.weaviate.client.base.util.GrpcVersionSupport; import io.weaviate.client.v1.async.batch.api.ObjectsBatchDeleter; import io.weaviate.client.v1.async.batch.api.ObjectsBatcher; import io.weaviate.client.v1.async.data.Data; +import io.weaviate.client.v1.auth.provider.AccessTokenProvider; import io.weaviate.client.v1.batch.api.ReferencePayloadBuilder; import io.weaviate.client.v1.batch.util.ObjectsPath; import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; @@ -16,12 +18,17 @@ public class Batch { private final ObjectsPath objectsPath; private final BeaconPath beaconPath; private final Data data; + private final GrpcVersionSupport grpcVersionSupport; + private final AccessTokenProvider tokenProvider; - public Batch(CloseableHttpAsyncClient client, Config config, DbVersionSupport dbVersionSupport, Data data) { + public Batch(CloseableHttpAsyncClient client, Config config, DbVersionSupport dbVersionSupport, + GrpcVersionSupport grpcVersionSupport, AccessTokenProvider tokenProvider, Data data) { this.client = client; this.config = config; this.objectsPath = new ObjectsPath(); this.beaconPath = new BeaconPath(dbVersionSupport); + this.grpcVersionSupport = grpcVersionSupport; + this.tokenProvider = tokenProvider; this.data = data; } @@ -30,9 +37,7 @@ public ObjectsBatcher objectsBatcher() { } public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) { - // TODO: add support for missing arguments - // return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig); - return ObjectsBatcher.create(client, config, data, objectsPath, null, null, batchRetriesConfig); + return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig); } public ObjectsBatcher objectsAutoBatcher() { @@ -58,9 +63,7 @@ public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatc public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, ObjectsBatcher.AutoBatchConfig autoBatchConfig) { - // TODO: add support for missing arguments - // return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig); - return ObjectsBatcher.createAuto(client, config, data, objectsPath, null, null, batchRetriesConfig, autoBatchConfig); + return ObjectsBatcher.createAuto(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig, autoBatchConfig); } public ObjectsBatchDeleter objectsBatchDeleter() { diff --git a/src/main/java/io/weaviate/client/v1/async/schema/Schema.java b/src/main/java/io/weaviate/client/v1/async/schema/Schema.java index 5976d98c..fe5411b9 100644 --- a/src/main/java/io/weaviate/client/v1/async/schema/Schema.java +++ b/src/main/java/io/weaviate/client/v1/async/schema/Schema.java @@ -1,6 +1,7 @@ package io.weaviate.client.v1.async.schema; import io.weaviate.client.Config; +import io.weaviate.client.base.util.DbVersionSupport; import io.weaviate.client.v1.async.schema.api.ClassCreator; import io.weaviate.client.v1.async.schema.api.ClassDeleter; import io.weaviate.client.v1.async.schema.api.ClassExists; @@ -12,17 +13,19 @@ import io.weaviate.client.v1.async.schema.api.ShardsGetter; import io.weaviate.client.v1.async.schema.api.ShardsUpdater; import io.weaviate.client.v1.async.schema.api.TenantsCreator; -//import io.weaviate.client.v1.async.schema.api.TenantsUpdater; import io.weaviate.client.v1.async.schema.api.TenantsGetter; +import io.weaviate.client.v1.async.schema.api.TenantsUpdater; import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; public class Schema { private final CloseableHttpAsyncClient client; private final Config config; + private final DbVersionSupport dbVersionSupport; - public Schema(CloseableHttpAsyncClient client, Config config) { + public Schema(CloseableHttpAsyncClient client, Config config, DbVersionSupport dbVersionSupport) { this.client = client; this.config = config; + this.dbVersionSupport = dbVersionSupport; } public SchemaGetter getter() { @@ -69,9 +72,9 @@ public TenantsCreator tenantsCreator() { return new TenantsCreator(client, config); } -// public TenantsUpdater tenantsUpdater() { -// return new TenantsUpdater(client, config); -// } + public TenantsUpdater tenantsUpdater() { + return new TenantsUpdater(client, config, dbVersionSupport); + } public TenantsGetter tenantsGetter() { return new TenantsGetter(client, config); diff --git a/src/main/java/io/weaviate/client/v1/async/schema/api/TenantsUpdater.java b/src/main/java/io/weaviate/client/v1/async/schema/api/TenantsUpdater.java new file mode 100644 index 00000000..4de8f209 --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/async/schema/api/TenantsUpdater.java @@ -0,0 +1,93 @@ +package io.weaviate.client.v1.async.schema.api; + +import io.weaviate.client.Config; +import io.weaviate.client.base.AsyncBaseClient; +import io.weaviate.client.base.AsyncClientResult; +import io.weaviate.client.base.Response; +import io.weaviate.client.base.Result; +import io.weaviate.client.base.http.async.ResponseParser; +import io.weaviate.client.base.util.DbVersionSupport; +import io.weaviate.client.base.util.UrlEncoder; +import io.weaviate.client.v1.schema.model.Tenant; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.HttpStatus; + +public class TenantsUpdater extends AsyncBaseClient implements AsyncClientResult { + + private final static int BATCH_SIZE = 100; + private final DbVersionSupport dbVersionSupport; + private String className; + private Tenant[] tenants; + + public TenantsUpdater(CloseableHttpAsyncClient client, Config config, DbVersionSupport dbVersionSupport) { + super(client, config); + this.dbVersionSupport = dbVersionSupport; + } + + public TenantsUpdater withClassName(String className) { + this.className = className; + return this; + } + + public TenantsUpdater withTenants(Tenant... tenants) { + this.tenants = tenants; + return this; + } + + @Override + public Future> run(FutureCallback> callback) { + if (dbVersionSupport.supportsOnly100TenantsInOneRequest() && tenants != null && tenants.length > BATCH_SIZE) { + CompletableFuture> updateALl = CompletableFuture.supplyAsync(() -> chunkTenants(tenants, BATCH_SIZE)).thenApplyAsync(tenants -> { + for (List batch : tenants) { + try { + Result resp = updateTenants(batch.toArray(new Tenant[0]), null).get(); + if (resp.hasErrors()) { + return resp; + } + } catch (InterruptedException | ExecutionException e) { + throw new CompletionException(e); + } + } + return new Result<>(200, true, null); + }); + if (callback != null) { + return updateALl.whenComplete((booleanResult, e) -> { + callback.completed(booleanResult); + if (e != null) { + callback.failed(new Exception(e)); + } + }); + } + return updateALl; + } + return updateTenants(tenants, callback); + } + + private Future> updateTenants(Tenant[] tenants, FutureCallback> callback) { + String path = String.format("/schema/%s/tenants", UrlEncoder.encodePathParam(className)); + return sendPutRequest(path, tenants, callback, new ResponseParser() { + @Override + public Result parse(HttpResponse response, String body, ContentType contentType) { + Response resp = serializer.toResponse(response.getCode(), body, Tenant[].class); + return new Result<>(resp.getStatusCode(), resp.getStatusCode() == HttpStatus.SC_OK, resp.getErrors()); + } + }); + } + + private Collection> chunkTenants(Tenant[] tenants, int chunkSize) { + AtomicInteger counter = new AtomicInteger(); + return Stream.of(tenants).collect(Collectors.groupingBy(it -> counter.getAndIncrement() / chunkSize)).values(); + } +} diff --git a/src/test/java/io/weaviate/integration/client/async/schema/ClientSchemaMultiTenancyTest.java b/src/test/java/io/weaviate/integration/client/async/schema/ClientSchemaMultiTenancyTest.java new file mode 100644 index 00000000..d9b08fbd --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/schema/ClientSchemaMultiTenancyTest.java @@ -0,0 +1,64 @@ +package io.weaviate.integration.client.async.schema; + +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.schema.model.ActivityStatus; +import io.weaviate.client.v1.schema.model.Tenant; +import io.weaviate.integration.client.AssertMultiTenancy; +import io.weaviate.integration.client.WeaviateDockerCompose; +import io.weaviate.integration.client.WeaviateTestGenerics; +import static io.weaviate.integration.client.WeaviateTestGenerics.TENANT_1; +import static io.weaviate.integration.client.WeaviateTestGenerics.TENANT_2; +import java.util.Arrays; +import java.util.concurrent.ExecutionException; +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public class ClientSchemaMultiTenancyTest { + private WeaviateClient client; + private WeaviateTestGenerics testGenerics; + private AssertMultiTenancy assertMT; + + @ClassRule + public static WeaviateDockerCompose compose = new WeaviateDockerCompose(); + + @Before + public void before() { + String httpHost = compose.getHttpHostAddress(); + Config config = new Config("http", httpHost); + + client = new WeaviateClient(config); + testGenerics = new WeaviateTestGenerics(); + assertMT = new AssertMultiTenancy(client); + } + + @After + public void after() { + testGenerics.cleanupWeaviate(client); + } + + @Test + public void shouldUpdateTenantsOfMTClass() throws ExecutionException, InterruptedException { + Tenant[] tenants = new Tenant[]{TENANT_1, TENANT_2}; + testGenerics.createSchemaPizzaForTenants(client); + testGenerics.createTenantsPizza(client, tenants); + + try (WeaviateAsyncClient asyncClient = client.async()) { + Result updateResult = asyncClient.schema().tenantsUpdater() + .withClassName("Pizza") + .withTenants(Arrays.stream(tenants) + .map(tenant -> Tenant.builder().name(tenant.getName()).activityStatus(ActivityStatus.COLD).build()) + .toArray(Tenant[]::new)) + .run().get(); + + assertThat(updateResult).isNotNull() + .returns(false, Result::hasErrors) + .returns(true, Result::getResult); + } + } +}