Skip to content

Commit

Permalink
bump rcf to 3.0-rc2.1 (#519)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
(cherry picked from commit 8227e32)
  • Loading branch information
amitgalitz authored and github-actions[bot] committed May 3, 2022
1 parent 5b2f2c8 commit 9474ad8
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 14 deletions.
6 changes: 3 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.10.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.0-rc1'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.0-rc1'
implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc1'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.0-rc2.1'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.0-rc2.1'
implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc2.1'

// force Jackson version to avoid version conflict issue
implementation "com.fasterxml.jackson.core:jackson-core:2.13.2"
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@

import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -367,7 +367,7 @@ public Collection<Object> createComponents(
mapper.setSaveExecutorContextEnabled(true);
mapper.setSaveTreeStateEnabled(true);
mapper.setPartialTreeStateEnabled(true);
V1JsonToV2StateConverter converter = new V1JsonToV2StateConverter();
V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter();

double modelMaxSizePercent = AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings);

Expand Down
16 changes: 12 additions & 4 deletions src/main/java/org/opensearch/ad/ml/CheckpointDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import com.google.gson.Gson;
Expand Down Expand Up @@ -117,7 +117,15 @@ public class CheckpointDao {

private Gson gson;
private RandomCutForestMapper mapper;
private V1JsonToV2StateConverter converter;

// For further reference v1, v2 and v3 refer to the different variations of RCF models
// used by AD. v1 was originally used with the launch of OS 1.0. We later converted to v2
// which included changes requiring a specific converter from v1 to v2 for BWC.
// v2 models are created by RCF-3.0-rc1 which can be found on maven central.
// v3 is the latest model version form RCF introduced by RCF-3.0-rc2.
// Although this version has a converter method for v2 to v3, after BWC testing it was decided that
// an explicit use of the converter won't be needed as the changes between the models are indeed BWC.
private V1JsonToV3StateConverter converter;
private ThresholdedRandomCutForestMapper trcfMapper;
private Schema<ThresholdedRandomCutForestState> trcfSchema;

Expand Down Expand Up @@ -157,7 +165,7 @@ public CheckpointDao(
String indexName,
Gson gson,
RandomCutForestMapper mapper,
V1JsonToV2StateConverter converter,
V1JsonToV3StateConverter converter,
ThresholdedRandomCutForestMapper trcfMapper,
Schema<ThresholdedRandomCutForestState> trcfSchema,
Class<? extends ThresholdingModel> thresholdingModelClass,
Expand Down Expand Up @@ -556,7 +564,7 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
}
}

private ThresholdedRandomCutForest toTrcf(String checkpoint) {
ThresholdedRandomCutForest toTrcf(String checkpoint) {
ThresholdedRandomCutForest trcf = null;
if (checkpoint != null && !checkpoint.isEmpty()) {
try {
Expand Down
91 changes: 88 additions & 3 deletions src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,20 @@
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.time.Instant;
import java.time.Month;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
Expand Down Expand Up @@ -106,10 +110,11 @@

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.AnomalyDescriptor;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
Expand Down Expand Up @@ -154,7 +159,7 @@ public class CheckpointDaoTests extends OpenSearchTestCase {
private GenericObjectPool<LinkedBuffer> serializeRCFBufferPool;
private RandomCutForestMapper mapper;
private ThresholdedRandomCutForestMapper trcfMapper;
private V1JsonToV2StateConverter converter;
private V1JsonToV3StateConverter converter;
double anomalyRate;

@Before
Expand All @@ -180,7 +185,7 @@ public void setup() {
.getSchema(ThresholdedRandomCutForestState.class)
);

converter = new V1JsonToV2StateConverter();
converter = new V1JsonToV3StateConverter();

serializeRCFBufferPool = spy(AccessController.doPrivileged(new PrivilegedAction<GenericObjectPool<LinkedBuffer>>() {
@Override
Expand Down Expand Up @@ -1000,4 +1005,84 @@ private double[] getPoint(int dimensions, Random random) {
}
return point;
}

// The checkpoint used for this test is from a single-stream detector
public void testDeserializeRCFModelPreINIT() throws Exception {
// Model in file 1_3_0_rcf_model_pre_init.json not passed initialization yet
String filePath = getClass().getResource("1_3_0_rcf_model_pre_init.json").getPath();
String json = Files.readString(Paths.get(filePath), Charset.defaultCharset());
Map map = gson.fromJson(json, Map.class);
String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2");
ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model);
assertEquals(256, forest.getForest().getSampleSize());
assertEquals(8, forest.getForest().getShingleSize());
assertEquals(30, forest.getForest().getNumberOfTrees());
}

// The checkpoint used for this test is from a single-stream detector
public void testDeserializeRCFModelPostINIT() throws Exception {
// Model in file rc1_model_single_running is from RCF-3.0-rc1
String filePath = getClass().getResource("rc1_model_single_running.json").getPath();
String json = Files.readString(Paths.get(filePath), Charset.defaultCharset());
Map map = gson.fromJson(json, Map.class);
String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2");
ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model);
assertEquals(256, forest.getForest().getSampleSize());
assertEquals(8, forest.getForest().getShingleSize());
assertEquals(30, forest.getForest().getNumberOfTrees());
}

// This test is intended to check if given a checkpoint created by RCF-3.0-rc1 ("rc1_trcf_model_direct.json")
// and given the same sample data will rc1 and current RCF version (this test originally created when 3.0-rc2.1 is in use)
// will produce the same anomaly scores and grades.
// The scores and grades in this method were produced from AD running with RCF3.0-rc1 dependency
// and this test runs with the most recent RCF dependency that is being pulled by this project.
public void testDeserializeTRCFModel() throws Exception {
// Model in file rc1_model_single_running is from RCF-3.0-rc1
String filePath = getClass().getResource("rc1_trcf_model_direct.json").getPath();
String json = Files.readString(Paths.get(filePath), Charset.defaultCharset());
// For the parsing of .toTrcf to work I had to manually change "\u003d" in code back to =.
// In the byte array it doesn't seem like this is an issue but whenever reading the byte array response into a file it
// converts "=" to "\u003d" https://groups.google.com/g/google-gson/c/JDHUo9DWyyM?pli=1
// I also needed to bypass the trcf as it wasn't being read as a key value but instead part of the string
Map map = gson.fromJson(json, Map.class);
String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2");
model = model.split(":")[1].substring(1);
ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model);

List<double[]> coldStartData = new ArrayList<>();
double[] sample1 = new double[] { 57.0 };
double[] sample2 = new double[] { 1.0 };
double[] sample3 = new double[] { -19.0 };
double[] sample4 = new double[] { 13.0 };
double[] sample5 = new double[] { 41.0 };

coldStartData.add(sample1);
coldStartData.add(sample2);
coldStartData.add(sample3);
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
List<Double> scores = new ArrayList<>();
scores.add(4.814651669367903);
scores.add(5.566968073093689);
scores.add(5.919907610660049);
scores.add(5.770278090352401);
scores.add(5.319779117320102);

List<Double> grade = new ArrayList<>();
grade.add(1.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(descriptor.getAnomalyGrade(), grade.get(i), 1e-9);
}
}
}
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.google.gson.Gson;

Expand Down Expand Up @@ -92,7 +92,7 @@ public void setUp() throws Exception {
maxCheckpointBytes = 1_000_000;

RandomCutForestMapper mapper = mock(RandomCutForestMapper.class);
V1JsonToV2StateConverter converter = mock(V1JsonToV2StateConverter.class);
V1JsonToV3StateConverter converter = mock(V1JsonToV3StateConverter.class);

objectPool = mock(GenericObjectPool.class);
int deserializeRCFBufferSize = 512;
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

0 comments on commit 9474ad8

Please sign in to comment.