Skip to content

Commit

Permalink
fix for Jackson bindings in benchmarks (#281)
Browse files Browse the repository at this point in the history
* fix for Jackson bindings in benchmarks

* changes + shortening tests
  • Loading branch information
sudiptoguha authored Oct 12, 2021
1 parent ed1bda6 commit c8b5039
Show file tree
Hide file tree
Showing 16 changed files with 37 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ public enum ImputationMethod {
* last known value in each input dimension
*/
PREVIOUS,
/**
* linear interpolation
*/
LINEAR,
/**
* use the RCF imputation; but would often require a minimum number of
* observations and would use PREVIOUS till that point
* observations and would uses a default (often LINEAR) till that point
*/
RCF;
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ public RandomCutForest toModel(RandomCutForestState state, ExecutionContext exec
.internalShinglingEnabled(state.isInternalShinglingEnabled()).randomSeed(seed);

if (state.isCompact()) {
if (state.getPrecisionEnumValue() == Precision.FLOAT_32) {
if (Precision.valueOf(state.getPrecision()) == Precision.FLOAT_32) {
return singlePrecisionForest(builder, state, null, null, null);
} else {
return doublePrecisionForest(builder, state, null, null, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import lombok.Data;

import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.state.sampler.CompactSamplerState;
import com.amazon.randomcutforest.state.store.PointStoreState;
import com.amazon.randomcutforest.state.tree.CompactRandomCutTreeState;
Expand Down Expand Up @@ -81,7 +80,4 @@ public class RandomCutForestState {

private boolean saveCoordinatorStateEnabled;

public Precision getPrecisionEnumValue() {
return Precision.valueOf(precision);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public NodeStore toModel(NodeStoreState state, long seed) {
int capacity = state.getCapacity();
int[] cutDimension = ArrayPacking.unpackInts(state.getCutDimension(), state.isCompressed());
double[] cutValue;
if (state.getPrecisionEnumValue() == Precision.FLOAT_32) {
if (Precision.valueOf(state.getPrecision()) == Precision.FLOAT_32) {
cutValue = CommonUtils.toDoubleArray(ArrayPacking.unpackFloats(state.getCutValueData()));
} else {
cutValue = ArrayPacking.unpackDoubles(state.getCutValueData());
Expand Down Expand Up @@ -107,7 +107,7 @@ public NodeStoreState toState(NodeStore model) {
}

state.setCutDimension(ArrayPacking.pack(model.getCutDimension(), state.isCompressed()));
if (state.getPrecisionEnumValue() == Precision.FLOAT_32) {
if (Precision.valueOf(state.getPrecision()) == Precision.FLOAT_32) {
state.setCutValueData(ArrayPacking.pack(CommonUtils.toFloatArray(model.getCutValue())));
} else {
state.setCutValueData(ArrayPacking.pack(model.getCutValue(), model.getCutValue().length));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

import lombok.Data;

import com.amazon.randomcutforest.config.Precision;

@Data
public class NodeStoreState {

Expand All @@ -47,7 +45,4 @@ public class NodeStoreState {
private int[] leafMass;
private int[] leafPointIndex;

public Precision getPrecisionEnumValue() {
return Precision.valueOf(precision);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public class PointStoreDoubleMapper implements IStateMapper<PointStoreDouble, Po
public PointStoreDouble toModel(PointStoreState state, long seed) {
checkNotNull(state.getRefCount(), "refCount must not be null");
checkNotNull(state.getPointData(), "pointData must not be null");
checkArgument(state.getPrecisionEnumValue() == Precision.FLOAT_64, "precision must be " + Precision.FLOAT_64);
checkArgument(Precision.valueOf(state.getPrecision()) == Precision.FLOAT_64,
"precision must be " + Precision.FLOAT_64);
int indexCapacity = state.getIndexCapacity();
int dimensions = state.getDimensions();
double[] store = ArrayPacking.unpackDoubles(state.getPointData(), state.getCurrentStoreCapacity() * dimensions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ public class PointStoreFloatMapper implements IStateMapper<PointStoreFloat, Poin
public PointStoreFloat toModel(PointStoreState state, long seed) {
checkNotNull(state.getRefCount(), "refCount must not be null");
checkNotNull(state.getPointData(), "pointData must not be null");
checkArgument(state.getPrecisionEnumValue() == Precision.FLOAT_32, "precision must be " + Precision.FLOAT_32);
checkArgument(Precision.valueOf(state.getPrecision()) == Precision.FLOAT_32,
"precision must be " + Precision.FLOAT_32);
int indexCapacity = state.getIndexCapacity();
int dimensions = state.getDimensions();
float[] store = ArrayPacking.unpackFloats(state.getPointData(), state.getCurrentStoreCapacity() * dimensions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

import lombok.Data;

import com.amazon.randomcutforest.config.Precision;

/**
* A class for storing the state of a
* {@link com.amazon.randomcutforest.store.PointStoreDouble} or a
Expand Down Expand Up @@ -112,7 +110,4 @@ public class PointStoreState {
*/
private int indexCapacity;

public Precision getPrecisionEnumValue() {
return Precision.valueOf(precision);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class SmallNodeStoreMapper implements IStateMapper<SmallNodeStore, NodeSt

@Override
public SmallNodeStore toModel(NodeStoreState state, long seed) {
checkState(state.getPrecisionEnumValue() == Precision.FLOAT_32, " incorrect invocation of SmallNodeStore");
checkState(Precision.valueOf(state.getPrecision()) == Precision.FLOAT_32,
" incorrect invocation of SmallNodeStore");
int capacity = state.getCapacity();
short[] cutDimension = ArrayPacking.unpackShorts(state.getCutDimension(), state.isCompressed());
float[] cutValue = ArrayPacking.unpackFloats(state.getCutValueData());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ public CompactRandomCutTreeFloat toModel(CompactRandomCutTreeState state, Compac

INodeStore nodeStore;

if (AbstractCompactRandomCutTree.canUseSmallNodeStore(state.getNodeStoreState().getPrecisionEnumValue(),
state.getMaxSize(), state.getDimensions())) {
if (AbstractCompactRandomCutTree.canUseSmallNodeStore(
Precision.valueOf(state.getNodeStoreState().getPrecision()), state.getMaxSize(),
state.getDimensions())) {
SmallNodeStoreMapper nodeStoreMapper = new SmallNodeStoreMapper();
nodeStoreMapper.setPartialTreeStateEnabled(state.isPartialTreeState());
nodeStore = nodeStoreMapper.toModel(state.getNodeStoreState());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ public void InternalShinglingTest(boolean rotation) {
int shingleSize = 2;
int dimensions = baseDimensions * shingleSize;
long seed = new Random().nextLong();
System.out.println(seed);

int numTrials = 1; // test is exact equality, reducing the number of trials
int length = 4000 * sampleSize;
Expand Down Expand Up @@ -143,7 +144,7 @@ public void InternalShinglingTest(boolean rotation) {
}

for (int j = 0; j < shingledData.length; j++) {
// validate eaulity of points
// validate equality of points
for (int y = 0; y < baseDimensions; y++) {
int position = (rotation) ? (count % shingleSize) : shingleSize - 1;
assertEquals(dataWithKeys.data[count][y], shingledData[j][position * baseDimensions + y], 1e-10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import lombok.Setter;

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.preprocessor.Preprocessor;
import com.amazon.randomcutforest.parkservices.state.preprocessor.PreprocessorMapper;
Expand Down Expand Up @@ -58,9 +60,9 @@ public ThresholdedRandomCutForest toModel(ThresholdedRandomCutForestState state,
tForest.setLastAnomalyAttribution(new DiVectorMapper().toModel(state.getLastAnomalyAttribution()));
tForest.setLastAnomalyPoint(state.getLastAnomalyPoint());
tForest.setLastExpectedPoint(state.getLastExpectedPoint());
tForest.setForestMode(state.getForestModeEnumValue());
tForest.setForestMode(ForestMode.valueOf(state.getForestMode()));
tForest.setLastRelativeIndex(state.getLastRelativeIndex());
tForest.setTransformMethod(state.getTransformMethodEnumValue());
tForest.setTransformMethod(TransformMethod.valueOf(state.getTransformMethod()));
return tForest;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

import lombok.Data;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.state.preprocessor.PreprocessorState;
import com.amazon.randomcutforest.parkservices.state.threshold.BasicThresholderState;
import com.amazon.randomcutforest.state.RandomCutForestState;
Expand Down Expand Up @@ -52,11 +50,4 @@ public class ThresholdedRandomCutForestState {
private int lastRelativeIndex;
private int lastReset;

public TransformMethod getTransformMethodEnumValue() {
return TransformMethod.valueOf(transformMethod);
}

public ForestMode getForestModeEnumValue() {
return ForestMode.valueOf(forestMode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import lombok.Getter;
import lombok.Setter;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.preprocessor.Preprocessor;
import com.amazon.randomcutforest.parkservices.state.statistics.DeviationMapper;
import com.amazon.randomcutforest.parkservices.state.statistics.DeviationState;
Expand All @@ -41,12 +44,13 @@ public Preprocessor toModel(PreprocessorState state, long seed) {
}
}
Preprocessor.Builder<?> preprocessorBuilder = new Preprocessor.Builder<>()
.forestMode(state.getForestModeEnumValue()).shingleSize(state.getShingleSize())
.dimensions(state.getDimensions()).imputationMethod(state.getImputationMethodEnumValue())
.forestMode(ForestMode.valueOf(state.getForestMode())).shingleSize(state.getShingleSize())
.dimensions(state.getDimensions())
.imputationMethod(ImputationMethod.valueOf(state.getImputationMethod()))
.fillValues(state.getDefaultFill()).inputLength(state.getInputLength()).weights(state.getWeights())
.transformMethod(state.getTransformMethodEnumValue()).startNormalization(state.getStartNormalization())
.useImputedFraction(state.getUseImputedFraction()).timeDeviation(timeStampDeviation)
.dataQuality(dataQuality);
.transformMethod(TransformMethod.valueOf(state.getTransformMethod()))
.startNormalization(state.getStartNormalization()).useImputedFraction(state.getUseImputedFraction())
.timeDeviation(timeStampDeviation).dataQuality(dataQuality);

if (deviations != null) {
preprocessorBuilder.deviations(deviations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

import lombok.Data;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.state.statistics.DeviationState;

@Data
Expand Down Expand Up @@ -54,16 +51,4 @@ public class PreprocessorState {
private DeviationState timeStampDeviationState;
private DeviationState[] deviationStates;

public TransformMethod getTransformMethodEnumValue() {
return TransformMethod.valueOf(transformMethod);
}

public ForestMode getForestModeEnumValue() {
return ForestMode.valueOf(forestMode);
}

public ImputationMethod getImputationMethodEnumValue() {
return ImputationMethod.valueOf(imputationMethod);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public void InternalShinglingTest() {
@Test
public void ExternalShinglingTest() {
int sampleSize = 256;
int baseDimensions = 2;
int shingleSize = 8;
int baseDimensions = 1;
int shingleSize = 4;
int dimensions = baseDimensions * shingleSize;
long seed = new Random().nextLong();

Expand Down Expand Up @@ -122,8 +122,8 @@ public void ExternalShinglingTest() {
@Test
public void MixedShinglingTest() {
int sampleSize = 256;
int baseDimensions = 2;
int shingleSize = 8;
int baseDimensions = 1;
int shingleSize = 4;
int dimensions = baseDimensions * shingleSize;
long seed = new Random().nextLong();

Expand Down Expand Up @@ -158,7 +158,7 @@ public void MixedShinglingTest() {
}

for (int j = 0; j < length; j++) {
// validate eaulity of points
// validate equality of points
for (int y = 0; y < baseDimensions; y++) {
assertEquals(dataWithKeys.data[count][y], shingledData[j][(shingleSize - 1) * baseDimensions + y],
1e-10);
Expand Down

0 comments on commit c8b5039

Please sign in to comment.