Skip to content

Commit

Permalink
fixing metrics correlation algorithm (#1448) (#1460)
Browse files Browse the repository at this point in the history
* fixing metrics correlation algorithm

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comments

Signed-off-by: Dhrubo Saha <[email protected]>

* addressing comment

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
(cherry picked from commit 1569011)

Co-authored-by: Dhrubo Saha <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and dhrubo-os authored Oct 6, 2023
1 parent 4ecc036 commit fba88b0
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 31 deletions.
6 changes: 5 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,11 @@ public MLModel(StreamInput input) throws IOException{
modelContentSizeInBytes = input.readOptionalLong();
modelContentHash = input.readOptionalString();
if (input.readBoolean()) {
modelConfig = new TextEmbeddingModelConfig(input);
if (algorithm.equals(FunctionName.METRICS_CORRELATION)) {
modelConfig = new MetricsCorrelationModelConfig(input);
} else {
modelConfig = new TextEmbeddingModelConfig(input);
}
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
Expand All @@ -28,6 +29,10 @@ public MetricsCorrelationModelConfig(String modelType, String allConfig) {
super(modelType, allConfig);
}

public MetricsCorrelationModelConfig(StreamInput in) throws IOException{
super(in);
}

@Override
public String getWriteableName() {
return PARSE_FIELD_NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
Expand Down Expand Up @@ -137,7 +138,11 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
this.modelFormat = in.readEnum(MLModelFormat.class);
}
if (in.readBoolean()) {
this.modelConfig = new TextEmbeddingModelConfig(in);
if (this.functionName.equals(FunctionName.METRICS_CORRELATION)) {
this.modelConfig = new MetricsCorrelationModelConfig(in);
} else {
this.modelConfig = new TextEmbeddingModelConfig(in);
}
}
this.deployModel = in.readBoolean();
this.modelNodeIds = in.readOptionalStringArray();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.util.function.Function;

import static org.junit.Assert.assertEquals;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

public class MetricsCorrelationModelConfigTests {

MetricsCorrelationModelConfig config;
Function<XContentParser, MetricsCorrelationModelConfig> function;
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Before
public void setUp() {
config = MetricsCorrelationModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.build();
function = parser -> {
try {
return MetricsCorrelationModelConfig.parse(parser);
} catch (IOException e) {
throw new RuntimeException("Failed to parse MetricsCorrelationModelConfig", e);
}
};
}

@Test
public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, EMPTY_PARAMS);
String configContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
}

@Test
public void nullFields_ModelType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model type is null");
config = MetricsCorrelationModelConfig.builder()
.build();
}

@Test
public void parse() throws IOException {
String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}";
TestHelper.testParseFromString(config, content, function);
}

@Test
public void readInputStream_Success() throws IOException {
readInputStream(config);
}

public void readInputStream(MetricsCorrelationModelConfig config) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
config.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
MetricsCorrelationModelConfig parsedConfig = new MetricsCorrelationModelConfig(streamInput);
assertEquals(config.getModelType(), parsedConfig.getModelType());
assertEquals(config.getAllConfig(), parsedConfig.getAllConfig());
assertEquals(config.getWriteableName(), parsedConfig.getWriteableName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.ml.common.connector.HttpConnectorTest;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.search.SearchModule;

Expand Down Expand Up @@ -74,7 +75,6 @@ public void setUp() throws Exception {
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();

}

@Test
Expand Down Expand Up @@ -257,6 +257,59 @@ public void readInputStream_WithInternalConnector() throws IOException {
});
}

@Test
public void testMCorrInput() throws IOException {
String testString = "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";

MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.build();

MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder()
.functionName(FunctionName.METRICS_CORRELATION)
.modelName(FunctionName.METRICS_CORRELATION.name())
.version("1.0.0b1")
.modelGroupId(modelGroupId)
.url(url)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.modelConfig(mcorrConfig)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
XContentBuilder builder = XContentFactory.jsonBuilder();
mcorrInput.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = builder.toString();
assertEquals(testString, jsonStr);
}

@Test
public void readInputStream_MCorr() throws IOException {
MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.build();

MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder()
.functionName(FunctionName.METRICS_CORRELATION)
.modelName(FunctionName.METRICS_CORRELATION.name())
.version("1.0.0b1")
.modelGroupId(modelGroupId)
.url(url)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.modelConfig(mcorrConfig)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
readInputStream(mcorrInput, parsedInput -> {
assertEquals(parsedInput.getModelConfig().getModelType(), mcorrConfig.getModelType());
assertEquals(parsedInput.getModelConfig().getAllConfig(), mcorrConfig.getAllConfig());
assertEquals(parsedInput.getFunctionName(), FunctionName.METRICS_CORRELATION);
assertEquals(parsedInput.getModelName(), FunctionName.METRICS_CORRELATION.name());
assertEquals(parsedInput.getModelGroupId(), modelGroupId);
});
}

private void readInputStream(MLRegisterModelInput input, Consumer<MLRegisterModelInput> verify) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
input.writeTo(bytesStreamOutput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public void close() {
* @param modelId id of the model
* @param modelName name of the model
* @param version version of the model
* @param engine engine where model will be run. For now we are supporting only pytorch engine only.
* @param engine engine where model will be run. For now, we are supporting only pytorch engine only.
*/
private void loadModel(File modelZipFile, String modelId, String modelName, String version,
String engine) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {

if (modelId == null) {
boolean hasModelGroupIndex = clusterService.state().getMetadata().hasIndex(ML_MODEL_GROUP_INDEX);
if (!hasModelGroupIndex) { // Create model group index if doesn't exist
if (!hasModelGroupIndex) { // Create model group index if it doesn't exist
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX).mapping(ML_MODEL_GROUP_INDEX_MAPPING);
CreateIndexResponse createIndexResponse = client.admin().indices().create(request).actionGet(1000);
Expand Down Expand Up @@ -162,7 +162,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
}, e-> {
log.error("Failed to get model", e);
});
client.get(getModelRequest, ActionListener.runBefore(listener, () -> context.restore()));
client.get(getModelRequest, ActionListener.runBefore(listener, context::restore));
}
}
} else {
Expand All @@ -177,10 +177,17 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
waitUntil(() -> {
if (modelId != null) {
MLModelState modelState = getModel(modelId).getModelState();
return modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED;
if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED){
log.info("Model deployed: " + modelState);
return true;
} else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) {
log.info("Model not deployed: " + modelState);
deployModel(modelId, ActionListener.wrap(deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e)));
return false;
}
}
return false;
}, 10, TimeUnit.SECONDS);
}, 120, TimeUnit.SECONDS);

Output djlOutput;
try {
Expand Down Expand Up @@ -230,9 +237,7 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
log.error("Failed to Register Model", e);
listener.onFailure(e);
}));
}, e-> {
listener.onFailure(e);
}), () -> context.restore()));
}, listener::onFailure), context::restore));
} catch (IOException e) {
throw new MLException(e);
}
Expand Down Expand Up @@ -300,6 +305,8 @@ public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime,
}
sum += timeInMillis;
timeInMillis = Math.min(AWAIT_BUSY_THRESHOLD, timeInMillis * 2);

log.info("Waiting... Time elapsed: " + sum + "ms");
}
timeInMillis = maxTimeInMillis - sum;
try {
Expand Down
Loading

0 comments on commit fba88b0

Please sign in to comment.