diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java new file mode 100644 index 0000000000..0c03d0b3be --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.dataset; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; + +import lombok.Builder; +import lombok.Data; + +/** + * This class defines the modes of operation of an asymmetric text embedding model. + * Asymmetric embedding models treat the input text differently, depending on whether it is a + * passage or a query. One example asymmetric model, that requires different prefixes is e5 + * (cf. https://arxiv.org/pdf/2212.03533.pdf). + *

+ * Use this parameter only if the model is asymmetric and has been registered with the corresponding + * `query_prefix` and `passage_prefix` configuration parameters. + */ +@Data +@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING }) +public class AsymmetricTextEmbeddingParameters implements MLAlgoParams { + + public enum EmbeddingContentType { + QUERY, + PASSAGE + } + + public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name(); + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) + ); + + @Builder(toBuilder = true) + public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) { + this.embeddingContentType = embeddingContentType; + } + + public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException { + this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString()); + } + + public static MLAlgoParams parse(XContentParser parser) throws IOException { + EmbeddingContentType embeddingContentType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case EMBEDDING_CONTENT_TYPE_FIELD: + String contentType = parser.text(); + embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT)); + break; + default: + parser.skipChildren(); + break; + } + } + return new AsymmetricTextEmbeddingParameters(embeddingContentType); + } + + public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type"; + + // The type of the content to be embedded + private EmbeddingContentType embeddingContentType; + + @Override + public int getVersion() { + return 1; + } + + @Override + public String getWriteableName() { + return PARSE_FIELD_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(embeddingContentType.name()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + xContentBuilder.startObject(); + if (embeddingContentType != null) { + xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name()); + } + xContentBuilder.endObject(); + return xContentBuilder; + } + + public EmbeddingContentType getEmbeddingContentType() { + return embeddingContentType; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java index dbb15fa2d6..b1c249da44 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java @@ -37,16 +37,25 @@ public class TextEmbeddingModelConfig extends MLModelConfig { public static final String POOLING_MODE_FIELD = "pooling_mode"; public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; + public static final String QUERY_PREFIX = "query_prefix"; + public static final String PASSAGE_PREFIX = "passage_prefix"; private final Integer embeddingDimension; private final FrameworkType frameworkType; private final PoolingMode poolingMode; private final boolean normalizeResult; private final Integer modelMaxLength; + private final String queryPrefix; + private final String passagePrefix; + + public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, + PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) { + this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null); + } @Builder(toBuilder = true) public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, - PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) { + PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix) { super(modelType, allConfig); if (embeddingDimension == null) { throw new IllegalArgumentException("embedding dimension is null"); @@ -59,6 +68,8 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr this.poolingMode = poolingMode; this.normalizeResult = normalizeResult; this.modelMaxLength = modelMaxLength; + this.queryPrefix = queryPrefix; + this.passagePrefix = passagePrefix; } public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException { @@ -69,6 +80,8 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc PoolingMode poolingMode = null; boolean normalizeResult = false; Integer modelMaxLength = null; + String queryPrefix = null; + String passagePrefix = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -97,12 +110,18 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc case MODEL_MAX_LENGTH_FIELD: modelMaxLength = parser.intValue(); break; + case QUERY_PREFIX: + queryPrefix = parser.text(); + break; + case PASSAGE_PREFIX: + passagePrefix = parser.text(); + break; default: parser.skipChildren(); break; } } - return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength); + return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, queryPrefix, passagePrefix); } @Override @@ -121,6 +140,8 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{ } normalizeResult = in.readBoolean(); modelMaxLength = in.readOptionalInt(); + queryPrefix = in.readOptionalString(); + passagePrefix = in.readOptionalString(); } @Override @@ -136,6 +157,8 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeBoolean(normalizeResult); out.writeOptionalInt(modelMaxLength); + out.writeOptionalString(queryPrefix); + out.writeOptionalString(passagePrefix); } @Override @@ -162,6 +185,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (normalizeResult) { builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); } + if (queryPrefix != null) { + builder.field(QUERY_PREFIX, queryPrefix); + } + if (passagePrefix != null) { + builder.field(PASSAGE_PREFIX, passagePrefix); + } builder.endObject(); return builder; } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java new file mode 100644 index 0000000000..a7a27c00ee --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java @@ -0,0 +1,83 @@ +package org.opensearch.ml.common.dataset; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.util.function.Function; +import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +public class AsymmetricTextEmbeddingParametersTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + AsymmetricTextEmbeddingParameters params; + private Function function = parser -> { + try { + return (AsymmetricTextEmbeddingParameters) AsymmetricTextEmbeddingParameters.parse(parser); + } catch (IOException e) { + throw new RuntimeException("failed to parse AsymmetricTextEmbeddingParameters", e); + } + }; + + @Before + public void setUp() { + params = AsymmetricTextEmbeddingParameters.builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .build(); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters() throws IOException { + TestHelper.testParse(params, function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException { + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("QUERY", "PASSAGE"), function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU"); + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("QUERY","fu"), function); + } + + @Test + public void parse_EmptyAsymmetricTextEmbeddingParameters() throws IOException { + TestHelper.testParse(AsymmetricTextEmbeddingParameters.builder().build(), function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(params); + } + + @Test + public void readInputStream_Success_EmptyParams() throws IOException { + readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()); + } + + private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + params.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput); + assertEquals(params, parsedParams); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java index 10468d44a3..8819786fbe 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -39,9 +39,9 @@ public class TextDocsMLInputTest { @Before public void setUp() throws Exception { ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true) - .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build(); + .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build(); MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")) - .resultFilter(resultFilter).build(); + .resultFilter(resultFilter).build(); input = new TextDocsMLInput(algorithm, inputDataset); } @@ -68,8 +68,8 @@ public void parseTextDocsMLInput_NewWay() throws IOException { private void parseMLInput(String jsonStr, int docSize) throws IOException { XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java index acba744ced..9bc97f7c9f 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java @@ -36,6 +36,8 @@ public void setUp() { .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) .embeddingDimension(100) + .passagePrefix("passage: ") + .queryPrefix("query: ") .build(); function = parser -> { try { @@ -51,7 +53,7 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}", configContent); } @Test @@ -83,7 +85,7 @@ public void nullFields_FrameworkType() { @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index 1090ee15b6..1f606bcaa4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -135,6 +135,12 @@ public void downloadPrebuiltModelConfig( case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD: configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue()); break; + case TextEmbeddingModelConfig.QUERY_PREFIX: + configBuilder.queryPrefix(configEntry.getValue().toString()); + break; + case TextEmbeddingModelConfig.PASSAGE_PREFIX: + configBuilder.passagePrefix(configEntry.getValue().toString()); + break; default: break; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java index 6c6033f2cb..e5db51324f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java @@ -67,6 +67,8 @@ public abstract class DLModel implements Predictable { protected Device[] devices; protected AtomicInteger nextDevice = new AtomicInteger(0); + protected MLModelConfig modelConfig; + @Override public MLOutput predict(MLInput mlInput, MLModel model) { throw new IllegalArgumentException("model not deployed"); @@ -183,6 +185,7 @@ protected void doLoadModel( IOException, TranslateException { devices = Engine.getEngine(engine).getDevices(); + this.modelConfig = modelConfig; for (int i = 0; i < devices.length; i++) { log.debug("load model {} to device {}: {}", modelId, i, devices[i]); ZooModel model; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index d74eee7b0f..a5a3a8fb50 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -4,10 +4,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -20,9 +24,15 @@ import ai.djl.translate.TranslateException; public abstract class TextEmbeddingModel extends DLModel { + @Override public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { - MLInputDataset inputDataSet = mlInput.getInputDataset(); + MLAlgoParams mlParams = mlInput.getParameters(); + + MLInputDataset inputDataSet = isAsymmetricModel(mlParams) + ? addPrefixesToData((AsymmetricTextEmbeddingParameters) mlParams, (TextDocsInputDataSet) mlInput.getInputDataset()) + : mlInput.getInputDataset(); + List tensorOutputs = new ArrayList<>(); Output output; TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet) inputDataSet; @@ -36,6 +46,48 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla return new ModelTensorOutput(tensorOutputs); } + private boolean isAsymmetricModel(MLAlgoParams mlParams) { + if (mlParams instanceof AsymmetricTextEmbeddingParameters) { + // Check for the necessary prefixes in modelConfig + if (modelConfig == null + || ((TextEmbeddingModelConfig) modelConfig).getPassagePrefix() == null + && ((TextEmbeddingModelConfig) modelConfig).getQueryPrefix() == null) { + throw new IllegalArgumentException( + "When passing AsymmetricTextEmbeddingParameters, the model requires to be " + + "registered with at least one of `query_prefix` or `passage_prefix`." + ); + } + // Passed all checks + return true; + } + + // no AsymmetricTextEmbeddingParameters passed, but the model is asymmetric. + if (modelConfig != null + && (((TextEmbeddingModelConfig) modelConfig).getPassagePrefix() != null + || ((TextEmbeddingModelConfig) modelConfig).getQueryPrefix() != null)) { + throw new IllegalArgumentException( + "The embedding model chosen is asymmetric. To use it, you must declare whether the input is of type `QUERY` or of type `PASSAGE`." + ); + } + + return false; + } + + private TextDocsInputDataSet addPrefixesToData(AsymmetricTextEmbeddingParameters mlParams, TextDocsInputDataSet inputDataSet) { + // Asymmetric embedding models typically work with "mini-prompts" that prime the model to embed a text + // as a query or as a passage. Here we apply the prompt as defined in the model configuration. We default + // to query embedding. + TextEmbeddingModelConfig modelConfig = (TextEmbeddingModelConfig) this.modelConfig; + String prefix = mlParams.getEmbeddingContentType() == EmbeddingContentType.PASSAGE + ? modelConfig.getPassagePrefix() + : modelConfig.getQueryPrefix(); + if (prefix != null) { + List prefixedDocs = inputDataSet.getDocs().stream().map(s -> prefix + s).collect(Collectors.toList()); + return TextDocsInputDataSet.builder().docs(prefixedDocs).build(); + } + return inputDataSet; + } + public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException { TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig; String warmUpSentence = "warm up sentence"; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 7c7e2be4b9..75bdde5bbe 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -6,6 +6,8 @@ package org.opensearch.ml.engine.algorithms.text_embedding; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS; import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE; @@ -31,6 +33,8 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; @@ -239,6 +243,257 @@ private void initModel_predict_HuggingfaceModel( } + @Test + public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_HappyPath() throws URISyntaxException { + Map params = new HashMap<>(); + params.put(MODEL_HELPER, modelHelper); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI())); + params.put(ML_ENGINE, mlEngine); + + TextEmbeddingModelConfig asymmetricModelConfig = this.modelConfig + .toBuilder() + .embeddingDimension(768) + .queryPrefix("query >> ") + .passagePrefix("passage >> ") + .build(); + MLModel asymmetricSmallModel = model.toBuilder().modelConfig(asymmetricModelConfig).build(); + textEmbeddingDenseModel.initModel(asymmetricSmallModel, params, encryptor); + MLInput asymmetricMlInputQueries = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + ) + .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY)) + .build(); + MLInput asymmetricMlInputPassages = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build() + ) + .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.PASSAGE)) + .build(); + + ModelTensorOutput asymmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputQueries); + ModelTensorOutput asymmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputPassages); + + TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build(); + MLModel smallModel = model.toBuilder().modelConfig(symmetricModelConfig).build(); + textEmbeddingDenseModel.initModel(smallModel, params, encryptor); + MLInput symmetricMlInputQueries = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + ) + .build(); + MLInput symmetricMlInputPassages = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build() + ) + .build(); + + ModelTensorOutput symmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputQueries); + ModelTensorOutput symmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputPassages); + + assertTrue( + "asymmetric and symmetric query embeddings should be different", + areTensorsDifferent( + asymmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + symmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + 0.1f + ) + ); + assertTrue( + "asymmetric and symmetric passage embeddings should be different", + areTensorsDifferent( + asymmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + symmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + 0.1f + ) + ); + + textEmbeddingDenseModel.close(); + } + + private boolean areTensorsDifferent(ModelTensor tensor1, ModelTensor tensor2, float delta) { + + if (!Arrays.equals(tensor1.getShape(), tensor2.getShape())) { + return true; // Tensors are different if they have different lengths + } + + List vectorA = Arrays.asList(tensor1.getData()); + List vectorB = Arrays.asList(tensor2.getData()); + + for (int i = 0; i < vectorA.size(); i++) { + if (Math.abs(vectorA.get(i).floatValue() - vectorB.get(i).floatValue()) > delta) { + return true; // Vectors are different if any pair of corresponding elements differ by more than the tolerance + } + } + return false; // Vectors are the same + + } + + @Test + public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_HappyPath2() + throws URISyntaxException { + // only the query embeddings need a prefix + Map params = new HashMap<>(); + params.put(MODEL_HELPER, modelHelper); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI())); + params.put(ML_ENGINE, mlEngine); + + TextEmbeddingModelConfig asymmetricModelConfig = this.modelConfig + .toBuilder() + .embeddingDimension(768) + .queryPrefix("query >> ") + .build(); + MLModel asymmetricSmallModel = model.toBuilder().modelConfig(asymmetricModelConfig).build(); + textEmbeddingDenseModel.initModel(asymmetricSmallModel, params, encryptor); + MLInput asymmetricMlInputQueries = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + ) + .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY)) + .build(); + MLInput asymmetricMlInputPassages = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build() + ) + .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.PASSAGE)) + .build(); + + ModelTensorOutput asymmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputQueries); + ModelTensorOutput asymmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputPassages); + + TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build(); + MLModel smallModel = model.toBuilder().modelConfig(symmetricModelConfig).build(); + textEmbeddingDenseModel.initModel(smallModel, params, encryptor); + MLInput symmetricMlInputQueries = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + ) + .build(); + MLInput symmetricMlInputPassages = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build() + ) + .build(); + + ModelTensorOutput symmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputQueries); + ModelTensorOutput symmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputPassages); + + assertTrue( + "asymmetric and symmetric query embeddings should be different", + areTensorsDifferent( + asymmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + symmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + 0.1f + ) + ); + assertTrue( + "asymmetric and symmetric passage embeddings should be equal", + !areTensorsDifferent( + asymmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + symmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0), + 0.1f + ) + ); + + textEmbeddingDenseModel.close(); + } + + @Test + public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_SadPath1() throws URISyntaxException { + // asymmetric model, no parameter passed + Map params = new HashMap<>(); + params.put(MODEL_HELPER, modelHelper); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI())); + params.put(ML_ENGINE, mlEngine); + + TextEmbeddingModelConfig asymmetricModelConfig = this.modelConfig + .toBuilder() + .embeddingDimension(768) + .queryPrefix("query >> ") + .passagePrefix("passage >>") + .build(); + MLModel asymmetricSmallModel = model.toBuilder().modelConfig(asymmetricModelConfig).build(); + textEmbeddingDenseModel.initModel(asymmetricSmallModel, params, encryptor); + + MLInput symmetricMlInputQueries = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + ) + .build(); + + try { + textEmbeddingDenseModel.predict(symmetricMlInputQueries); + } catch (MLException e) { + assertEquals(IllegalArgumentException.class, e.getCause().getClass()); + assertEquals( + "The embedding model chosen is asymmetric. To use it, you must declare whether the input is of type `QUERY` or of type `PASSAGE`.", + e.getCause().getMessage() + ); + return; + } finally { + textEmbeddingDenseModel.close(); + } + + fail("Expected exception not thrown"); + + } + + @Test + public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_SadPath2() throws URISyntaxException { + // symmetric model, asymmetric parameter passed + Map params = new HashMap<>(); + params.put(MODEL_HELPER, modelHelper); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI())); + params.put(ML_ENGINE, mlEngine); + + TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build(); + MLModel symmetricSmallModel = model.toBuilder().modelConfig(symmetricModelConfig).build(); + textEmbeddingDenseModel.initModel(symmetricSmallModel, params, encryptor); + + MLInput asymmetricMlInputQueries = MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset( + TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + ) + .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY)) + .build(); + + try { + textEmbeddingDenseModel.predict(asymmetricMlInputQueries); + } catch (MLException e) { + assertEquals(IllegalArgumentException.class, e.getCause().getClass()); + assertEquals( + "When passing AsymmetricTextEmbeddingParameters, the model requires to be registered with at least one of `query_prefix` or `passage_prefix`.", + e.getCause().getMessage() + ); + return; + } finally { + textEmbeddingDenseModel.close(); + } + + fail("Expected exception not thrown"); + + } + @Test public void initModel_NullModelZipFile() { exceptionRule.expect(IllegalArgumentException.class); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index fd28c47970..5a46d10574 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -91,6 +91,7 @@ import org.opensearch.ml.cluster.MLCommonsClusterEventListener; import org.opensearch.ml.cluster.MLCommonsClusterManagerEventListener; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput; import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; @@ -843,7 +844,8 @@ public List getNamedXContent() { AnomalyLocalizationInput.XCONTENT_REGISTRY_ENTRY, RCFSummarizeParams.XCONTENT_REGISTRY, LogisticRegressionParams.XCONTENT_REGISTRY, - TextEmbeddingModelConfig.XCONTENT_REGISTRY + TextEmbeddingModelConfig.XCONTENT_REGISTRY, + AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY ); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index b3f3e3f956..2ade96cab2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -36,6 +36,7 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.common.transport.register.MLRegisterModelAction; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -136,6 +137,21 @@ public void testRegisterModelRequest() throws Exception { assertEquals("test_model", registerModelInput.getModelName()); assertEquals("1", registerModelInput.getVersion()); assertEquals("TORCH_SCRIPT", registerModelInput.getModelFormat().toString()); + assertEquals(null, ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getQueryPrefix()); + assertEquals(null, ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getPassagePrefix()); + } + + public void testRegisterAsymmetricModelRequest() throws Exception { + RestRequest request = getRestRequestAsymmetricModel(); + restMLRegisterModelAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelRequest.class); + verify(client, times(1)).execute(eq(MLRegisterModelAction.INSTANCE), argumentCaptor.capture(), any()); + MLRegisterModelInput registerModelInput = argumentCaptor.getValue().getRegisterModelInput(); + assertEquals("test_model", registerModelInput.getModelName()); + assertEquals("1", registerModelInput.getVersion()); + assertEquals("TORCH_SCRIPT", registerModelInput.getModelFormat().toString()); + assertEquals("query: ", ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getQueryPrefix()); + assertEquals("passage: ", ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getPassagePrefix()); } public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception { @@ -212,6 +228,49 @@ private RestRequest getRestRequest() { return request; } + private RestRequest getRestRequestAsymmetricModel() { + RestRequest.Method method = RestRequest.Method.POST; + final Map modelConfig = Map + .of( + "model_type", + "bert", + "embedding_dimension", + 384, + "framework_type", + "sentence_transformers", + "all_config", + "All Config", + "query_prefix", + "query: ", + "passage_prefix", + "passage: " + ); + final Map model = Map + .of( + "name", + "test_model", + "model_id", + "test_model_with_modelId", + "version", + "1", + "model_group_id", + "modelGroupId", + "url", + "testUrl", + "model_format", + "TORCH_SCRIPT", + "model_config", + modelConfig + ); + String requestContent = new Gson().toJson(model).toString(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/models/{model_id}/{version}/_register") + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + private RestRequest getRestRequestWithNullModelId() { RestRequest.Method method = RestRequest.Method.POST; final Map modelConfig = Map