Skip to content

Commit

Permalink
add convenience methods for creating empty index.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jelmer Kuperus committed Jan 2, 2025
1 parent c5f19d7 commit 2bda445
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.github.jelmerk.knn.Item;
import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.util.ClassLoaderObjectInputStream;
import com.github.jelmerk.knn.util.DummyComparator;

import java.io.*;
import java.nio.file.Files;
Expand All @@ -26,6 +27,7 @@ public class BruteForceIndex<TId, TVector, TItem extends Item<TId, TVector>, TDi

private static final long serialVersionUID = 1L;

private final boolean immutable;
private final int dimensions;
private final DistanceFunction<TVector, TDistance> distanceFunction;
private final Comparator<TDistance> distanceComparator;
Expand All @@ -34,6 +36,7 @@ public class BruteForceIndex<TId, TVector, TItem extends Item<TId, TVector>, TDi
private final Map<TId, Long> deletedItemVersions;

private BruteForceIndex(BruteForceIndex.Builder<TVector, TDistance> builder) {
this.immutable = builder.immutable;
this.dimensions = builder.dimensions;
this.distanceFunction = builder.distanceFunction;
this.distanceComparator = builder.distanceComparator;
Expand Down Expand Up @@ -79,6 +82,9 @@ public int getDimensions() {
*/
@Override
public boolean add(TItem item) {
if (immutable) {
throw new UnsupportedOperationException("Index is immutable");
}
if (item.dimensions() != dimensions) {
throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions);
}
Expand Down Expand Up @@ -286,7 +292,7 @@ public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> BruteF
Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction) {

Comparator<TDistance> distanceComparator = Comparator.naturalOrder();
return new Builder<>(dimensions, distanceFunction, distanceComparator);
return new Builder<>(false, dimensions, distanceFunction, distanceComparator);
}

/**
Expand All @@ -301,7 +307,23 @@ Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector
*/
public static <TVector, TDistance> Builder <TVector, TDistance> newBuilder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {

return new Builder<>(dimensions, distanceFunction, distanceComparator);
return new Builder<>(false, dimensions, distanceFunction, distanceComparator);
}

/**
* Creates an immutable empty index.
*
* @return the empty index
* @param <TId> Type of the external identifier of an item
* @param <TVector> Type of the vector to perform distance calculation on
* @param <TItem> Type of items stored in the index
* @param <TDistance> Type of distance between items (expect any numeric type: float, double, int, ..)
*/
public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> BruteForceIndex<TId, TVector, TItem, TDistance> empty() {
BruteForceIndex.Builder<TVector, TDistance> builder = new BruteForceIndex.Builder<>(true,0, (DistanceFunction<TVector, TDistance>) (u, v) -> {
throw new UnsupportedOperationException();
}, new DummyComparator<>());
return builder.build();
}

/**
Expand All @@ -318,7 +340,10 @@ public static class Builder <TVector, TDistance> {

private final Comparator<TDistance> distanceComparator;

Builder(int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
private final boolean immutable;

Builder(boolean immutable, int dimensions, DistanceFunction<TVector, TDistance> distanceFunction, Comparator<TDistance> distanceComparator) {
this.immutable = immutable;
this.dimensions = dimensions;
this.distanceFunction = distanceFunction;
this.distanceComparator = distanceComparator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
implements Index<TId, TVector, TItem, TDistance> {

private static final byte VERSION_1 = 0x01;
private static final byte VERSION_2 = 0x02;

private static final long serialVersionUID = 1L;

Expand All @@ -42,6 +43,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance
private Comparator<TDistance> distanceComparator;
private MaxValueComparator<TDistance> maxValueDistanceComparator;

private boolean immutable;
private int dimensions;
private int maxItemCount;
private int m;
Expand Down Expand Up @@ -74,6 +76,7 @@ public class HnswIndex<TId, TVector, TItem extends Item<TId, TVector>, TDistance

private HnswIndex(RefinedBuilder<TId, TVector, TItem, TDistance> builder) {

this.immutable = builder.immutable;
this.dimensions = builder.dimensions;
this.maxItemCount = builder.maxItemCount;
this.distanceFunction = builder.distanceFunction;
Expand Down Expand Up @@ -202,6 +205,9 @@ public boolean remove(TId id, long version) {
*/
@Override
public boolean add(TItem item) {
if (immutable) {
throw new UnsupportedOperationException("Index is immutable");
}
if (item.dimensions() != dimensions) {
throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions);
}
Expand Down Expand Up @@ -757,7 +763,7 @@ public void save(OutputStream out) throws IOException {
}

private void writeObject(ObjectOutputStream oos) throws IOException {
oos.writeByte(VERSION_1);
oos.writeByte(VERSION_2);
oos.writeInt(dimensions);
oos.writeObject(distanceFunction);
oos.writeObject(distanceComparator);
Expand All @@ -776,6 +782,7 @@ private void writeObject(ObjectOutputStream oos) throws IOException {
writeMutableObjectLongMap(oos, deletedItemVersions);
writeNodesArray(oos, nodes);
oos.writeInt(entryPoint == null ? -1 : entryPoint.id);
oos.writeBoolean(immutable);
}

@SuppressWarnings("unchecked")
Expand All @@ -802,6 +809,8 @@ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFound
this.nodes = readNodesArray(ois, itemSerializer, maxM0, maxM);

int entrypointNodeId = ois.readInt();

this.immutable = version != VERSION_1 && ois.readBoolean();
this.entryPoint = entrypointNodeId == -1 ? null : nodes.get(entrypointNodeId);

this.globalLock = new ReentrantLock();
Expand Down Expand Up @@ -1069,7 +1078,26 @@ public static <TVector, TDistance extends Comparable<TDistance>> Builder<TVector

Comparator<TDistance> distanceComparator = Comparator.naturalOrder();

return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount);
return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount);
}

