Skip to content

Commit

Permalink
addressing comments
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os committed Oct 6, 2023
1 parent 6dabfa4 commit ec79bf2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.ml.common.connector.HttpConnectorTest;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.search.SearchModule;

Expand Down Expand Up @@ -74,7 +75,6 @@ public void setUp() throws Exception {
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();

}

@Test
Expand Down Expand Up @@ -257,6 +257,59 @@ public void readInputStream_WithInternalConnector() throws IOException {
});
}

@Test
public void testMCorrInput() throws IOException {
String testString = "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";

MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.build();

MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder()
.functionName(FunctionName.METRICS_CORRELATION)
.modelName(FunctionName.METRICS_CORRELATION.name())
.version("1.0.0b1")
.modelGroupId(modelGroupId)
.url(url)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.modelConfig(mcorrConfig)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
XContentBuilder builder = XContentFactory.jsonBuilder();
mcorrInput.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = builder.toString();
assertEquals(testString, jsonStr);
}

@Test
public void readInputStream_MCorr() throws IOException {
MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.build();

MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder()
.functionName(FunctionName.METRICS_CORRELATION)
.modelName(FunctionName.METRICS_CORRELATION.name())
.version("1.0.0b1")
.modelGroupId(modelGroupId)
.url(url)
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.modelConfig(mcorrConfig)
.deployModel(true)
.modelNodeIds(new String[]{"modelNodeIds" })
.build();
readInputStream(mcorrInput, parsedInput -> {
assertEquals(parsedInput.getModelConfig().getModelType(), mcorrConfig.getModelType());
assertEquals(parsedInput.getModelConfig().getAllConfig(), mcorrConfig.getAllConfig());
assertEquals(parsedInput.getFunctionName(), FunctionName.METRICS_CORRELATION);
assertEquals(parsedInput.getModelName(), FunctionName.METRICS_CORRELATION.name());
assertEquals(parsedInput.getModelGroupId(), modelGroupId);
});
}

private void readInputStream(MLRegisterModelInput input, Consumer<MLRegisterModelInput> verify) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
input.writeTo(bytesStreamOutput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,7 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
log.error("Failed to Register Model", e);
listener.onFailure(e);
}));
}, e-> {
listener.onFailure(e);
}), () -> context.restore()));
}, listener::onFailure), context::restore));
} catch (IOException e) {
throw new MLException(e);
}
Expand Down

0 comments on commit ec79bf2

Please sign in to comment.