Skip to content

Commit

Permalink
enable add query_text to model_config
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Sep 3, 2024
1 parent 88fd3e7 commit d909b02
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,42 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
return errorMessage + " Model ID: " + modelId;
}
}

/**
* Checks if the given input string matches the dot path format.
*
* <p>The dot path format is a string that consists of one or more word characters
* (letters, digits, or underscores) separated by dots. The string can optionally
* start or end with a dot followed by one or more word characters.
*
* <p>Examples of valid dot path format strings:
* <ul>
* <li>"foo"</li>
* <li>"foo.bar"</li>
* <li>"foo.bar.baz"</li>
* <li>"foo.bar.baz.qux"</li>
* <li>".foo"</li>
* <li>"foo."</li>
* <li>".foo.bar"</li>
* </ul>
*
* <p>Examples of invalid dot path format strings:
* <ul>
* <li>"foo..bar"</li>
* <li>"."</li>
* <li>".."</li>
* </ul>
*
* @param input the input string to be checked
* @return true if the input string matches the dot path format, false otherwise
*/
public static boolean isValidJSONPath(String input) {
if (input == null) {
return false;
}
String DOT_PATH_REGEX = "^(\\$|\\@)?((\\.?([\\w\\[\\]\\*'\\\\])*)?(\\(\\?\\(.*?\\)\\))?)*$";
Pattern DOT_PATH_PATTERN = Pattern.compile(DOT_PATH_REGEX);
return DOT_PATH_PATTERN.matcher(input).matches();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

package org.opensearch.ml.common.utils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;
import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath;

import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -21,36 +22,36 @@ public class StringUtilsTest {

@Test
public void isJson_True() {
Assert.assertTrue(StringUtils.isJson("{}"));
Assert.assertTrue(StringUtils.isJson("[]"));
Assert.assertTrue(StringUtils.isJson("{\"key\": \"value\"}"));
Assert.assertTrue(StringUtils.isJson("{\"key\": 123}"));
Assert.assertTrue(StringUtils.isJson("[1, 2, 3]"));
Assert.assertTrue(StringUtils.isJson("[\"a\", \"b\"]"));
Assert.assertTrue(StringUtils.isJson("[1, \"a\"]"));
Assert.assertTrue(StringUtils.isJson("{\"key1\": \"value\", \"key2\": 123}"));
Assert.assertTrue(StringUtils.isJson("{}"));
Assert.assertTrue(StringUtils.isJson("[]"));
Assert.assertTrue(StringUtils.isJson("[ ]"));
Assert.assertTrue(StringUtils.isJson("[,]"));
Assert.assertTrue(StringUtils.isJson("[abc]"));
Assert.assertTrue(StringUtils.isJson("[\"abc\", 123]"));
assertTrue(StringUtils.isJson("{}"));
assertTrue(StringUtils.isJson("[]"));
assertTrue(StringUtils.isJson("{\"key\": \"value\"}"));
assertTrue(StringUtils.isJson("{\"key\": 123}"));
assertTrue(StringUtils.isJson("[1, 2, 3]"));
assertTrue(StringUtils.isJson("[\"a\", \"b\"]"));
assertTrue(StringUtils.isJson("[1, \"a\"]"));
assertTrue(StringUtils.isJson("{\"key1\": \"value\", \"key2\": 123}"));
assertTrue(StringUtils.isJson("{}"));
assertTrue(StringUtils.isJson("[]"));
assertTrue(StringUtils.isJson("[ ]"));
assertTrue(StringUtils.isJson("[,]"));
assertTrue(StringUtils.isJson("[abc]"));
assertTrue(StringUtils.isJson("[\"abc\", 123]"));
}

@Test
public void isJson_False() {
Assert.assertFalse(StringUtils.isJson("{"));
Assert.assertFalse(StringUtils.isJson("["));
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value}"));
Assert.assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}"));
Assert.assertFalse(StringUtils.isJson("[1, \"a]"));
Assert.assertFalse(StringUtils.isJson("[]\""));
Assert.assertFalse(StringUtils.isJson("[ ]\""));
Assert.assertFalse(StringUtils.isJson("[,]\""));
Assert.assertFalse(StringUtils.isJson("[,\"]"));
Assert.assertFalse(StringUtils.isJson("[]\"123\""));
Assert.assertFalse(StringUtils.isJson("[abc\"]"));
Assert.assertFalse(StringUtils.isJson("[abc\n123]"));
assertFalse(StringUtils.isJson("{"));
assertFalse(StringUtils.isJson("["));
assertFalse(StringUtils.isJson("{\"key\": \"value}"));
assertFalse(StringUtils.isJson("{\"key\": \"value\", \"key\": 123}"));
assertFalse(StringUtils.isJson("[1, \"a]"));
assertFalse(StringUtils.isJson("[]\""));
assertFalse(StringUtils.isJson("[ ]\""));
assertFalse(StringUtils.isJson("[,]\""));
assertFalse(StringUtils.isJson("[,\"]"));
assertFalse(StringUtils.isJson("[]\"123\""));
assertFalse(StringUtils.isJson("[abc\"]"));
assertFalse(StringUtils.isJson("[abc\n123]"));
}

