Skip to content

Commit

Permalink
upgrading impute (aws#294)
Browse files Browse the repository at this point in the history
* upgrading impute

* cleanup of double and float

* fixes

* name changes
  • Loading branch information
sudiptoguha authored Mar 9, 2022
1 parent 4b1d376 commit e7496ef
Show file tree
Hide file tree
Showing 21 changed files with 943 additions and 519 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

package com.amazon.randomcutforest;

import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.Neighbor;
import com.amazon.randomcutforest.testutils.NormalMixtureTestData;
import java.util.List;
import java.util.Random;

import org.github.jamm.MemoryMeter;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
Expand All @@ -32,8 +31,10 @@
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.util.List;
import java.util.Random;
import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.Neighbor;
import com.amazon.randomcutforest.testutils.NormalMixtureTestData;

@Warmup(iterations = 2)
@Measurement(iterations = 5)
Expand All @@ -46,10 +47,10 @@ public class RandomCutForestBenchmark {

@State(Scope.Benchmark)
public static class BenchmarkState {
@Param({ "5" })
@Param({ "40" })
int baseDimensions;

@Param({ "8" })
@Param({ "1" })
int shingleSize;

@Param({ "30" })
Expand Down Expand Up @@ -187,13 +188,13 @@ public RandomCutForest basicNeighborAndUpdate(BenchmarkState state, Blackhole bl

@Benchmark
@OperationsPerInvocation(DATA_SIZE)
public RandomCutForest basicExtrapolateAndUpdate(BenchmarkState state, Blackhole blackhole) {
public RandomCutForest imputeAndUpdate(BenchmarkState state, Blackhole blackhole) {
double[][] data = state.data;
forest = state.forest;
double[] output = null;

for (int i = INITIAL_DATA_SIZE; i < data.length; i++) {
output = forest.extrapolate(1);
output = forest.imputeMissingValues(data[i], 1, new int[] { forest.dimensions - 1 });
forest.update(data[i]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

package com.amazon.randomcutforest;

import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.Neighbor;
import com.amazon.randomcutforest.testutils.ShingledMultiDimDataWithKeys;
import java.util.List;
import java.util.Random;

import org.github.jamm.MemoryMeter;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
Expand All @@ -32,8 +31,10 @@
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.util.List;
import java.util.Random;
import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.Neighbor;
import com.amazon.randomcutforest.testutils.ShingledMultiDimDataWithKeys;

@Warmup(iterations = 2)
@Measurement(iterations = 5)
Expand Down Expand Up @@ -170,7 +171,7 @@ public RandomCutForest basicDensityAndUpdate(BenchmarkState state, Blackhole bla

@Benchmark
@OperationsPerInvocation(DATA_SIZE)
public RandomCutForest basicNeighborAndUpdate(BenchmarkState state, Blackhole blackhole) {
public RandomCutForest neighborAndUpdate(BenchmarkState state, Blackhole blackhole) {
double[][] data = state.data;
forest = state.forest;
List<Neighbor> output = null;
Expand All @@ -186,7 +187,23 @@ public RandomCutForest basicNeighborAndUpdate(BenchmarkState state, Blackhole bl

@Benchmark
@OperationsPerInvocation(DATA_SIZE)
public RandomCutForest basicExtrapolateAndUpdate(BenchmarkState state, Blackhole blackhole) {
public RandomCutForest imputeAndUpdate(BenchmarkState state, Blackhole blackhole) {
double[][] data = state.data;
forest = state.forest;
double[] output = null;

for (int i = INITIAL_DATA_SIZE; i < data.length; i++) {
output = forest.imputeMissingValues(data[i], 1, new int[] { state.baseDimensions - 1 });
forest.update(data[i]);
}

blackhole.consume(output);
return forest;
}

@Benchmark
@OperationsPerInvocation(DATA_SIZE)
public RandomCutForest extrapolateAndUpdate(BenchmarkState state, Blackhole blackhole) {
double[][] data = state.data;
forest = state.forest;
double[] output = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
Expand All @@ -49,9 +48,12 @@
import com.amazon.randomcutforest.executor.SamplerPlusTree;
import com.amazon.randomcutforest.executor.SequentialForestTraversalExecutor;
import com.amazon.randomcutforest.executor.SequentialForestUpdateExecutor;
import com.amazon.randomcutforest.imputation.ConditionalSampleSummarizer;
import com.amazon.randomcutforest.imputation.ImputeVisitor;
import com.amazon.randomcutforest.inspect.NearNeighborVisitor;
import com.amazon.randomcutforest.interpolation.SimpleInterpolationVisitor;
import com.amazon.randomcutforest.returntypes.ConditionalSampleSummary;
import com.amazon.randomcutforest.returntypes.ConditionalTreeSample;
import com.amazon.randomcutforest.returntypes.ConvergingAccumulator;
import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
Expand Down Expand Up @@ -985,13 +987,8 @@ public DensityOutput getSimpleDensity(float[] point) {
* versus a more random estimation
* @return A point with the missing values imputed.
*/
public List<double[]> getConditionalField(double[] point, int numberOfMissingValues, int[] missingIndexes,
double centrality) {
return getConditionalField(toFloatArray(point), numberOfMissingValues, missingIndexes, centrality);
}

public List<double[]> getConditionalField(float[] point, int numberOfMissingValues, int[] missingIndexes,
double centrality) {
public List<ConditionalTreeSample> getConditionalField(float[] point, int numberOfMissingValues,
int[] missingIndexes, double centrality) {
checkArgument(numberOfMissingValues > 0, "numberOfMissingValues must be greater than 0");
checkNotNull(missingIndexes, "missingIndexes must not be null");
checkArgument(numberOfMissingValues <= missingIndexes.length,
Expand All @@ -1003,16 +1000,31 @@ public List<double[]> getConditionalField(float[] point, int numberOfMissingValu
}

int[] liftedIndices = transformIndices(missingIndexes, point.length);
IMultiVisitorFactory<double[]> visitorFactory = (tree, y) -> new ImputeVisitor(y, tree.projectToTree(y),
liftedIndices, tree.projectMissingIndices(liftedIndices), 1.0);
IMultiVisitorFactory<ConditionalTreeSample> visitorFactory = (tree, y) -> new ImputeVisitor(y,
tree.projectToTree(y), liftedIndices, tree.projectMissingIndices(liftedIndices), centrality);
return traverseForestMulti(transformToShingledPoint(point), visitorFactory, ConditionalTreeSample.collector);
}

Collector<double[], ArrayList<double[]>, ArrayList<double[]>> collector = Collector.of(ArrayList::new,
ArrayList::add, (left, right) -> {
left.addAll(right);
return left;
}, list -> list);
public ConditionalSampleSummary getConditionalFieldSummary(float[] point, int numberOfMissingValues,
int[] missingIndexes, double centrality) {
checkArgument(numberOfMissingValues >= 0, "cannot be negative");
checkNotNull(missingIndexes, "missingIndexes must not be null");
checkArgument(numberOfMissingValues <= missingIndexes.length,
"numberOfMissingValues must be less than or equal to missingIndexes.length");
checkArgument(centrality >= 0 && centrality <= 1, "centrality needs to be in range [0,1]");
checkArgument(point != null, " cannot be null");
if (!isOutputReady()) {
return new ConditionalSampleSummary(dimensions);
}

return traverseForestMulti(transformToShingledPoint(point), visitorFactory, collector);
int[] liftedIndices = transformIndices(missingIndexes, point.length);
ConditionalSampleSummarizer summarizer = new ConditionalSampleSummarizer(liftedIndices,
transformToShingledPoint(point), centrality);
return summarizer.summarize(getConditionalField(point, numberOfMissingValues, missingIndexes, centrality));
}

public float[] imputeMissingValues(float[] point, int numberOfMissingValues, int[] missingIndexes) {
return getConditionalFieldSummary(point, numberOfMissingValues, missingIndexes, 1.0).median;
}

/**
Expand All @@ -1030,37 +1042,10 @@ public List<double[]> getConditionalField(float[] point, int numberOfMissingValu
* missing values.
* @return A point with the missing values imputed.
*/
public double[] imputeMissingValues(double[] point, int numberOfMissingValues, int[] missingIndexes) {
return imputeMissingValues(toFloatArray(point), numberOfMissingValues, missingIndexes);
}

public double[] imputeMissingValues(float[] point, int numberOfMissingValues, int[] missingIndexes) {
checkArgument(numberOfMissingValues >= 0, "numberOfMissingValues must be greater or equal than 0");
checkNotNull(missingIndexes, "missingIndexes must not be null");
checkArgument(numberOfMissingValues <= missingIndexes.length,
"numberOfMissingValues must be less than or equal to missingIndexes.length");
checkArgument(point != null, " cannot be null");

if (!isOutputReady()) {
return new double[dimensions];
}
// checks will be performed in the function call
List<double[]> conditionalField = getConditionalField(point, numberOfMissingValues, missingIndexes, 1.0);

if (numberOfMissingValues == 1) {
// when there is 1 missing value, we sort all the imputed values and return the
// median
double[] returnPoint = toDoubleArray(point);
double[] basicList = conditionalField.stream()
.mapToDouble(array -> array[transformIndices(missingIndexes, point.length)[0]]).sorted().toArray();
returnPoint[missingIndexes[0]] = basicList[numberOfTrees / 2];
return returnPoint;
} else {
// when there is more than 1 missing value, we sort the imputed points by
// anomaly score and return the point with the 25th percentile anomaly score
conditionalField.sort(Comparator.comparing(this::getAnomalyScore));
return conditionalField.get(numberOfTrees / 4);
}
@Deprecated
public double[] imputeMissingValues(double[] point, int numberOfMissingValues, int[] missingIndexes) {
return toDoubleArray(imputeMissingValues(toFloatArray(point), numberOfMissingValues, missingIndexes));
}

/**
Expand All @@ -1081,16 +1066,21 @@ public double[] imputeMissingValues(float[] point, int numberOfMissingValues, in
* then this value is not used.
* @return a forecasted time series.
*/
@Deprecated
public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) {
return toDoubleArray(extrapolateBasic(toFloatArray(point), horizon, blockSize, cyclic, shingleIndex));
}

public float[] extrapolateBasic(float[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) {
checkArgument(0 < blockSize && blockSize < dimensions,
"blockSize must be between 0 and dimensions (exclusive)");
checkArgument(dimensions % blockSize == 0, "dimensions must be evenly divisible by blockSize");
checkArgument(0 <= shingleIndex && shingleIndex < dimensions / blockSize,
"shingleIndex must be between 0 (inclusive) and dimensions / blockSize");

double[] result = new double[blockSize * horizon];
float[] result = new float[blockSize * horizon];
int[] missingIndexes = new int[blockSize];
double[] queryPoint = Arrays.copyOf(point, dimensions);
float[] queryPoint = Arrays.copyOf(point, dimensions);

if (cyclic) {
extrapolateBasicCyclic(result, horizon, blockSize, shingleIndex, queryPoint, missingIndexes);
Expand All @@ -1115,10 +1105,15 @@ public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boo
* sliding shingle.
* @return a forecasted time series.
*/
@Deprecated
public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic) {
return extrapolateBasic(point, horizon, blockSize, cyclic, 0);
}

public float[] extrapolateBasic(float[] point, int horizon, int blockSize, boolean cyclic) {
return extrapolateBasic(point, horizon, blockSize, cyclic, 0);
}

/**
* Given a shingle builder, extrapolate the stream into the future to produce a
* forecast. This method assumes you are passing in the shingle builder used to
Expand All @@ -1129,13 +1124,13 @@ public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boo
* @param horizon The number of blocks to forecast.
* @return a forecasted time series.
*/
@Deprecated
public double[] extrapolateBasic(ShingleBuilder builder, int horizon) {
return extrapolateBasic(builder.getShingle(), horizon, builder.getInputPointSize(), builder.isCyclic(),
builder.getShingleIndex());
}

void extrapolateBasicSliding(double[] result, int horizon, int blockSize, double[] queryPoint,
int[] missingIndexes) {
void extrapolateBasicSliding(float[] result, int horizon, int blockSize, float[] queryPoint, int[] missingIndexes) {
int resultIndex = 0;

Arrays.fill(missingIndexes, 0);
Expand All @@ -1147,15 +1142,15 @@ void extrapolateBasicSliding(double[] result, int horizon, int blockSize, double
// shift all entries in the query point left by 1 block
System.arraycopy(queryPoint, blockSize, queryPoint, 0, dimensions - blockSize);

double[] imputedPoint = imputeMissingValues(queryPoint, blockSize, missingIndexes);
float[] imputedPoint = imputeMissingValues(queryPoint, blockSize, missingIndexes);
for (int y = 0; y < blockSize; y++) {
result[resultIndex++] = queryPoint[dimensions - blockSize + y] = imputedPoint[dimensions - blockSize
+ y];
}
}
}

void extrapolateBasicCyclic(double[] result, int horizon, int blockSize, int shingleIndex, double[] queryPoint,
void extrapolateBasicCyclic(float[] result, int horizon, int blockSize, int shingleIndex, float[] queryPoint,
int[] missingIndexes) {

int resultIndex = 0;
Expand All @@ -1167,7 +1162,7 @@ void extrapolateBasicCyclic(double[] result, int horizon, int blockSize, int shi
missingIndexes[y] = (currentPosition + y) % dimensions;
}

double[] imputedPoint = imputeMissingValues(queryPoint, blockSize, missingIndexes);
float[] imputedPoint = imputeMissingValues(queryPoint, blockSize, missingIndexes);

for (int y = 0; y < blockSize; y++) {
result[resultIndex++] = queryPoint[(currentPosition + y)
Expand All @@ -1187,10 +1182,14 @@ void extrapolateBasicCyclic(double[] result, int horizon, int blockSize, int shi
* @return a forecasted time series.
*/
public double[] extrapolate(int horizon) {
return toDoubleArray(extrapolateFromCurrentTime(horizon));
}

public float[] extrapolateFromCurrentTime(int horizon) {
checkArgument(internalShinglingEnabled, "incorrect use");
IPointStore<?> store = stateCoordinator.getStore();
return extrapolateBasic(toDoubleArray(lastShingledPoint()), horizon, inputDimensions,
store.isInternalRotationEnabled(), ((int) nextSequenceIndex()) % shingleSize);
return extrapolateBasic(lastShingledPoint(), horizon, inputDimensions, store.isInternalRotationEnabled(),
((int) nextSequenceIndex()) % shingleSize);
}

/**
Expand Down
23 changes: 15 additions & 8 deletions Java/core/src/main/java/com/amazon/randomcutforest/Visitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public interface Visitor<R> {
void accept(INodeView node, int depthOfNode);

/**
* Visit the leaf node in the traversal path. By default this method proxies to
* Visit the leaf node in the traversal path. By default, this method proxies to
* {@link #accept(INodeView, int)}.
*
* @param leafNode the leaf node being visited
Expand All @@ -56,14 +56,21 @@ default void acceptLeaf(INodeView leafNode, final int depthOfNode) {
R getResult();

/**
* This method short-circuits the evaluation of the Visitor at nodes on the traversal path. By default, the
* accept (or acceptLeaf) method will be invoked for each Node in the traversal path. But the NodeView has to prepare
* information to support that visitor invocation. Before invocation, the value of isConverged will be checked.
* If it is true, some of that preparation can be skipped -- because the visitor would not be upodated.
* This method can be overriden to optimize visitors that do not need to visit every node on the root to leaf path
* before returning a value.
* This method short-circuits the evaluation of the Visitor at nodes on the
* traversal path. By default, the accept (or acceptLeaf) method will be invoked
* for each Node in the traversal path. But the NodeView has to prepare
* information to support that visitor invocation. Before invocation, the value
* of isConverged will be checked. If it is true, some of that preparation can
* be skipped -- because the visitor would not be updated. This method can be
* overwritten to optimize visitors that do not need to visit every node on the
* root to leaf path before returning a value.
*
* Mote that this convergence applies to a single visitor computation and is
* expected to be a speedup without any change in the value of the answer. This
* is different from converging accumulator which corresponds to sequential
* evaluation of different visitors and early stopping.
**/
default boolean isConverged() {
return false;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@

package com.amazon.randomcutforest.anomalydetection;

import java.util.Arrays;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.Visitor;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.tree.IBoundingBoxView;
import com.amazon.randomcutforest.tree.INodeView;

import java.util.Arrays;

/**
* Attribution exposes the attribution of scores produced by ScalarScoreVisitor
* corresponding to different attributes. It allows a boolean
Expand Down
Loading

0 comments on commit e7496ef

Please sign in to comment.