From ab046351d63474cb76f0e2b56fd400d87f3e9b2c Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 6 Feb 2023 15:06:51 -0800 Subject: [PATCH] tune model config: change pooling mode to optional (#724) Signed-off-by: Yaliang Wu --- .../model/TextEmbeddingModelConfig.java | 29 ++++++++++++------- .../model/TextEmbeddingModelConfigTests.java | 2 +- .../transport/upload/MLUploadInputTest.java | 2 +- ...ingfaceTextEmbeddingTranslatorFactory.java | 2 +- ...nceTransformerTextEmbeddingTranslator.java | 2 +- 5 files changed, 23 insertions(+), 14 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java index 6f0f54fb1c..45a1dd6c71 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java @@ -55,11 +55,7 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr } this.embeddingDimension = embeddingDimension; this.frameworkType = frameworkType; - if (poolingMode != null) { - this.poolingMode = poolingMode; - } else { - this.poolingMode = PoolingMode.MEAN; - } + this.poolingMode = poolingMode; this.normalizeResult = normalizeResult; this.modelMaxLength = modelMaxLength; } @@ -69,7 +65,7 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc Integer embeddingDimension = null; FrameworkType frameworkType = null; String allConfig = null; - PoolingMode poolingMode = PoolingMode.MEAN; + PoolingMode poolingMode = null; boolean normalizeResult = false; Integer modelMaxLength = null; @@ -117,7 +113,11 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{ super(in); embeddingDimension = in.readInt(); frameworkType = in.readEnum(FrameworkType.class); - poolingMode = in.readEnum(PoolingMode.class); + if (in.readBoolean()) { + poolingMode = in.readEnum(PoolingMode.class); + } else { + poolingMode = null; + } normalizeResult = in.readBoolean(); modelMaxLength = in.readOptionalInt(); } @@ -127,7 +127,12 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeInt(embeddingDimension); out.writeEnum(frameworkType); - out.writeEnum(poolingMode); + if (poolingMode != null) { + out.writeBoolean(true); + out.writeEnum(poolingMode); + } else { + out.writeBoolean(false); + } out.writeBoolean(normalizeResult); out.writeOptionalInt(modelMaxLength); } @@ -150,8 +155,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelMaxLength != null) { builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength); } - builder.field(POOLING_MODE_FIELD, poolingMode); - builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); + if (poolingMode != null) { + builder.field(POOLING_MODE_FIELD, poolingMode); + } + if (normalizeResult) { + builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); + } builder.endObject(); return builder; } diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java index 8ea37a7111..90777c8b33 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java @@ -52,7 +52,7 @@ public void toXContent() throws IOException { config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); System.out.println(configContent); - assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"pooling_mode\":\"MEAN\",\"normalize_result\":false}", configContent); + assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java index 0d29fa5772..d654de5987 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java @@ -37,7 +37,7 @@ public class MLUploadInputTest { private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," + "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," + "\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" + - ",\"pooling_mode\":\"MEAN\",\"normalize_result\":false},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + "},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; private final FunctionName functionName = FunctionName.LINEAR_REGRESSION; private final String modelName = "modelName"; private final String version = "version"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java index 64ddc50558..482bed7c81 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java @@ -37,7 +37,7 @@ public class HuggingfaceTextEmbeddingTranslatorFactory implements TranslatorFact private final boolean neuron; public HuggingfaceTextEmbeddingTranslatorFactory(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType, boolean neuron) { - this.poolingMode = poolingMode; + this.poolingMode = poolingMode == null ? TextEmbeddingModelConfig.PoolingMode.MEAN : poolingMode; this.normalizeResult = normalizeResult; this.modelType = modelType; this.neuron = neuron; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java index b14ea38747..9b671298ca 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java @@ -36,7 +36,7 @@ public class ONNXSentenceTransformerTextEmbeddingTranslator implements ServingTr private String modelType; public ONNXSentenceTransformerTextEmbeddingTranslator(TextEmbeddingModelConfig.PoolingMode poolingMode, boolean normalizeResult, String modelType) { - this.poolingMode = poolingMode; + this.poolingMode = poolingMode == null ? TextEmbeddingModelConfig.PoolingMode.MEAN : poolingMode; this.normalizeResult = normalizeResult; this.modelType = modelType; }