diff --git a/avs-client-java/.gitignore b/avs-client-java/.gitignore new file mode 100644 index 0000000..89540d8 --- /dev/null +++ b/avs-client-java/.gitignore @@ -0,0 +1 @@ +/dependency-reduced-pom.xml diff --git a/avs-client-java/README.md b/avs-client-java/README.md new file mode 100644 index 0000000..5b26a0f --- /dev/null +++ b/avs-client-java/README.md @@ -0,0 +1,74 @@ +# Proximus Java Client + +This project demonstrates the use of Aerospike's vector database capabilities with `AdminClient` and `Client` classes. The project is configured with Maven and includes sample data for testing vector search functionalities. + +## Overview + - This is a demo project to illustrate how can a simple image search application can be built with Aerospike Vector database using vector database java client. + +## Prerequisites + +- Java 21 +- An AVS **0.9.0** running locally and accessible from the application. If the AVS is not available locally then update `HOSTNAME` and `PORT` In `SetupUtils.java`, other connection related information can be also updated in this file. + +## Build and run + - run `mvn package` command from `avs-client-java` directory. + - run `java -jar target/avs-client-java-demo-0.3.0.jar` + +## Project Structure + +``` +client-test/ + ├── src/ + │ ├── main/ + │ │ ├── java/ + │ │ │ └── com/ + │ │ │ └── aerospike/ + │ │ │ ├── App.java + │ │ │ └── SetupUtils.java + │ └── resources/ + │ └── sift/ + │ ├── siftsmall_base.fvecs + │ ├── siftsmall_groundtruth.ivecs + │ └── siftsmall_query.fvecs + └── pom.xml +``` + +## Maven Configuration + +Necessary `pom.xml` configuration for the project: + +```xml + + + + com.aerospike + avs-client-java + 0.3.0 + + +``` + + +## Application Code +`App.java` This class initializes the setup and test methods for the vector database. + +### SetupUtils.java + - This class handles the setup of the vector database and the asynchronous vector search tests. + - What it does: Load Data: + - Loads base vectors, query vectors, and ground truth vectors from files. + - Initialize Clients: + - Sets up `Client` and `AdminClient` using the provided Aerospike host(`localhost`) and port (`10000`). + - Create Index: Creates an index in the Aerospike database for storing vector data. + - Insert Vectors: Inserts the base vectors into the Aerospike database. + - Wait for Merge: Waits until all records are merged in the index. + - Perform Vector Search: Executes asynchronous vector searches using the query vectors. + - Computes recall metrics to evaluate the search results. + + - Key Methods: + - `setup(String host, int port)`: Handles data loading, client initialization, index creation, vector insertion, and waiting for records to merge. + - `testVectorSearchAsync()`: Executes the vector search tests and computes recall metrics. + + +### Javadocs +Please refer to the [javadocs](https://javadoc.io/doc/com.aerospike/avs-client-java/latest/index.html) for more details. + diff --git a/avs-client-java/pom.xml b/avs-client-java/pom.xml new file mode 100644 index 0000000..0928d3c --- /dev/null +++ b/avs-client-java/pom.xml @@ -0,0 +1,101 @@ + + 4.0.0 + com.aerospike + avs-client-java-demo + jar + 0.3.0 + avs-client-java-demo + Aerospike Vector Database Java Client Demo Project + https://aerospike.com/docs/vector + + + + The Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + Aerospike Inc. + https://www.aerospike.com + + + + + rkumar-aerospike + Rahul Kumar + rkumar@aerospike.com + https://aerospike.com/docs/vector/develop/java + Aerospike Inc. + https://www.aerospike.com + + developer + + America/Los_Angeles + + + + + + UTF-8 + UTF-8 + 21 + 21 + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + + ${maven.compiler.source} + ${maven.compiler.target} + ${project.build.sourceEncoding} + + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.4 + + + package + + shade + + + + + com.aerospike.App + + + + + + + + + + + + + + ch.qos.logback + logback-classic + 1.2.3 + + + + + com.aerospike + avs-client-java + 0.3.0 + + + + \ No newline at end of file diff --git a/avs-client-java/src/main/java/com/aerospike/App.java b/avs-client-java/src/main/java/com/aerospike/App.java new file mode 100644 index 0000000..7d8c18e --- /dev/null +++ b/avs-client-java/src/main/java/com/aerospike/App.java @@ -0,0 +1,19 @@ +package com.aerospike; + +public class App { + + public static void main(String[] args) { + + SetupUtils su = new SetupUtils(); + try { + su.setup(); + su.testVectorSearchAsync(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (Exception e) { + throw new RuntimeException(e); + }finally { + su.close(); + } + } +} diff --git a/avs-client-java/src/main/java/com/aerospike/SetupUtils.java b/avs-client-java/src/main/java/com/aerospike/SetupUtils.java new file mode 100644 index 0000000..394deb2 --- /dev/null +++ b/avs-client-java/src/main/java/com/aerospike/SetupUtils.java @@ -0,0 +1,320 @@ +package com.aerospike; + +import com.aerospike.vector.client.*; +import com.aerospike.vector.client.adminclient.AdminClient; +import com.aerospike.vector.client.proto.*; +import com.aerospike.vector.client.dbclient.Client; +import com.aerospike.vector.client.dbclient.VectorSearchListener; +import com.aerospike.vector.client.VectorSearchQuery; +import com.aerospike.vector.client.adminclient.AdminClient; +import com.aerospike.vector.client.proto.HnswSearchParams; +import com.aerospike.vector.client.proto.IndexId; +import com.aerospike.vector.client.proto.Neighbor; +import com.aerospike.vector.client.proto.VectorDistanceMetric; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +public class SetupUtils implements AutoCloseable { + private final Logger LOG = LoggerFactory.getLogger(SetupUtils.class); + private final float TOLERANCE = 0.00001f; + private final String BASE_PATH = "src/main/resources/sift/"; + private final int DIMENSIONS = 128; + private final int TRUTH_VECTOR_DIMENSIONS = 100; + private final int BASE_VECTOR_NUMBER = 10000; + private final int QUERY_VECTOR_NUMBER = 100; + private final String NAMESPACE = "test"; + private final String INDEX_NAME = "searchtest"; + private final String INDEX_SET_NAME = "demo"; + private final String VECTOR_BIN_NAME = "v-test-bin"; + private float[][] baseVectors; + private float[][] queryVectors; + private int[][] truthVectors; + public final String HOSTNAME = "127.0.0.1"; + public final int PORT = 10000; + private AdminClient adminClient; + private Client client; + private final String PROXIMUS_HOST = "localhost"; + private final String ADMIN = "admin"; + private final String ADMIN_PWD = "admin"; + public final List HOSTS = List.of(new HostPort(HOSTNAME, PORT)); + public final com.aerospike.vector.client.auth.PasswordCredentials ADMIN_CREDENTIALS = + new com.aerospike.vector.client.auth.PasswordCredentials(ADMIN, ADMIN_PWD); + + public SetupUtils(){ + + } + + + /** + * This function create index into the Aerospike vector DB and then inserts images, These images will be used to demonstrate vector search capabilities. + * @throws Exception + */ + public void setup() throws Exception { + baseVectors = loadBaseNumpy(); + queryVectors = loadQueryNumpy(); + truthVectors = loadTruthNumpy(); + + final ClientTlsConfig tlsConfig = null; + + ConnectionConfig clientConfig = + new ConnectionConfig.Builder() + .withLoadBalancer(false) + .withCredentials(ADMIN_CREDENTIALS) + .withListenerName("client-test") + .withHosts(HOSTS) + .withTls(tlsConfig) + .build(); + ConnectionConfig adminConfig = + new ConnectionConfig.Builder() + .withLoadBalancer(false) + .withCredentials(ADMIN_CREDENTIALS) + .withTls(tlsConfig) + .withListenerName("admin-test").withHosts(HOSTS).build(); + + // Instantiate admin client and VectorDB client. Admin client is primarily sued for adminsttrative purposes like index creation etc. + // VectorDB client is used for inserting records, doing vector search etc. + client = new Client(clientConfig); + adminClient = new AdminClient(adminConfig); + LOG.info("Created clients.."); + + // Create index + adminClient.indexCreate( + IndexId.newBuilder().setName(INDEX_NAME).setNamespace(NAMESPACE).build(), + VECTOR_BIN_NAME, DIMENSIONS, VectorDistanceMetric.SQUARED_EUCLIDEAN, + INDEX_SET_NAME,null , null, Map.of(), 60_000, 1_000); + + // Insert vectors + for (int i = 0; i < baseVectors.length; i++) { + client.putAsync(NAMESPACE, INDEX_SET_NAME, String.valueOf(i), + Map.of(VECTOR_BIN_NAME, baseVectors[i]), 0); + if( i % 100 == 0) { + Thread.sleep(1000); + } + } + LOG.info("Inserted image data in AVS."); + + + // Wait for records to get merged in the vector DB index. + boolean allRecordsMerged = false; + while (!allRecordsMerged) { + long unmerged = adminClient.indexStatus(IndexId.newBuilder().setNamespace(NAMESPACE).setName(INDEX_NAME).build()).getUnmergedRecordCount(); + allRecordsMerged = unmerged == 0; + LOG.warn("Waiting for index to merge, found unmerged {} records", unmerged); + Thread.sleep(2000); + } + } + + /** + * Demonstrate how to use async vector search + * @throws Exception + */ + public void testVectorSearchAsync() throws Exception { + AtomicLong counter = new AtomicLong(); + List[] results = new ArrayList[queryVectors.length]; + for (int i = 0; i < queryVectors.length; i++) { + SimpleListener listener = new SimpleListener(i, results, counter); + if (i % 2 == 0) { + VectorSearchQuery query = new VectorSearchQuery.Builder(NAMESPACE, INDEX_NAME, + Conversions.buildVectorValue(queryVectors[i]), 100).withProjection(Projection.getDefault()).build(); + client.vectorSearchAsync(query, listener); + + } else { + VectorSearchQuery query = new VectorSearchQuery.Builder(NAMESPACE, INDEX_NAME, + Conversions.buildVectorValue(queryVectors[i]), 100) + .withHnswSearchParams(HnswSearchParams.newBuilder().setEf(80).build()) + .withProjection(Projection.getDefault()) + .build(); + client.vectorSearchAsync(query, listener); + + } + } + + while (counter.get() != queryVectors.length) { + LOG.warn("Waiting for async search completion, current counter: {}, expected: {}", counter.get(), queryVectors.length); + Thread.sleep(1000); + } + + List recallForEachQuery = computeRecall(Arrays.stream(results).toList()); + assertRecallMetrics(recallForEachQuery); + LOG.info("Verified tha average recall is 95% and individual recall is 90% for each query."); + } + + //----Utility functions------- + + private float[][] loadBaseNumpy() throws Exception { + String baseFilename = BASE_PATH + "siftsmall_base.fvecs"; + Path path = Paths.get(baseFilename); + if (!Files.exists(path)) { + throw new IOException("File does not exist: " + path.toAbsolutePath()); + } + byte[] baseBytes = Files.readAllBytes(path); + return parseSiftToFloatArray(baseBytes, BASE_VECTOR_NUMBER); + } + + private int[][] loadTruthNumpy() throws Exception { + String truthFilename = BASE_PATH + "siftsmall_groundtruth.ivecs"; + byte[] truthBytes = Files.readAllBytes(Paths.get(truthFilename)); + return parseSiftToIntArray(truthBytes); + } + + private float[][] loadQueryNumpy() throws Exception { + String queryFilename = BASE_PATH + "siftsmall_query.fvecs"; + byte[] queryBytes = Files.readAllBytes(Paths.get(queryFilename)); + return parseSiftToFloatArray(queryBytes, QUERY_VECTOR_NUMBER); + } + + private int[][] parseSiftToIntArray(byte[] byteBuffer) throws Exception { + int[][] numpyArray = new int[QUERY_VECTOR_NUMBER][TRUTH_VECTOR_DIMENSIONS]; + int recordLength = (TRUTH_VECTOR_DIMENSIONS * 4) + 4; + + for (int i = 0; i < QUERY_VECTOR_NUMBER; i++) { + int currentOffset = i * recordLength; + ByteBuffer buffer = ByteBuffer.wrap(byteBuffer, currentOffset, recordLength); + buffer.order(ByteOrder.LITTLE_ENDIAN); + + int readDim = buffer.getInt(); + if (readDim != TRUTH_VECTOR_DIMENSIONS) { + throw new Exception("Failed to parse byte buffer correctly, expected dimension " + TRUTH_VECTOR_DIMENSIONS + ", but got " + readDim); + } + + IntBuffer intBuffer = buffer.asIntBuffer(); + intBuffer.get(numpyArray[i]); + } + return numpyArray; + } + + private float[][] parseSiftToFloatArray(byte[] byteBuffer, int length) throws Exception { + float[][] numpyArray = new float[length][DIMENSIONS]; + int recordLength = (DIMENSIONS * 4) + 4; + + for (int i = 0; i < length; i++) { + int currentOffset = i * recordLength; + ByteBuffer buffer = ByteBuffer.wrap(byteBuffer, currentOffset, recordLength); + buffer.order(ByteOrder.LITTLE_ENDIAN); + + int readDim = buffer.getInt(); + if (readDim != DIMENSIONS) { + throw new Exception("Failed to parse byte buffer correctly, expected dimension " + DIMENSIONS + ", but got " + readDim); + } + + FloatBuffer floatBuffer = buffer.asFloatBuffer(); + floatBuffer.get(numpyArray[i]); + } + return numpyArray; + } + + private void assertRecallMetrics(List recallForEachQuery) { + double recallSum = recallForEachQuery.stream().mapToDouble(Double::doubleValue).sum(); + double average = recallSum / recallForEachQuery.size(); + + if (average < 0.95) { + throw new RuntimeException(String.format("Average recall is too low: %f", average)); + } + + for (Double recall : recallForEachQuery) { + if (recall < 0.9) { + throw new RuntimeException(String.format("Recall is too low for a query: %f", recall)); + } + } + } + + private List computeRecall(List> results) { + List recallForEachQuery = new ArrayList<>(); + + for (int i = 0; i < truthVectors.length; i++) { + final int[] truth = truthVectors[i]; + int truePositive = 0; + int falseNegative = 0; + List binList = new ArrayList<>(); + + for (Neighbor result : results.get(i)) { + List floatList = result.getRecord().getFields(0).getValue().getVectorValue().getFloatData().getValueList(); + float[] floats = new float[floatList.size()]; + for (int j = 0; j < floatList.size(); j++) { + floats[j] = floatList.get(j); + } + binList.add(floats); + } + + for (int idx : truth) { + float[] vector = baseVectors[idx]; + if (binList.stream().anyMatch(searchResult -> areEqual(searchResult, vector))) { + truePositive++; + } else { + falseNegative++; + } + } + + double recall = truePositive / (double) (truePositive + falseNegative); + recallForEachQuery.add(recall); + } + + return recallForEachQuery; + } + + private boolean areEqual(float[] array1, float[] array2) { + if (array1 == null || array2 == null) { + return array1 == array2; + } + + if (array1.length != array2.length) { + return false; + } + + for (int i = 0; i < array1.length; i++) { + if (Math.abs(array1[i] - array2[i]) > TOLERANCE) { + return false; + } + } + return true; + } + + private class SimpleListener implements VectorSearchListener { + List[] results; + int idx; + AtomicLong counter; + + public SimpleListener(int idx, List[] results, AtomicLong counter) { + this.results = results; + this.idx = idx; + this.counter = counter; + } + + List result = new ArrayList<>(); + + @Override + public void onNext(Neighbor neighbor) { + result.add(neighbor); + } + + @Override + public void onComplete() { + results[idx] = result; + counter.incrementAndGet(); + } + + @Override + public void onError(Throwable e) { + LOG.warn("Error in listener {}", e); + } + } + @Override + public void close() { + client.close(); + adminClient.close(); + } +} diff --git a/avs-client-java/src/main/resources/sift/siftsmall_base.fvecs b/avs-client-java/src/main/resources/sift/siftsmall_base.fvecs new file mode 100644 index 0000000..e3b90ae Binary files /dev/null and b/avs-client-java/src/main/resources/sift/siftsmall_base.fvecs differ diff --git a/avs-client-java/src/main/resources/sift/siftsmall_groundtruth.ivecs b/avs-client-java/src/main/resources/sift/siftsmall_groundtruth.ivecs new file mode 100644 index 0000000..9948ffa Binary files /dev/null and b/avs-client-java/src/main/resources/sift/siftsmall_groundtruth.ivecs differ diff --git a/avs-client-java/src/main/resources/sift/siftsmall_learn.fvecs b/avs-client-java/src/main/resources/sift/siftsmall_learn.fvecs new file mode 100644 index 0000000..9ea42f0 Binary files /dev/null and b/avs-client-java/src/main/resources/sift/siftsmall_learn.fvecs differ diff --git a/avs-client-java/src/main/resources/sift/siftsmall_query.fvecs b/avs-client-java/src/main/resources/sift/siftsmall_query.fvecs new file mode 100644 index 0000000..88622e3 Binary files /dev/null and b/avs-client-java/src/main/resources/sift/siftsmall_query.fvecs differ