Skip to content

Commit

Permalink
pass search extension to pipeline context
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Dec 31, 2024
1 parent cfa8a38 commit 3cf89dc
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonSyntaxException;
import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.InvalidJsonException;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;
import com.jayway.jsonpath.PathNotFoundException;
import com.networknt.schema.JsonSchema;
import com.networknt.schema.JsonSchemaFactory;
Expand Down Expand Up @@ -389,53 +387,73 @@ public static boolean pathExists(Object json, String path) {
/**
* Prepares nested structures in a JSON object based on the given field path.
*
* This method ensures that all intermediate nested objects exist in the JSON object
* This method ensures that all intermediate nested objects and arrays exist in the JSON object
* for a given field path. If any part of the path doesn't exist, it creates new empty objects
* (HashMaps) for those parts.
* (HashMaps) or arrays (ArrayLists) for those parts.
*
* @param jsonObject The JSON object to be updated.
* @param fieldPath The full path of the field, potentially including nested structures.
* @return The updated JSON object with necessary nested structures in place.
* The method can handle complex paths including both object properties and array indices.
* For example, it can process paths like "foo.bar[1].baz[0].qux".
*
* @throws IllegalArgumentException If there's an issue with JSON parsing or path manipulation.
* @param jsonObject The JSON object to be updated. If this is not a Map, a new Map will be created.
* @param fieldPath The full path of the field, potentially including nested structures and array indices.
* The path can optionally start with "$." which will be ignored if present.
* @return The updated JSON object with necessary nested structures in place.
* If the input was not a Map, returns the newly created Map structure.
*
* @implNote This method uses JsonPath for JSON manipulation and StringUtils for path existence checks.
* It handles paths both with and without a leading "$." notation.
* Each non-existent intermediate object in the path is created as an empty HashMap.
* @throws IllegalArgumentException If the field path is null or not a valid JSON path.
*
* @see JsonPath
* @see StringUtils
*/
public static Object prepareNestedStructures(Object jsonObject, String fieldPath) {

if (fieldPath == null) {
throw new IllegalArgumentException("the field path is null");
throw new IllegalArgumentException("The field path is null");
}
if (jsonObject == null) {
throw new IllegalArgumentException("The object is null");
}
if (!isValidJSONPath(fieldPath)) {
throw new IllegalArgumentException("the field path is not a valid json path: " + fieldPath);
throw new IllegalArgumentException("The field path is not a valid JSON path: " + fieldPath);
}

String path = fieldPath.startsWith("$.") ? fieldPath.substring(2) : fieldPath;
String[] pathParts = path.split("\\.");
Configuration suppressExceptionConfiguration = Configuration
.builder()
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL)
.build();
StringBuilder currentPath = new StringBuilder("$");

for (int i = 0; i < pathParts.length - 1; i++) {
currentPath.append(".").append(pathParts[i]);
if (!StringUtils.pathExists(jsonObject, currentPath.toString())) {
try {
jsonObject = JsonPath
.using(suppressExceptionConfiguration)
.parse(jsonObject)
.set(currentPath.toString(), new java.util.HashMap<>())
.json();
} catch (Exception e) {
throw new IllegalArgumentException("Error creating nested structure for path: " + currentPath, e);
String[] pathParts = path.split("(?<!\\\\)\\.");

Map<String, Object> current = (jsonObject instanceof Map) ? (Map<String, Object>) jsonObject : new HashMap<>();

for (String part : pathParts) {
if (part.contains("[")) {
// Handle array notation
String[] arrayParts = part.split("\\[");
String key = arrayParts[0];
int index = Integer.parseInt(arrayParts[1].replaceAll("\\]", ""));

if (!current.containsKey(key)) {
current.put(key, new ArrayList<>());
}
if (!(current.get(key) instanceof List)) {
return jsonObject;
}
List<Object> list = (List<Object>) current.get(key);
if (index >= list.size()) {
while (list.size() <= index) {
list.add(null);
}
list.set(index, new HashMap<>());
}
if (!(list.get(index) instanceof Map)) {
return jsonObject;
}
current = (Map<String, Object>) list.get(index);
} else {
// Handle object notation
if (!current.containsKey(part)) {
current.put(part, new HashMap<>());
} else if (!(current.get(part) instanceof Map)) {
return jsonObject;
}
current = (Map<String, Object>) current.get(part);
}
}

return jsonObject;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public void processRequestAsync(
if (request.source() == null) {
throw new IllegalArgumentException("query body is empty, cannot processor inference on empty query request.");
}

setRequestContextFromExt(request, requestContext);
String queryString = request.source().toString();
rewriteQueryString(request, queryString, requestListener, requestContext);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ public void processResponseAsync(
return;
}

setRequestContextFromExt(request, responseContext);

// if many to one, run rewriteResponseDocuments
if (!oneToOne) {
// use MLInferenceSearchResponseProcessor to allow writing to extension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder.PARAMETER_NAME;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -19,6 +20,7 @@

import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -33,6 +35,10 @@
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
Expand Down Expand Up @@ -335,4 +341,26 @@ default List<String> writeNewDotPathForNestedObject(Object json, String dotPath)
default String convertToDotPath(String path) {
return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", "");
}

default void setRequestContextFromExt(SearchRequest request, PipelineProcessingContext requestContext) {

List<SearchExtBuilder> extBuilderList = request.source().ext();
for (SearchExtBuilder ext : extBuilderList) {
if (ext instanceof MLInferenceRequestParametersExtBuilder) {
MLInferenceRequestParametersExtBuilder mlExtBuilder = (MLInferenceRequestParametersExtBuilder) ext;
Map<String, Object> mlParams = mlExtBuilder.getRequestParameters().getParams();
mlParams
.forEach(
(key, value) -> requestContext
.setAttribute(String.format("ext.%s.%s", MLInferenceRequestParametersExtBuilder.NAME, key), value)
);
}
if (ext instanceof GenerativeQAParamExtBuilder) {
GenerativeQAParamExtBuilder qaParamExtBuilder = (GenerativeQAParamExtBuilder) ext;
Map<String, Object> mlParams = (Map<String, Object>) qaParamExtBuilder.getParams();
mlParams.forEach((key, value) -> requestContext.setAttribute(String.format("ext.%s.%s", PARAMETER_NAME, key), value));
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,109 @@ public void onFailure(Exception e) {

}

/**
* Tests the successful rewriting of a complex nested array in query extension based on the model output.
* verify the pipelineConext is set from the extension
* @throws Exception if an error occurs during the test
*/
public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception {
String modelInputField = "inputs";
String originalQueryField = "ext.ml_inference.question";
String newQueryField = "ext.ml_inference.llm_response";
String modelOutputField = "response";
MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor(
null,
modelInputField,
originalQueryField,
newQueryField,
modelOutputField,
false,
false
);

// Test model return a complex nested array
Map<String, Object> nestedResponse = new HashMap<>();
List<Map<String, String>> languageList = new ArrayList<>();
languageList.add(Collections.singletonMap("eng", "0.95"));
languageList.add(Collections.singletonMap("es", "0.67"));
nestedResponse.put("language", languageList);
nestedResponse.put("type", "bert");

ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", nestedResponse)).build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();

doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
return null;
}).when(client).execute(any(), any(), any());

QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo");

Map<String, Object> llmQuestion = new HashMap<>();
llmQuestion.put("question", "what language is this text in?");
MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(llmQuestion);
MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder();
mlInferenceExtBuilder.setRequestParameters(requestParameters);
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).ext(List.of(mlInferenceExtBuilder));

SearchRequest request = new SearchRequest().source(source);

// Expecting new request with ml inference search extensions including the complex nested array
Map<String, Object> params = new HashMap<>();
params.put("question", "what language is this text in?");
params.put("llm_response", nestedResponse);
MLInferenceRequestParameters expectedRequestParameters = new MLInferenceRequestParameters(params);
MLInferenceRequestParametersExtBuilder expectedMlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder();
expectedMlInferenceExtBuilder.setRequestParameters(expectedRequestParameters);
SearchSourceBuilder expectedSource = new SearchSourceBuilder().query(incomingQuery).ext(List.of(expectedMlInferenceExtBuilder));
SearchRequest expectRequest = new SearchRequest().source(expectedSource);

ActionListener<SearchRequest> Listener = new ActionListener<>() {
@Override
public void onResponse(SearchRequest newSearchRequest) {
assertEquals(incomingQuery, newSearchRequest.source().query());
assertEquals(expectRequest.toString(), newSearchRequest.toString());

// Additional checks for the complex nested array
MLInferenceRequestParametersExtBuilder actualExtBuilder = (MLInferenceRequestParametersExtBuilder) newSearchRequest
.source()
.ext()
.get(0);
MLInferenceRequestParameters actualParams = actualExtBuilder.getRequestParameters();
Object actualResponse = actualParams.getParams().get("llm_response");

assertTrue(actualResponse instanceof Map);
Map<?, ?> actualNestedResponse = (Map<?, ?>) actualResponse;

// Check the "language" field
assertTrue(actualNestedResponse.get("language") instanceof List);
List<?> actualLanguageList = (List<?>) actualNestedResponse.get("language");
assertEquals(2, actualLanguageList.size());

Map<?, ?> engMap = (Map<?, ?>) actualLanguageList.get(0);
assertEquals("0.95", engMap.get("eng"));

Map<?, ?> esMap = (Map<?, ?>) actualLanguageList.get(1);
assertEquals("0.67", esMap.get("es"));

// Check the "type" field
assertEquals("bert", actualNestedResponse.get("type"));
verify(requestContext).setAttribute("ext.ml_inference.question", "what language is this text in?");
verify(requestContext).setAttribute("ext.ml_inference.llm_response", nestedResponse);

}

@Override
public void onFailure(Exception e) {
throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage());
}
};

requestProcessor.processRequestAsync(request, requestContext, Listener);
}

/**
* Helper method to create an instance of the MLInferenceSearchRequestProcessor with the specified parameters.
*
Expand Down
Loading

0 comments on commit 3cf89dc

Please sign in to comment.