Skip to content

Commit

Permalink
Fixing compaction for internal shingling (#279)
Browse files Browse the repository at this point in the history
* pointstore fix

* more stringent tests

* more tests

* eternal spotless issues
  • Loading branch information
sudiptoguha authored Oct 11, 2021
1 parent 544b507 commit ed1bda6
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import lombok.Getter;

import com.amazon.randomcutforest.ComponentList;
import com.amazon.randomcutforest.store.IPointStore;

/**
* The class transforms input points into the form expected by internal models,
Expand Down Expand Up @@ -59,7 +60,12 @@ protected AbstractForestUpdateExecutor(IStateCoordinator<PointReference, Point>
* @param point The point used to update the forest.
*/
public void update(double[] point) {
update(point, updateCoordinator.getTotalUpdates());
long internalSequenceNumber = updateCoordinator.getTotalUpdates();
IPointStore<?> store = updateCoordinator.getStore();
if (store != null && store.isInternalShinglingEnabled()) {
internalSequenceNumber -= store.getShingleSize() - 1;
}
update(point, internalSequenceNumber);
}

public void update(double[] point, long sequenceNumber) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ public PointStoreState toState(PointStoreDouble model) {
state.setShingleSize(model.getShingleSize());
state.setDirectLocationMap(model.isDirectLocationMap());
state.setInternalShinglingEnabled(model.isInternalShinglingEnabled());
state.setLastTimeStamp(model.getNextSequenceIndex());
if (model.isInternalShinglingEnabled()) {
state.setInternalShingle(model.getInternalShingle());
state.setLastTimeStamp(model.getNextSequenceIndex());
state.setRotationEnabled(model.isInternalRotationEnabled());
}
state.setDynamicResizingEnabled(model.isDynamicResizingEnabled());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ public PointStoreState toState(PointStoreFloat model) {
state.setShingleSize(model.getShingleSize());
state.setDirectLocationMap(model.isDirectLocationMap());
state.setInternalShinglingEnabled(model.isInternalShinglingEnabled());
state.setLastTimeStamp(model.getNextSequenceIndex());
if (model.isInternalShinglingEnabled()) {
state.setInternalShingle(model.getInternalShingle());
state.setLastTimeStamp(model.getNextSequenceIndex());
state.setRotationEnabled(model.isInternalRotationEnabled());
}
state.setDynamicResizingEnabled(model.isDynamicResizingEnabled());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public interface IPointStoreView<Point> {

boolean isInternalRotationEnabled();

boolean isInternalShinglingEnabled();

int getShingleSize();

int[] transformIndices(int[] indexList);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,6 @@ int takeIndex() {
return location;
}

void verifyAndMakeSpace(int amount) {
if (startOfFreeSegment > currentStoreCapacity * dimensions - amount) {
// try compaction and then resizing
compact();
if (startOfFreeSegment > currentStoreCapacity * dimensions - amount) {
checkState(dynamicResizingEnabled, " out of store, enable dynamic resizing ");
resizeStore();
checkState(startOfFreeSegment + amount <= currentStoreCapacity * dimensions, "out of space");
}
}
}

/**
* Add a point to the point store and return the index of the stored point.
*
Expand All @@ -207,10 +195,10 @@ public int add(double[] point, long sequenceNum) {
"point.length must be equal to dimensions");

double[] tempPoint = point;
nextSequenceIndex++;
if (internalShinglingEnabled) {
// rotation is supported via the output and input is unchanged
tempPoint = constructShingleInPlace(internalShingle, point, false);
nextSequenceIndex++;
if (nextSequenceIndex < shingleSize) {
return INFEASIBLE_POINTSTORE_INDEX;
}
Expand All @@ -221,7 +209,18 @@ public int add(double[] point, long sequenceNum) {
// point has to be written, otherwise we need to write the full shingle
int amountToWrite = checkShingleAlignment(startOfFreeSegment, tempPoint) ? baseDimension : dimensions;

verifyAndMakeSpace(amountToWrite);
if (startOfFreeSegment > currentStoreCapacity * dimensions - amountToWrite) {
// try compaction and then resizing
compact();
// the compaction can change the array contents
amountToWrite = checkShingleAlignment(startOfFreeSegment, tempPoint) ? baseDimension : dimensions;
if (startOfFreeSegment > currentStoreCapacity * dimensions - amountToWrite) {
checkState(dynamicResizingEnabled, " out of store, enable dynamic resizing ");
resizeStore();
checkState(startOfFreeSegment + amountToWrite <= currentStoreCapacity * dimensions, "out of space");
}
}

nextIndex = takeIndex();

locationList[nextIndex] = startOfFreeSegment - dimensions + amountToWrite;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,23 @@

package com.amazon.randomcutforest;

import static com.amazon.randomcutforest.testutils.ShingledMultiDimDataWithKeys.generateShingledData;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.util.Random;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.store.PointStoreFloat;
import com.amazon.randomcutforest.testutils.MultiDimDataWithKey;
import com.amazon.randomcutforest.testutils.NormalMixtureTestData;
import com.amazon.randomcutforest.testutils.ShingledMultiDimDataWithKeys;
import com.amazon.randomcutforest.util.ShingleBuilder;

@Tag("functional")
Expand Down Expand Up @@ -93,4 +102,70 @@ public void testExtrapolateBasic() {
assertThrows(IllegalArgumentException.class,
() -> forest.extrapolateBasic(shingleBuilder.getShingle(), 4, 4, true, 2));
}

@ParameterizedTest
@ValueSource(booleans = { true, false })
public void InternalShinglingTest(boolean rotation) {
int sampleSize = 256;
int baseDimensions = 2;
int shingleSize = 2;
int dimensions = baseDimensions * shingleSize;
long seed = new Random().nextLong();

int numTrials = 1; // test is exact equality, reducing the number of trials
int length = 4000 * sampleSize;

for (int i = 0; i < numTrials; i++) {

RandomCutForest first = new RandomCutForest.Builder<>().compact(true).dimensions(dimensions)
.precision(Precision.FLOAT_32).randomSeed(seed).internalShinglingEnabled(true)
.internalRotationEnabled(rotation).shingleSize(shingleSize).build();

RandomCutForest second = new RandomCutForest.Builder<>().compact(true).dimensions(dimensions)
.precision(Precision.FLOAT_32).randomSeed(seed).internalShinglingEnabled(false)
.shingleSize(shingleSize).build();

RandomCutForest third = new RandomCutForest.Builder<>().compact(true).dimensions(dimensions)
.precision(Precision.FLOAT_32).randomSeed(seed).internalShinglingEnabled(false).shingleSize(1)
.build();

MultiDimDataWithKey dataWithKeys = ShingledMultiDimDataWithKeys.getMultiDimData(length, 50, 100, 5,
seed + i, baseDimensions);

double[][] shingledData = generateShingledData(dataWithKeys.data, shingleSize, baseDimensions, rotation);

assertEquals(shingledData.length, dataWithKeys.data.length - shingleSize + 1);

int count = shingleSize - 1;
// insert initial points
for (int j = 0; j < shingleSize - 1; j++) {
first.update(dataWithKeys.data[j]);
}

for (int j = 0; j < shingledData.length; j++) {
// validate eaulity of points
for (int y = 0; y < baseDimensions; y++) {
int position = (rotation) ? (count % shingleSize) : shingleSize - 1;
assertEquals(dataWithKeys.data[count][y], shingledData[j][position * baseDimensions + y], 1e-10);
}

double firstResult = first.getAnomalyScore(dataWithKeys.data[count]);
first.update(dataWithKeys.data[count]);
++count;
double secondResult = second.getAnomalyScore(shingledData[j]);
second.update(shingledData[j]);
double thirdResult = third.getAnomalyScore(shingledData[j]);
third.update(shingledData[j]);

assertEquals(firstResult, secondResult, 1e-10);
assertEquals(secondResult, thirdResult, 1e-10);
}
PointStoreFloat store = (PointStoreFloat) first.getUpdateCoordinator().getStore();
assertEquals(store.getCurrentStoreCapacity() * dimensions, store.getStore().length);
store = (PointStoreFloat) second.getUpdateCoordinator().getStore();
assertEquals(store.getCurrentStoreCapacity() * dimensions, store.getStore().length);
store = (PointStoreFloat) third.getUpdateCoordinator().getStore();
assertEquals(store.getCurrentStoreCapacity() * dimensions, store.getStore().length);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,27 @@ public void checkRotationAndCompact() {
store.add(new double[] { -6 * shinglesize }, 6 * shinglesize - 1);
store.compact();
}

@Test
void CompactionTest() {
int shinglesize = 2;
PointStoreDouble store = new PointStoreDouble.Builder().capacity(6).dimensions(shinglesize)
.shingleSize(shinglesize).indexCapacity(6).directLocationEnabled(false).internalShinglingEnabled(true)
.build();

store.add(new double[] { 0 }, 0L);
for (int i = 0; i < 5; i++) {
store.add(new double[] { i + 1 }, 0L);
}
int finalIndex = store.add(new double[] { 4 + 2 }, 0L);
assertArrayEquals(store.get(finalIndex), new double[] { 5, 6 });
store.decrementRefCount(1);
store.decrementRefCount(2);
int index = store.add(new double[] { 7 }, 0L);
assertArrayEquals(store.get(index), new double[] { 6, 7 });
store.decrementRefCount(index);
assertTrue(store.size() < store.capacity);
index = store.add(new double[] { 8 }, 0L);
assertArrayEquals(store.get(index), new double[] { 7, 8 });
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,26 @@ public void checkRotationAndCompact() {
store.compact();
}

@Test
void CompactionTest() {
int shinglesize = 2;
PointStoreFloat store = new PointStoreFloat.Builder().capacity(6).dimensions(shinglesize)
.shingleSize(shinglesize).indexCapacity(6).directLocationEnabled(false).internalShinglingEnabled(true)
.build();

store.add(new double[] { 0 }, 0L);
for (int i = 0; i < 5; i++) {
store.add(new double[] { i + 1 }, 0L);
}
int finalIndex = store.add(new double[] { 4 + 2 }, 0L);
assertArrayEquals(store.get(finalIndex), new float[] { 5, 6 });
store.decrementRefCount(1);
store.decrementRefCount(2);
int index = store.add(new double[] { 7 }, 0L);
assertArrayEquals(store.get(index), new float[] { 6, 7 });
store.decrementRefCount(index);
assertTrue(store.size() < store.capacity);
index = store.add(new double[] { 8 }, 0L);
assertArrayEquals(store.get(index), new float[] { 7, 8 });
}
}
Loading

0 comments on commit ed1bda6

Please sign in to comment.