-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use protostuff to serialize/deserialize RCF model
Signed-off-by: Yaliang Wu <[email protected]>
- Loading branch information
Showing
15 changed files
with
322 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.algorithms.rcf; | ||
|
||
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; | ||
import com.amazon.randomcutforest.state.RandomCutForestState; | ||
import io.protostuff.LinkedBuffer; | ||
import io.protostuff.ProtostuffIOUtil; | ||
import io.protostuff.Schema; | ||
import io.protostuff.runtime.RuntimeSchema; | ||
import lombok.experimental.UtilityClass; | ||
|
||
import java.security.AccessController; | ||
import java.security.PrivilegedAction; | ||
|
||
@UtilityClass | ||
public class RCFModelSerDeSer { | ||
private static final int SERIALIZATION_BUFFER_BYTES = 512; | ||
private static final Schema<RandomCutForestState> rcfSchema = | ||
AccessController.doPrivileged((PrivilegedAction<Schema<RandomCutForestState>>) () -> | ||
RuntimeSchema.getSchema(RandomCutForestState.class)); | ||
private static final Schema<ThresholdedRandomCutForestState> trcfSchema = | ||
AccessController.doPrivileged((PrivilegedAction<Schema<ThresholdedRandomCutForestState>>) () -> | ||
RuntimeSchema.getSchema(ThresholdedRandomCutForestState.class)); | ||
|
||
public static byte[] serializeRCF(RandomCutForestState model) { | ||
return serialize(model, rcfSchema); | ||
} | ||
|
||
public static byte[] serializeTRCF(ThresholdedRandomCutForestState model) { | ||
return serialize(model, trcfSchema); | ||
} | ||
|
||
public static RandomCutForestState deserializeRCF(byte[] bytes) { | ||
return deserialize(bytes, rcfSchema); | ||
} | ||
|
||
public static ThresholdedRandomCutForestState deserializeTRCF(byte[] bytes) { | ||
return deserialize(bytes, trcfSchema); | ||
} | ||
|
||
private static <T> byte[] serialize(T model, Schema<T> schema) { | ||
LinkedBuffer buffer = LinkedBuffer.allocate(SERIALIZATION_BUFFER_BYTES); | ||
byte[] bytes = AccessController.doPrivileged((PrivilegedAction<byte[]>) () -> ProtostuffIOUtil.toByteArray(model, schema, buffer)); | ||
return bytes; | ||
} | ||
|
||
private static <T> T deserialize(byte[] bytes, Schema<T> schema) { | ||
T model = schema.newMessage(); | ||
AccessController.doPrivileged((PrivilegedAction<Void>) () -> { | ||
ProtostuffIOUtil.mergeFrom(bytes, model, schema); | ||
return null; | ||
}); | ||
return model; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
...lgorithms/src/test/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSerTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.algorithms.rcf; | ||
|
||
import com.amazon.randomcutforest.RandomCutForest; | ||
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; | ||
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; | ||
import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; | ||
import com.amazon.randomcutforest.state.RandomCutForestMapper; | ||
import com.amazon.randomcutforest.state.RandomCutForestState; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.Model; | ||
import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; | ||
import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; | ||
|
||
import java.util.Arrays; | ||
|
||
import static org.junit.Assert.assertNotNull; | ||
import static org.junit.Assert.assertTrue; | ||
import static org.opensearch.ml.engine.helper.MLTestHelper.TIME_FIELD; | ||
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; | ||
|
||
public class RCFModelSerDeSerTest { | ||
@Rule | ||
public ExpectedException thrown = ExpectedException.none(); | ||
|
||
private final RandomCutForestMapper rcfMapper = new RandomCutForestMapper(); | ||
private final ThresholdedRandomCutForestMapper trcfMapper = new ThresholdedRandomCutForestMapper(); | ||
|
||
@Test | ||
public void testModelSerDeSerBatchRCF() { | ||
BatchRCFParams params = BatchRCFParams.builder().build(); | ||
BatchRandomCutForest batchRCF = new BatchRandomCutForest(params); | ||
Model model = batchRCF.train(constructTestDataFrame(500)); | ||
|
||
RandomCutForestState deserializedState = RCFModelSerDeSer.deserializeRCF(model.getContent()); | ||
RandomCutForest forest = rcfMapper.toModel(deserializedState); | ||
assertNotNull(forest); | ||
byte[] serializedModel = RCFModelSerDeSer.serializeRCF(deserializedState); | ||
assertTrue(Arrays.equals(serializedModel, model.getContent())); | ||
} | ||
|
||
@Test | ||
public void testModelSerDeSerFitRCF() { | ||
FitRCFParams params = FitRCFParams.builder().timeField(TIME_FIELD).build(); | ||
FixedInTimeRandomCutForest fitRCF = new FixedInTimeRandomCutForest(params); | ||
Model model = fitRCF.train(constructTestDataFrame(500, true)); | ||
|
||
ThresholdedRandomCutForestState deserializedState = RCFModelSerDeSer.deserializeTRCF(model.getContent()); | ||
ThresholdedRandomCutForest forest = trcfMapper.toModel(deserializedState); | ||
assertNotNull(forest); | ||
byte[] serializedModel = RCFModelSerDeSer.serializeTRCF(deserializedState); | ||
assertTrue(Arrays.equals(serializedModel, model.getContent())); | ||
} | ||
|
||
} |
Oops, something went wrong.