Skip to content

Commit

Permalink
cleanup of bounding box (#292)
Browse files Browse the repository at this point in the history
Migrate bounding box code to use float instead of double
  • Loading branch information
sudiptoguha authored Jan 28, 2022
1 parent 6b9d8ad commit c76e6aa
Show file tree
Hide file tree
Showing 68 changed files with 761 additions and 3,073 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

package com.amazon.randomcutforest;

import java.util.Random;

import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.Neighbor;
import com.amazon.randomcutforest.testutils.ShingledMultiDimDataWithKeys;
import org.github.jamm.MemoryMeter;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
Expand All @@ -29,9 +32,8 @@
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.testutils.NormalMixtureTestData;
import java.util.List;
import java.util.Random;

@Warmup(iterations = 2)
@Measurement(iterations = 5)
Expand All @@ -43,28 +45,38 @@ public class RandomCutForestBenchmark {

@State(Scope.Benchmark)
public static class BenchmarkState {
@Param({ "10" })
int dimensions;
@Param({ "5" })
int baseDimensions;

@Param({ "50" })
@Param({ "8" })
int shingleSize;

@Param({ "30" })
int numberOfTrees;

@Param({ "1.0", "0.9", "0.8", "0.7", "0.6", "0.5", "0.4", "0.3", "0.2", "0.1", "0.0" })
double boundingBoxCacheFraction;

@Param({ "false", "true" })
boolean compact;
boolean parallel;

double[][] data;
RandomCutForest forest;

@Setup(Level.Trial)
public void setUpData() {
NormalMixtureTestData testData = new NormalMixtureTestData();
data = testData.generateTestData(DATA_SIZE, dimensions);
int dimensions = baseDimensions * shingleSize;
int sampleSize = 256;
int dataSize = 100 * sampleSize;
data = ShingledMultiDimDataWithKeys.getMultiDimData(dataSize + shingleSize - 1, 50, 100, 5, 17,
baseDimensions).data;
}

@Setup(Level.Invocation)
public void setUpForest() {
forest = RandomCutForest.builder().numberOfTrees(numberOfTrees).dimensions(dimensions)
.parallelExecutionEnabled(false).compact(compact).randomSeed(99).build();
forest = RandomCutForest.builder().numberOfTrees(numberOfTrees).dimensions(baseDimensions * shingleSize)
.internalShinglingEnabled(true).shingleSize(shingleSize).parallelExecutionEnabled(parallel)
.boundingBoxCacheFraction(boundingBoxCacheFraction).randomSeed(99).build();
}
}

Expand Down Expand Up @@ -115,6 +127,10 @@ public RandomCutForest scoreAndUpdate(BenchmarkState state, Blackhole blackhole)
}

blackhole.consume(score);
if (!forest.parallelExecutionEnabled) {
MemoryMeter meter = new MemoryMeter();
System.out.println(" forest size " + meter.measureDeep(forest));
}
return forest;
}

Expand Down Expand Up @@ -149,4 +165,36 @@ public RandomCutForest basicDensityAndUpdate(BenchmarkState state, Blackhole bla
blackhole.consume(output);
return forest;
}

@Benchmark
@OperationsPerInvocation(DATA_SIZE)
public RandomCutForest basicNeighborAndUpdate(BenchmarkState state, Blackhole blackhole) {
double[][] data = state.data;
forest = state.forest;
List<Neighbor> output = null;

for (int i = 0; i < data.length; i++) {
output = forest.getNearNeighborsInSample(data[i]);
forest.update(data[i]);
}

blackhole.consume(output);
return forest;
}

@Benchmark
@OperationsPerInvocation(DATA_SIZE)
public RandomCutForest basicExtrapolateAndUpdate(BenchmarkState state, Blackhole blackhole) {
double[][] data = state.data;
forest = state.forest;
double[] output = null;

for (int i = 0; i < data.length; i++) {
output = forest.extrapolate(1);
forest.update(data[i]);
}

blackhole.consume(output);
return forest;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public static <T> T checkNotNull(T object, String message) {
* @return the probability of separation choosing a random cut
*/

public static double getProbabilityOfSeparation(final IBoundingBoxView boundingBox, double[] queryPoint) {
public static double getProbabilityOfSeparation(final IBoundingBoxView boundingBox, float[] queryPoint) {
double sumOfNewRange = 0d;
double sumOfDifferenceInRange = 0d;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@FunctionalInterface
public interface IMultiVisitorFactory<R> {
MultiVisitor<R> newVisitor(ITree<?, ?> tree, double[] point);
MultiVisitor<R> newVisitor(ITree<?, ?> tree, float[] point);

default R liftResult(ITree<?, ?> tree, R result) {
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@FunctionalInterface
public interface IVisitorFactory<R> {
Visitor<R> newVisitor(ITree<?, ?> tree, double[] point);
Visitor<R> newVisitor(ITree<?, ?> tree, float[] point);

default R liftResult(ITree<?, ?> tree, R result) {
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@
*/

public class MultiVisitorFactory<R> implements IMultiVisitorFactory<R> {
private final BiFunction<ITree<?, ?>, double[], MultiVisitor<R>> newVisitor;
private final BiFunction<ITree<?, ?>, float[], MultiVisitor<R>> newVisitor;
private final BiFunction<ITree<?, ?>, R, R> liftResult;

public MultiVisitorFactory(BiFunction<ITree<?, ?>, double[], MultiVisitor<R>> newVisitor,
public MultiVisitorFactory(BiFunction<ITree<?, ?>, float[], MultiVisitor<R>> newVisitor,
BiFunction<ITree<?, ?>, R, R> liftResult) {
this.newVisitor = newVisitor;
this.liftResult = liftResult;
}

public MultiVisitorFactory(BiFunction<ITree<?, ?>, double[], MultiVisitor<R>> newVisitor) {
public MultiVisitorFactory(BiFunction<ITree<?, ?>, float[], MultiVisitor<R>> newVisitor) {
this(newVisitor, (tree, x) -> x);
}

@Override
public MultiVisitor<R> newVisitor(ITree<?, ?> tree, double[] point) {
public MultiVisitor<R> newVisitor(ITree<?, ?> tree, float[] point) {
return newVisitor.apply(tree, point);
}

Expand Down
Loading

0 comments on commit c76e6aa

Please sign in to comment.