Skip to content

Commit

Permalink
adding multi-modal pre-processor for cohere (opensearch-project#3219)
Browse files Browse the repository at this point in the history
* adding multi-modal pre-processor for cohere

Signed-off-by: Dhrubo Saha <[email protected]>

* added javadoc

Signed-off-by: Dhrubo Saha <[email protected]>

---------

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os authored Nov 18, 2024
1 parent 9c1b8a8 commit 7041c22
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereMultiModalEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
Expand All @@ -21,6 +22,7 @@ public class MLPreProcessFunction {

private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT = "connector.pre_process.cohere.multimodal_embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
Expand All @@ -37,7 +39,10 @@ public class MLPreProcessFunction {
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
CohereMultiModalEmbeddingPreProcessFunction cohereMultiModalEmbeddingPreProcessFunction =
new CohereMultiModalEmbeddingPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

public class CohereMultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {

public CohereMultiModalEmbeddingPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.isEmpty() || (docs.size() == 1 && docs.getFirst() == null)) {
throw new IllegalArgumentException("No image provided");
}
}

@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, String> parametersMap = new HashMap<>();

/**
* Cohere multi-modal model expects either image or texts, not both.
* For image, customer can use this pre-process function. For texts, customer can use
* connector.pre_process.cohere.embedding
* Cohere expects An array of image data URIs for the model to embed. Maximum number of images per call is 1.
*/
parametersMap.put("images", inputData.getDocs().getFirst());
return RemoteInferenceInputDataSet
.builder()
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import static org.junit.Assert.assertEquals;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

public class CohereMultiModalEmbeddingPreProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

CohereMultiModalEmbeddingPreProcessFunction function;

TextSimilarityInputDataSet textSimilarityInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
RemoteInferenceInputDataSet remoteInferenceInputDataSet;

MLInput textEmbeddingInput;
MLInput textSimilarityInput;
MLInput remoteInferenceInput;

@Before
public void setUp() {
function = new CohereMultiModalEmbeddingPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(List.of("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(List.of("imageString")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("images", "value2")).build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build();
}

@Test
public void testProcess_whenNullInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Preprocess function input can't be null");
function.apply(null);
}

@Test
public void testProcess_whenWrongInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
function.apply(textSimilarityInput);
}

@Test
public void testProcess_whenCorrectInput_expectCorrectOutput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("imageString", dataSet.getParameters().get("images"));

}

@Test
public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No image provided");
List<String> docs = new ArrayList<>();
docs.add(null);
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
}

@Test
public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class MultiModalConnectorPreProcessFunctionTest {
@Before
public void setUp() {
function = new MultiModalConnectorPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(List.of("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet
.builder()
Expand Down

0 comments on commit 7041c22

Please sign in to comment.