/**
* Creates an immutable empty index.
*
* @return the empty index
* @param <TId> Type of the external identifier of an item
* @param <TVector> Type of the vector to perform distance calculation on
* @param <TItem> Type of items stored in the index
* @param <TDistance> Type of distance between items (expect any numeric type: float, double, int, ..)
*/
public static <TId, TVector, TItem extends Item<TId, TVector>, TDistance> HnswIndex<TId, TVector, TItem, TDistance> empty() {
Builder<TVector, TDistance> builder = new Builder<>(true, 0, new DistanceFunction<TVector, TDistance>() {
@Override
public TDistance distance(TVector u, TVector v) {
throw new UnsupportedOperationException();
}
}, new DummyComparator<>(), 0);
return builder.build();
}

/**
Expand All @@ -1089,7 +1117,7 @@ public static <TVector, TDistance> Builder<TVector, TDistance> newBuilder(
Comparator<TDistance> distanceComparator,
int maxItemCount) {

return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount);
return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount);
}

private int assignLevel(TId value, double lambda) {
Expand Down Expand Up @@ -1318,6 +1346,7 @@ public static abstract class BuilderBase<TBuilder extends BuilderBase<TBuilder,
public static final int DEFAULT_EF_CONSTRUCTION = 200;
public static final boolean DEFAULT_REMOVE_ENABLED = false;

boolean immutable;
int dimensions;
DistanceFunction<TVector, TDistance> distanceFunction;
Comparator<TDistance> distanceComparator;
Expand All @@ -1329,11 +1358,12 @@ public static abstract class BuilderBase<TBuilder extends BuilderBase<TBuilder,
int efConstruction = DEFAULT_EF_CONSTRUCTION;
boolean removeEnabled = DEFAULT_REMOVE_ENABLED;

BuilderBase(int dimensions,
BuilderBase(boolean immutable,
int dimensions,
DistanceFunction<TVector, TDistance> distanceFunction,
Comparator<TDistance> distanceComparator,
int maxItemCount) {

this.immutable = immutable;
this.dimensions = dimensions;
this.distanceFunction = distanceFunction;
this.distanceComparator = distanceComparator;
Expand Down Expand Up @@ -1417,12 +1447,13 @@ public static class Builder<TVector, TDistance> extends BuilderBase<Builder<TVec
* @param distanceFunction the distance function
* @param maxItemCount the maximum number of elements in the index
*/
Builder(int dimensions,
Builder(boolean immutable,
int dimensions,
DistanceFunction<TVector, TDistance> distanceFunction,
Comparator<TDistance> distanceComparator,
int maxItemCount) {

super(dimensions, distanceFunction, distanceComparator, maxItemCount);
super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount);
}

