diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 8879306773..fb8487a13b 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -34,7 +34,9 @@ import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.google.gson.JsonSyntaxException; +import com.jayway.jsonpath.InvalidJsonException; import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.PathNotFoundException; import com.networknt.schema.JsonSchema; import com.networknt.schema.JsonSchemaFactory; import com.networknt.schema.SpecVersion; @@ -347,6 +349,114 @@ public static JsonObject getJsonObjectFromString(String jsonString) { return JsonParser.parseString(jsonString).getAsJsonObject(); } + /** + * Checks if a specified JSON path exists within a given JSON object. + * + * This method attempts to read the value at the specified path in the JSON object. + * If the path exists, it returns true. If a PathNotFoundException is thrown, + * indicating that the path does not exist, it returns false. + * + * @param json The JSON object to check. This can be a Map, List, or any object + * that JsonPath can parse. + * @param path The JSON path to check for existence. This should be a valid + * JsonPath expression (e.g., "$.store.book[0].title"). + * @return true if the path exists in the JSON object, false otherwise. + * @throws IllegalArgumentException if the json object is null or if the path is null or empty. + * @throws PathNotFoundException if there's an error in parsing the JSON or the path. + */ + public static boolean pathExists(Object json, String path) { + if (json == null) { + throw new IllegalArgumentException("JSON object cannot be null"); + } + if (path == null || path.isEmpty()) { + throw new IllegalArgumentException("Path cannot be null or empty"); + } + if (!isValidJSONPath(path)) { + throw new IllegalArgumentException("the field path is not a valid json path: " + path); + } + try { + JsonPath.read(json, path); + return true; + } catch (PathNotFoundException e) { + return false; + } catch (InvalidJsonException e) { + throw new IllegalArgumentException("Invalid JSON input", e); + } + } + + /** + * Prepares nested structures in a JSON object based on the given field path. + * + * This method ensures that all intermediate nested objects and arrays exist in the JSON object + * for a given field path. If any part of the path doesn't exist, it creates new empty objects + * (HashMaps) or arrays (ArrayLists) for those parts. + * + * The method can handle complex paths including both object properties and array indices. + * For example, it can process paths like "foo.bar[1].baz[0].qux". + * + * @param jsonObject The JSON object to be updated. If this is not a Map, a new Map will be created. + * @param fieldPath The full path of the field, potentially including nested structures and array indices. + * The path can optionally start with "$." which will be ignored if present. + * @return The updated JSON object with necessary nested structures in place. + * If the input was not a Map, returns the newly created Map structure. + * + * @throws IllegalArgumentException If the field path is null or not a valid JSON path. + * + */ + public static Object prepareNestedStructures(Object jsonObject, String fieldPath) { + if (fieldPath == null) { + throw new IllegalArgumentException("The field path is null"); + } + if (jsonObject == null) { + throw new IllegalArgumentException("The object is null"); + } + if (!isValidJSONPath(fieldPath)) { + throw new IllegalArgumentException("The field path is not a valid JSON path: " + fieldPath); + } + + String path = fieldPath.startsWith("$.") ? fieldPath.substring(2) : fieldPath; + String[] pathParts = path.split("(? current = (jsonObject instanceof Map) ? (Map) jsonObject : new HashMap<>(); + + for (String part : pathParts) { + if (part.contains("[")) { + // Handle array notation + String[] arrayParts = part.split("\\["); + String key = arrayParts[0]; + int index = Integer.parseInt(arrayParts[1].replaceAll("\\]", "")); + + if (!current.containsKey(key)) { + current.put(key, new ArrayList<>()); + } + if (!(current.get(key) instanceof List)) { + return jsonObject; + } + List list = (List) current.get(key); + if (index >= list.size()) { + while (list.size() <= index) { + list.add(null); + } + list.set(index, new HashMap<>()); + } + if (!(list.get(index) instanceof Map)) { + return jsonObject; + } + current = (Map) list.get(index); + } else { + // Handle object notation + if (!current.containsKey(part)) { + current.put(part, new HashMap<>()); + } else if (!(current.get(part) instanceof Map)) { + return jsonObject; + } + current = (Map) current.get(part); + } + } + + return jsonObject; + } + public static void validateSchema(String schemaString, String instanceString) { try { // parse the schema JSON as string diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index d440c44faf..72ec6a05ba 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -6,7 +6,10 @@ package org.opensearch.ml.common.utils; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME; import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes; import static org.opensearch.ml.common.utils.StringUtils.getJsonPath; @@ -22,6 +25,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import org.apache.commons.text.StringSubstitutor; @@ -29,40 +33,42 @@ import org.junit.Test; import org.opensearch.OpenSearchParseException; +import com.jayway.jsonpath.JsonPath; + public class StringUtilsTest { @Test public void isJson_True() { - Assert.assertTrue(StringUtils.isJson("{}")); - Assert.assertTrue(StringUtils.isJson("[]")); - Assert.assertTrue(StringUtils.isJson("{\"key\": \"value\"}")); - Assert.assertTrue(StringUtils.isJson("{\"key\": 123}")); - Assert.assertTrue(StringUtils.isJson("[1, 2, 3]")); - Assert.assertTrue(StringUtils.isJson("[\"a\", \"b\"]")); - Assert.assertTrue(StringUtils.isJson("[1, \"a\"]")); - Assert.assertTrue(StringUtils.isJson("{\"key1\": \"value\", \"key2\": 123}")); - Assert.assertTrue(StringUtils.isJson("{}")); - Assert.assertTrue(StringUtils.isJson("[]")); - Assert.assertTrue(StringUtils.isJson("[ ]")); - Assert.assertTrue(StringUtils.isJson("[,]")); - Assert.assertTrue(StringUtils.isJson("[abc]")); - Assert.assertTrue(StringUtils.isJson("[\"abc\", 123]")); + assertTrue(StringUtils.isJson("{}")); + assertTrue(StringUtils.isJson("[]")); + assertTrue(StringUtils.isJson("{\"key\": \"value\"}")); + assertTrue(StringUtils.isJson("{\"key\": 123}")); + assertTrue(StringUtils.isJson("[1, 2, 3]")); + assertTrue(StringUtils.isJson("[\"a\", \"b\"]")); + assertTrue(StringUtils.isJson("[1, \"a\"]")); + assertTrue(StringUtils.isJson("{\"key1\": \"value\", \"key2\": 123}")); + assertTrue(StringUtils.isJson("{}")); + assertTrue(StringUtils.isJson("[]")); + assertTrue(StringUtils.isJson("[ ]")); + assertTrue(StringUtils.isJson("[,]")); + assertTrue(StringUtils.isJson("[abc]")); + assertTrue(StringUtils.isJson("[\"abc\", 123]")); } @Test public void isJson_False() { - Assert.assertFalse(StringUtils.isJson("{")); - Assert.assertFalse(StringUtils.isJson("[")); - Assert.assertFalse(StringUtils.isJson("{\"key\": \"value}")); - Assert.assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}")); - Assert.assertFalse(StringUtils.isJson("[1, \"a]")); - Assert.assertFalse(StringUtils.isJson("[]\"")); - Assert.assertFalse(StringUtils.isJson("[ ]\"")); - Assert.assertFalse(StringUtils.isJson("[,]\"")); - Assert.assertFalse(StringUtils.isJson("[,\"]")); - Assert.assertFalse(StringUtils.isJson("[]\"123\"")); - Assert.assertFalse(StringUtils.isJson("[abc\"]")); - Assert.assertFalse(StringUtils.isJson("[abc\n123]")); + assertFalse(StringUtils.isJson("{")); + assertFalse(StringUtils.isJson("[")); + assertFalse(StringUtils.isJson("{\"key\": \"value}")); + assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}")); + assertFalse(StringUtils.isJson("[1, \"a]")); + assertFalse(StringUtils.isJson("[]\"")); + assertFalse(StringUtils.isJson("[ ]\"")); + assertFalse(StringUtils.isJson("[,]\"")); + assertFalse(StringUtils.isJson("[,\"]")); + assertFalse(StringUtils.isJson("[]\"123\"")); + assertFalse(StringUtils.isJson("[abc\"]")); + assertFalse(StringUtils.isJson("[abc\n123]")); } @Test @@ -84,7 +90,7 @@ public void fromJson_NestedMap() { Map response = StringUtils .fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); assertEquals(1, response.size()); - Assert.assertTrue(response.get("key") instanceof Map); + assertTrue(response.get("key") instanceof Map); Map nestedMap = (Map) response.get("key"); assertEquals("nested_value", nestedMap.get("nested_key")); List list = (List) nestedMap.get("nested_array"); @@ -97,7 +103,7 @@ public void fromJson_NestedMap() { public void fromJson_SimpleList() { Map response = StringUtils.fromJson("[1, \"a\"]", "response"); assertEquals(1, response.size()); - Assert.assertTrue(response.get("response") instanceof List); + assertTrue(response.get("response") instanceof List); List list = (List) response.get("response"); assertEquals(1.0, list.get(0)); assertEquals("a", list.get(1)); @@ -107,12 +113,12 @@ public void fromJson_SimpleList() { public void fromJson_NestedList() { Map response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response"); assertEquals(1, response.size()); - Assert.assertTrue(response.get("response") instanceof List); + assertTrue(response.get("response") instanceof List); List list = (List) response.get("response"); assertEquals(1.0, list.get(0)); assertEquals("a", list.get(1)); - Assert.assertTrue(list.get(2) instanceof List); - Assert.assertTrue(list.get(3) instanceof Map); + assertTrue(list.get(2) instanceof List); + assertTrue(list.get(3) instanceof Map); } @Test @@ -152,23 +158,23 @@ public void processTextDocs() { List processedDocs = StringUtils.processTextDocs(Arrays.asList("abc \n\n123\"4", null, "[1.01,\"abc\"]")); assertEquals(3, processedDocs.size()); assertEquals("abc \\n\\n123\\\"4", processedDocs.get(0)); - Assert.assertNull(processedDocs.get(1)); + assertNull(processedDocs.get(1)); assertEquals("[1.01,\\\"abc\\\"]", processedDocs.get(2)); } @Test public void isEscapeUsed() { - Assert.assertFalse(StringUtils.isEscapeUsed("String escape")); - Assert.assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")")); + assertFalse(StringUtils.isEscapeUsed("String escape")); + assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")")); } @Test public void containsEscapeMethod() { - Assert.assertFalse(StringUtils.containsEscapeMethod("String escape")); - Assert.assertFalse(StringUtils.containsEscapeMethod("String escape()")); - Assert.assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")")); - Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)")); - Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(String input)")); + assertFalse(StringUtils.containsEscapeMethod("String escape")); + assertFalse(StringUtils.containsEscapeMethod("String escape()")); + assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")")); + assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)")); + assertTrue(StringUtils.containsEscapeMethod("String escape(String input)")); } @Test @@ -183,7 +189,7 @@ public void addDefaultMethod_Escape() { String input = "return escape(\"abc\n123\");"; String result = StringUtils.addDefaultMethod(input); Assert.assertNotEquals(input, result); - Assert.assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION)); + assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION)); } @Test @@ -464,51 +470,174 @@ public void testGetJsonPath_ValidJsonPathWithoutSource() { @Test public void testisValidJSONPath_InvalidInputs() { - Assert.assertFalse(isValidJSONPath("..bar")); - Assert.assertFalse(isValidJSONPath(".")); - Assert.assertFalse(isValidJSONPath("..")); - Assert.assertFalse(isValidJSONPath("foo.bar.")); - Assert.assertFalse(isValidJSONPath(".foo.bar.")); + assertFalse(isValidJSONPath("..bar")); + assertFalse(isValidJSONPath(".")); + assertFalse(isValidJSONPath("..")); + assertFalse(isValidJSONPath("foo.bar.")); + assertFalse(isValidJSONPath(".foo.bar.")); } @Test public void testisValidJSONPath_NullInput() { - Assert.assertFalse(isValidJSONPath(null)); + assertFalse(isValidJSONPath(null)); } @Test public void testisValidJSONPath_EmptyInput() { - Assert.assertFalse(isValidJSONPath("")); + assertFalse(isValidJSONPath("")); } @Test public void testisValidJSONPath_ValidInputs() { - Assert.assertTrue(isValidJSONPath("foo")); - Assert.assertTrue(isValidJSONPath("foo.bar")); - Assert.assertTrue(isValidJSONPath("foo.bar.baz")); - Assert.assertTrue(isValidJSONPath("foo.bar.baz.qux")); - Assert.assertTrue(isValidJSONPath(".foo")); - Assert.assertTrue(isValidJSONPath("$.foo")); - Assert.assertTrue(isValidJSONPath(".foo.bar")); - Assert.assertTrue(isValidJSONPath("$.foo.bar")); + assertTrue(isValidJSONPath("foo")); + assertTrue(isValidJSONPath("foo.bar")); + assertTrue(isValidJSONPath("foo.bar.baz")); + assertTrue(isValidJSONPath("foo.bar.baz.qux")); + assertTrue(isValidJSONPath(".foo")); + assertTrue(isValidJSONPath("$.foo")); + assertTrue(isValidJSONPath(".foo.bar")); + assertTrue(isValidJSONPath("$.foo.bar")); } @Test public void testisValidJSONPath_WithFilter() { - Assert.assertTrue(isValidJSONPath("$.store['book']")); - Assert.assertTrue(isValidJSONPath("$['store']['book'][0]['title']")); - Assert.assertTrue(isValidJSONPath("$.store.book[0]")); - Assert.assertTrue(isValidJSONPath("$.store.book[1,2]")); - Assert.assertTrue(isValidJSONPath("$.store.book[-1:] ")); - Assert.assertTrue(isValidJSONPath("$.store.book[0:2]")); - Assert.assertTrue(isValidJSONPath("$.store.book[*]")); - Assert.assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]")); - Assert.assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]")); - Assert.assertTrue(isValidJSONPath("$..author")); - Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 15)]")); - Assert.assertTrue(isValidJSONPath("$.store.book[0,1]")); - Assert.assertTrue(isValidJSONPath("$['store','warehouse']")); - Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title")); + assertTrue(isValidJSONPath("$.store['book']")); + assertTrue(isValidJSONPath("$['store']['book'][0]['title']")); + assertTrue(isValidJSONPath("$.store.book[0]")); + assertTrue(isValidJSONPath("$.store.book[1,2]")); + assertTrue(isValidJSONPath("$.store.book[-1:] ")); + assertTrue(isValidJSONPath("$.store.book[0:2]")); + assertTrue(isValidJSONPath("$.store.book[*]")); + assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]")); + assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]")); + assertTrue(isValidJSONPath("$..author")); + assertTrue(isValidJSONPath("$..book[?(@.price > 15)]")); + assertTrue(isValidJSONPath("$.store.book[0,1]")); + assertTrue(isValidJSONPath("$['store','warehouse']")); + assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title")); + } + + @Test + public void testPathExists_ExistingPath() { + Object json = JsonPath.parse("{\"a\":{\"b\":42}}").json(); + assertTrue(StringUtils.pathExists(json, "$.a.b")); + } + + @Test + public void testPathExists_NonExistingPath() { + Object json = JsonPath.parse("{\"a\":{\"b\":42}}").json(); + assertFalse(StringUtils.pathExists(json, "$.a.c")); + } + + @Test + public void testPathExists_EmptyObject() { + Object json = JsonPath.parse("{}").json(); + assertFalse(StringUtils.pathExists(json, "$.a")); + } + + @Test + public void testPathExists_NullJson() { + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(null, "$.a")); + } + + @Test + public void testPathExists_NullPath() { + Object json = JsonPath.parse("{\"a\":42}").json(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(json, null)); + } + + @Test + public void testPathExists_EmptyPath() { + Object json = JsonPath.parse("{\"a\":42}").json(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(json, "")); + } + + @Test + public void testPathExists_InvalidPath() { + Object json = JsonPath.parse("{\"a\":42}").json(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.pathExists(json, "This is not a valid path")); + } + + @Test + public void testPathExists_ArrayElement() { + Object json = JsonPath.parse("{\"a\":[1,2,3]}").json(); + assertTrue(StringUtils.pathExists(json, "$.a[1]")); + assertFalse(StringUtils.pathExists(json, "$.a[3]")); + } + + @Test + public void testPathExists_NestedStructure() { + Object json = JsonPath.parse("{\"a\":{\"b\":{\"c\":{\"d\":42}}}}").json(); + assertTrue(StringUtils.pathExists(json, "$.a.b.c.d")); + assertFalse(StringUtils.pathExists(json, "$.a.b.c.e")); + } + + @Test + public void testPrepareNestedStructures_EmptyObject() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c"); + assertTrue(JsonPath.read(result, "$.a.b") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_ExistingStructure() { + Object jsonObject = JsonPath.parse("{\"a\":{\"b\":{}}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c"); + assertTrue(JsonPath.read(result, "$.a.b") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_PartiallyExistingStructure() { + Object jsonObject = JsonPath.parse("{\"a\":{}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c.d"); + assertTrue(JsonPath.read(result, "$.a.b.c") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_WithDollarSign() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "$.a.b.c"); + assertTrue(JsonPath.read(result, "$.a.b") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_SingleLevel() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a"); + assertEquals(jsonObject, result); + } + + @Test + public void testPrepareNestedStructures_ExistingValue() { + Object jsonObject = JsonPath.parse("{\"a\":{\"b\":42}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c"); + assertEquals(Optional.ofNullable(42), Optional.ofNullable(JsonPath.read(result, "$.a.b"))); + } + + @Test + public void testPrepareNestedStructures_NullInput() { + assertThrows(IllegalArgumentException.class, () -> StringUtils.prepareNestedStructures(null, "a.b.c")); + } + + @Test + public void testPrepareNestedStructures_NullPath() { + Object jsonObject = new HashMap<>(); + assertThrows(IllegalArgumentException.class, () -> StringUtils.prepareNestedStructures(jsonObject, null)); + } + + @Test + public void testPrepareNestedStructures_ComplexPath() { + Object jsonObject = new HashMap<>(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.b.c.d.e.f"); + assertTrue(JsonPath.read(result, "$.a.b.c.d.e") instanceof Map); + } + + @Test + public void testPrepareNestedStructures_MixedExistingAndNew() { + Object jsonObject = JsonPath.parse("{\"a\":{\"b\":42,\"c\":{}}}").json(); + Object result = StringUtils.prepareNestedStructures(jsonObject, "a.c.d.e"); + assertEquals(Optional.of(42), Optional.of(JsonPath.read(result, "$.a.b"))); + assertTrue(JsonPath.read(result, "$.a.c.d") instanceof Map); } @Test diff --git a/memory/build.gradle b/memory/build.gradle index 940a6b9621..86198c4521 100644 --- a/memory/build.gradle +++ b/memory/build.gradle @@ -41,6 +41,7 @@ dependencies { testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' + testImplementation 'com.jayway.jsonpath:json-path:2.9.0' } test { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 065e0ec371..5067bdc138 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -273,6 +273,7 @@ import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; import org.opensearch.ml.rest.RestMemoryUpdateConversationAction; import org.opensearch.ml.rest.RestMemoryUpdateInteractionAction; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; @@ -999,6 +1000,15 @@ public List> getSearchExts() { ) ); + searchExts + .add( + new SearchPlugin.SearchExtSpec<>( + MLInferenceRequestParametersExtBuilder.NAME, + input -> new MLInferenceRequestParametersExtBuilder(input), + parser -> MLInferenceRequestParametersExtBuilder.parse(parser) + ) + ); + return searchExts; } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java index 8782addc82..ccfb2b7ab6 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java @@ -5,6 +5,7 @@ package org.opensearch.ml.processor; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.toJson; import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS; import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG; @@ -46,6 +47,7 @@ import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; import com.jayway.jsonpath.Option; +import com.jayway.jsonpath.PathNotFoundException; import com.jayway.jsonpath.ReadContext; /** @@ -147,10 +149,9 @@ public void processRequestAsync( if (request.source() == null) { throw new IllegalArgumentException("query body is empty, cannot processor inference on empty query request."); } - + setRequestContextFromExt(request, requestContext); String queryString = request.source().toString(); - - rewriteQueryString(request, queryString, requestListener); + rewriteQueryString(request, queryString, requestListener, requestContext); } catch (Exception e) { if (ignoreFailure) { @@ -164,13 +165,18 @@ public void processRequestAsync( /** * Rewrites the query string based on the input and output mappings and the ML model output. * - * @param request the {@link SearchRequest} to be rewritten - * @param queryString the original query string + * @param request the {@link SearchRequest} to be rewritten + * @param queryString the original query string * @param requestListener the {@link ActionListener} to be notified when the rewriting is complete + * @param requestContext * @throws IOException if an I/O error occurs during the rewriting process */ - private void rewriteQueryString(SearchRequest request, String queryString, ActionListener requestListener) - throws IOException { + private void rewriteQueryString( + SearchRequest request, + String queryString, + ActionListener requestListener, + PipelineProcessingContext requestContext + ) throws IOException { List> processInputMap = inferenceProcessorAttributes.getInputMaps(); List> processOutputMap = inferenceProcessorAttributes.getOutputMaps(); int inputMapSize = (processInputMap != null) ? processInputMap.size() : 0; @@ -198,7 +204,8 @@ private void rewriteQueryString(SearchRequest request, String queryString, Actio request, queryString, requestListener, - processOutputMap + processOutputMap, + requestContext ); GroupedActionListener> batchPredictionListener = createBatchPredictionListener( rewriteRequestListener, @@ -219,13 +226,15 @@ private void rewriteQueryString(SearchRequest request, String queryString, Actio * @param queryString the original query string * @param requestListener the {@link ActionListener} to be notified when the query string or query template is updated * @param processOutputMap the list of output mappings + * @param requestContext * @return an {@link ActionListener} that handles the response from the ML model inference */ private ActionListener> createRewriteRequestListener( SearchRequest request, String queryString, ActionListener requestListener, - List> processOutputMap + List> processOutputMap, + PipelineProcessingContext requestContext ) { return new ActionListener<>() { @Override @@ -237,12 +246,10 @@ public void onResponse(Map multipleMLOutputs) { try { if (queryTemplate == null) { Object incomeQueryObject = JsonPath.parse(queryString).read("$"); - updateIncomeQueryObject(incomeQueryObject, outputMapping, mlOutput); - SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder( - xContentRegistry, - StringUtils.toJson(incomeQueryObject) - ); + updateIncomeQueryObject(incomeQueryObject, outputMapping, mlOutput, requestContext); + SearchSourceBuilder searchSourceBuilder = getSearchSourceBuilder(xContentRegistry, toJson(incomeQueryObject)); request.source(searchSourceBuilder); + requestListener.onResponse(request); } else { String newQueryString = updateQueryTemplate(queryTemplate, outputMapping, mlOutput); @@ -273,13 +280,52 @@ public void onFailure(Exception e) { } } - private void updateIncomeQueryObject(Object incomeQueryObject, Map outputMapping, MLOutput mlOutput) { + /** + * Updates the income query object with values from the ML output based on the provided output mapping. + * + * This method iterates through the output mapping, retrieves corresponding values from the ML output, + * and updates the income query object accordingly. It also handles nested JSON structures and updates + * the request context with the new values. + * + * @param incomeQueryObject The object representing the income query to be updated. + * @param outputMapping A map containing the mapping between new query fields and model output field names. + * @param mlOutput The MLOutput object containing the results from the machine learning model. + * @param requestContext The context object for the current pipeline processing request. + * + * @throws IllegalArgumentException If a specified JSON path cannot be found in the query string. + * + * @implNote This method uses JsonPath for JSON manipulation and supports both regular and extended (ext) fields. + * For extended fields, it creates nested structures if they don't exist. + * The method also updates the request context with new field values for further processing. + * + * @see JsonPath + * @see PipelineProcessingContext + * @see MLOutput + */ + private void updateIncomeQueryObject( + Object incomeQueryObject, + Map outputMapping, + MLOutput mlOutput, + PipelineProcessingContext requestContext + ) { for (Map.Entry outputMapEntry : outputMapping.entrySet()) { - String newQueryField = outputMapEntry.getKey(); - String modelOutputFieldName = outputMapEntry.getValue(); - Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); - String jsonPathExpression = "$." + newQueryField; - JsonPath.parse(incomeQueryObject).set(jsonPathExpression, modelOutputValue); + String newQueryField = null; + try { + newQueryField = outputMapEntry.getKey(); + String modelOutputFieldName = outputMapEntry.getValue(); + Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); + + if (newQueryField.startsWith("$.ext.") || newQueryField.startsWith("ext.")) { + incomeQueryObject = StringUtils.prepareNestedStructures(incomeQueryObject, newQueryField); + } + + JsonPath.using(suppressExceptionConfiguration).parse(incomeQueryObject).set(newQueryField, modelOutputValue); + + requestContext.setAttribute(newQueryField, modelOutputValue); + + } catch (PathNotFoundException e) { + throw new IllegalArgumentException("can not find path " + newQueryField + "in query string"); + } } } @@ -300,12 +346,12 @@ private String updateQueryTemplate(String queryTemplate, Map out /** * Creates a {@link GroupedActionListener} that collects the responses from multiple ML model inferences. * - * @param rewriteRequestListner the {@link ActionListener} to be notified when all ML model inferences are complete + * @param rewriteRequestListener the {@link ActionListener} to be notified when all ML model inferences are complete * @param inputMapSize the number of input mappings * @return a {@link GroupedActionListener} that handles the responses from multiple ML model inferences */ private GroupedActionListener> createBatchPredictionListener( - ActionListener> rewriteRequestListner, + ActionListener> rewriteRequestListener, int inputMapSize ) { return new GroupedActionListener<>(new ActionListener<>() { @@ -315,13 +361,13 @@ public void onResponse(Collection> mlOutputMapCollection) for (Map mlOutputMap : mlOutputMapCollection) { mlOutputMaps.putAll(mlOutputMap); } - rewriteRequestListner.onResponse(mlOutputMaps); + rewriteRequestListener.onResponse(mlOutputMaps); } @Override public void onFailure(Exception e) { logger.error("Prediction Failed:", e); - rewriteRequestListner.onFailure(e); + rewriteRequestListener.onFailure(e); } }, Math.max(inputMapSize, 1)); } @@ -358,11 +404,12 @@ private boolean validateQueryFieldInQueryString( for (Map outputMap : processOutputMap) { for (Map.Entry entry : outputMap.entrySet()) { String queryField = entry.getKey(); - Object pathData = jsonData.read(queryField); - if (pathData == null) { - throw new IllegalArgumentException( - "cannot find field: " + queryField + " in query string: " + jsonData.jsonString() - ); + // output writing to search extension can be new field + if (!queryField.startsWith("ext.") && !queryField.startsWith("$.ext.")) { + Object pathData = jsonData.read(queryField); + if (pathData == null) { + throw new IllegalArgumentException(); + } } } } @@ -402,7 +449,7 @@ private void processPredictions( // model field as key, query field name as value String modelInputFieldName = entry.getKey(); String queryFieldName = entry.getValue(); - String queryFieldValue = StringUtils.toJson(JsonPath.parse(newQuery).read(queryFieldName)); + String queryFieldValue = toJson(JsonPath.parse(newQuery).read(queryFieldName)); modelParameters.put(modelInputFieldName, queryFieldValue); } } @@ -446,13 +493,18 @@ public void onFailure(Exception e) { /** * Creates a SearchSourceBuilder instance from the given query string. * + * This method parses the provided query string, substitutes parameters, and constructs + * a SearchSourceBuilder object. It handles JSON content and performs variable substitution + * using a StringSubstitutor. + * * @param xContentRegistry the XContentRegistry instance to be used for parsing - * @param queryString the query template string to be parsed + * @param queryString the query template string to be parsed * @return a SearchSourceBuilder instance created from the query string - * @throws IOException if an I/O error occurs during parsing + * @throws IOException if an I/O error occurs during parsing or content creation */ private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry xContentRegistry, String queryString) throws IOException { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); XContentParser queryParser = XContentType.JSON @@ -461,7 +513,9 @@ private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry ensureExpectedToken(XContentParser.Token.START_OBJECT, queryParser.nextToken(), queryParser); searchSourceBuilder.parseXContent(queryParser); + return searchSourceBuilder; + } /** diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index e39b7f4b74..67df00d6ca 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -162,6 +162,8 @@ public void processResponseAsync( return; } + setRequestContextFromExt(request, responseContext); + // if many to one, run rewriteResponseDocuments if (!oneToOne) { // use MLInferenceSearchResponseProcessor to allow writing to extension diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index d32308d2ef..bd0eff429b 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder.PARAMETER_NAME; import java.io.IOException; import java.util.ArrayList; @@ -19,6 +20,7 @@ import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; +import org.opensearch.action.search.SearchRequest; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -33,6 +35,10 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; import com.jayway.jsonpath.Configuration; import com.jayway.jsonpath.JsonPath; @@ -105,11 +111,11 @@ default ActionRequest getMLModelInferenceRequest( * Retrieves the model output value from the given ModelTensorOutput for the specified modelOutputFieldName. * It handles cases where the output contains a single tensor or multiple tensors. * - * @param modelTensorOutput the ModelTensorOutput containing the model output - * @param modelOutputFieldName the name of the field in the model output to retrieve the value for - * @param ignoreMissing a flag indicating whether to ignore missing fields or throw an exception + * @param modelTensorOutput the ModelTensorOutput containing the model output + * @param modelOutputFieldName the name of the field in the model output to retrieve the value for + * @param ignoreMissing a flag indicating whether to ignore missing fields or throw an exception * @return the model output value as an Object - * @throws RuntimeException if there is an error retrieving the model output value + * @throws RuntimeException if there is an error retrieving the model output value */ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String modelOutputFieldName, boolean ignoreMissing) { Object modelOutputValue; @@ -298,6 +304,7 @@ default boolean hasField(Object json, String path) { * Writes a new dot path for a nested object within the given JSON object. * This method is useful when dealing with arrays or nested objects in the JSON structure. * for example foo.*.bar.*.quk to be [foo.0.bar.0.quk, foo.0.bar.1.quk..] + * * @param json the JSON object containing the nested object * @param dotPath the dot path representing the location of the nested object * @return a list of dot paths representing the new locations of the nested object @@ -334,4 +341,26 @@ default List writeNewDotPathForNestedObject(Object json, String dotPath) default String convertToDotPath(String path) { return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", ""); } + + default void setRequestContextFromExt(SearchRequest request, PipelineProcessingContext requestContext) { + + List extBuilderList = request.source().ext(); + for (SearchExtBuilder ext : extBuilderList) { + if (ext instanceof MLInferenceRequestParametersExtBuilder) { + MLInferenceRequestParametersExtBuilder mlExtBuilder = (MLInferenceRequestParametersExtBuilder) ext; + Map mlParams = mlExtBuilder.getRequestParameters().getParams(); + mlParams + .forEach( + (key, value) -> requestContext + .setAttribute(String.format("ext.%s.%s", MLInferenceRequestParametersExtBuilder.NAME, key), value) + ); + } + if (ext instanceof GenerativeQAParamExtBuilder) { + GenerativeQAParamExtBuilder qaParamExtBuilder = (GenerativeQAParamExtBuilder) ext; + Map mlParams = (Map) qaParamExtBuilder.getParams(); + mlParams.forEach((key, value) -> requestContext.setAttribute(String.format("ext.%s.%s", PARAMETER_NAME, key), value)); + } + } + + } } diff --git a/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParameters.java b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParameters.java new file mode 100644 index 0000000000..042b3d915e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParameters.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.ml.searchext; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@NoArgsConstructor +public class MLInferenceRequestParameters implements Writeable, ToXContentObject { + static final String ML_INFERENCE_FIELD = "ml_inference"; + + @Setter + @Getter + private Map params; + + public MLInferenceRequestParameters(Map params) { + this.params = params; + + } + + public MLInferenceRequestParameters(StreamInput input) throws IOException { + this.params = input.readMap(); + } + + /** + * Write this into the {@linkplain StreamOutput}. + * + * @param out + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(this.params); + } + + public static MLInferenceRequestParameters parse(XContentParser parser) throws IOException { + return new MLInferenceRequestParameters(parser.map()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(ML_INFERENCE_FIELD); + return builder.map(this.params); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + MLInferenceRequestParameters config = (MLInferenceRequestParameters) o; + + return params.equals(config.getParams()); + } + + @Override + public int hashCode() { + return Objects.hash(params); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilder.java b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilder.java new file mode 100644 index 0000000000..c8c9ffd8aa --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilder.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.ml.searchext; + +import static org.opensearch.ml.searchext.MLInferenceRequestParameters.ML_INFERENCE_FIELD; + +import java.io.IOException; +import java.util.Objects; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +public class MLInferenceRequestParametersExtBuilder extends SearchExtBuilder { + private static final Logger logger = LogManager.getLogger(MLInferenceRequestParametersExtBuilder.class); + public static final String NAME = ML_INFERENCE_FIELD; + private MLInferenceRequestParameters requestParameters; + + public MLInferenceRequestParametersExtBuilder() {} + + public MLInferenceRequestParametersExtBuilder(StreamInput input) throws IOException { + this.requestParameters = new MLInferenceRequestParameters(input); + } + + public MLInferenceRequestParameters getRequestParameters() { + return requestParameters; + } + + public void setRequestParameters(MLInferenceRequestParameters requestParameters) { + this.requestParameters = requestParameters; + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.requestParameters); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (!(obj instanceof MLInferenceRequestParametersExtBuilder)) { + return false; + } + MLInferenceRequestParametersExtBuilder o = (MLInferenceRequestParametersExtBuilder) obj; + return this.requestParameters.equals(o.requestParameters); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + requestParameters.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(requestParameters); + } + + public static MLInferenceRequestParametersExtBuilder parse(XContentParser parser) throws IOException { + + MLInferenceRequestParametersExtBuilder extBuilder = new MLInferenceRequestParametersExtBuilder(); + MLInferenceRequestParameters requestParameters = MLInferenceRequestParameters.parse(parser); + extBuilder.setRequestParameters(requestParameters); + return extBuilder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java new file mode 100644 index 0000000000..1073a55b40 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtil.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.ml.searchext; + +import java.util.List; +import java.util.stream.Collectors; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.SearchExtBuilder; + +public class MLInferenceRequestParametersUtil { + + public static MLInferenceRequestParameters getMLInferenceRequestParameters(SearchRequest searchRequest) { + MLInferenceRequestParametersExtBuilder mLInferenceRequestParametersExtBuilder = null; + if (searchRequest.source() != null && searchRequest.source().ext() != null && !searchRequest.source().ext().isEmpty()) { + List extBuilders = searchRequest + .source() + .ext() + .stream() + .filter(extBuilder -> MLInferenceRequestParametersExtBuilder.NAME.equals(extBuilder.getWriteableName())) + .collect(Collectors.toList()); + + if (!extBuilders.isEmpty()) { + mLInferenceRequestParametersExtBuilder = (MLInferenceRequestParametersExtBuilder) extBuilders.get(0); + } + } + MLInferenceRequestParameters mlInferenceRequestParameters = null; + if (mLInferenceRequestParametersExtBuilder != null) { + mlInferenceRequestParameters = mLInferenceRequestParametersExtBuilder.getRequestParameters(); + } + return mlInferenceRequestParameters; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java index 6da9cb406a..0ff85c939a 100644 --- a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -40,6 +40,7 @@ import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; @@ -66,9 +67,11 @@ public void setUp() { @Test public void testGetSearchExts() { List> searchExts = plugin.getSearchExts(); - assertEquals(1, searchExts.size()); - SearchPlugin.SearchExtSpec spec = searchExts.get(0); - assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName()); + assertEquals(2, searchExts.size()); + SearchPlugin.SearchExtSpec spec1 = searchExts.get(0); + assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec1.getName().getPreferredName()); + SearchPlugin.SearchExtSpec spec2 = searchExts.get(1); + assertEquals(MLInferenceRequestParametersExtBuilder.NAME, spec2.getName().getPreferredName()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java index 353d2be1a3..a8fb6a2b59 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java @@ -35,6 +35,9 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.ml.searchext.MLInferenceRequestParameters; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; +import org.opensearch.plugins.SearchPlugin; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.PipelineProcessingContext; @@ -48,15 +51,27 @@ public class MLInferenceSearchRequestProcessorTests extends AbstractBuilderTestC @Mock private PipelineProcessingContext requestContext; - static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( - new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() - ); + static public NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY; private static final String PROCESSOR_TAG = "inference"; private static final String DESCRIPTION = "inference_test"; @Before public void setup() { MockitoAnnotations.openMocks(this); + + TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of(new SearchPlugin() { + @Override + public List> getSearchExts() { + return List + .of( + new SearchExtSpec<>( + MLInferenceRequestParametersExtBuilder.NAME, + MLInferenceRequestParametersExtBuilder::new, + parser -> MLInferenceRequestParametersExtBuilder.parse(parser) + ) + ); + } + })).getNamedXContents()); } /** @@ -183,7 +198,7 @@ public void onResponse(SearchRequest newSearchRequest) { @Override public void onFailure(Exception e) { - throw new RuntimeException("Failed in executing processRequestAsync."); + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); } }; @@ -240,7 +255,7 @@ public void onResponse(SearchRequest newSearchRequest) { @Override public void onFailure(Exception e) { - throw new RuntimeException("Failed in executing processRequestAsync."); + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); } }; @@ -1021,6 +1036,242 @@ public void onFailure(Exception ex) { } + /** + * Tests the successful rewriting of a single string in a term query based on the model output. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteTermQueryWriteToExtensionSuccess() throws Exception { + + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "query.term.text.value"; + String newQueryField = "$.ext.ml_inference.llm_response"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "eng")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); + SearchRequest request = new SearchRequest().source(source); + + Map llmResponse = new HashMap<>(); + llmResponse.put("llm_response", "eng"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(llmResponse); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder expectedSource = new SearchSourceBuilder().query(incomingQuery).ext(List.of(mlInferenceExtBuilder)); + SearchRequest expectRequest = new SearchRequest().source(expectedSource); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, newSearchRequest.source().query()); + assertEquals(expectRequest.source().toString(), newSearchRequest.source().toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a single string in a term query based on the model output. + * + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteTermQueryReadAndWriteToExtensionSuccess() throws Exception { + + /** + * example term query: {"query":{"term":{"text":{"value":"foo","boost":1.0}}}} + */ + String modelInputField = "inputs"; + String originalQueryField = "ext.ml_inference.question"; + String newQueryField = "ext.ml_inference.llm_response"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "eng")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + + Map llmQuestion = new HashMap<>(); + llmQuestion.put("question", "what language is this text in?"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(llmQuestion); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).ext(List.of(mlInferenceExtBuilder)); + + SearchRequest request = new SearchRequest().source(source); + + // expecting new request with ml inference search extensions + Map params = new HashMap<>(); + params.put("question", "what language is this text in?"); + params.put("llm_response", "eng"); + MLInferenceRequestParameters expectedRequestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder expectedMlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + expectedMlInferenceExtBuilder.setRequestParameters(expectedRequestParameters); + SearchSourceBuilder expectedSource = new SearchSourceBuilder().query(incomingQuery).ext(List.of(expectedMlInferenceExtBuilder)); + SearchRequest expectRequest = new SearchRequest().source(expectedSource); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, newSearchRequest.source().query()); + assertEquals(expectRequest.toString(), newSearchRequest.toString()); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + + } + + /** + * Tests the successful rewriting of a complex nested array in query extension based on the model output. + * verify the pipelineConext is set from the extension + * @throws Exception if an error occurs during the test + */ + public void testExecute_rewriteTermQueryReadAndWriteComplexNestedArrayToExtensionSuccess() throws Exception { + String modelInputField = "inputs"; + String originalQueryField = "ext.ml_inference.question"; + String newQueryField = "ext.ml_inference.llm_response"; + String modelOutputField = "response"; + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( + null, + modelInputField, + originalQueryField, + newQueryField, + modelOutputField, + false, + false + ); + + // Test model return a complex nested array + Map nestedResponse = new HashMap<>(); + List> languageList = new ArrayList<>(); + languageList.add(Collections.singletonMap("eng", "0.95")); + languageList.add(Collections.singletonMap("es", "0.67")); + nestedResponse.put("language", languageList); + nestedResponse.put("type", "bert"); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", nestedResponse)).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + + Map llmQuestion = new HashMap<>(); + llmQuestion.put("question", "what language is this text in?"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(llmQuestion); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).ext(List.of(mlInferenceExtBuilder)); + + SearchRequest request = new SearchRequest().source(source); + + // Expecting new request with ml inference search extensions including the complex nested array + Map params = new HashMap<>(); + params.put("question", "what language is this text in?"); + params.put("llm_response", nestedResponse); + MLInferenceRequestParameters expectedRequestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder expectedMlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + expectedMlInferenceExtBuilder.setRequestParameters(expectedRequestParameters); + SearchSourceBuilder expectedSource = new SearchSourceBuilder().query(incomingQuery).ext(List.of(expectedMlInferenceExtBuilder)); + SearchRequest expectRequest = new SearchRequest().source(expectedSource); + + ActionListener Listener = new ActionListener<>() { + @Override + public void onResponse(SearchRequest newSearchRequest) { + assertEquals(incomingQuery, newSearchRequest.source().query()); + assertEquals(expectRequest.toString(), newSearchRequest.toString()); + + // Additional checks for the complex nested array + MLInferenceRequestParametersExtBuilder actualExtBuilder = (MLInferenceRequestParametersExtBuilder) newSearchRequest + .source() + .ext() + .get(0); + MLInferenceRequestParameters actualParams = actualExtBuilder.getRequestParameters(); + Object actualResponse = actualParams.getParams().get("llm_response"); + + assertTrue(actualResponse instanceof Map); + Map actualNestedResponse = (Map) actualResponse; + + // Check the "language" field + assertTrue(actualNestedResponse.get("language") instanceof List); + List actualLanguageList = (List) actualNestedResponse.get("language"); + assertEquals(2, actualLanguageList.size()); + + Map engMap = (Map) actualLanguageList.get(0); + assertEquals("0.95", engMap.get("eng")); + + Map esMap = (Map) actualLanguageList.get(1); + assertEquals("0.67", esMap.get("es")); + + // Check the "type" field + assertEquals("bert", actualNestedResponse.get("type")); + verify(requestContext).setAttribute("ext.ml_inference.question", "what language is this text in?"); + verify(requestContext).setAttribute("ext.ml_inference.llm_response", nestedResponse); + + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException("Failed in executing processRequestAsync." + e.getMessage()); + } + }; + + requestProcessor.processRequestAsync(request, requestContext, Listener); + } + /** * Helper method to create an instance of the MLInferenceSearchRequestProcessor with the specified parameters. * diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index dedae5f1bd..20c586be65 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -66,6 +66,8 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.ml.searchext.MLInferenceRequestParameters; +import org.opensearch.ml.searchext.MLInferenceRequestParametersExtBuilder; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; @@ -518,7 +520,84 @@ public void onFailure(Exception e) { toJson(inputDataSet.getParameters()), "{\"text_docs\":\"[\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]\"}" ); + } + + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * read the query text into model config + * with query extensions + * @throws Exception if an error occurs during the test + */ + @Test + public void testProcessResponseSuccessReadQueryTextFromExt() throws Exception { + String modelInputField = "text_docs"; + String originalDocumentField = "text"; + String newDocumentField = "similarity_score"; + String modelOutputField = "response"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, originalDocumentField); + input.put("query_text", "_request.ext.ml_inference.query_text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "text_similarity", + false, + false, + false, + "{ \"query_text\": \"${input_map.query_text}\", \"text_docs\":${input_map.text_docs}}", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + assertEquals(responseProcessor.getType(), TYPE); + SearchRequest request = getSearchRequestWithExtension("query_text", "query.term.text.value"); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals(newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField), 0.0); + assertEquals(newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField), 1.0); + assertEquals(newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField), 2.0); + assertEquals(newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField), 3.0); + assertEquals(newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField), 4.0); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, response, responseContext, listener); } /** @@ -3924,6 +4003,105 @@ public void onFailure(Exception e) { } + @Test + public void testProcessResponseAsyncSetRequestContextFromExt() throws Exception { + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "ext.ml_inference.summary"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + SearchResponse response = getSearchResponse(5, true, documentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + Map role = new HashMap<>(); + role.put("role", "users"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(role); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + SearchSourceBuilder source = new SearchSourceBuilder() + .query(incomingQuery) + .size(5) + .sort("text") + .ext(List.of(mlInferenceExtBuilder)); + SearchRequest request = new SearchRequest().source(source); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + Map newParams = new HashMap<>(); + newParams.put("llm_response", "answer"); + newParams.put("summary", "there is 1 value"); + assertEquals(responseAfterProcessor.getParams(), newParams); + verify(responseContext).setAttribute("ext.ml_inference.role", "users"); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + private static SearchRequest getSearchRequest() { QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).size(5).sort("text"); @@ -3931,6 +4109,22 @@ private static SearchRequest getSearchRequest() { return request; } + private static SearchRequest getSearchRequestWithExtension(String queryText, String queryPath) { + QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); + + Map params = new HashMap<>(); + params.put(queryText, queryPath); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + + MLInferenceRequestParametersExtBuilder extBuilder = new MLInferenceRequestParametersExtBuilder(); + extBuilder.setRequestParameters(requestParameters); + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery).ext(List.of(extBuilder)); + ; + SearchRequest request = new SearchRequest().source(source); + + return request; + } + private static Map generateInferenceResult(String response) { Map inferenceResult = new HashMap<>(); List> inferenceResults = new ArrayList<>(); @@ -4348,4 +4542,5 @@ public void testWriteToExtensionAndOneToOne() throws Exception { } } + } diff --git a/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilderTests.java b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilderTests.java new file mode 100644 index 0000000000..bf705c6649 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersExtBuilderTests.java @@ -0,0 +1,306 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.ml.searchext; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.searchext.MLInferenceRequestParameters.ML_INFERENCE_FIELD; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentHelper; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.SearchModule; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class MLInferenceRequestParametersExtBuilderTests extends OpenSearchTestCase { + + public NamedXContentRegistry xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of(new SearchPlugin() { + @Override + public List> getSearchExts() { + return List + .of( + new SearchPlugin.SearchExtSpec<>( + MLInferenceRequestParametersExtBuilder.NAME, + MLInferenceRequestParametersExtBuilder::new, + parser -> MLInferenceRequestParametersExtBuilder.parse(parser) + ) + ); + } + })).getNamedXContents()); + + public void testParse() throws IOException { + String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}"; + + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, requiredJsonStr); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInferenceRequestParametersExtBuilder builder = MLInferenceRequestParametersExtBuilder.parse(parser); + assertNotNull(builder); + assertNotNull(builder.getRequestParameters()); + MLInferenceRequestParameters params = builder.getRequestParameters(); + Assert.assertEquals("this is test llm question", params.getParams().get("llm_question")); + } + + @Test + public void testMultipleParameters() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + params.put("model_id", "model1"); + params.put("max_tokens", 100); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(requestParameters); + + BytesStreamOutput out = new BytesStreamOutput(); + builder.writeTo(out); + + MLInferenceRequestParametersExtBuilder deserialized = new MLInferenceRequestParametersExtBuilder(out.bytes().streamInput()); + assertEquals(builder, deserialized); + assertEquals(params, deserialized.getRequestParameters().getParams()); + } + + @Test + public void testParseWithEmptyObject() throws IOException { + String emptyJsonStr = "{}"; + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, emptyJsonStr); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInferenceRequestParametersExtBuilder builder = MLInferenceRequestParametersExtBuilder.parse(parser); + assertNotNull(builder); + assertNotNull(builder.getRequestParameters()); + assertTrue(builder.getRequestParameters().getParams().isEmpty()); + } + + @Test + public void testWriteableName() throws IOException { + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + assertEquals(builder.getWriteableName(), ML_INFERENCE_FIELD); + } + + @Test + public void testEquals() throws IOException { + MLInferenceRequestParametersExtBuilder MlInferenceParamBuilder = new MLInferenceRequestParametersExtBuilder(); + GenerativeQAParamExtBuilder qaParamExtBuilder = new GenerativeQAParamExtBuilder(); + assertEquals(MlInferenceParamBuilder.equals(qaParamExtBuilder), false); + assertEquals(MlInferenceParamBuilder.equals(null), false); + } + + @Test + public void testMLInferenceRequestParametersEqualsWithNull() { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(params); + assertFalse(parameters.equals(null)); + } + + @Test + public void testMLInferenceRequestParametersEqualsWithDifferentClass() { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(params); + assertFalse(parameters.equals("not a MLInferenceRequestParameters object")); + } + + @Test + public void testMLInferenceRequestParametersToXContentWithEmptyParams() throws IOException { + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(new HashMap<>()); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + parameters.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + assertEquals("{\"ml_inference\":{}}", builder.toString()); + } + + @Test + public void testMLInferenceRequestParametersExtBuilderToXContentWithEmptyParams() throws IOException { + MLInferenceRequestParameters parameters = new MLInferenceRequestParameters(new HashMap<>()); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(parameters); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + builder.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + xContentBuilder.endObject(); + assertEquals("{\"ml_inference\":{}}", xContentBuilder.toString()); + } + + @Test + public void testMLInferenceRequestParametersStreamRoundTripWithNullParams() throws IOException { + MLInferenceRequestParameters original = new MLInferenceRequestParameters(); + original.setParams(null); + BytesStreamOutput out = new BytesStreamOutput(); + original.writeTo(out); + MLInferenceRequestParameters deserialized = new MLInferenceRequestParameters(out.bytes().streamInput()); + assertNull(deserialized.getParams()); + } + + @Test + public void testMLInferenceRequestParametersExtBuilderStreamRoundTripWithNullParams() throws IOException { + MLInferenceRequestParametersExtBuilder original = new MLInferenceRequestParametersExtBuilder(); + original.setRequestParameters(null); + BytesStreamOutput out = new BytesStreamOutput(); + assertThrows(NullPointerException.class, () -> original.writeTo(out)); + } + + @Test + public void testEqualsAndHashCode() { + Map params1 = new HashMap<>(); + params1.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters1 = new MLInferenceRequestParameters(params1); + MLInferenceRequestParametersExtBuilder builder1 = new MLInferenceRequestParametersExtBuilder(); + builder1.setRequestParameters(requestParameters1); + + Map params2 = new HashMap<>(); + params2.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters2 = new MLInferenceRequestParameters(params2); + MLInferenceRequestParametersExtBuilder builder2 = new MLInferenceRequestParametersExtBuilder(); + builder2.setRequestParameters(requestParameters2); + + assertEquals(builder1, builder2); + assertEquals(builder1.hashCode(), builder2.hashCode()); + + Map params3 = new HashMap<>(); + params3.put("query_text", "bar"); + MLInferenceRequestParameters requestParameters3 = new MLInferenceRequestParameters(params3); + MLInferenceRequestParametersExtBuilder builder3 = new MLInferenceRequestParametersExtBuilder(); + builder3.setRequestParameters(requestParameters3); + + assertNotEquals(builder1, builder3); + assertNotEquals(builder1.hashCode(), builder3.hashCode()); + } + + @Test + public void testXContentRoundTrip() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference serialized = XContentHelper.toXContent(mlInferenceExtBuilder, xContentType, true); + + XContentParser parser = createParser(xContentType.xContent(), serialized); + + MLInferenceRequestParametersExtBuilder deserialized = MLInferenceRequestParametersExtBuilder.parse(parser); + + assertEquals(deserialized.getRequestParameters().getParams().get(ML_INFERENCE_FIELD), params); + + } + + @Test + public void testStreamRoundTrip() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(); + requestParameters.setParams(params); + MLInferenceRequestParametersExtBuilder mlInferenceExtBuilder = new MLInferenceRequestParametersExtBuilder(); + mlInferenceExtBuilder.setRequestParameters(requestParameters); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlInferenceExtBuilder.writeTo(bytesStreamOutput); + + MLInferenceRequestParametersExtBuilder deserialized = new MLInferenceRequestParametersExtBuilder( + bytesStreamOutput.bytes().streamInput() + ); + assertEquals(mlInferenceExtBuilder, deserialized); + } + + @Test + public void testNullRequestParameters() throws IOException { + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + assertNull(builder.getRequestParameters()); + + BytesStreamOutput out = new BytesStreamOutput(); + + // Expect NullPointerException when writing null requestParameters + assertThrows(NullPointerException.class, () -> builder.writeTo(out)); + + // Test that we can still create a new builder with null requestParameters + MLInferenceRequestParametersExtBuilder newBuilder = new MLInferenceRequestParametersExtBuilder(); + assertNull(newBuilder.getRequestParameters()); + } + + @Test + public void testEmptyRequestParameters() throws IOException { + MLInferenceRequestParameters emptyParams = new MLInferenceRequestParameters(new HashMap<>()); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(emptyParams); + + BytesStreamOutput out = new BytesStreamOutput(); + builder.writeTo(out); + + MLInferenceRequestParametersExtBuilder deserialized = new MLInferenceRequestParametersExtBuilder(out.bytes().streamInput()); + assertNotNull(deserialized.getRequestParameters()); + assertTrue(deserialized.getRequestParameters().getParams().isEmpty()); + } + + @Test + public void testToXContent() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder builder = new MLInferenceRequestParametersExtBuilder(); + builder.setRequestParameters(requestParameters); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + builder.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + xContentBuilder.endObject(); + + String expected = "{\"ml_inference\":{\"query_text\":\"foo\"}}"; + assertEquals(expected, xContentBuilder.toString()); + } + + @Test + public void testMLInferenceRequestParametersEqualsAndHashCode() { + Map params1 = new HashMap<>(); + params1.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters1 = new MLInferenceRequestParameters(params1); + + Map params2 = new HashMap<>(); + params2.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters2 = new MLInferenceRequestParameters(params2); + + Map params3 = new HashMap<>(); + params3.put("query_text", "bar"); + MLInferenceRequestParameters requestParameters3 = new MLInferenceRequestParameters(params3); + + assertEquals(requestParameters1, requestParameters2); + assertEquals(requestParameters1.hashCode(), requestParameters2.hashCode()); + assertNotEquals(requestParameters1, requestParameters3); + assertNotEquals(requestParameters1.hashCode(), requestParameters3.hashCode()); + } + + @Test + public void testMLInferenceRequestParametersToXContent() throws IOException { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters requestParameters = new MLInferenceRequestParameters(params); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + requestParameters.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + xContentBuilder.endObject(); + + String expected = "{\"ml_inference\":{\"query_text\":\"foo\"}}"; + assertEquals(expected, xContentBuilder.toString()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtilTests.java b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtilTests.java new file mode 100644 index 0000000000..ea2dc55d26 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/searchext/MLInferenceRequestParametersUtilTests.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ +package org.opensearch.ml.searchext; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; + +public class MLInferenceRequestParametersUtilTests { + @Test + public void testExtractParameters() { + Map params = new HashMap<>(); + params.put("query_text", "foo"); + MLInferenceRequestParameters expected = new MLInferenceRequestParameters(params); + MLInferenceRequestParametersExtBuilder extBuilder = new MLInferenceRequestParametersExtBuilder(); + extBuilder.setRequestParameters(expected); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); + MLInferenceRequestParameters actual = MLInferenceRequestParametersUtil.getMLInferenceRequestParameters(request); + assertEquals(expected, actual); + } + + @Test + public void testExtractParametersWithNullSource() { + SearchRequest request = new SearchRequest(); + MLInferenceRequestParameters result = MLInferenceRequestParametersUtil.getMLInferenceRequestParameters(request); + assertNull(result); + } + + @Test + public void testExtractParametersWithEmptyExt() { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + MLInferenceRequestParameters result = MLInferenceRequestParametersUtil.getMLInferenceRequestParameters(request); + assertNull(result); + } + +}