@Test
Expand All @@ -72,7 +73,7 @@ public void fromJson_NestedMap() {
Map<String, Object> response = StringUtils
.fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response");
assertEquals(1, response.size());
Assert.assertTrue(response.get("key") instanceof Map);
assertTrue(response.get("key") instanceof Map);
Map nestedMap = (Map) response.get("key");
assertEquals("nested_value", nestedMap.get("nested_key"));
List list = (List) nestedMap.get("nested_array");
Expand All @@ -85,7 +86,7 @@ public void fromJson_NestedMap() {
public void fromJson_SimpleList() {
Map<String, Object> response = StringUtils.fromJson("[1, \"a\"]", "response");
assertEquals(1, response.size());
Assert.assertTrue(response.get("response") instanceof List);
assertTrue(response.get("response") instanceof List);
List list = (List) response.get("response");
assertEquals(1.0, list.get(0));
assertEquals("a", list.get(1));
Expand All @@ -95,12 +96,12 @@ public void fromJson_SimpleList() {
public void fromJson_NestedList() {
Map<String, Object> response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response");
assertEquals(1, response.size());
Assert.assertTrue(response.get("response") instanceof List);
assertTrue(response.get("response") instanceof List);
List list = (List) response.get("response");
assertEquals(1.0, list.get(0));
assertEquals("a", list.get(1));
Assert.assertTrue(list.get(2) instanceof List);
Assert.assertTrue(list.get(3) instanceof Map);
assertTrue(list.get(2) instanceof List);
assertTrue(list.get(3) instanceof Map);
}

@Test
Expand Down Expand Up @@ -146,17 +147,17 @@ public void processTextDocs() {

@Test
public void isEscapeUsed() {
Assert.assertFalse(StringUtils.isEscapeUsed("String escape"));
Assert.assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")"));
assertFalse(StringUtils.isEscapeUsed("String escape"));
assertTrue(StringUtils.isEscapeUsed(" escape(\"abc\n123\")"));
}

@Test
public void containsEscapeMethod() {
Assert.assertFalse(StringUtils.containsEscapeMethod("String escape"));
Assert.assertFalse(StringUtils.containsEscapeMethod("String escape()"));
Assert.assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")"));
Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)"));
Assert.assertTrue(StringUtils.containsEscapeMethod("String escape(String input)"));
assertFalse(StringUtils.containsEscapeMethod("String escape"));
assertFalse(StringUtils.containsEscapeMethod("String escape()"));
assertFalse(StringUtils.containsEscapeMethod(" escape(\"abc\n123\")"));
assertTrue(StringUtils.containsEscapeMethod("String escape(def abc)"));
assertTrue(StringUtils.containsEscapeMethod("String escape(String input)"));
}

@Test
Expand All @@ -171,7 +172,7 @@ public void addDefaultMethod_Escape() {
String input = "return escape(\"abc\n123\");";
String result = StringUtils.addDefaultMethod(input);
Assert.assertNotEquals(input, result);
Assert.assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION));
assertTrue(result.startsWith(StringUtils.DEFAULT_ESCAPE_FUNCTION));
}

@Test
Expand Down Expand Up @@ -218,4 +219,35 @@ public void testGetErrorMessageWhenHiddenNull() {
// Assert
assertEquals(expected, result);
}

