Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bump rcf to 3.0-rc2.1 #519

Merged
merged 5 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.10.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.0-rc1'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.0-rc1'
implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc1'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.0-rc2.1'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.0-rc2.1'
implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc2.1'

// force Jackson version to avoid version conflict issue
implementation "com.fasterxml.jackson.core:jackson-core:2.13.2"
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/ad/AnomalyDetectorPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@

import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -367,7 +367,7 @@ public Collection<Object> createComponents(
mapper.setSaveExecutorContextEnabled(true);
mapper.setSaveTreeStateEnabled(true);
mapper.setPartialTreeStateEnabled(true);
V1JsonToV2StateConverter converter = new V1JsonToV2StateConverter();
V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need V2 to V3 converter? How about add some comments to explain what's V1, V2 and V3?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After testing and looking at RCF code, we don't need a converter between v2 to v3 as that is dealt with accordingly on RCF side. I also added a unit test where we parse a v2 checkpoint with RCF3.0-rc2.1 as a dependency(v3) and we get the same result as we do when parsing with rc1(v2). I'll add some comments explaining v1, v2, v3 in checkPointDAO class.


double modelMaxSizePercent = AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings);

Expand Down
16 changes: 12 additions & 4 deletions src/main/java/org/opensearch/ad/ml/CheckpointDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.amazon.randomcutforest.state.RandomCutForestState;
import com.google.gson.Gson;
Expand Down Expand Up @@ -117,7 +117,15 @@ public class CheckpointDao {

private Gson gson;
private RandomCutForestMapper mapper;
private V1JsonToV2StateConverter converter;

// For further reference v1, v2 and v3 refer to the different variations of RCF models
// used by AD. v1 was originally used with the launch of OS 1.0. We later converted to v2
// which included changes requiring a specific converter from v1 to v2 for BWC.
// v2 models are created by RCF-3.0-rc1 which can be found on maven central.
// v3 is the latest model version form RCF introduced by RCF-3.0-rc2.
// Although this version has a converter method for v2 to v3, after BWC testing it was decided that
// an explicit use of the converter won't be needed as the changes between the models are indeed BWC.
private V1JsonToV3StateConverter converter;
private ThresholdedRandomCutForestMapper trcfMapper;
private Schema<ThresholdedRandomCutForestState> trcfSchema;

Expand Down Expand Up @@ -157,7 +165,7 @@ public CheckpointDao(
String indexName,
Gson gson,
RandomCutForestMapper mapper,
V1JsonToV2StateConverter converter,
V1JsonToV3StateConverter converter,
ThresholdedRandomCutForestMapper trcfMapper,
Schema<ThresholdedRandomCutForestState> trcfSchema,
Class<? extends ThresholdingModel> thresholdingModelClass,
Expand Down Expand Up @@ -556,7 +564,7 @@ public Optional<Entry<EntityModel, Instant>> fromEntityModelCheckpoint(Map<Strin
}
}

private ThresholdedRandomCutForest toTrcf(String checkpoint) {
ThresholdedRandomCutForest toTrcf(String checkpoint) {
ThresholdedRandomCutForest trcf = null;
if (checkpoint != null && !checkpoint.isEmpty()) {
try {
Expand Down
91 changes: 88 additions & 3 deletions src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,20 @@
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.time.Instant;
import java.time.Month;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
Expand Down Expand Up @@ -106,10 +110,11 @@

import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.parkservices.AnomalyDescriptor;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
Expand Down Expand Up @@ -154,7 +159,7 @@ public class CheckpointDaoTests extends OpenSearchTestCase {
private GenericObjectPool<LinkedBuffer> serializeRCFBufferPool;
private RandomCutForestMapper mapper;
private ThresholdedRandomCutForestMapper trcfMapper;
private V1JsonToV2StateConverter converter;
private V1JsonToV3StateConverter converter;
double anomalyRate;

@Before
Expand All @@ -180,7 +185,7 @@ public void setup() {
.getSchema(ThresholdedRandomCutForestState.class)
);

converter = new V1JsonToV2StateConverter();
converter = new V1JsonToV3StateConverter();

serializeRCFBufferPool = spy(AccessController.doPrivileged(new PrivilegedAction<GenericObjectPool<LinkedBuffer>>() {
@Override
Expand Down Expand Up @@ -1000,4 +1005,84 @@ private double[] getPoint(int dimensions, Random random) {
}
return point;
}

// The checkpoint used for this test is from a single-stream detector
public void testDeserializeRCFModelPreINIT() throws Exception {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the following 2 tests used for single-stream detector's checkpoints? If so, could you add comments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

// Model in file 1_3_0_rcf_model_pre_init.json not passed initialization yet
String filePath = getClass().getResource("1_3_0_rcf_model_pre_init.json").getPath();
String json = Files.readString(Paths.get(filePath), Charset.defaultCharset());
Map map = gson.fromJson(json, Map.class);
String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2");
ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model);
assertEquals(256, forest.getForest().getSampleSize());
assertEquals(8, forest.getForest().getShingleSize());
assertEquals(30, forest.getForest().getNumberOfTrees());
}

// The checkpoint used for this test is from a single-stream detector
public void testDeserializeRCFModelPostINIT() throws Exception {
// Model in file rc1_model_single_running is from RCF-3.0-rc1
String filePath = getClass().getResource("rc1_model_single_running.json").getPath();
String json = Files.readString(Paths.get(filePath), Charset.defaultCharset());
Map map = gson.fromJson(json, Map.class);
String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2");
ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model);
assertEquals(256, forest.getForest().getSampleSize());
assertEquals(8, forest.getForest().getShingleSize());
assertEquals(30, forest.getForest().getNumberOfTrees());
}

// This test is intended to check if given a checkpoint created by RCF-3.0-rc1 ("rc1_trcf_model_direct.json")
// and given the same sample data will rc1 and current RCF version (this test originally created when 3.0-rc2.1 is in use)
// will produce the same anomaly scores and grades.
// The scores and grades in this method were produced from AD running with RCF3.0-rc1 dependency
// and this test runs with the most recent RCF dependency that is being pulled by this project.
public void testDeserializeTRCFModel() throws Exception {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add the test's purpose?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

// Model in file rc1_model_single_running is from RCF-3.0-rc1
String filePath = getClass().getResource("rc1_trcf_model_direct.json").getPath();
String json = Files.readString(Paths.get(filePath), Charset.defaultCharset());
// For the parsing of .toTrcf to work I had to manually change "\u003d" in code back to =.
// In the byte array it doesn't seem like this is an issue but whenever reading the byte array response into a file it
// converts "=" to "\u003d" https://groups.google.com/g/google-gson/c/JDHUo9DWyyM?pli=1
// I also needed to bypass the trcf as it wasn't being read as a key value but instead part of the string
Map map = gson.fromJson(json, Map.class);
String model = (String) ((Map) ((Map) ((ArrayList) ((Map) map.get("hits")).get("hits")).get(0)).get("_source")).get("modelV2");
model = model.split(":")[1].substring(1);
ThresholdedRandomCutForest forest = checkpointDao.toTrcf(model);

List<double[]> coldStartData = new ArrayList<>();
double[] sample1 = new double[] { 57.0 };
double[] sample2 = new double[] { 1.0 };
double[] sample3 = new double[] { -19.0 };
double[] sample4 = new double[] { 13.0 };
double[] sample5 = new double[] { 41.0 };

coldStartData.add(sample1);
coldStartData.add(sample2);
coldStartData.add(sample3);
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
List<Double> scores = new ArrayList<>();
scores.add(4.814651669367903);
scores.add(5.566968073093689);
scores.add(5.919907610660049);
scores.add(5.770278090352401);
scores.add(5.319779117320102);

List<Double> grade = new ArrayList<>();
grade.add(1.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(descriptor.getAnomalyGrade(), grade.get(i), 1e-9);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper;
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV2StateConverter;
import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter;
import com.amazon.randomcutforest.state.RandomCutForestMapper;
import com.google.gson.Gson;

Expand Down Expand Up @@ -92,7 +92,7 @@ public void setUp() throws Exception {
maxCheckpointBytes = 1_000_000;

RandomCutForestMapper mapper = mock(RandomCutForestMapper.class);
V1JsonToV2StateConverter converter = mock(V1JsonToV2StateConverter.class);
V1JsonToV3StateConverter converter = mock(V1JsonToV3StateConverter.class);

objectPool = mock(GenericObjectPool.class);
int deserializeRCFBufferSize = 512;
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.