@Override
Expand All @@ -1440,7 +1471,7 @@ Builder<TVector, TDistance> self() {
* @return the builder
*/
public <TId, TItem extends Item<TId, TVector>> RefinedBuilder<TId, TVector, TItem, TDistance> withCustomSerializers(ObjectSerializer<TId> itemIdSerializer, ObjectSerializer<TItem> itemSerializer) {
return new RefinedBuilder<>(dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction,
return new RefinedBuilder<>(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction,
removeEnabled, itemIdSerializer, itemSerializer);
}

Expand Down Expand Up @@ -1475,7 +1506,8 @@ public static class RefinedBuilder<TId, TVector, TItem extends Item<TId, TVector
private ObjectSerializer<TId> itemIdSerializer;
private ObjectSerializer<TItem> itemSerializer;

RefinedBuilder(int dimensions,
RefinedBuilder(boolean immutable,
int dimensions,
DistanceFunction<TVector, TDistance> distanceFunction,
Comparator<TDistance> distanceComparator,
int maxItemCount,
Expand All @@ -1486,7 +1518,7 @@ public static class RefinedBuilder<TId, TVector, TItem extends Item<TId, TVector
ObjectSerializer<TId> itemIdSerializer,
ObjectSerializer<TItem> itemSerializer) {

super(dimensions, distanceFunction, distanceComparator, maxItemCount);
super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount);

this.m = m;
this.ef = ef;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.github.jelmerk.knn.util;

import java.io.Serializable;
import java.util.Comparator;

/**
* Implementation of {@link Comparator} that is serializable and throws {@link UnsupportedOperationException} when
* compare is called. Useful as a dummy placeholder when you know it will never be called.
*
* @param <T> the type of objects that may be compared by this comparator
*/
public class DummyComparator<T> implements Comparator<T>, Serializable {

@Override
public int compare(T o1, T o2) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import java.io.IOException;
import java.util.*;

import com.github.jelmerk.knn.hnsw.HnswIndex;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;

import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

class BruteForceIndexTest {

Expand Down Expand Up @@ -152,6 +154,20 @@ void saveAndLoadIndex() throws IOException {
assertThat(loadedIndex.size(), is(1));
}

@Test
void createEmptyIndex() {
BruteForceIndex<String, float[], TestItem, Float> index = BruteForceIndex.empty();

assertThrows(
UnsupportedOperationException.class,
() -> index.add(item1),
"Index should be immutable"
);

assertThat(index.size(), is(0));
assertThat(index.getDimensions(), is(0));
}


}

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

class HnswIndexTest {

Expand Down Expand Up @@ -215,4 +216,18 @@ void saveAndLoadIndex() throws IOException {

assertThat(loadedIndex.size(), is(1));
}

@Test
void emptyIndexIsImmutable() {
HnswIndex<String, float[], TestItem, Float> index = HnswIndex.empty();

assertThrows(
UnsupportedOperationException.class,
() -> index.add(item1),
"Index should be immutable"
);

assertThat(index.size(), is(0));
assertThat(index.getDimensions(), is(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ object BruteForceIndex {

new BruteForceIndex[TId, TVector, TItem, TDistance](jIndex)
}

/**
* Creates an immutable empty index.
*
* @tparam TId Type of the external identifier of an item
* @tparam TVector Type of the vector to perform distance calculation on
* @tparam TItem Type of items stored in the index
* @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..)
* @return the index
*/
def empty[TId, TVector, TItem <: Item[TId, TVector], TDistance]: BruteForceIndex[TId, TVector, TItem, TDistance] = {
val jIndex: JBruteForceIndex[TId, TVector, TItem, TDistance] = JBruteForceIndex.empty()
new BruteForceIndex(jIndex)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,21 @@ object HnswIndex {
if(removeEnabled) builder.withRemoveEnabled().build()
else builder.build()

new HnswIndex[TId, TVector, TItem, TDistance](jIndex)
new HnswIndex(jIndex)
}

/**
* Creates an immutable empty index.
*
* @tparam TId Type of the external identifier of an item
* @tparam TVector Type of the vector to perform distance calculation on
* @tparam TItem Type of items stored in the index
* @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..)
* @return the index
*/
def empty[TId, TVector, TItem <: Item[TId, TVector], TDistance]: HnswIndex[TId, TVector, TItem, TDistance] = {
val jIndex: JHnswIndex[TId, TVector, TItem, TDistance] = JHnswIndex.empty()
new HnswIndex(jIndex)
}

}
Expand All @@ -140,13 +153,18 @@ class HnswIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] private (d
/**
* This distance function.
*/
val distanceFunction: DistanceFunction[TVector, TDistance] = delegate
.getDistanceFunction.asInstanceOf[ScalaDistanceFunctionAdapter[TVector, TDistance]].scalaFunction
val distanceFunction: DistanceFunction[TVector, TDistance] = delegate.getDistanceFunction match {
case a: ScalaDistanceFunctionAdapter[TVector, TDistance] => a.scalaFunction
case f => (v1: TVector, v2: TVector) => f.distance(v1, v2)
}

/**
* The ordering used to compare distances
*/
val distanceOrdering: Ordering[TDistance] = delegate.getDistanceComparator.asInstanceOf[Ordering[TDistance]]
val distanceOrdering: Ordering[TDistance] = delegate.getDistanceComparator match {
case ordering: Ordering[TDistance] => ordering
case c => (x: TDistance, y: TDistance) => c.compare(x, y)
}

/**
* The maximum number of items the index can hold.
Expand Down
Loading

0 comments on commit 2bda445

Please sign in to comment.