diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java
new file mode 100644
index 0000000000..0c03d0b3be
--- /dev/null
+++ b/common/src/main/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParameters.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.ml.common.dataset;
+
+import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
+
+import java.io.IOException;
+import java.util.Locale;
+
+import org.opensearch.core.ParseField;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.common.io.stream.StreamOutput;
+import org.opensearch.core.xcontent.NamedXContentRegistry;
+import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.ml.common.FunctionName;
+import org.opensearch.ml.common.annotation.MLAlgoParameter;
+import org.opensearch.ml.common.input.parameter.MLAlgoParams;
+
+import lombok.Builder;
+import lombok.Data;
+
+/**
+ * This class defines the modes of operation of an asymmetric text embedding model.
+ * Asymmetric embedding models treat the input text differently, depending on whether it is a
+ * passage or a query. One example asymmetric model, that requires different prefixes is e5
+ * (cf. https://arxiv.org/pdf/2212.03533.pdf).
+ *
+ * Use this parameter only if the model is asymmetric and has been registered with the corresponding
+ * `query_prefix` and `passage_prefix` configuration parameters.
+ */
+@Data
+@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING })
+public class AsymmetricTextEmbeddingParameters implements MLAlgoParams {
+
+ public enum EmbeddingContentType {
+ QUERY,
+ PASSAGE
+ }
+
+ public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name();
+ public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
+ MLAlgoParams.class,
+ new ParseField(PARSE_FIELD_NAME),
+ it -> parse(it)
+ );
+
+ @Builder(toBuilder = true)
+ public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) {
+ this.embeddingContentType = embeddingContentType;
+ }
+
+ public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException {
+ this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString());
+ }
+
+ public static MLAlgoParams parse(XContentParser parser) throws IOException {
+ EmbeddingContentType embeddingContentType = null;
+
+ ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
+ while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
+ String fieldName = parser.currentName();
+ parser.nextToken();
+
+ switch (fieldName) {
+ case EMBEDDING_CONTENT_TYPE_FIELD:
+ String contentType = parser.text();
+ embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT));
+ break;
+ default:
+ parser.skipChildren();
+ break;
+ }
+ }
+ return new AsymmetricTextEmbeddingParameters(embeddingContentType);
+ }
+
+ public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type";
+
+ // The type of the content to be embedded
+ private EmbeddingContentType embeddingContentType;
+
+ @Override
+ public int getVersion() {
+ return 1;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return PARSE_FIELD_NAME;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalString(embeddingContentType.name());
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
+ xContentBuilder.startObject();
+ if (embeddingContentType != null) {
+ xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name());
+ }
+ xContentBuilder.endObject();
+ return xContentBuilder;
+ }
+
+ public EmbeddingContentType getEmbeddingContentType() {
+ return embeddingContentType;
+ }
+}
diff --git a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java
index dbb15fa2d6..b1c249da44 100644
--- a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java
+++ b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java
@@ -37,16 +37,25 @@ public class TextEmbeddingModelConfig extends MLModelConfig {
public static final String POOLING_MODE_FIELD = "pooling_mode";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";
+ public static final String QUERY_PREFIX = "query_prefix";
+ public static final String PASSAGE_PREFIX = "passage_prefix";
private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMode poolingMode;
private final boolean normalizeResult;
private final Integer modelMaxLength;
+ private final String queryPrefix;
+ private final String passagePrefix;
+
+ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
+ PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
+ this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null);
+ }
@Builder(toBuilder = true)
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
- PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) {
+ PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix) {
super(modelType, allConfig);
if (embeddingDimension == null) {
throw new IllegalArgumentException("embedding dimension is null");
@@ -59,6 +68,8 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
this.poolingMode = poolingMode;
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
+ this.queryPrefix = queryPrefix;
+ this.passagePrefix = passagePrefix;
}
public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
@@ -69,6 +80,8 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
PoolingMode poolingMode = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;
+ String queryPrefix = null;
+ String passagePrefix = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -97,12 +110,18 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;
+ case QUERY_PREFIX:
+ queryPrefix = parser.text();
+ break;
+ case PASSAGE_PREFIX:
+ passagePrefix = parser.text();
+ break;
default:
parser.skipChildren();
break;
}
}
- return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength);
+ return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, queryPrefix, passagePrefix);
}
@Override
@@ -121,6 +140,8 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
}
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
+ queryPrefix = in.readOptionalString();
+ passagePrefix = in.readOptionalString();
}
@Override
@@ -136,6 +157,8 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
+ out.writeOptionalString(queryPrefix);
+ out.writeOptionalString(passagePrefix);
}
@Override
@@ -162,6 +185,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (normalizeResult) {
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
}
+ if (queryPrefix != null) {
+ builder.field(QUERY_PREFIX, queryPrefix);
+ }
+ if (passagePrefix != null) {
+ builder.field(PASSAGE_PREFIX, passagePrefix);
+ }
builder.endObject();
return builder;
}
diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java
new file mode 100644
index 0000000000..a7a27c00ee
--- /dev/null
+++ b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java
@@ -0,0 +1,83 @@
+package org.opensearch.ml.common.dataset;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.opensearch.common.io.stream.BytesStreamOutput;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.ml.common.TestHelper;
+
+import java.io.IOException;
+import java.util.function.Function;
+import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
+
+import static org.junit.Assert.assertEquals;
+import static org.opensearch.ml.common.TestHelper.contentObjectToString;
+import static org.opensearch.ml.common.TestHelper.testParseFromString;
+
+public class AsymmetricTextEmbeddingParametersTest {
+
+ @Rule
+ public ExpectedException exceptionRule = ExpectedException.none();
+
+ AsymmetricTextEmbeddingParameters params;
+ private Function function = parser -> {
+ try {
+ return (AsymmetricTextEmbeddingParameters) AsymmetricTextEmbeddingParameters.parse(parser);
+ } catch (IOException e) {
+ throw new RuntimeException("failed to parse AsymmetricTextEmbeddingParameters", e);
+ }
+ };
+
+ @Before
+ public void setUp() {
+ params = AsymmetricTextEmbeddingParameters.builder()
+ .embeddingContentType(EmbeddingContentType.QUERY)
+ .build();
+ }
+
+ @Test
+ public void parse_AsymmetricTextEmbeddingParameters() throws IOException {
+ TestHelper.testParse(params, function);
+ }
+
+ @Test
+ public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException {
+ String paramsStr = contentObjectToString(params);
+ testParseFromString(params, paramsStr.replace("QUERY", "PASSAGE"), function);
+ }
+
+ @Test
+ public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException {
+ exceptionRule.expect(IllegalArgumentException.class);
+ exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU");
+ String paramsStr = contentObjectToString(params);
+ testParseFromString(params, paramsStr.replace("QUERY","fu"), function);
+ }
+
+ @Test
+ public void parse_EmptyAsymmetricTextEmbeddingParameters() throws IOException {
+ TestHelper.testParse(AsymmetricTextEmbeddingParameters.builder().build(), function);
+ }
+
+ @Test
+ public void readInputStream_Success() throws IOException {
+ readInputStream(params);
+ }
+
+ @Test
+ public void readInputStream_Success_EmptyParams() throws IOException {
+ readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build());
+ }
+
+ private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException {
+ BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
+ params.writeTo(bytesStreamOutput);
+
+ StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
+ AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput);
+ assertEquals(params, parsedParams);
+ }
+}
diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java
index 10468d44a3..8819786fbe 100644
--- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java
+++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java
@@ -39,9 +39,9 @@ public class TextDocsMLInputTest {
@Before
public void setUp() throws Exception {
ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true)
- .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
+ .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build();
MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2"))
- .resultFilter(resultFilter).build();
+ .resultFilter(resultFilter).build();
input = new TextDocsMLInput(algorithm, inputDataset);
}
@@ -68,8 +68,8 @@ public void parseTextDocsMLInput_NewWay() throws IOException {
private void parseMLInput(String jsonStr, int docSize) throws IOException {
XContentParser parser = XContentType.JSON.xContent()
- .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
- Collections.emptyList()).getNamedXContents()), null, jsonStr);
+ .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
+ Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name());
diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java
index acba744ced..9bc97f7c9f 100644
--- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java
+++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java
@@ -36,6 +36,8 @@ public void setUp() {
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
+ .passagePrefix("passage: ")
+ .queryPrefix("query: ")
.build();
function = parser -> {
try {
@@ -51,7 +53,7 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, EMPTY_PARAMS);
String configContent = TestHelper.xContentBuilderToString(builder);
- assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
+ assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}", configContent);
}
@Test
@@ -83,7 +85,7 @@ public void nullFields_FrameworkType() {
@Test
public void parse() throws IOException {
- String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}";
+ String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}";
TestHelper.testParseFromString(config, content, function);
}
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java
index 1090ee15b6..1f606bcaa4 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java
@@ -135,6 +135,12 @@ public void downloadPrebuiltModelConfig(
case TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD:
configBuilder.modelMaxLength(((Double) configEntry.getValue()).intValue());
break;
+ case TextEmbeddingModelConfig.QUERY_PREFIX:
+ configBuilder.queryPrefix(configEntry.getValue().toString());
+ break;
+ case TextEmbeddingModelConfig.PASSAGE_PREFIX:
+ configBuilder.passagePrefix(configEntry.getValue().toString());
+ break;
default:
break;
}
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java
index 6c6033f2cb..e5db51324f 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java
@@ -67,6 +67,8 @@ public abstract class DLModel implements Predictable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);
+ protected MLModelConfig modelConfig;
+
@Override
public MLOutput predict(MLInput mlInput, MLModel model) {
throw new IllegalArgumentException("model not deployed");
@@ -183,6 +185,7 @@ protected void doLoadModel(
IOException,
TranslateException {
devices = Engine.getEngine(engine).getDevices();
+ this.modelConfig = modelConfig;
for (int i = 0; i < devices.length; i++) {
log.debug("load model {} to device {}: {}", modelId, i, devices[i]);
ZooModel model;
diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java
index d74eee7b0f..a5a3a8fb50 100644
--- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java
+++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java
@@ -4,10 +4,14 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.stream.Collectors;
+import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
+import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
+import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
@@ -20,9 +24,15 @@
import ai.djl.translate.TranslateException;
public abstract class TextEmbeddingModel extends DLModel {
+
@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
- MLInputDataset inputDataSet = mlInput.getInputDataset();
+ MLAlgoParams mlParams = mlInput.getParameters();
+
+ MLInputDataset inputDataSet = isAsymmetricModel(mlParams)
+ ? addPrefixesToData((AsymmetricTextEmbeddingParameters) mlParams, (TextDocsInputDataSet) mlInput.getInputDataset())
+ : mlInput.getInputDataset();
+
List tensorOutputs = new ArrayList<>();
Output output;
TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet) inputDataSet;
@@ -36,6 +46,48 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla
return new ModelTensorOutput(tensorOutputs);
}
+ private boolean isAsymmetricModel(MLAlgoParams mlParams) {
+ if (mlParams instanceof AsymmetricTextEmbeddingParameters) {
+ // Check for the necessary prefixes in modelConfig
+ if (modelConfig == null
+ || ((TextEmbeddingModelConfig) modelConfig).getPassagePrefix() == null
+ && ((TextEmbeddingModelConfig) modelConfig).getQueryPrefix() == null) {
+ throw new IllegalArgumentException(
+ "When passing AsymmetricTextEmbeddingParameters, the model requires to be "
+ + "registered with at least one of `query_prefix` or `passage_prefix`."
+ );
+ }
+ // Passed all checks
+ return true;
+ }
+
+ // no AsymmetricTextEmbeddingParameters passed, but the model is asymmetric.
+ if (modelConfig != null
+ && (((TextEmbeddingModelConfig) modelConfig).getPassagePrefix() != null
+ || ((TextEmbeddingModelConfig) modelConfig).getQueryPrefix() != null)) {
+ throw new IllegalArgumentException(
+ "The embedding model chosen is asymmetric. To use it, you must declare whether the input is of type `QUERY` or of type `PASSAGE`."
+ );
+ }
+
+ return false;
+ }
+
+ private TextDocsInputDataSet addPrefixesToData(AsymmetricTextEmbeddingParameters mlParams, TextDocsInputDataSet inputDataSet) {
+ // Asymmetric embedding models typically work with "mini-prompts" that prime the model to embed a text
+ // as a query or as a passage. Here we apply the prompt as defined in the model configuration. We default
+ // to query embedding.
+ TextEmbeddingModelConfig modelConfig = (TextEmbeddingModelConfig) this.modelConfig;
+ String prefix = mlParams.getEmbeddingContentType() == EmbeddingContentType.PASSAGE
+ ? modelConfig.getPassagePrefix()
+ : modelConfig.getQueryPrefix();
+ if (prefix != null) {
+ List prefixedDocs = inputDataSet.getDocs().stream().map(s -> prefix + s).collect(Collectors.toList());
+ return TextDocsInputDataSet.builder().docs(prefixedDocs).build();
+ }
+ return inputDataSet;
+ }
+
public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig;
String warmUpSentence = "warm up sentence";
diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java
index 7c7e2be4b9..75bdde5bbe 100644
--- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java
+++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java
@@ -6,6 +6,8 @@
package org.opensearch.ml.engine.algorithms.text_embedding;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS;
import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE;
@@ -31,6 +33,8 @@
import org.opensearch.ResourceNotFoundException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
+import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
+import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
@@ -239,6 +243,257 @@ private void initModel_predict_HuggingfaceModel(
}
+ @Test
+ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_HappyPath() throws URISyntaxException {
+ Map params = new HashMap<>();
+ params.put(MODEL_HELPER, modelHelper);
+ params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI()));
+ params.put(ML_ENGINE, mlEngine);
+
+ TextEmbeddingModelConfig asymmetricModelConfig = this.modelConfig
+ .toBuilder()
+ .embeddingDimension(768)
+ .queryPrefix("query >> ")
+ .passagePrefix("passage >> ")
+ .build();
+ MLModel asymmetricSmallModel = model.toBuilder().modelConfig(asymmetricModelConfig).build();
+ textEmbeddingDenseModel.initModel(asymmetricSmallModel, params, encryptor);
+ MLInput asymmetricMlInputQueries = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
+ )
+ .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY))
+ .build();
+ MLInput asymmetricMlInputPassages = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build()
+ )
+ .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.PASSAGE))
+ .build();
+
+ ModelTensorOutput asymmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputQueries);
+ ModelTensorOutput asymmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputPassages);
+
+ TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build();
+ MLModel smallModel = model.toBuilder().modelConfig(symmetricModelConfig).build();
+ textEmbeddingDenseModel.initModel(smallModel, params, encryptor);
+ MLInput symmetricMlInputQueries = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
+ )
+ .build();
+ MLInput symmetricMlInputPassages = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build()
+ )
+ .build();
+
+ ModelTensorOutput symmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputQueries);
+ ModelTensorOutput symmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputPassages);
+
+ assertTrue(
+ "asymmetric and symmetric query embeddings should be different",
+ areTensorsDifferent(
+ asymmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ symmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ 0.1f
+ )
+ );
+ assertTrue(
+ "asymmetric and symmetric passage embeddings should be different",
+ areTensorsDifferent(
+ asymmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ symmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ 0.1f
+ )
+ );
+
+ textEmbeddingDenseModel.close();
+ }
+
+ private boolean areTensorsDifferent(ModelTensor tensor1, ModelTensor tensor2, float delta) {
+
+ if (!Arrays.equals(tensor1.getShape(), tensor2.getShape())) {
+ return true; // Tensors are different if they have different lengths
+ }
+
+ List vectorA = Arrays.asList(tensor1.getData());
+ List vectorB = Arrays.asList(tensor2.getData());
+
+ for (int i = 0; i < vectorA.size(); i++) {
+ if (Math.abs(vectorA.get(i).floatValue() - vectorB.get(i).floatValue()) > delta) {
+ return true; // Vectors are different if any pair of corresponding elements differ by more than the tolerance
+ }
+ }
+ return false; // Vectors are the same
+
+ }
+
+ @Test
+ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_HappyPath2()
+ throws URISyntaxException {
+ // only the query embeddings need a prefix
+ Map params = new HashMap<>();
+ params.put(MODEL_HELPER, modelHelper);
+ params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI()));
+ params.put(ML_ENGINE, mlEngine);
+
+ TextEmbeddingModelConfig asymmetricModelConfig = this.modelConfig
+ .toBuilder()
+ .embeddingDimension(768)
+ .queryPrefix("query >> ")
+ .build();
+ MLModel asymmetricSmallModel = model.toBuilder().modelConfig(asymmetricModelConfig).build();
+ textEmbeddingDenseModel.initModel(asymmetricSmallModel, params, encryptor);
+ MLInput asymmetricMlInputQueries = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
+ )
+ .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY))
+ .build();
+ MLInput asymmetricMlInputPassages = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build()
+ )
+ .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.PASSAGE))
+ .build();
+
+ ModelTensorOutput asymmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputQueries);
+ ModelTensorOutput asymmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(asymmetricMlInputPassages);
+
+ TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build();
+ MLModel smallModel = model.toBuilder().modelConfig(symmetricModelConfig).build();
+ textEmbeddingDenseModel.initModel(smallModel, params, encryptor);
+ MLInput symmetricMlInputQueries = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
+ )
+ .build();
+ MLInput symmetricMlInputPassages = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build()
+ )
+ .build();
+
+ ModelTensorOutput symmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputQueries);
+ ModelTensorOutput symmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputPassages);
+
+ assertTrue(
+ "asymmetric and symmetric query embeddings should be different",
+ areTensorsDifferent(
+ asymmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ symmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ 0.1f
+ )
+ );
+ assertTrue(
+ "asymmetric and symmetric passage embeddings should be equal",
+ !areTensorsDifferent(
+ asymmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ symmetricPassageEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().get(0),
+ 0.1f
+ )
+ );
+
+ textEmbeddingDenseModel.close();
+ }
+
+ @Test
+ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_SadPath1() throws URISyntaxException {
+ // asymmetric model, no parameter passed
+ Map params = new HashMap<>();
+ params.put(MODEL_HELPER, modelHelper);
+ params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI()));
+ params.put(ML_ENGINE, mlEngine);
+
+ TextEmbeddingModelConfig asymmetricModelConfig = this.modelConfig
+ .toBuilder()
+ .embeddingDimension(768)
+ .queryPrefix("query >> ")
+ .passagePrefix("passage >>")
+ .build();
+ MLModel asymmetricSmallModel = model.toBuilder().modelConfig(asymmetricModelConfig).build();
+ textEmbeddingDenseModel.initModel(asymmetricSmallModel, params, encryptor);
+
+ MLInput symmetricMlInputQueries = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
+ )
+ .build();
+
+ try {
+ textEmbeddingDenseModel.predict(symmetricMlInputQueries);
+ } catch (MLException e) {
+ assertEquals(IllegalArgumentException.class, e.getCause().getClass());
+ assertEquals(
+ "The embedding model chosen is asymmetric. To use it, you must declare whether the input is of type `QUERY` or of type `PASSAGE`.",
+ e.getCause().getMessage()
+ );
+ return;
+ } finally {
+ textEmbeddingDenseModel.close();
+ }
+
+ fail("Expected exception not thrown");
+
+ }
+
+ @Test
+ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_Asymmetric_Prompts_SadPath2() throws URISyntaxException {
+ // symmetric model, asymmetric parameter passed
+ Map params = new HashMap<>();
+ params.put(MODEL_HELPER, modelHelper);
+ params.put(MODEL_ZIP_FILE, new File(getClass().getResource("traced_small_model.zip").toURI()));
+ params.put(ML_ENGINE, mlEngine);
+
+ TextEmbeddingModelConfig symmetricModelConfig = this.modelConfig.toBuilder().embeddingDimension(768).build();
+ MLModel symmetricSmallModel = model.toBuilder().modelConfig(symmetricModelConfig).build();
+ textEmbeddingDenseModel.initModel(symmetricSmallModel, params, encryptor);
+
+ MLInput asymmetricMlInputQueries = MLInput
+ .builder()
+ .algorithm(FunctionName.TEXT_EMBEDDING)
+ .inputDataset(
+ TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build()
+ )
+ .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY))
+ .build();
+
+ try {
+ textEmbeddingDenseModel.predict(asymmetricMlInputQueries);
+ } catch (MLException e) {
+ assertEquals(IllegalArgumentException.class, e.getCause().getClass());
+ assertEquals(
+ "When passing AsymmetricTextEmbeddingParameters, the model requires to be registered with at least one of `query_prefix` or `passage_prefix`.",
+ e.getCause().getMessage()
+ );
+ return;
+ } finally {
+ textEmbeddingDenseModel.close();
+ }
+
+ fail("Expected exception not thrown");
+
+ }
+
@Test
public void initModel_NullModelZipFile() {
exceptionRule.expect(IllegalArgumentException.class);
diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java
index fd28c47970..5a46d10574 100644
--- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java
+++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java
@@ -91,6 +91,7 @@
import org.opensearch.ml.cluster.MLCommonsClusterEventListener;
import org.opensearch.ml.cluster.MLCommonsClusterManagerEventListener;
import org.opensearch.ml.common.FunctionName;
+import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
@@ -843,7 +844,8 @@ public List getNamedXContent() {
AnomalyLocalizationInput.XCONTENT_REGISTRY_ENTRY,
RCFSummarizeParams.XCONTENT_REGISTRY,
LogisticRegressionParams.XCONTENT_REGISTRY,
- TextEmbeddingModelConfig.XCONTENT_REGISTRY
+ TextEmbeddingModelConfig.XCONTENT_REGISTRY,
+ AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY
);
}
diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java
index b3f3e3f956..2ade96cab2 100644
--- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java
@@ -36,6 +36,7 @@
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
+import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.model.MLModelGetResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelAction;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
@@ -136,6 +137,21 @@ public void testRegisterModelRequest() throws Exception {
assertEquals("test_model", registerModelInput.getModelName());
assertEquals("1", registerModelInput.getVersion());
assertEquals("TORCH_SCRIPT", registerModelInput.getModelFormat().toString());
+ assertEquals(null, ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getQueryPrefix());
+ assertEquals(null, ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getPassagePrefix());
+ }
+
+ public void testRegisterAsymmetricModelRequest() throws Exception {
+ RestRequest request = getRestRequestAsymmetricModel();
+ restMLRegisterModelAction.handleRequest(request, channel, client);
+ ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelRequest.class);
+ verify(client, times(1)).execute(eq(MLRegisterModelAction.INSTANCE), argumentCaptor.capture(), any());
+ MLRegisterModelInput registerModelInput = argumentCaptor.getValue().getRegisterModelInput();
+ assertEquals("test_model", registerModelInput.getModelName());
+ assertEquals("1", registerModelInput.getVersion());
+ assertEquals("TORCH_SCRIPT", registerModelInput.getModelFormat().toString());
+ assertEquals("query: ", ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getQueryPrefix());
+ assertEquals("passage: ", ((TextEmbeddingModelConfig) registerModelInput.getModelConfig()).getPassagePrefix());
}
public void testRegisterModelRequestRemoteInferenceDisabled() throws Exception {
@@ -212,6 +228,49 @@ private RestRequest getRestRequest() {
return request;
}
+ private RestRequest getRestRequestAsymmetricModel() {
+ RestRequest.Method method = RestRequest.Method.POST;
+ final Map modelConfig = Map
+ .of(
+ "model_type",
+ "bert",
+ "embedding_dimension",
+ 384,
+ "framework_type",
+ "sentence_transformers",
+ "all_config",
+ "All Config",
+ "query_prefix",
+ "query: ",
+ "passage_prefix",
+ "passage: "
+ );
+ final Map model = Map
+ .of(
+ "name",
+ "test_model",
+ "model_id",
+ "test_model_with_modelId",
+ "version",
+ "1",
+ "model_group_id",
+ "modelGroupId",
+ "url",
+ "testUrl",
+ "model_format",
+ "TORCH_SCRIPT",
+ "model_config",
+ modelConfig
+ );
+ String requestContent = new Gson().toJson(model).toString();
+ RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
+ .withMethod(method)
+ .withPath("/_plugins/_ml/models/{model_id}/{version}/_register")
+ .withContent(new BytesArray(requestContent), XContentType.JSON)
+ .build();
+ return request;
+ }
+
private RestRequest getRestRequestWithNullModelId() {
RestRequest.Method method = RestRequest.Method.POST;
final Map modelConfig = Map