@Test
public void testisValidJSONPath_ValidInputs() {
assertTrue(isValidJSONPath("foo"));
assertTrue(isValidJSONPath("foo.bar"));
assertTrue(isValidJSONPath("foo.bar.baz"));
assertTrue(isValidJSONPath("foo.bar.baz.qux"));
assertTrue(isValidJSONPath(".foo"));
assertTrue(isValidJSONPath("foo."));
assertTrue(isValidJSONPath(".foo.bar"));
assertTrue(isValidJSONPath("$.foo.bar"));
}

@Test
public void testisValidJSONPath_InvalidInputs() {
assertFalse(isValidJSONPath("foo..bar"));
assertFalse(isValidJSONPath("."));
assertFalse(isValidJSONPath(".."));
assertFalse(isValidJSONPath("foo.bar."));
assertFalse(isValidJSONPath(".foo.bar."));
}

@Test
public void testisValidJSONPath_NullInput() {
assertFalse(isValidJSONPath(null));
}

@Test
public void testisValidJSONPath_EmptyInput() {
assertFalse(isValidJSONPath(""));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.processor;

import static java.lang.Math.max;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG;
Expand Down Expand Up @@ -151,14 +152,16 @@ public void processResponseAsync(
try {
SearchHit[] hits = response.getHits().getHits();
// skip processing when there is no hit

String queryString = request.source().toString();
if (hits.length == 0) {
responseListener.onResponse(response);
return;
}

// if many to one, run rewriteResponseDocuments
if (!oneToOne) {
rewriteResponseDocuments(response, responseListener);
rewriteResponseDocuments(response, responseListener, queryString);
} else {
// if one to one, make one hit search response and run rewriteResponseDocuments
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
Expand All @@ -173,7 +176,7 @@ public void processResponseAsync(
newHits[0] = hit;
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed);
rewriteResponseDocuments(oneHitResponse, oneHitListener);
rewriteResponseDocuments(oneHitResponse, oneHitListener, queryString);
// if any OneHitListener failure, try stop the rest of the predictions
if (isOneHitListenerFailed.get()) {
break;
Expand Down Expand Up @@ -280,9 +283,11 @@ public void onFailure(Exception e) {
*
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @param queryString
* @throws IOException if an I/O error occurs during the rewriting process
*/
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener) throws IOException {
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener, String queryString)
throws IOException {
List<Map<String, String>> processInputMap = inferenceProcessorAttributes.getInputMaps();
List<Map<String, String>> processOutputMap = inferenceProcessorAttributes.getOutputMaps();
int inputMapSize = (processInputMap == null) ? 0 : processInputMap.size();
Expand All @@ -304,7 +309,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions, queryString);
}
}

Expand All @@ -316,22 +321,42 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
* @param inputMapIndex the index of the input mapping to process
* @param batchPredictionListener the listener to be notified when the predictions are processed
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @param queryString
* @throws IOException if an I/O error occurs during the prediction process
*/
private void processPredictions(
SearchHit[] hits,
List<Map<String, String>> processInputMap,
int inputMapIndex,
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener,
Map<Integer, Integer> hitCountInPredictions
Map<Integer, Integer> hitCountInPredictions,
String queryString
) throws IOException {

Map<String, String> modelParameters = new HashMap<>();
Map<String, String> modelConfigs = new HashMap<>();

if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
Map<String, String> modelConfigMapsInput = inferenceProcessorAttributes.getModelConfigMaps();

for (Map.Entry<String, String> entry : modelConfigMapsInput.entrySet()) {
String modelConfigKey = entry.getKey();
String modelConfigValue = entry.getValue();
if (StringUtils.isValidJSONPath(modelConfigValue)) {
Object queryJson = JsonPath.parse(queryString).read("$");
Configuration configuration = Configuration
.builder()
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL)
.build();
Object querySubString = JsonPath.using(configuration).parse(queryJson).read(modelConfigValue);
if (querySubString != null) {
modelConfigMapsInput.put(modelConfigKey, toJson(querySubString));
}
}
}
modelParameters.putAll(modelConfigMapsInput);
modelConfigs.putAll(modelConfigMapsInput);

}

Map<String, Object> modelInputParameters = new HashMap<>();
Expand Down
Loading

0 comments on commit d909b02

Please sign in to comment.