forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding multi-modal pre-processor for cohere (opensearch-project#3219)
* adding multi-modal pre-processor for cohere Signed-off-by: Dhrubo Saha <[email protected]> * added javadoc Signed-off-by: Dhrubo Saha <[email protected]> --------- Signed-off-by: Dhrubo Saha <[email protected]>
- Loading branch information
Showing
4 changed files
with
145 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
...ml/common/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
public class CohereMultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { | ||
|
||
public CohereMultiModalEmbeddingPreProcessFunction() { | ||
this.returnDirectlyForRemoteInferenceInput = true; | ||
} | ||
|
||
@Override | ||
public void validate(MLInput mlInput) { | ||
validateTextDocsInput(mlInput); | ||
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); | ||
if (docs.isEmpty() || (docs.size() == 1 && docs.getFirst() == null)) { | ||
throw new IllegalArgumentException("No image provided"); | ||
} | ||
} | ||
|
||
@Override | ||
public RemoteInferenceInputDataSet process(MLInput mlInput) { | ||
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); | ||
Map<String, String> parametersMap = new HashMap<>(); | ||
|
||
/** | ||
* Cohere multi-modal model expects either image or texts, not both. | ||
* For image, customer can use this pre-process function. For texts, customer can use | ||
* connector.pre_process.cohere.embedding | ||
* Cohere expects An array of image data URIs for the model to embed. Maximum number of images per call is 1. | ||
*/ | ||
parametersMap.put("images", inputData.getDocs().getFirst()); | ||
return RemoteInferenceInputDataSet | ||
.builder() | ||
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))) | ||
.build(); | ||
} | ||
} |
89 changes: 89 additions & 0 deletions
89
...ommon/connector/functions/preprocess/CohereMultiModalEmbeddingPreProcessFunctionTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
public class CohereMultiModalEmbeddingPreProcessFunctionTest { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
CohereMultiModalEmbeddingPreProcessFunction function; | ||
|
||
TextSimilarityInputDataSet textSimilarityInputDataSet; | ||
TextDocsInputDataSet textDocsInputDataSet; | ||
RemoteInferenceInputDataSet remoteInferenceInputDataSet; | ||
|
||
MLInput textEmbeddingInput; | ||
MLInput textSimilarityInput; | ||
MLInput remoteInferenceInput; | ||
|
||
@Before | ||
public void setUp() { | ||
function = new CohereMultiModalEmbeddingPreProcessFunction(); | ||
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(List.of("hello")).build(); | ||
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(List.of("imageString")).build(); | ||
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("images", "value2")).build(); | ||
|
||
textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); | ||
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenNullInput_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Preprocess function input can't be null"); | ||
function.apply(null); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenWrongInput_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); | ||
function.apply(textSimilarityInput); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenCorrectInput_expectCorrectOutput() { | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
assertEquals(1, dataSet.getParameters().size()); | ||
assertEquals("imageString", dataSet.getParameters().get("images")); | ||
|
||
} | ||
|
||
@Test | ||
public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("No image provided"); | ||
List<String> docs = new ArrayList<>(); | ||
docs.add(null); | ||
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() { | ||
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); | ||
assertEquals(remoteInferenceInputDataSet, dataSet); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters