Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add neural search default processor for non OpenAI/Cohere scenario #1274

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@

package org.opensearch.ml.common.connector;


import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -20,17 +31,6 @@
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,64 @@

package org.opensearch.ml.common.connector;

import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MLPostProcessFunction {

private static Map<String, String> POST_PROCESS_FUNCTIONS;
public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding";
public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding";

public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";

private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();

private static final Map<String, Function<List<List<Float>>, List<ModelTensor>>> POST_PROCESS_FUNCTIONS = new HashMap<>();


static {
POST_PROCESS_FUNCTIONS = new HashMap<>();
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, "\n def name = \"sentence_embedding\";\n" +
" def dataType = \"FLOAT32\";\n" +
" if (params.embeddings == null || params.embeddings.length == 0) {\n" +
" return null;\n" +
" }\n" +
" def embeddings = params.embeddings;\n" +
" StringBuilder builder = new StringBuilder(\"[\");\n" +
" for (int i=0; i<embeddings.length; i++) {\n" +
" def shape = [embeddings[i].length];\n" +
" def json = \"{\" +\n" +
" \"\\\"name\\\":\\\"\" + name + \"\\\",\" +\n" +
" \"\\\"data_type\\\":\\\"\" + dataType + \"\\\",\" +\n" +
" \"\\\"shape\\\":\" + shape + \",\" +\n" +
" \"\\\"data\\\":\" + embeddings[i] +\n" +
" \"}\";\n" +
" builder.append(json);\n" +
" if (i < embeddings.length - 1) {\n" +
" builder.append(\",\");\n" +
" }\n" +
" }\n" +
" builder.append(\"]\");\n" +
" \n" +
" return builder.toString();\n ");
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
JSON_PATH_EXPRESSION.put(DEFAULT_EMBEDDING, "$[*]");
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, buildModelTensorList());
POST_PROCESS_FUNCTIONS.put(DEFAULT_EMBEDDING, buildModelTensorList());
}

POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, "\n def name = \"sentence_embedding\";\n" +
" def dataType = \"FLOAT32\";\n" +
" if (params.data == null || params.data.length == 0) {\n" +
" return null;\n" +
" }\n" +
" def shape = [params.data[0].embedding.length];\n" +
" def json = \"{\" +\n" +
" \"\\\"name\\\":\\\"\" + name + \"\\\",\" +\n" +
" \"\\\"data_type\\\":\\\"\" + dataType + \"\\\",\" +\n" +
" \"\\\"shape\\\":\" + shape + \",\" +\n" +
" \"\\\"data\\\":\" + params.data[0].embedding +\n" +
" \"}\";\n" +
" return json;\n ");
public static Function<List<List<Float>>, List<ModelTensor>> buildModelTensorList() {
return embeddings -> {
List<ModelTensor> modelTensors = new ArrayList<>();
if (embeddings == null) {
throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function.");
}
embeddings.forEach(embedding -> modelTensors.add(
ModelTensor
.builder()
.name("sentence_embedding")
.dataType(MLResultDataType.FLOAT32)
.shape(new long[]{embedding.size()})
.data(embedding.toArray(new Number[0]))
.build()
));
return modelTensors;
};
}

public static boolean contains(String functionName) {
return POST_PROCESS_FUNCTIONS.containsKey(functionName);
public static String getResponseFilter(String postProcessFunction) {
return JSON_PATH_EXPRESSION.get(postProcessFunction);
}

public static String get(String postProcessFunction) {
public static Function<List<List<Float>>, List<ModelTensor>> get(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.get(postProcessFunction);
}

public static boolean contains(String postProcessFunction) {
return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,37 @@
package org.opensearch.ml.common.connector;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class MLPreProcessFunction {

private static Map<String, String> PRE_PROCESS_FUNCTIONS;
private static final Map<String, Function<List<String>, Map<String, Object>>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";

public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";

private static Function<List<String>, Map<String, Object>> cohereTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("texts", inputs));
}

private static Function<List<String>, Map<String, Object>> openAiTextEmbeddingPreProcess() {
return inputs -> Map.of("parameters", Map.of("input", inputs));
}

static {
PRE_PROCESS_FUNCTIONS = new HashMap<>();
//TODO: change to java for openAI, embedding and Titan
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" +
" builder.append(\"[\");\n" +
" for (int i=0; i< params.text_docs.length; i++) {\n" +
" builder.append(\"\\\"\");\n" +
" builder.append(params.text_docs[i]);\n" +
" builder.append(\"\\\"\");\n" +
" if (i < params.text_docs.length - 1) {\n" +
" builder.append(\",\")\n" +
" }\n" +
" }\n" +
" builder.append(\"]\");\n" +
" def parameters = \"{\" +\"\\\"prompt\\\":\" + builder + \"}\";\n" +
" return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";");

PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" +
" builder.append(\"\\\"\");\n" +
" builder.append(params.text_docs[0]);\n" +
" builder.append(\"\\\"\");\n" +
" def parameters = \"{\" +\"\\\"input\\\":\" + builder + \"}\";\n" +
" return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";");
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess());
}

public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static String get(String postProcessFunction) {
public static Function<List<String>, Map<String, Object>> get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;

@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE})
public class RemoteInferenceMLInput extends MLInput {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
public class StringUtils {

public static final Gson gson;

static {
gson = new Gson();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
package org.opensearch.ml.common.connector;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;

public class MLPostProcessFunctionTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Test
public void contains() {
Assert.assertTrue(MLPostProcessFunction.contains(OPENAI_EMBEDDING));
Expand All @@ -23,4 +32,24 @@ public void get() {
Assert.assertNotNull(MLPostProcessFunction.get(OPENAI_EMBEDDING));
Assert.assertNull(MLPostProcessFunction.get("wrong value"));
}

@Test
public void test_getResponseFilter() {
assert null != MLPostProcessFunction.getResponseFilter(OPENAI_EMBEDDING);
assert null == MLPostProcessFunction.getResponseFilter("wrong value");
}

@Test
public void test_buildModelTensorList() {
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList());
List<List<Float>> numbersList = new ArrayList<>();
numbersList.add(Collections.singletonList(1.0f));
Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList));
}

@Test
public void test_buildModelTensorList_exception() {
exceptionRule.expect(IllegalArgumentException.class);
MLPostProcessFunction.buildModelTensorList().apply(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import ai.djl.training.util.DownloadUtils;
import ai.djl.training.util.ProgressBar;
import com.google.gson.Gson;
import com.google.gson.stream.JsonReader;
import lombok.extern.log4j.Log4j2;
import org.opensearch.core.action.ActionListener;
Expand All @@ -32,6 +31,7 @@
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash;
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks;
Expand All @@ -48,11 +48,9 @@ public class ModelHelper {
public static final String PYTORCH_ENGINE = "PyTorch";
public static final String ONNX_ENGINE = "OnnxRuntime";
private final MLEngine mlEngine;
private Gson gson;

public ModelHelper(MLEngine mlEngine) {
this.mlEngine = mlEngine;
gson = new Gson();
}

public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput registerModelInput, ActionListener<MLRegisterModelInput> listener) {
Expand Down
Loading
Loading